Commit 6683676d by Ting PAN

GroupNormalization Support

1 parent 18b664b1
...@@ -16,15 +16,16 @@ class ReshapeOp final : public Operator<Context> { ...@@ -16,15 +16,16 @@ class ReshapeOp final : public Operator<Context> {
public: public:
ReshapeOp(const OperatorDef& op_def, Workspace* ws) ReshapeOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
shape(OperatorBase::GetRepeatedArg<int>("shape")) { shape_like_desc(OperatorBase::GetSingleArg<string>("shape_like", "")) {
new_shape.resize(shape.size()); GET_ARGUMENTS_WITH_DESC(int, shape);
} }
void RunOnDevice() override; void RunOnDevice() override;
protected: protected:
vector<int> shape; DECLARE_ARGUMENTS_WITH_DESC(int, shape);
vector<TIndex> new_shape; string shape_like_desc;
vector<TIndex> require_shape, new_shape;
}; };
template <class Context> template <class Context>
...@@ -38,6 +39,8 @@ class ReshapeGradientOp final : public Operator<Context> { ...@@ -38,6 +39,8 @@ class ReshapeGradientOp final : public Operator<Context> {
void RunOnDevice() override; void RunOnDevice() override;
}; };
DEFINE_ARGUMENTS_WITH_DESC(int, ReshapeOp, shape);
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_NDARRAY_RESHAPE_OP_H_ #endif // DRAGON_OPERATORS_NDARRAY_RESHAPE_OP_H_
\ No newline at end of file
...@@ -105,7 +105,7 @@ class FusedBatchNormGradientOp : public Operator<Context> { ...@@ -105,7 +105,7 @@ class FusedBatchNormGradientOp : public Operator<Context> {
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", -1)), axis(OperatorBase::GetSingleArg<int>("axis", -1)),
eps(OperatorBase::GetSingleArg<float>("eps", float(1e-3))), eps(OperatorBase::GetSingleArg<float>("eps", float(1e-3))),
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) { } use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) {}
void Setup(); void Setup();
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_NORM_GROUP_NORM_OP_H_
#define DRAGON_OPERATORS_NORM_GROUP_NORM_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class GroupNormOp : public Operator<Context> {
public:
GroupNormOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
group(OperatorBase::GetSingleArg<int>("group", 32)),
axis(OperatorBase::GetSingleArg<int>("axis", -1)),
momentum(OperatorBase::GetSingleArg<float>("momentum", float(0.9))),
eps(OperatorBase::GetSingleArg<float>("eps", float(1e-3))),
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)),
mode(OperatorBase::GetSingleArg<string>("mode", "DEFAULT")) {
if (axis != -1)
CHECK_EQ(axis, 1)
<< "\nThe axis can only be set to 1.";
}
void Setup();
void RunOnDevice() override;
template <typename T> void TrainingRunWithType();
template <typename T> void InferenceRunWithType();
protected:
float momentum, eps;
Tensor mean, num_by_chans;
Tensor* multiplier, *num_multiplier, *spatial_multiplier, *cgs_multiplier;
Tensor* stddev, *var;
TIndex group, axis, N, C, S, NG, NC, NS, CGS;
string data_format, mode;
int use_stats;
bool use_global_stats, is_recomputing;
};
template <class Context>
class GroupNormGradientOp final : public Operator<Context> {
public:
GroupNormGradientOp(const OperatorDef& op_def, Workspace *ws)
: Operator<Context>(op_def, ws),
group(OperatorBase::GetSingleArg<int>("group", 32)),
axis(OperatorBase::GetSingleArg<int>("axis", -1)),
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) {
if (axis != -1)
CHECK_EQ(axis, 1)
<< "\nThe axis can only be set to 1.";
}
void Setup();
void RunOnDevice() override;
template <typename T> void TrainingRunWithType();
template <typename T> void InferenceRunWithType();
protected:
Tensor num_by_chans;
Tensor* multiplier, *num_multiplier, *spatial_multiplier, *cgs_multiplier;
Tensor* stddev, *var;
TIndex group, axis, N, C, S, NG, NC, NS, CGS;
string data_format;
int use_stats;
bool use_global_stats;
};
template <class Context>
class FusedGroupNormOp : public Operator<Context> {
public:
FusedGroupNormOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
group(OperatorBase::GetSingleArg<int>("group", 32)),
axis(OperatorBase::GetSingleArg<int>("axis", -1)),
momentum(OperatorBase::GetSingleArg<float>("momentum", float(0.9))),
eps(OperatorBase::GetSingleArg<float>("eps", float(1e-3))),
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) {}
void Setup();
void RunOnDevice() override;
template <typename T> void TrainingRunWithType();
template <typename T> void InferenceRunWithType();
protected:
float momentum, eps;
Tensor num_by_chans;
Tensor* multiplier, *num_multiplier, *spatial_multiplier, *cgs_multiplier;
Tensor* mean, *var, *stddev, *x_norm;
TIndex group, axis, N, C, S, NG, NC, NS, CGS;
string data_format;
int use_stats;
bool use_global_stats, is_recomputing;
};
template <class Context>
class FusedGroupNormGradientOp : public Operator<Context> {
public:
FusedGroupNormGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
group(OperatorBase::GetSingleArg<int>("group", 32)),
axis(OperatorBase::GetSingleArg<int>("axis", -1)),
eps(OperatorBase::GetSingleArg<float>("eps", float(1e-3))),
use_stats(OperatorBase::GetSingleArg<int>("use_stats", -1)) {}
void Setup();
void ShareGradient() override;
void RunOnDevice() override;
template <typename T> void TrainingRunWithType();
template <typename T> void InferenceRunWithType();
protected:
float eps;
Tensor num_by_chans;
Tensor* multiplier, *num_multiplier, *spatial_multiplier, *cgs_multiplier;
Tensor* mean, *var, *stddev, *x_norm;
TIndex group, axis, N, C, S, NG, NC, NS, CGS;
string data_format;
int use_stats;
bool use_global_stats;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_NORM_GROUP_NORM_OP_H_
\ No newline at end of file
...@@ -16,6 +16,14 @@ ...@@ -16,6 +16,14 @@
\sigma_{B}^{2} = \frac{1}{m} \sum_{i=1}^{m}(x_{i} - \mu_{B})^{2} \\ \sigma_{B}^{2} = \frac{1}{m} \sum_{i=1}^{m}(x_{i} - \mu_{B})^{2} \\
\hat{x}_{i} = \frac{x_{i} - \mu_{B}}{\sqrt{\sigma_{B}^{2} + \epsilon}} \\ y_{i} = \gamma\hat{x}_{i} + \beta \\ \, \hat{x}_{i} = \frac{x_{i} - \mu_{B}}{\sqrt{\sigma_{B}^{2} + \epsilon}} \\ y_{i} = \gamma\hat{x}_{i} + \beta \\ \,
.. |groupnorm_function| mathmacro:: \\ \, \\ \mu_{G} = \frac{1}{m} \sum_{i=1}^{m}x_{i} \\
\sigma_{G}^{2} = \frac{1}{m} \sum_{i=1}^{m}(x_{i} - \mu_{G})^{2} \\
\hat{x}_{i} = \frac{x_{i} - \mu_{G}}{\sqrt{\sigma_{G}^{2} + \epsilon}} \\ \,
.. |groupnorm_scale_function| mathmacro:: \\ \, \\ \mu_{G} = \frac{1}{m} \sum_{i=1}^{m}x_{i} \\
\sigma_{G}^{2} = \frac{1}{m} \sum_{i=1}^{m}(x_{i} - \mu_{G})^{2} \\
\hat{x}_{i} = \frac{x_{i} - \mu_{G}}{\sqrt{\sigma_{G}^{2} + \epsilon}} \\ y_{i} = \gamma\hat{x}_{i} + \beta \\ \,
.. |batchrenorm_function| mathmacro:: \\ \, \\ \mu_{B} = \frac{1}{m} \sum_{i=1}^{m}x_{i} \\ .. |batchrenorm_function| mathmacro:: \\ \, \\ \mu_{B} = \frac{1}{m} \sum_{i=1}^{m}x_{i} \\
\sigma_{B}^{2} = \frac{1}{m} \sum_{i=1}^{m}(x_{i} - \mu_{B})^{2} \\ \sigma_{B}^{2} = \frac{1}{m} \sum_{i=1}^{m}(x_{i} - \mu_{B})^{2} \\
\hat{x}_{i} = \frac{x_{i} - \mu_{B}}{\sqrt{\sigma_{B}^{2} + \epsilon}} \cdot r + d \\ \, \hat{x}_{i} = \frac{x_{i} - \mu_{B}}{\sqrt{\sigma_{B}^{2} + \epsilon}} \cdot r + d \\ \,
......
...@@ -113,6 +113,8 @@ List Brief ...@@ -113,6 +113,8 @@ List Brief
`BatchNorm`_ Batch Normalization, introduced by `[Ioffe & Szegedy, 2015] <https://arxiv.org/abs/1502.03167>`_. `BatchNorm`_ Batch Normalization, introduced by `[Ioffe & Szegedy, 2015] <https://arxiv.org/abs/1502.03167>`_.
`BatchRenorm`_ Batch Renormalization, introduced by `[Ioffe, 2017] <https://arxiv.org/abs/1702.03275>`_. `BatchRenorm`_ Batch Renormalization, introduced by `[Ioffe, 2017] <https://arxiv.org/abs/1702.03275>`_.
`FusedBatchNorm`_ Batch Normalization, with scale procedure after normalization. `FusedBatchNorm`_ Batch Normalization, with scale procedure after normalization.
`GroupNorm`_ Group Normalization, introduced by `[Wu & He, 2018] <https://arxiv.org/abs/1803.08494>`_.
`FusedGroupNorm`_ Group Normalization, with scale procedure after normalization.
`InstanceNorm`_ Instance Normalization, introduced by `[Ulyanov et.al, 2016] <https://arxiv.org/abs/1607.08022>`_. `InstanceNorm`_ Instance Normalization, introduced by `[Ulyanov et.al, 2016] <https://arxiv.org/abs/1607.08022>`_.
`L2Norm`_ L2 Normalization, introduced by `[Liu et.al, 2015] <https://arxiv.org/abs/1506.04579>`_. `L2Norm`_ L2 Normalization, introduced by `[Liu et.al, 2015] <https://arxiv.org/abs/1506.04579>`_.
================== ====================================================================== ================== ======================================================================
...@@ -253,6 +255,8 @@ List Brief ...@@ -253,6 +255,8 @@ List Brief
.. _BatchNorm: operators/norm.html#dragon.operators.norm.BatchNorm .. _BatchNorm: operators/norm.html#dragon.operators.norm.BatchNorm
.. _BatchRenorm: operators/norm.html#dragon.operators.norm.BatchRenorm .. _BatchRenorm: operators/norm.html#dragon.operators.norm.BatchRenorm
.. _FusedBatchNorm: operators/norm.html#dragon.operators.norm.FusedBatchNorm .. _FusedBatchNorm: operators/norm.html#dragon.operators.norm.FusedBatchNorm
.. _GroupNorm: operators/norm.html#dragon.operators.norm.GroupNorm
.. _FusedGroupNorm: operators/norm.html#dragon.operators.norm.FusedGroupNorm
.. _InstanceNorm: operators/norm.html#dragon.operators.norm.InstanceNorm .. _InstanceNorm: operators/norm.html#dragon.operators.norm.InstanceNorm
.. _L2Norm: operators/norm.html#dragon.operators.norm.L2Norm .. _L2Norm: operators/norm.html#dragon.operators.norm.L2Norm
......
...@@ -73,9 +73,11 @@ List Brief ...@@ -73,9 +73,11 @@ List Brief
`ArgMaxLayer`_ The implementation of ``ArgMaxLayer``. `ArgMaxLayer`_ The implementation of ``ArgMaxLayer``.
`BatchNormLayer`_ The implementation of ``BatchNormLayer``. `BatchNormLayer`_ The implementation of ``BatchNormLayer``.
`BatchRenormLayer`_ The implementation of ``BatchRenormLayer``. `BatchRenormLayer`_ The implementation of ``BatchRenormLayer``.
`GroupNormLayer`_ The implementation of ``GroupNormLayer``.
`InstanceNormLayer`_ The implementation of ``InstanceNormLayer``. `InstanceNormLayer`_ The implementation of ``InstanceNormLayer``.
`ScaleLayer`_ The implementation of ``ScaleLayer``. `ScaleLayer`_ The implementation of ``ScaleLayer``.
`BNLayer`_ The implementation of ``BNLayer``. `BNLayer`_ The implementation of ``BNLayer``.
`GNLayer`_ The implementation of ``GNLayer``.
`NormalizeLayer`_ The implementation of ``NormalizeLayer``. `NormalizeLayer`_ The implementation of ``NormalizeLayer``.
`TileLayer`_ The extended implementation of ``TileLayer``. `TileLayer`_ The extended implementation of ``TileLayer``.
`ExpandDimsLayer`_ The implementation of ``ExpandDimsLayer``. `ExpandDimsLayer`_ The implementation of ``ExpandDimsLayer``.
...@@ -181,9 +183,11 @@ API Reference ...@@ -181,9 +183,11 @@ API Reference
.. _ArgMaxLayer: #dragon.vm.caffe.layers.common.ArgMaxLayer .. _ArgMaxLayer: #dragon.vm.caffe.layers.common.ArgMaxLayer
.. _BatchNormLayer: #dragon.vm.caffe.layers.common.BatchNormLayer .. _BatchNormLayer: #dragon.vm.caffe.layers.common.BatchNormLayer
.. _BatchRenormLayer: #dragon.vm.caffe.layers.common.BatchRenormLayer .. _BatchRenormLayer: #dragon.vm.caffe.layers.common.BatchRenormLayer
.. _GroupNormLayer: #dragon.vm.caffe.layers.common.GroupNormLayer
.. _InstanceNormLayer: #dragon.vm.caffe.layers.common.InstanceNormLayer .. _InstanceNormLayer: #dragon.vm.caffe.layers.common.InstanceNormLayer
.. _ScaleLayer: #dragon.vm.caffe.layers.common.ScaleLayer .. _ScaleLayer: #dragon.vm.caffe.layers.common.ScaleLayer
.. _BNLayer: #dragon.vm.caffe.layers.common.BNLayer .. _BNLayer: #dragon.vm.caffe.layers.common.BNLayer
.. _GNLayer: #dragon.vm.caffe.layers.common.GNLayer
.. _NormalizeLayer: #dragon.vm.caffe.layers.common.NormalizeLayer .. _NormalizeLayer: #dragon.vm.caffe.layers.common.NormalizeLayer
.. _TileLayer: #dragon.vm.caffe.layers.common.TileLayer .. _TileLayer: #dragon.vm.caffe.layers.common.TileLayer
.. _ExpandDimsLayer: #dragon.vm.caffe.layers.common.ExpandDimsLayer .. _ExpandDimsLayer: #dragon.vm.caffe.layers.common.ExpandDimsLayer
......
...@@ -648,15 +648,21 @@ def Flatten(inputs, axis=0, num_axes=-1, keep_axes=None, **kwargs): ...@@ -648,15 +648,21 @@ def Flatten(inputs, axis=0, num_axes=-1, keep_axes=None, **kwargs):
return output return output
def Reshape(inputs, shape, **kwargs): def Reshape(inputs, shape, shape_like=None, **kwargs):
"""Reshape the dimensions of input. """Reshape the dimensions of input.
``shape`` could be a list of numbers or Tensors.
Set ``shape`` to ``None``, if you want to use ``shape_like``.
Parameters Parameters
---------- ----------
inputs : Tensor inputs : Tensor
The input tensor. The input tensor.
shape : list or tuple shape : list, tuple or None
The new shape. The new shape.
shape_like: Tensor or None
The tensor for indicating the output shape.
Returns Returns
------- -------
...@@ -677,17 +683,29 @@ def Reshape(inputs, shape, **kwargs): ...@@ -677,17 +683,29 @@ def Reshape(inputs, shape, **kwargs):
CheckInputs(inputs, 1) CheckInputs(inputs, 1)
arguments = ParseArguments(locals()) arguments = ParseArguments(locals())
if not isinstance(shape, tuple) and not isinstance(shape, list): if shape is not None:
raise TypeError('The type of dims must be a tuple or list.') AddArgumentsWithDesc(arguments, shape, 'shape', 'int32', as_target=True)
elif shape_like is not None:
if not isinstance(shape_like, Tensor):
raise TypeError('The shape_like should be a Tensor.')
arguments['shape_like'] = shape_like.name
output = Tensor.CreateOperator(nout=1, op_type='Reshape', **arguments) output = Tensor.CreateOperator(nout=1, op_type='Reshape', **arguments)
if inputs.shape is not None: if inputs.shape is not None:
output.shape = [1] * len(shape) possible_to_infer_shape = True
for i, s in enumerate(shape): if shape is not None:
if s == -1: output.shape[i] = 1 for dim in shape:
elif s == 0: output.shape[i] = inputs.shape[i] if isinstance(dim, Tensor):
else: output.shape[i] = s possible_to_infer_shape = False
if shape_like is not None:
possible_to_infer_shape = False
if possible_to_infer_shape:
output.shape = [1] * len(shape)
for i, s in enumerate(shape):
if s == -1: output.shape[i] = 1
elif s == 0: output.shape[i] = inputs.shape[i]
else: output.shape[i] = s
return output return output
......
...@@ -165,6 +165,103 @@ def FusedBatchNorm(inputs, axis=-1, momentum=0.9, eps=1e-3, use_stats=-1, **kwar ...@@ -165,6 +165,103 @@ def FusedBatchNorm(inputs, axis=-1, momentum=0.9, eps=1e-3, use_stats=-1, **kwar
return output return output
def GroupNorm(inputs, group=32, axis=-1, momentum=0.9, eps=1e-3,
use_stats=-1, mode='DEFAULT', **kwargs):
"""Group Normalization, introduced by `[Wu & He, 2018] <https://arxiv.org/abs/1803.08494>`_.
It follows the implementation of `Caffe`_, that scale procedure is moved to `ops.Scale(*args, **kwargs)`_.
The number of inputs vary from ``3`` to ``4`` (``DEFAULT`` or ``CAFFE`` mode).
Parameters
----------
inputs : list of Tensor
The inputs, represent [input, mean, var] or [input, mean, var, factor].
group : int
The group size.
axis : int
The channel axis.
momentum : float
The momentum of moving average.
eps : float
The eps.
use_stats : int
Whether to use global stats. Default is ``-1`` (Auto).
mode : str
The moving average mode. ``DEFAULT`` or ``CAFFE``.
Returns
-------
Tensor
The output tensor, calculated as:
|groupnorm_function|
The ``DEFAULT`` moving average of mean/var, calculated as:
|default_moving_average_function|
The ``CAFFE`` moving average of mean/var, calculated as:
|caffe_moving_average_function|
"""
CheckInputs(inputs, 3, 4)
arguments = ParseArguments(locals())
if len(inputs) > 3:
if mode != 'CAFFE':
raise ValueError('Only the CAFFE mode will take 4 inputs.')
output = Tensor.CreateOperator(nout=1, op_type='GroupNorm', **arguments)
if inputs[0].shape is not None:
output.shape = inputs[0].shape[:]
return output
def FusedGroupNorm(inputs, group=32, axis=-1, momentum=0.9, eps=1e-3, use_stats=-1, **kwargs):
"""Group Normalization, with scale procedure after normalization.
Parameters
----------
inputs : list of Tensor
The inputs, represent [input, mean, var, scale, bias].
group : int
The group size.
axis : int
The channel axis.
momentum : float
The momentum of moving average.
eps : float
The eps.
use_stats : int
Whether to use global stats. Default is ``-1`` (Auto).
Returns
-------
Tensor
The output tensor, calculated as:
|groupnorm_scale_function|
The moving average of mean/var, calculated as:
|default_moving_average_function|
"""
CheckInputs(inputs, 5)
arguments = ParseArguments(locals())
output = Tensor.CreateOperator(nout=1, op_type='FusedGroupNorm', **arguments)
if inputs[0].shape is not None:
output.shape = inputs[0].shape[:]
return output
def InstanceNorm(inputs, axis=-1, eps=1e-3, **kwargs): def InstanceNorm(inputs, axis=-1, eps=1e-3, **kwargs):
"""Instance Normalization, introduced by `[Ulyanov et.al, 2016] <https://arxiv.org/abs/1607.08022>`_ """Instance Normalization, introduced by `[Ulyanov et.al, 2016] <https://arxiv.org/abs/1607.08022>`_
......
...@@ -92,7 +92,9 @@ GramMatrix = math.GramMatrix ...@@ -92,7 +92,9 @@ GramMatrix = math.GramMatrix
# normalization # normalization
BatchNorm = norm.BatchNorm BatchNorm = norm.BatchNorm
BatchRenorm = norm.BatchRenorm BatchRenorm = norm.BatchRenorm
GroupNorm = norm.GroupNorm
FusedBatchNorm = norm.FusedBatchNorm FusedBatchNorm = norm.FusedBatchNorm
FusedGroupNorm = norm.FusedGroupNorm
InstanceNorm = norm.InstanceNorm InstanceNorm = norm.InstanceNorm
L2Norm = norm.L2Norm L2Norm = norm.L2Norm
......
...@@ -39,6 +39,8 @@ from .common import InnerProductLayer, \ ...@@ -39,6 +39,8 @@ from .common import InnerProductLayer, \
BatchNormLayer, \ BatchNormLayer, \
BatchRenormLayer,\ BatchRenormLayer,\
BNLayer, \ BNLayer, \
GroupNormLayer, \
GNLayer, \
ConcatLayer, \ ConcatLayer, \
CropLayer, \ CropLayer, \
PythonLayer, \ PythonLayer, \
......
...@@ -412,6 +412,47 @@ class BatchRenormLayer(Layer): ...@@ -412,6 +412,47 @@ class BatchRenormLayer(Layer):
return ops.BatchRenorm(bottom + [blob['data'] for blob in self._blobs], **self._param) return ops.BatchRenorm(bottom + [blob['data'] for blob in self._blobs], **self._param)
class GroupNormLayer(Layer):
"""The implementation of ``GroupNormLayer``.
Parameters
----------
group : int
Refer ``GroupNormParameter.group``.
use_global_stats : boolean
Refer ``GroupNormParameter.use_global_stats``.
moving_average_fraction : float
Refer ``GroupNormParameter.moving_average_fraction``.
eps : float
Refer ``GroupNormParameter.eps``.
"""
def __init__(self, LayerParameter):
super(GroupNormLayer, self).__init__(LayerParameter)
param = LayerParameter.group_norm_param
self._param = {'group': int(param.group),
'use_stats': int(param.use_global_stats)
if param.HasField('use_global_stats') else -1,
'momentum': param.moving_average_fraction,
'eps': param.eps,
'axis': 1,
'mode': 'CAFFE'}
scope = LayerParameter.name
# mean, var, factor are set to 0 in order to do statistics
mean = Tensor(scope + '/param:0').Constant(value=0.0)
var = Tensor(scope + '/param:1').Constant(value=0.0)
factor = Tensor(scope + '/param:2').Constant(value=0.0)
# in dragon, set diff as None will ignore computing grad automatically
# but in bvlc-caffe1, you must set lr_mult = 0 manually
self._blobs.append({'data': mean, 'diff': None})
self._blobs.append({'data': var, 'diff': None})
self._blobs.append({'data': factor, 'diff': None})
def Setup(self, bottom):
super(GroupNormLayer, self).Setup(bottom)
return ops.GroupNorm(bottom + [blob['data'] for blob in self._blobs], **self._param)
class InstanceNormLayer(Layer): class InstanceNormLayer(Layer):
""" """
The implementation of ``InstanceNormLayer``. The implementation of ``InstanceNormLayer``.
...@@ -518,6 +559,59 @@ class BNLayer(Layer): ...@@ -518,6 +559,59 @@ class BNLayer(Layer):
return ops.FusedBatchNorm(bottom + [blob['data'] for blob in self._blobs], **self._param) return ops.FusedBatchNorm(bottom + [blob['data'] for blob in self._blobs], **self._param)
class GNLayer(Layer):
"""The implementation of ``GNLayer``.
Parameters
----------
group : int
Refer ``GroupNormParameter.group``.
use_global_stats : boolean
Refer ``GroupNormParameter.use_global_stats``.
moving_average_fraction : float
Refer ``GroupNormParameter.moving_average_fraction``.
eps : float
Refer ``GroupNormParameter.eps``.
filler : FillerParameter
The filler of scale parameter. Refer `ScaleParameter.filler`_.
bias_filler : FillerParameter
The filler of bias parameter. Refer `ScaleParameter.bias_filler`_.
"""
def __init__(self, LayerParameter):
super(GNLayer, self).__init__(LayerParameter)
gn_param = LayerParameter.group_norm_param
scale_param = LayerParameter.scale_param
self._param = {'group': int(gn_param.group),
'use_stats': int(gn_param.use_global_stats)
if gn_param.HasField('use_global_stats') else -1,
'momentum': gn_param.moving_average_fraction,
'eps': gn_param.eps,
'axis': 1}
scope = LayerParameter.name
mean = Tensor(scope + '/param:0').Constant(value=0.0)
var = Tensor(scope + '/param:1').Constant(value=0.0)
scale = Tensor(scope + '/param:2')
scale_diff = Tensor(scope + '/param:2_grad')
bias = Tensor(scope + '/param:3')
bias_diff = Tensor(scope + '/param:3_grad')
if scale_param.HasField('filler'):
self.Fill(scale, scale_param, 'filler')
else: scale.Constant(value=1.0)
self.Fill(bias, scale_param, 'bias_filler')
self.norm_blobs = [{'data': mean, 'diff': None},
{'data': var, 'diff': None}]
self.scale_blobs = [{'data': scale, 'diff': scale_diff},
{'data': bias, 'diff': bias_diff}]
self._blobs.extend(self.norm_blobs)
self._blobs.extend(self.scale_blobs)
def Setup(self, bottom):
super(GNLayer, self).Setup(bottom)
return ops.FusedGroupNorm(bottom + [blob['data'] for blob in self._blobs], **self._param)
class NormalizeLayer(Layer): class NormalizeLayer(Layer):
"""The implementation of ``NormalizeLayer``. """The implementation of ``NormalizeLayer``.
......
...@@ -423,6 +423,7 @@ message LayerParameter { ...@@ -423,6 +423,7 @@ message LayerParameter {
optional DenseConcatParameter dense_concat_param = 163; optional DenseConcatParameter dense_concat_param = 163;
optional FocalLossParameter focal_loss_param = 164; optional FocalLossParameter focal_loss_param = 164;
optional GatherParameter gather_param = 165; optional GatherParameter gather_param = 165;
optional GroupNormParameter group_norm_param = 166;
} }
// Message that stores parameters used to apply transformation // Message that stores parameters used to apply transformation
...@@ -1512,3 +1513,17 @@ message GatherParameter { ...@@ -1512,3 +1513,17 @@ message GatherParameter {
optional int32 axis = 1 [default = 0]; optional int32 axis = 1 [default = 0];
} }
message GroupNormParameter {
// If false, accumulate global mean/variance values via a moving average. If
// true, use those accumulated values instead of computing mean/variance
// across the batch.
optional bool use_global_stats = 1;
// How much does the moving average decay each iteration?
optional float moving_average_fraction = 2 [default = 0.9];
// Small value to add to the variance estimate so that we don't divide by
// zero.
optional float eps = 3 [default = 1e-3];
optional uint32 group = 5 [default = 32]; // The group size
}
...@@ -19,7 +19,7 @@ _sym_db = _symbol_database.Default() ...@@ -19,7 +19,7 @@ _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor.FileDescriptor( DESCRIPTOR = _descriptor.FileDescriptor(
name='caffe.proto', name='caffe.proto',
package='caffe', package='caffe',
serialized_pb=_b('\n\x0b\x63\x61\x66\x66\x65.proto\x12\x05\x63\x61\x66\x66\x65\"\x1c\n\tBlobShape\x12\x0f\n\x03\x64im\x18\x01 \x03(\x03\x42\x02\x10\x01\"\xcc\x01\n\tBlobProto\x12\x1f\n\x05shape\x18\x07 \x01(\x0b\x32\x10.caffe.BlobShape\x12\x10\n\x04\x64\x61ta\x18\x05 \x03(\x02\x42\x02\x10\x01\x12\x10\n\x04\x64iff\x18\x06 \x03(\x02\x42\x02\x10\x01\x12\x17\n\x0b\x64ouble_data\x18\x08 \x03(\x01\x42\x02\x10\x01\x12\x17\n\x0b\x64ouble_diff\x18\t \x03(\x01\x42\x02\x10\x01\x12\x0e\n\x03num\x18\x01 \x01(\x05:\x01\x30\x12\x13\n\x08\x63hannels\x18\x02 \x01(\x05:\x01\x30\x12\x11\n\x06height\x18\x03 \x01(\x05:\x01\x30\x12\x10\n\x05width\x18\x04 \x01(\x05:\x01\x30\"2\n\x0f\x42lobProtoVector\x12\x1f\n\x05\x62lobs\x18\x01 \x03(\x0b\x32\x10.caffe.BlobProto\"\x81\x01\n\x05\x44\x61tum\x12\x10\n\x08\x63hannels\x18\x01 \x01(\x05\x12\x0e\n\x06height\x18\x02 \x01(\x05\x12\r\n\x05width\x18\x03 \x01(\x05\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\x0c\x12\r\n\x05label\x18\x05 \x01(\x05\x12\x12\n\nfloat_data\x18\x06 \x03(\x02\x12\x16\n\x07\x65ncoded\x18\x07 \x01(\x08:\x05\x66\x61lse\"\x8a\x02\n\x0f\x46illerParameter\x12\x16\n\x04type\x18\x01 \x01(\t:\x08\x63onstant\x12\x10\n\x05value\x18\x02 \x01(\x02:\x01\x30\x12\x0e\n\x03min\x18\x03 \x01(\x02:\x01\x30\x12\x0e\n\x03max\x18\x04 \x01(\x02:\x01\x31\x12\x0f\n\x04mean\x18\x05 \x01(\x02:\x01\x30\x12\x0e\n\x03std\x18\x06 \x01(\x02:\x01\x31\x12\x12\n\x06sparse\x18\x07 \x01(\x05:\x02-1\x12\x42\n\rvariance_norm\x18\x08 \x01(\x0e\x32#.caffe.FillerParameter.VarianceNorm:\x06\x46\x41N_IN\"4\n\x0cVarianceNorm\x12\n\n\x06\x46\x41N_IN\x10\x00\x12\x0b\n\x07\x46\x41N_OUT\x10\x01\x12\x0b\n\x07\x41VERAGE\x10\x02\"\x8e\x02\n\x0cNetParameter\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\r\n\x05input\x18\x03 \x03(\t\x12%\n\x0binput_shape\x18\x08 \x03(\x0b\x32\x10.caffe.BlobShape\x12\x11\n\tinput_dim\x18\x04 \x03(\x05\x12\x1d\n\x0e\x66orce_backward\x18\x05 \x01(\x08:\x05\x66\x61lse\x12\x1e\n\x05state\x18\x06 \x01(\x0b\x32\x0f.caffe.NetState\x12\x19\n\ndebug_info\x18\x07 \x01(\x08:\x05\x66\x61lse\x12$\n\x05layer\x18\x64 \x03(\x0b\x32\x15.caffe.LayerParameter\x12\'\n\x06layers\x18\x02 \x03(\x0b\x32\x17.caffe.V1LayerParameter\"\xc9\n\n\x0fSolverParameter\x12\x0b\n\x03net\x18\x18 \x01(\t\x12&\n\tnet_param\x18\x19 \x01(\x0b\x32\x13.caffe.NetParameter\x12\x11\n\ttrain_net\x18\x01 \x01(\t\x12\x10\n\x08test_net\x18\x02 \x03(\t\x12,\n\x0ftrain_net_param\x18\x15 \x01(\x0b\x32\x13.caffe.NetParameter\x12+\n\x0etest_net_param\x18\x16 \x03(\x0b\x32\x13.caffe.NetParameter\x12$\n\x0btrain_state\x18\x1a \x01(\x0b\x32\x0f.caffe.NetState\x12#\n\ntest_state\x18\x1b \x03(\x0b\x32\x0f.caffe.NetState\x12\x11\n\ttest_iter\x18\x03 \x03(\x05\x12\x18\n\rtest_interval\x18\x04 \x01(\x05:\x01\x30\x12 \n\x11test_compute_loss\x18\x13 \x01(\x08:\x05\x66\x61lse\x12!\n\x13test_initialization\x18 \x01(\x08:\x04true\x12\x0f\n\x07\x62\x61se_lr\x18\x05 \x01(\x02\x12\x10\n\x08stage_lr\x18\x32 \x03(\x02\x12\x12\n\nstage_iter\x18\x33 \x03(\x05\x12\x0f\n\x07\x64isplay\x18\x06 \x01(\x05\x12\x17\n\x0c\x61verage_loss\x18! \x01(\x05:\x01\x31\x12\x10\n\x08max_iter\x18\x07 \x01(\x05\x12\x14\n\titer_size\x18$ \x01(\x05:\x01\x31\x12\x11\n\tlr_policy\x18\x08 \x01(\t\x12\r\n\x05gamma\x18\t \x01(\x02\x12\r\n\x05power\x18\n \x01(\x02\x12\x10\n\x08momentum\x18\x0b \x01(\x02\x12\x14\n\x0cweight_decay\x18\x0c \x01(\x02\x12\x1f\n\x13regularization_type\x18\x1d \x01(\t:\x02L2\x12\x10\n\x08stepsize\x18\r \x01(\x05\x12\x11\n\tstepvalue\x18\" \x03(\x05\x12\x1a\n\x0e\x63lip_gradients\x18# \x01(\x02:\x02-1\x12\x13\n\x08snapshot\x18\x0e \x01(\x05:\x01\x30\x12\x17\n\x0fsnapshot_prefix\x18\x0f \x01(\t\x12\x1c\n\rsnapshot_diff\x18\x10 \x01(\x08:\x05\x66\x61lse\x12K\n\x0fsnapshot_format\x18% \x01(\x0e\x32%.caffe.SolverParameter.SnapshotFormat:\x0b\x42INARYPROTO\x12;\n\x0bsolver_mode\x18\x11 \x01(\x0e\x32!.caffe.SolverParameter.SolverMode:\x03GPU\x12\x14\n\tdevice_id\x18\x12 \x01(\x05:\x01\x30\x12\x17\n\x0brandom_seed\x18\x14 \x01(\x03:\x02-1\x12\x11\n\x04type\x18( \x01(\t:\x03SGD\x12\x15\n\x05\x64\x65lta\x18\x1f \x01(\x02:\x06\x31\x65-008\x12\x18\n\tmomentum2\x18\' \x01(\x02:\x05\x30.999\x12\x17\n\trms_decay\x18& \x01(\x02:\x04\x30.99\x12\x19\n\ndebug_info\x18\x17 \x01(\x08:\x05\x66\x61lse\x12\"\n\x14snapshot_after_train\x18\x1c \x01(\x08:\x04true\x12;\n\x0bsolver_type\x18\x1e \x01(\x0e\x32!.caffe.SolverParameter.SolverType:\x03SGD\"+\n\x0eSnapshotFormat\x12\x08\n\x04HDF5\x10\x00\x12\x0f\n\x0b\x42INARYPROTO\x10\x01\"\x1e\n\nSolverMode\x12\x07\n\x03\x43PU\x10\x00\x12\x07\n\x03GPU\x10\x01\"U\n\nSolverType\x12\x07\n\x03SGD\x10\x00\x12\x0c\n\x08NESTEROV\x10\x01\x12\x0b\n\x07\x41\x44\x41GRAD\x10\x02\x12\x0b\n\x07RMSPROP\x10\x03\x12\x0c\n\x08\x41\x44\x41\x44\x45LTA\x10\x04\x12\x08\n\x04\x41\x44\x41M\x10\x05\"l\n\x0bSolverState\x12\x0c\n\x04iter\x18\x01 \x01(\x05\x12\x13\n\x0blearned_net\x18\x02 \x01(\t\x12!\n\x07history\x18\x03 \x03(\x0b\x32\x10.caffe.BlobProto\x12\x17\n\x0c\x63urrent_step\x18\x04 \x01(\x05:\x01\x30\"N\n\x08NetState\x12!\n\x05phase\x18\x01 \x01(\x0e\x32\x0c.caffe.Phase:\x04TEST\x12\x10\n\x05level\x18\x02 \x01(\x05:\x01\x30\x12\r\n\x05stage\x18\x03 \x03(\t\"\x85\x01\n\x0cNetStateRule\x12\x1b\n\x05phase\x18\x01 \x01(\x0e\x32\x0c.caffe.Phase\x12\x11\n\tmin_level\x18\x02 \x01(\x05\x12\x11\n\tmax_level\x18\x03 \x01(\x05\x12\r\n\x05stage\x18\x04 \x03(\t\x12\x11\n\tnot_stage\x18\x05 \x03(\t\x12\x10\n\x08mpi_rank\x18\x06 \x03(\r\"\xa3\x01\n\tParamSpec\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x31\n\nshare_mode\x18\x02 \x01(\x0e\x32\x1d.caffe.ParamSpec.DimCheckMode\x12\x12\n\x07lr_mult\x18\x03 \x01(\x02:\x01\x31\x12\x15\n\ndecay_mult\x18\x04 \x01(\x02:\x01\x31\"*\n\x0c\x44imCheckMode\x12\n\n\x06STRICT\x10\x00\x12\x0e\n\nPERMISSIVE\x10\x01\"\x95\x19\n\x0eLayerParameter\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\t\x12\x0e\n\x06\x62ottom\x18\x03 \x03(\t\x12\x0b\n\x03top\x18\x04 \x03(\t\x12\x1c\n\x0cmirror_stage\x18\xa2\x01 \x01(\x08:\x05\x66\x61lse\x12\x1b\n\x05phase\x18\n \x01(\x0e\x32\x0c.caffe.Phase\x12\x13\n\x0bloss_weight\x18\x05 \x03(\x02\x12\x1f\n\x05param\x18\x06 \x03(\x0b\x32\x10.caffe.ParamSpec\x12\x1f\n\x05\x62lobs\x18\x07 \x03(\x0b\x32\x10.caffe.BlobProto\x12\x16\n\x0epropagate_down\x18\x0b \x03(\x08\x12$\n\x07include\x18\x08 \x03(\x0b\x32\x13.caffe.NetStateRule\x12$\n\x07\x65xclude\x18\t \x03(\x0b\x32\x13.caffe.NetStateRule\x12\x37\n\x0ftransform_param\x18\x64 \x01(\x0b\x32\x1e.caffe.TransformationParameter\x12(\n\nloss_param\x18\x65 \x01(\x0b\x32\x14.caffe.LossParameter\x12\x30\n\x0e\x61\x63\x63uracy_param\x18\x66 \x01(\x0b\x32\x18.caffe.AccuracyParameter\x12,\n\x0c\x61rgmax_param\x18g \x01(\x0b\x32\x16.caffe.ArgMaxParameter\x12\x34\n\x10\x62\x61tch_norm_param\x18\x8b\x01 \x01(\x0b\x32\x19.caffe.BatchNormParameter\x12)\n\nbias_param\x18\x8d\x01 \x01(\x0b\x32\x14.caffe.BiasParameter\x12,\n\x0c\x63oncat_param\x18h \x01(\x0b\x32\x16.caffe.ConcatParameter\x12?\n\x16\x63ontrastive_loss_param\x18i \x01(\x0b\x32\x1f.caffe.ContrastiveLossParameter\x12\x36\n\x11\x63onvolution_param\x18j \x01(\x0b\x32\x1b.caffe.ConvolutionParameter\x12)\n\ncrop_param\x18\x90\x01 \x01(\x0b\x32\x14.caffe.CropParameter\x12(\n\ndata_param\x18k \x01(\x0b\x32\x14.caffe.DataParameter\x12.\n\rdropout_param\x18l \x01(\x0b\x32\x17.caffe.DropoutParameter\x12\x33\n\x10\x64ummy_data_param\x18m \x01(\x0b\x32\x19.caffe.DummyDataParameter\x12.\n\reltwise_param\x18n \x01(\x0b\x32\x17.caffe.EltwiseParameter\x12\'\n\telu_param\x18\x8c\x01 \x01(\x0b\x32\x13.caffe.ELUParameter\x12+\n\x0b\x65mbed_param\x18\x89\x01 \x01(\x0b\x32\x15.caffe.EmbedParameter\x12&\n\texp_param\x18o \x01(\x0b\x32\x13.caffe.ExpParameter\x12/\n\rflatten_param\x18\x87\x01 \x01(\x0b\x32\x17.caffe.FlattenParameter\x12\x31\n\x0fhdf5_data_param\x18p \x01(\x0b\x32\x18.caffe.HDF5DataParameter\x12\x35\n\x11hdf5_output_param\x18q \x01(\x0b\x32\x1a.caffe.HDF5OutputParameter\x12\x33\n\x10hinge_loss_param\x18r \x01(\x0b\x32\x19.caffe.HingeLossParameter\x12\x33\n\x10image_data_param\x18s \x01(\x0b\x32\x19.caffe.ImageDataParameter\x12\x39\n\x13infogain_loss_param\x18t \x01(\x0b\x32\x1c.caffe.InfogainLossParameter\x12\x39\n\x13inner_product_param\x18u \x01(\x0b\x32\x1c.caffe.InnerProductParameter\x12+\n\x0binput_param\x18\x8f\x01 \x01(\x0b\x32\x15.caffe.InputParameter\x12\'\n\tlog_param\x18\x86\x01 \x01(\x0b\x32\x13.caffe.LogParameter\x12&\n\tlrn_param\x18v \x01(\x0b\x32\x13.caffe.LRNParameter\x12\x35\n\x11memory_data_param\x18w \x01(\x0b\x32\x1a.caffe.MemoryDataParameter\x12&\n\tmvn_param\x18x \x01(\x0b\x32\x13.caffe.MVNParameter\x12\x33\n\x0fparameter_param\x18\x91\x01 \x01(\x0b\x32\x19.caffe.ParameterParameter\x12.\n\rpooling_param\x18y \x01(\x0b\x32\x17.caffe.PoolingParameter\x12*\n\x0bpower_param\x18z \x01(\x0b\x32\x15.caffe.PowerParameter\x12+\n\x0bprelu_param\x18\x83\x01 \x01(\x0b\x32\x15.caffe.PReLUParameter\x12-\n\x0cpython_param\x18\x82\x01 \x01(\x0b\x32\x16.caffe.PythonParameter\x12\x33\n\x0freduction_param\x18\x88\x01 \x01(\x0b\x32\x19.caffe.ReductionParameter\x12(\n\nrelu_param\x18{ \x01(\x0b\x32\x14.caffe.ReLUParameter\x12/\n\rreshape_param\x18\x85\x01 \x01(\x0b\x32\x17.caffe.ReshapeParameter\x12+\n\x0bscale_param\x18\x8e\x01 \x01(\x0b\x32\x15.caffe.ScaleParameter\x12.\n\rsigmoid_param\x18| \x01(\x0b\x32\x17.caffe.SigmoidParameter\x12.\n\rsoftmax_param\x18} \x01(\x0b\x32\x17.caffe.SoftmaxParameter\x12\'\n\tspp_param\x18\x84\x01 \x01(\x0b\x32\x13.caffe.SPPParameter\x12*\n\x0bslice_param\x18~ \x01(\x0b\x32\x15.caffe.SliceParameter\x12(\n\ntanh_param\x18\x7f \x01(\x0b\x32\x14.caffe.TanHParameter\x12\x33\n\x0fthreshold_param\x18\x80\x01 \x01(\x0b\x32\x19.caffe.ThresholdParameter\x12)\n\ntile_param\x18\x8a\x01 \x01(\x0b\x32\x14.caffe.TileParameter\x12\x36\n\x11window_data_param\x18\x81\x01 \x01(\x0b\x32\x1a.caffe.WindowDataParameter\x12\x36\n\x11roi_pooling_param\x18\x97\x01 \x01(\x0b\x32\x1a.caffe.ROIPoolingParameter\x12;\n\x14smooth_l1_loss_param\x18\x98\x01 \x01(\x0b\x32\x1c.caffe.SmoothL1LossParameter\x12\'\n\tmpi_param\x18\x99\x01 \x01(\x0b\x32\x13.caffe.MPIParameter\x12/\n\rpermute_param\x18\x9a\x01 \x01(\x0b\x32\x17.caffe.PermuteParameter\x12\x33\n\x0fnormalize_param\x18\x9b\x01 \x01(\x0b\x32\x19.caffe.NormalizeParameter\x12\x31\n\x0eparallel_param\x18\x9d\x01 \x01(\x0b\x32\x18.caffe.ParallelParameter\x12-\n\x0cresize_param\x18\x9e\x01 \x01(\x0b\x32\x16.caffe.ResizeParameter\x12\x36\n\x11\x65xpand_dims_param\x18\x9f\x01 \x01(\x0b\x32\x1a.caffe.ExpandDimsParameter\x12\x31\n\x0eproposal_param\x18\xa0\x01 \x01(\x0b\x32\x18.caffe.ProposalParameter\x12\x38\n\x12\x62\x61tch_renorm_param\x18\xa1\x01 \x01(\x0b\x32\x1b.caffe.BatchRenormParameter\x12\x38\n\x12\x64\x65nse_concat_param\x18\xa3\x01 \x01(\x0b\x32\x1b.caffe.DenseConcatParameter\x12\x34\n\x10\x66ocal_loss_param\x18\xa4\x01 \x01(\x0b\x32\x19.caffe.FocalLossParameter\x12-\n\x0cgather_param\x18\xa5\x01 \x01(\x0b\x32\x16.caffe.GatherParameter\"\xa7\x02\n\x17TransformationParameter\x12\x10\n\x05scale\x18\x01 \x01(\x02:\x01\x31\x12\x15\n\x06mirror\x18\x02 \x01(\x08:\x05\x66\x61lse\x12\x14\n\tcrop_size\x18\x03 \x01(\r:\x01\x30\x12\x12\n\x07padding\x18\x0b \x01(\r:\x01\x30\x12\x11\n\tmean_file\x18\x04 \x01(\t\x12\x12\n\nmean_value\x18\x05 \x03(\x02\x12\x1a\n\x0b\x66orce_color\x18\x06 \x01(\x08:\x05\x66\x61lse\x12\x19\n\nforce_gray\x18\x07 \x01(\x08:\x05\x66\x61lse\x12!\n\x12\x63olor_augmentation\x18\x08 \x01(\x08:\x05\x66\x61lse\x12\x1b\n\x10min_random_scale\x18\t \x01(\x02:\x01\x31\x12\x1b\n\x10max_random_scale\x18\n \x01(\x02:\x01\x31\"\xf5\x01\n\rLossParameter\x12\x14\n\x0cignore_label\x18\x01 \x01(\x05\x12\x44\n\rnormalization\x18\x03 \x01(\x0e\x32&.caffe.LossParameter.NormalizationMode:\x05VALID\x12\x11\n\tnormalize\x18\x02 \x01(\x08\x1a\'\n\x13\x45xpandDimsParameter\x12\x10\n\x04\x61xis\x18\x01 \x01(\x05:\x02-1\"L\n\x11NormalizationMode\x12\x08\n\x04\x46ULL\x10\x00\x12\t\n\x05VALID\x10\x01\x12\x0e\n\nBATCH_SIZE\x10\x02\x12\x08\n\x04NONE\x10\x03\x12\x08\n\x04UNIT\x10\x04\"L\n\x11\x41\x63\x63uracyParameter\x12\x10\n\x05top_k\x18\x01 \x01(\r:\x01\x31\x12\x0f\n\x04\x61xis\x18\x02 \x01(\x05:\x01\x31\x12\x14\n\x0cignore_label\x18\x03 \x01(\x05\"M\n\x0f\x41rgMaxParameter\x12\x1a\n\x0bout_max_val\x18\x01 \x01(\x08:\x05\x66\x61lse\x12\x10\n\x05top_k\x18\x02 \x01(\r:\x01\x31\x12\x0c\n\x04\x61xis\x18\x03 \x01(\x05\"9\n\x0f\x43oncatParameter\x12\x0f\n\x04\x61xis\x18\x02 \x01(\x05:\x01\x31\x12\x15\n\nconcat_dim\x18\x01 \x01(\r:\x01\x31\"h\n\x12\x42\x61tchNormParameter\x12\x18\n\x10use_global_stats\x18\x01 \x01(\x08\x12$\n\x17moving_average_fraction\x18\x02 \x01(\x02:\x03\x30.9\x12\x12\n\x03\x65ps\x18\x03 \x01(\x02:\x05\x30.001\"]\n\rBiasParameter\x12\x0f\n\x04\x61xis\x18\x01 \x01(\x05:\x01\x31\x12\x13\n\x08num_axes\x18\x02 \x01(\x05:\x01\x31\x12&\n\x06\x66iller\x18\x03 \x01(\x0b\x32\x16.caffe.FillerParameter\"L\n\x18\x43ontrastiveLossParameter\x12\x11\n\x06margin\x18\x01 \x01(\x02:\x01\x31\x12\x1d\n\x0elegacy_version\x18\x02 \x01(\x08:\x05\x66\x61lse\"\xfc\x03\n\x14\x43onvolutionParameter\x12\x12\n\nnum_output\x18\x01 \x01(\r\x12\x17\n\tbias_term\x18\x02 \x01(\x08:\x04true\x12\x0b\n\x03pad\x18\x03 \x03(\r\x12\x13\n\x0bkernel_size\x18\x04 \x03(\r\x12\x0e\n\x06stride\x18\x06 \x03(\r\x12\x10\n\x08\x64ilation\x18\x12 \x03(\r\x12\x10\n\x05pad_h\x18\t \x01(\r:\x01\x30\x12\x10\n\x05pad_w\x18\n \x01(\r:\x01\x30\x12\x10\n\x08kernel_h\x18\x0b \x01(\r\x12\x10\n\x08kernel_w\x18\x0c \x01(\r\x12\x10\n\x08stride_h\x18\r \x01(\r\x12\x10\n\x08stride_w\x18\x0e \x01(\r\x12\x10\n\x05group\x18\x05 \x01(\r:\x01\x31\x12-\n\rweight_filler\x18\x07 \x01(\x0b\x32\x16.caffe.FillerParameter\x12+\n\x0b\x62ias_filler\x18\x08 \x01(\x0b\x32\x16.caffe.FillerParameter\x12;\n\x06\x65ngine\x18\x0f \x01(\x0e\x32\".caffe.ConvolutionParameter.Engine:\x07\x44\x45\x46\x41ULT\x12\x0f\n\x04\x61xis\x18\x10 \x01(\x05:\x01\x31\x12\x1e\n\x0f\x66orce_nd_im2col\x18\x11 \x01(\x08:\x05\x66\x61lse\"+\n\x06\x45ngine\x12\x0b\n\x07\x44\x45\x46\x41ULT\x10\x00\x12\t\n\x05\x43\x41\x46\x46\x45\x10\x01\x12\t\n\x05\x43UDNN\x10\x02\"0\n\rCropParameter\x12\x0f\n\x04\x61xis\x18\x01 \x01(\x05:\x01\x32\x12\x0e\n\x06offset\x18\x02 \x03(\r\"\xa4\x02\n\rDataParameter\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x12\n\nbatch_size\x18\x04 \x01(\r\x12\x14\n\trand_skip\x18\x07 \x01(\r:\x01\x30\x12\x31\n\x07\x62\x61\x63kend\x18\x08 \x01(\x0e\x32\x17.caffe.DataParameter.DB:\x07LEVELDB\x12\x10\n\x05scale\x18\x02 \x01(\x02:\x01\x31\x12\x11\n\tmean_file\x18\x03 \x01(\t\x12\x14\n\tcrop_size\x18\x05 \x01(\r:\x01\x30\x12\x15\n\x06mirror\x18\x06 \x01(\x08:\x05\x66\x61lse\x12\"\n\x13\x66orce_encoded_color\x18\t \x01(\x08:\x05\x66\x61lse\x12\x13\n\x08prefetch\x18\n \x01(\r:\x01\x35\"\x1b\n\x02\x44\x42\x12\x0b\n\x07LEVELDB\x10\x00\x12\x08\n\x04LMDB\x10\x01\"I\n\x10\x44ropoutParameter\x12\x1a\n\rdropout_ratio\x18\x01 \x01(\x02:\x03\x30.5\x12\x19\n\x0bscale_train\x18\x02 \x01(\x08:\x04true\"\xa0\x01\n\x12\x44ummyDataParameter\x12+\n\x0b\x64\x61ta_filler\x18\x01 \x03(\x0b\x32\x16.caffe.FillerParameter\x12\x1f\n\x05shape\x18\x06 \x03(\x0b\x32\x10.caffe.BlobShape\x12\x0b\n\x03num\x18\x02 \x03(\r\x12\x10\n\x08\x63hannels\x18\x03 \x03(\r\x12\x0e\n\x06height\x18\x04 \x03(\r\x12\r\n\x05width\x18\x05 \x03(\r\"\xa5\x01\n\x10\x45ltwiseParameter\x12\x39\n\toperation\x18\x01 \x01(\x0e\x32!.caffe.EltwiseParameter.EltwiseOp:\x03SUM\x12\r\n\x05\x63oeff\x18\x02 \x03(\x02\x12\x1e\n\x10stable_prod_grad\x18\x03 \x01(\x08:\x04true\"\'\n\tEltwiseOp\x12\x08\n\x04PROD\x10\x00\x12\x07\n\x03SUM\x10\x01\x12\x07\n\x03MAX\x10\x02\" \n\x0c\x45LUParameter\x12\x10\n\x05\x61lpha\x18\x01 \x01(\x02:\x01\x31\"\xac\x01\n\x0e\x45mbedParameter\x12\x12\n\nnum_output\x18\x01 \x01(\r\x12\x11\n\tinput_dim\x18\x02 \x01(\r\x12\x17\n\tbias_term\x18\x03 \x01(\x08:\x04true\x12-\n\rweight_filler\x18\x04 \x01(\x0b\x32\x16.caffe.FillerParameter\x12+\n\x0b\x62ias_filler\x18\x05 \x01(\x0b\x32\x16.caffe.FillerParameter\"D\n\x0c\x45xpParameter\x12\x10\n\x04\x62\x61se\x18\x01 \x01(\x02:\x02-1\x12\x10\n\x05scale\x18\x02 \x01(\x02:\x01\x31\x12\x10\n\x05shift\x18\x03 \x01(\x02:\x01\x30\"9\n\x10\x46lattenParameter\x12\x0f\n\x04\x61xis\x18\x01 \x01(\x05:\x01\x31\x12\x14\n\x08\x65nd_axis\x18\x02 \x01(\x05:\x02-1\"O\n\x11HDF5DataParameter\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x12\n\nbatch_size\x18\x02 \x01(\r\x12\x16\n\x07shuffle\x18\x03 \x01(\x08:\x05\x66\x61lse\"(\n\x13HDF5OutputParameter\x12\x11\n\tfile_name\x18\x01 \x01(\t\"^\n\x12HingeLossParameter\x12\x30\n\x04norm\x18\x01 \x01(\x0e\x32\x1e.caffe.HingeLossParameter.Norm:\x02L1\"\x16\n\x04Norm\x12\x06\n\x02L1\x10\x01\x12\x06\n\x02L2\x10\x02\"\x97\x02\n\x12ImageDataParameter\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x15\n\nbatch_size\x18\x04 \x01(\r:\x01\x31\x12\x14\n\trand_skip\x18\x07 \x01(\r:\x01\x30\x12\x16\n\x07shuffle\x18\x08 \x01(\x08:\x05\x66\x61lse\x12\x15\n\nnew_height\x18\t \x01(\r:\x01\x30\x12\x14\n\tnew_width\x18\n \x01(\r:\x01\x30\x12\x16\n\x08is_color\x18\x0b \x01(\x08:\x04true\x12\x10\n\x05scale\x18\x02 \x01(\x02:\x01\x31\x12\x11\n\tmean_file\x18\x03 \x01(\t\x12\x14\n\tcrop_size\x18\x05 \x01(\r:\x01\x30\x12\x15\n\x06mirror\x18\x06 \x01(\x08:\x05\x66\x61lse\x12\x15\n\x0broot_folder\x18\x0c \x01(\t:\x00\"\'\n\x15InfogainLossParameter\x12\x0e\n\x06source\x18\x01 \x01(\t\"\xcb\x01\n\x15InnerProductParameter\x12\x12\n\nnum_output\x18\x01 \x01(\r\x12\x17\n\tbias_term\x18\x02 \x01(\x08:\x04true\x12-\n\rweight_filler\x18\x03 \x01(\x0b\x32\x16.caffe.FillerParameter\x12+\n\x0b\x62ias_filler\x18\x04 \x01(\x0b\x32\x16.caffe.FillerParameter\x12\x0f\n\x04\x61xis\x18\x05 \x01(\x05:\x01\x31\x12\x18\n\ttranspose\x18\x06 \x01(\x08:\x05\x66\x61lse\"1\n\x0eInputParameter\x12\x1f\n\x05shape\x18\x01 \x03(\x0b\x32\x10.caffe.BlobShape\"D\n\x0cLogParameter\x12\x10\n\x04\x62\x61se\x18\x01 \x01(\x02:\x02-1\x12\x10\n\x05scale\x18\x02 \x01(\x02:\x01\x31\x12\x10\n\x05shift\x18\x03 \x01(\x02:\x01\x30\"\xb8\x02\n\x0cLRNParameter\x12\x15\n\nlocal_size\x18\x01 \x01(\r:\x01\x35\x12\x10\n\x05\x61lpha\x18\x02 \x01(\x02:\x01\x31\x12\x12\n\x04\x62\x65ta\x18\x03 \x01(\x02:\x04\x30.75\x12\x44\n\x0bnorm_region\x18\x04 \x01(\x0e\x32\x1e.caffe.LRNParameter.NormRegion:\x0f\x41\x43ROSS_CHANNELS\x12\x0c\n\x01k\x18\x05 \x01(\x02:\x01\x31\x12\x33\n\x06\x65ngine\x18\x06 \x01(\x0e\x32\x1a.caffe.LRNParameter.Engine:\x07\x44\x45\x46\x41ULT\"5\n\nNormRegion\x12\x13\n\x0f\x41\x43ROSS_CHANNELS\x10\x00\x12\x12\n\x0eWITHIN_CHANNEL\x10\x01\"+\n\x06\x45ngine\x12\x0b\n\x07\x44\x45\x46\x41ULT\x10\x00\x12\t\n\x05\x43\x41\x46\x46\x45\x10\x01\x12\t\n\x05\x43UDNN\x10\x02\"\xbd\x01\n\x13MemoryDataParameter\x12\x12\n\nbatch_size\x18\x01 \x01(\r\x12\x10\n\x08\x63hannels\x18\x02 \x01(\r\x12\x0e\n\x06height\x18\x03 \x01(\r\x12\r\n\x05width\x18\x04 \x01(\r\x12;\n\x05\x64type\x18\x05 \x01(\x0e\x32#.caffe.MemoryDataParameter.DataType:\x07\x46LOAT32\"$\n\x08\x44\x61taType\x12\x0b\n\x07\x46LOAT32\x10\x00\x12\x0b\n\x07\x46LOAT16\x10\x01\"e\n\x0cMVNParameter\x12 \n\x12normalize_variance\x18\x01 \x01(\x08:\x04true\x12\x1e\n\x0f\x61\x63ross_channels\x18\x02 \x01(\x08:\x05\x66\x61lse\x12\x13\n\x03\x65ps\x18\x03 \x01(\x02:\x06\x31\x65-009\"5\n\x12ParameterParameter\x12\x1f\n\x05shape\x18\x01 \x01(\x0b\x32\x10.caffe.BlobShape\"\xa2\x03\n\x10PoolingParameter\x12\x35\n\x04pool\x18\x01 \x01(\x0e\x32\".caffe.PoolingParameter.PoolMethod:\x03MAX\x12\x0e\n\x03pad\x18\x04 \x01(\r:\x01\x30\x12\x10\n\x05pad_h\x18\t \x01(\r:\x01\x30\x12\x10\n\x05pad_w\x18\n \x01(\r:\x01\x30\x12\x13\n\x0bkernel_size\x18\x02 \x01(\r\x12\x10\n\x08kernel_h\x18\x05 \x01(\r\x12\x10\n\x08kernel_w\x18\x06 \x01(\r\x12\x11\n\x06stride\x18\x03 \x01(\r:\x01\x31\x12\x10\n\x08stride_h\x18\x07 \x01(\r\x12\x10\n\x08stride_w\x18\x08 \x01(\r\x12\x37\n\x06\x65ngine\x18\x0b \x01(\x0e\x32\x1e.caffe.PoolingParameter.Engine:\x07\x44\x45\x46\x41ULT\x12\x1d\n\x0eglobal_pooling\x18\x0c \x01(\x08:\x05\x66\x61lse\".\n\nPoolMethod\x12\x07\n\x03MAX\x10\x00\x12\x07\n\x03\x41VE\x10\x01\x12\x0e\n\nSTOCHASTIC\x10\x02\"+\n\x06\x45ngine\x12\x0b\n\x07\x44\x45\x46\x41ULT\x10\x00\x12\t\n\x05\x43\x41\x46\x46\x45\x10\x01\x12\t\n\x05\x43UDNN\x10\x02\"Y\n\x13ROIPoolingParameter\x12\x13\n\x08pooled_h\x18\x01 \x01(\r:\x01\x30\x12\x13\n\x08pooled_w\x18\x02 \x01(\r:\x01\x30\x12\x18\n\rspatial_scale\x18\x03 \x01(\x02:\x01\x31\"F\n\x0ePowerParameter\x12\x10\n\x05power\x18\x01 \x01(\x02:\x01\x31\x12\x10\n\x05scale\x18\x02 \x01(\x02:\x01\x31\x12\x10\n\x05shift\x18\x03 \x01(\x02:\x01\x30\"g\n\x0fPythonParameter\x12\x0e\n\x06module\x18\x01 \x01(\t\x12\r\n\x05layer\x18\x02 \x01(\t\x12\x13\n\tparam_str\x18\x03 \x01(\t:\x00\x12 \n\x11share_in_parallel\x18\x04 \x01(\x08:\x05\x66\x61lse\"\xad\x01\n\x12ReductionParameter\x12=\n\toperation\x18\x01 \x01(\x0e\x32%.caffe.ReductionParameter.ReductionOp:\x03SUM\x12\x0f\n\x04\x61xis\x18\x02 \x01(\x05:\x01\x30\x12\x10\n\x05\x63oeff\x18\x03 \x01(\x02:\x01\x31\"5\n\x0bReductionOp\x12\x07\n\x03SUM\x10\x01\x12\x08\n\x04\x41SUM\x10\x02\x12\t\n\x05SUMSQ\x10\x03\x12\x08\n\x04MEAN\x10\x04\"\x8d\x01\n\rReLUParameter\x12\x19\n\x0enegative_slope\x18\x01 \x01(\x02:\x01\x30\x12\x34\n\x06\x65ngine\x18\x02 \x01(\x0e\x32\x1b.caffe.ReLUParameter.Engine:\x07\x44\x45\x46\x41ULT\"+\n\x06\x45ngine\x12\x0b\n\x07\x44\x45\x46\x41ULT\x10\x00\x12\t\n\x05\x43\x41\x46\x46\x45\x10\x01\x12\t\n\x05\x43UDNN\x10\x02\"Z\n\x10ReshapeParameter\x12\x1f\n\x05shape\x18\x01 \x01(\x0b\x32\x10.caffe.BlobShape\x12\x0f\n\x04\x61xis\x18\x02 \x01(\x05:\x01\x30\x12\x14\n\x08num_axes\x18\x03 \x01(\x05:\x02-1\"\xa5\x01\n\x0eScaleParameter\x12\x0f\n\x04\x61xis\x18\x01 \x01(\x05:\x01\x31\x12\x13\n\x08num_axes\x18\x02 \x01(\x05:\x01\x31\x12&\n\x06\x66iller\x18\x03 \x01(\x0b\x32\x16.caffe.FillerParameter\x12\x18\n\tbias_term\x18\x04 \x01(\x08:\x05\x66\x61lse\x12+\n\x0b\x62ias_filler\x18\x05 \x01(\x0b\x32\x16.caffe.FillerParameter\"x\n\x10SigmoidParameter\x12\x37\n\x06\x65ngine\x18\x01 \x01(\x0e\x32\x1e.caffe.SigmoidParameter.Engine:\x07\x44\x45\x46\x41ULT\"+\n\x06\x45ngine\x12\x0b\n\x07\x44\x45\x46\x41ULT\x10\x00\x12\t\n\x05\x43\x41\x46\x46\x45\x10\x01\x12\t\n\x05\x43UDNN\x10\x02\"L\n\x0eSliceParameter\x12\x0f\n\x04\x61xis\x18\x03 \x01(\x05:\x01\x31\x12\x13\n\x0bslice_point\x18\x02 \x03(\r\x12\x14\n\tslice_dim\x18\x01 \x01(\r:\x01\x31\"\x89\x01\n\x10SoftmaxParameter\x12\x37\n\x06\x65ngine\x18\x01 \x01(\x0e\x32\x1e.caffe.SoftmaxParameter.Engine:\x07\x44\x45\x46\x41ULT\x12\x0f\n\x04\x61xis\x18\x02 \x01(\x05:\x01\x31\"+\n\x06\x45ngine\x12\x0b\n\x07\x44\x45\x46\x41ULT\x10\x00\x12\t\n\x05\x43\x41\x46\x46\x45\x10\x01\x12\t\n\x05\x43UDNN\x10\x02\"r\n\rTanHParameter\x12\x34\n\x06\x65ngine\x18\x01 \x01(\x0e\x32\x1b.caffe.TanHParameter.Engine:\x07\x44\x45\x46\x41ULT\"+\n\x06\x45ngine\x12\x0b\n\x07\x44\x45\x46\x41ULT\x10\x00\x12\t\n\x05\x43\x41\x46\x46\x45\x10\x01\x12\t\n\x05\x43UDNN\x10\x02\"T\n\rTileParameter\x12\x0f\n\x04\x61xis\x18\x01 \x01(\x05:\x01\x31\x12\r\n\x05tiles\x18\x02 \x01(\x05\x12#\n\tmultiples\x18\x03 \x01(\x0b\x32\x10.caffe.BlobShape\"*\n\x12ThresholdParameter\x12\x14\n\tthreshold\x18\x01 \x01(\x02:\x01\x30\"\xc1\x02\n\x13WindowDataParameter\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x10\n\x05scale\x18\x02 \x01(\x02:\x01\x31\x12\x11\n\tmean_file\x18\x03 \x01(\t\x12\x12\n\nbatch_size\x18\x04 \x01(\r\x12\x14\n\tcrop_size\x18\x05 \x01(\r:\x01\x30\x12\x15\n\x06mirror\x18\x06 \x01(\x08:\x05\x66\x61lse\x12\x19\n\x0c\x66g_threshold\x18\x07 \x01(\x02:\x03\x30.5\x12\x19\n\x0c\x62g_threshold\x18\x08 \x01(\x02:\x03\x30.5\x12\x19\n\x0b\x66g_fraction\x18\t \x01(\x02:\x04\x30.25\x12\x16\n\x0b\x63ontext_pad\x18\n \x01(\r:\x01\x30\x12\x17\n\tcrop_mode\x18\x0b \x01(\t:\x04warp\x12\x1b\n\x0c\x63\x61\x63he_images\x18\x0c \x01(\x08:\x05\x66\x61lse\x12\x15\n\x0broot_folder\x18\r \x01(\t:\x00\"\xeb\x01\n\x0cSPPParameter\x12\x16\n\x0epyramid_height\x18\x01 \x01(\r\x12\x31\n\x04pool\x18\x02 \x01(\x0e\x32\x1e.caffe.SPPParameter.PoolMethod:\x03MAX\x12\x33\n\x06\x65ngine\x18\x06 \x01(\x0e\x32\x1a.caffe.SPPParameter.Engine:\x07\x44\x45\x46\x41ULT\".\n\nPoolMethod\x12\x07\n\x03MAX\x10\x00\x12\x07\n\x03\x41VE\x10\x01\x12\x0e\n\nSTOCHASTIC\x10\x02\"+\n\x06\x45ngine\x12\x0b\n\x07\x44\x45\x46\x41ULT\x10\x00\x12\t\n\x05\x43\x41\x46\x46\x45\x10\x01\x12\t\n\x05\x43UDNN\x10\x02\"\xe0\x13\n\x10V1LayerParameter\x12\x0e\n\x06\x62ottom\x18\x02 \x03(\t\x12\x0b\n\x03top\x18\x03 \x03(\t\x12\x0c\n\x04name\x18\x04 \x01(\t\x12$\n\x07include\x18 \x03(\x0b\x32\x13.caffe.NetStateRule\x12$\n\x07\x65xclude\x18! \x03(\x0b\x32\x13.caffe.NetStateRule\x12/\n\x04type\x18\x05 \x01(\x0e\x32!.caffe.V1LayerParameter.LayerType\x12\x1f\n\x05\x62lobs\x18\x06 \x03(\x0b\x32\x10.caffe.BlobProto\x12\x0e\n\x05param\x18\xe9\x07 \x03(\t\x12>\n\x0f\x62lob_share_mode\x18\xea\x07 \x03(\x0e\x32$.caffe.V1LayerParameter.DimCheckMode\x12\x10\n\x08\x62lobs_lr\x18\x07 \x03(\x02\x12\x14\n\x0cweight_decay\x18\x08 \x03(\x02\x12\x13\n\x0bloss_weight\x18# \x03(\x02\x12\x30\n\x0e\x61\x63\x63uracy_param\x18\x1b \x01(\x0b\x32\x18.caffe.AccuracyParameter\x12,\n\x0c\x61rgmax_param\x18\x17 \x01(\x0b\x32\x16.caffe.ArgMaxParameter\x12,\n\x0c\x63oncat_param\x18\t \x01(\x0b\x32\x16.caffe.ConcatParameter\x12?\n\x16\x63ontrastive_loss_param\x18( \x01(\x0b\x32\x1f.caffe.ContrastiveLossParameter\x12\x36\n\x11\x63onvolution_param\x18\n \x01(\x0b\x32\x1b.caffe.ConvolutionParameter\x12(\n\ndata_param\x18\x0b \x01(\x0b\x32\x14.caffe.DataParameter\x12.\n\rdropout_param\x18\x0c \x01(\x0b\x32\x17.caffe.DropoutParameter\x12\x33\n\x10\x64ummy_data_param\x18\x1a \x01(\x0b\x32\x19.caffe.DummyDataParameter\x12.\n\reltwise_param\x18\x18 \x01(\x0b\x32\x17.caffe.EltwiseParameter\x12&\n\texp_param\x18) \x01(\x0b\x32\x13.caffe.ExpParameter\x12\x31\n\x0fhdf5_data_param\x18\r \x01(\x0b\x32\x18.caffe.HDF5DataParameter\x12\x35\n\x11hdf5_output_param\x18\x0e \x01(\x0b\x32\x1a.caffe.HDF5OutputParameter\x12\x33\n\x10hinge_loss_param\x18\x1d \x01(\x0b\x32\x19.caffe.HingeLossParameter\x12\x33\n\x10image_data_param\x18\x0f \x01(\x0b\x32\x19.caffe.ImageDataParameter\x12\x39\n\x13infogain_loss_param\x18\x10 \x01(\x0b\x32\x1c.caffe.InfogainLossParameter\x12\x39\n\x13inner_product_param\x18\x11 \x01(\x0b\x32\x1c.caffe.InnerProductParameter\x12&\n\tlrn_param\x18\x12 \x01(\x0b\x32\x13.caffe.LRNParameter\x12\x35\n\x11memory_data_param\x18\x16 \x01(\x0b\x32\x1a.caffe.MemoryDataParameter\x12&\n\tmvn_param\x18\" \x01(\x0b\x32\x13.caffe.MVNParameter\x12.\n\rpooling_param\x18\x13 \x01(\x0b\x32\x17.caffe.PoolingParameter\x12*\n\x0bpower_param\x18\x15 \x01(\x0b\x32\x15.caffe.PowerParameter\x12(\n\nrelu_param\x18\x1e \x01(\x0b\x32\x14.caffe.ReLUParameter\x12.\n\rsigmoid_param\x18& \x01(\x0b\x32\x17.caffe.SigmoidParameter\x12.\n\rsoftmax_param\x18\' \x01(\x0b\x32\x17.caffe.SoftmaxParameter\x12*\n\x0bslice_param\x18\x1f \x01(\x0b\x32\x15.caffe.SliceParameter\x12(\n\ntanh_param\x18% \x01(\x0b\x32\x14.caffe.TanHParameter\x12\x32\n\x0fthreshold_param\x18\x19 \x01(\x0b\x32\x19.caffe.ThresholdParameter\x12\x35\n\x11window_data_param\x18\x14 \x01(\x0b\x32\x1a.caffe.WindowDataParameter\x12\x37\n\x0ftransform_param\x18$ \x01(\x0b\x32\x1e.caffe.TransformationParameter\x12(\n\nloss_param\x18* \x01(\x0b\x32\x14.caffe.LossParameter\x12&\n\x05layer\x18\x01 \x01(\x0b\x32\x17.caffe.V0LayerParameter\"\xd8\x04\n\tLayerType\x12\x08\n\x04NONE\x10\x00\x12\n\n\x06\x41\x42SVAL\x10#\x12\x0c\n\x08\x41\x43\x43URACY\x10\x01\x12\n\n\x06\x41RGMAX\x10\x1e\x12\x08\n\x04\x42NLL\x10\x02\x12\n\n\x06\x43ONCAT\x10\x03\x12\x14\n\x10\x43ONTRASTIVE_LOSS\x10%\x12\x0f\n\x0b\x43ONVOLUTION\x10\x04\x12\x08\n\x04\x44\x41TA\x10\x05\x12\x11\n\rDECONVOLUTION\x10\'\x12\x0b\n\x07\x44ROPOUT\x10\x06\x12\x0e\n\nDUMMY_DATA\x10 \x12\x12\n\x0e\x45UCLIDEAN_LOSS\x10\x07\x12\x0b\n\x07\x45LTWISE\x10\x19\x12\x07\n\x03\x45XP\x10&\x12\x0b\n\x07\x46LATTEN\x10\x08\x12\r\n\tHDF5_DATA\x10\t\x12\x0f\n\x0bHDF5_OUTPUT\x10\n\x12\x0e\n\nHINGE_LOSS\x10\x1c\x12\n\n\x06IM2COL\x10\x0b\x12\x0e\n\nIMAGE_DATA\x10\x0c\x12\x11\n\rINFOGAIN_LOSS\x10\r\x12\x11\n\rINNER_PRODUCT\x10\x0e\x12\x07\n\x03LRN\x10\x0f\x12\x0f\n\x0bMEMORY_DATA\x10\x1d\x12\x1d\n\x19MULTINOMIAL_LOGISTIC_LOSS\x10\x10\x12\x07\n\x03MVN\x10\"\x12\x0b\n\x07POOLING\x10\x11\x12\t\n\x05POWER\x10\x1a\x12\x08\n\x04RELU\x10\x12\x12\x0b\n\x07SIGMOID\x10\x13\x12\x1e\n\x1aSIGMOID_CROSS_ENTROPY_LOSS\x10\x1b\x12\x0b\n\x07SILENCE\x10$\x12\x0b\n\x07SOFTMAX\x10\x14\x12\x10\n\x0cSOFTMAX_LOSS\x10\x15\x12\t\n\x05SPLIT\x10\x16\x12\t\n\x05SLICE\x10!\x12\x08\n\x04TANH\x10\x17\x12\x0f\n\x0bWINDOW_DATA\x10\x18\x12\r\n\tTHRESHOLD\x10\x1f\"*\n\x0c\x44imCheckMode\x12\n\n\x06STRICT\x10\x00\x12\x0e\n\nPERMISSIVE\x10\x01\"\xfd\x07\n\x10V0LayerParameter\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\t\x12\x12\n\nnum_output\x18\x03 \x01(\r\x12\x16\n\x08\x62iasterm\x18\x04 \x01(\x08:\x04true\x12-\n\rweight_filler\x18\x05 \x01(\x0b\x32\x16.caffe.FillerParameter\x12+\n\x0b\x62ias_filler\x18\x06 \x01(\x0b\x32\x16.caffe.FillerParameter\x12\x0e\n\x03pad\x18\x07 \x01(\r:\x01\x30\x12\x12\n\nkernelsize\x18\x08 \x01(\r\x12\x10\n\x05group\x18\t \x01(\r:\x01\x31\x12\x11\n\x06stride\x18\n \x01(\r:\x01\x31\x12\x35\n\x04pool\x18\x0b \x01(\x0e\x32\".caffe.V0LayerParameter.PoolMethod:\x03MAX\x12\x1a\n\rdropout_ratio\x18\x0c \x01(\x02:\x03\x30.5\x12\x15\n\nlocal_size\x18\r \x01(\r:\x01\x35\x12\x10\n\x05\x61lpha\x18\x0e \x01(\x02:\x01\x31\x12\x12\n\x04\x62\x65ta\x18\x0f \x01(\x02:\x04\x30.75\x12\x0c\n\x01k\x18\x16 \x01(\x02:\x01\x31\x12\x0e\n\x06source\x18\x10 \x01(\t\x12\x10\n\x05scale\x18\x11 \x01(\x02:\x01\x31\x12\x10\n\x08meanfile\x18\x12 \x01(\t\x12\x11\n\tbatchsize\x18\x13 \x01(\r\x12\x13\n\x08\x63ropsize\x18\x14 \x01(\r:\x01\x30\x12\x15\n\x06mirror\x18\x15 \x01(\x08:\x05\x66\x61lse\x12\x1f\n\x05\x62lobs\x18\x32 \x03(\x0b\x32\x10.caffe.BlobProto\x12\x10\n\x08\x62lobs_lr\x18\x33 \x03(\x02\x12\x14\n\x0cweight_decay\x18\x34 \x03(\x02\x12\x14\n\trand_skip\x18\x35 \x01(\r:\x01\x30\x12\x1d\n\x10\x64\x65t_fg_threshold\x18\x36 \x01(\x02:\x03\x30.5\x12\x1d\n\x10\x64\x65t_bg_threshold\x18\x37 \x01(\x02:\x03\x30.5\x12\x1d\n\x0f\x64\x65t_fg_fraction\x18\x38 \x01(\x02:\x04\x30.25\x12\x1a\n\x0f\x64\x65t_context_pad\x18: \x01(\r:\x01\x30\x12\x1b\n\rdet_crop_mode\x18; \x01(\t:\x04warp\x12\x12\n\x07new_num\x18< \x01(\x05:\x01\x30\x12\x17\n\x0cnew_channels\x18= \x01(\x05:\x01\x30\x12\x15\n\nnew_height\x18> \x01(\x05:\x01\x30\x12\x14\n\tnew_width\x18? \x01(\x05:\x01\x30\x12\x1d\n\x0eshuffle_images\x18@ \x01(\x08:\x05\x66\x61lse\x12\x15\n\nconcat_dim\x18\x41 \x01(\r:\x01\x31\x12\x36\n\x11hdf5_output_param\x18\xe9\x07 \x01(\x0b\x32\x1a.caffe.HDF5OutputParameter\".\n\nPoolMethod\x12\x07\n\x03MAX\x10\x00\x12\x07\n\x03\x41VE\x10\x01\x12\x0e\n\nSTOCHASTIC\x10\x02\"W\n\x0ePReLUParameter\x12&\n\x06\x66iller\x18\x01 \x01(\x0b\x32\x16.caffe.FillerParameter\x12\x1d\n\x0e\x63hannel_shared\x18\x02 \x01(\x08:\x05\x66\x61lse\")\n\x15SmoothL1LossParameter\x12\x10\n\x05sigma\x18\x01 \x01(\x02:\x01\x31\"H\n\x0cMPIParameter\x12\x0f\n\x04root\x18\x01 \x01(\r:\x01\x30\x12\x12\n\x07\x63omm_id\x18\x02 \x01(\x04:\x01\x30\x12\x13\n\x08group_id\x18\x03 \x01(\x04:\x01\x30\"!\n\x10PermuteParameter\x12\r\n\x05order\x18\x01 \x03(\r\"\x93\x01\n\x12NormalizeParameter\x12\x1c\n\x0e\x61\x63ross_spatial\x18\x01 \x01(\x08:\x04true\x12,\n\x0cscale_filler\x18\x02 \x01(\x0b\x32\x16.caffe.FillerParameter\x12\x1c\n\x0e\x63hannel_shared\x18\x03 \x01(\x08:\x04true\x12\x13\n\x03\x65ps\x18\x04 \x01(\x02:\x06\x31\x65-010\"_\n\x11ParallelParameter\x12\x16\n\x07shuffle\x18\x01 \x01(\x08:\x05\x66\x61lse\x12\x18\n\tnode_step\x18\x02 \x01(\x08:\x05\x66\x61lse\x12\x18\n\tpartition\x18\x03 \x01(\x08:\x05\x66\x61lse\"R\n\x0fResizeParameter\x12\x1f\n\x05shape\x18\x01 \x01(\x0b\x32\x10.caffe.BlobShape\x12\x0e\n\x02\x66x\x18\x02 \x01(\x02:\x02-1\x12\x0e\n\x02\x66y\x18\x03 \x01(\x02:\x02-1\"\'\n\x13\x45xpandDimsParameter\x12\x10\n\x04\x61xis\x18\x01 \x01(\x05:\x02-1\"\x90\x02\n\x11ProposalParameter\x12\x0e\n\x06stride\x18\x01 \x03(\x05\x12\r\n\x05ratio\x18\x02 \x03(\x02\x12\r\n\x05scale\x18\x03 \x03(\x02\x12\x1b\n\rpre_nms_top_n\x18\x04 \x01(\r:\x04\x36\x30\x30\x30\x12\x1b\n\x0epost_nms_top_n\x18\x05 \x01(\r:\x03\x33\x30\x30\x12\x17\n\nnms_thresh\x18\x06 \x01(\x02:\x03\x30.7\x12\x14\n\x08min_size\x18\x07 \x01(\r:\x02\x31\x36\x12\x14\n\tmin_level\x18\x08 \x01(\x05:\x01\x32\x12\x14\n\tmax_level\x18\t \x01(\x05:\x01\x35\x12\x1c\n\x0f\x63\x61nonical_scale\x18\n \x01(\x05:\x03\x32\x32\x34\x12\x1a\n\x0f\x63\x61nonical_level\x18\x0b \x01(\x05:\x01\x34\"\xa6\x01\n\x14\x42\x61tchRenormParameter\x12\x18\n\x10use_global_stats\x18\x01 \x01(\x08\x12$\n\x17moving_average_fraction\x18\x02 \x01(\x02:\x03\x30.9\x12\x12\n\x03\x65ps\x18\x03 \x01(\x02:\x05\x30.001\x12\x10\n\x05r_max\x18\x04 \x01(\x02:\x01\x33\x12\x10\n\x05\x64_max\x18\x05 \x01(\x02:\x01\x35\x12\x16\n\x07t_delta\x18\x06 \x01(\x02:\x05\x30.001\"?\n\x14\x44\x65nseConcatParameter\x12\x0f\n\x04\x61xis\x18\x01 \x01(\x05:\x01\x31\x12\x16\n\x0bgrowth_rate\x18\x02 \x01(\x05:\x01\x30\"c\n\x12\x46ocalLossParameter\x12\x12\n\x05\x61lpha\x18\x01 \x01(\x02:\x03\x30.5\x12\x10\n\x05gamma\x18\x02 \x01(\x02:\x01\x30\x12\x13\n\x03\x65ps\x18\x03 \x01(\x02:\x06\x31\x65-010\x12\x12\n\x06neg_id\x18\x04 \x01(\x05:\x02-1\"\"\n\x0fGatherParameter\x12\x0f\n\x04\x61xis\x18\x01 \x01(\x05:\x01\x30*\x1c\n\x05Phase\x12\t\n\x05TRAIN\x10\x00\x12\x08\n\x04TEST\x10\x01') serialized_pb=_b('\n\x0b\x63\x61\x66\x66\x65.proto\x12\x05\x63\x61\x66\x66\x65\"\x1c\n\tBlobShape\x12\x0f\n\x03\x64im\x18\x01 \x03(\x03\x42\x02\x10\x01\"\xcc\x01\n\tBlobProto\x12\x1f\n\x05shape\x18\x07 \x01(\x0b\x32\x10.caffe.BlobShape\x12\x10\n\x04\x64\x61ta\x18\x05 \x03(\x02\x42\x02\x10\x01\x12\x10\n\x04\x64iff\x18\x06 \x03(\x02\x42\x02\x10\x01\x12\x17\n\x0b\x64ouble_data\x18\x08 \x03(\x01\x42\x02\x10\x01\x12\x17\n\x0b\x64ouble_diff\x18\t \x03(\x01\x42\x02\x10\x01\x12\x0e\n\x03num\x18\x01 \x01(\x05:\x01\x30\x12\x13\n\x08\x63hannels\x18\x02 \x01(\x05:\x01\x30\x12\x11\n\x06height\x18\x03 \x01(\x05:\x01\x30\x12\x10\n\x05width\x18\x04 \x01(\x05:\x01\x30\"2\n\x0f\x42lobProtoVector\x12\x1f\n\x05\x62lobs\x18\x01 \x03(\x0b\x32\x10.caffe.BlobProto\"\x81\x01\n\x05\x44\x61tum\x12\x10\n\x08\x63hannels\x18\x01 \x01(\x05\x12\x0e\n\x06height\x18\x02 \x01(\x05\x12\r\n\x05width\x18\x03 \x01(\x05\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\x0c\x12\r\n\x05label\x18\x05 \x01(\x05\x12\x12\n\nfloat_data\x18\x06 \x03(\x02\x12\x16\n\x07\x65ncoded\x18\x07 \x01(\x08:\x05\x66\x61lse\"\x8a\x02\n\x0f\x46illerParameter\x12\x16\n\x04type\x18\x01 \x01(\t:\x08\x63onstant\x12\x10\n\x05value\x18\x02 \x01(\x02:\x01\x30\x12\x0e\n\x03min\x18\x03 \x01(\x02:\x01\x30\x12\x0e\n\x03max\x18\x04 \x01(\x02:\x01\x31\x12\x0f\n\x04mean\x18\x05 \x01(\x02:\x01\x30\x12\x0e\n\x03std\x18\x06 \x01(\x02:\x01\x31\x12\x12\n\x06sparse\x18\x07 \x01(\x05:\x02-1\x12\x42\n\rvariance_norm\x18\x08 \x01(\x0e\x32#.caffe.FillerParameter.VarianceNorm:\x06\x46\x41N_IN\"4\n\x0cVarianceNorm\x12\n\n\x06\x46\x41N_IN\x10\x00\x12\x0b\n\x07\x46\x41N_OUT\x10\x01\x12\x0b\n\x07\x41VERAGE\x10\x02\"\x8e\x02\n\x0cNetParameter\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\r\n\x05input\x18\x03 \x03(\t\x12%\n\x0binput_shape\x18\x08 \x03(\x0b\x32\x10.caffe.BlobShape\x12\x11\n\tinput_dim\x18\x04 \x03(\x05\x12\x1d\n\x0e\x66orce_backward\x18\x05 \x01(\x08:\x05\x66\x61lse\x12\x1e\n\x05state\x18\x06 \x01(\x0b\x32\x0f.caffe.NetState\x12\x19\n\ndebug_info\x18\x07 \x01(\x08:\x05\x66\x61lse\x12$\n\x05layer\x18\x64 \x03(\x0b\x32\x15.caffe.LayerParameter\x12\'\n\x06layers\x18\x02 \x03(\x0b\x32\x17.caffe.V1LayerParameter\"\xc9\n\n\x0fSolverParameter\x12\x0b\n\x03net\x18\x18 \x01(\t\x12&\n\tnet_param\x18\x19 \x01(\x0b\x32\x13.caffe.NetParameter\x12\x11\n\ttrain_net\x18\x01 \x01(\t\x12\x10\n\x08test_net\x18\x02 \x03(\t\x12,\n\x0ftrain_net_param\x18\x15 \x01(\x0b\x32\x13.caffe.NetParameter\x12+\n\x0etest_net_param\x18\x16 \x03(\x0b\x32\x13.caffe.NetParameter\x12$\n\x0btrain_state\x18\x1a \x01(\x0b\x32\x0f.caffe.NetState\x12#\n\ntest_state\x18\x1b \x03(\x0b\x32\x0f.caffe.NetState\x12\x11\n\ttest_iter\x18\x03 \x03(\x05\x12\x18\n\rtest_interval\x18\x04 \x01(\x05:\x01\x30\x12 \n\x11test_compute_loss\x18\x13 \x01(\x08:\x05\x66\x61lse\x12!\n\x13test_initialization\x18 \x01(\x08:\x04true\x12\x0f\n\x07\x62\x61se_lr\x18\x05 \x01(\x02\x12\x10\n\x08stage_lr\x18\x32 \x03(\x02\x12\x12\n\nstage_iter\x18\x33 \x03(\x05\x12\x0f\n\x07\x64isplay\x18\x06 \x01(\x05\x12\x17\n\x0c\x61verage_loss\x18! \x01(\x05:\x01\x31\x12\x10\n\x08max_iter\x18\x07 \x01(\x05\x12\x14\n\titer_size\x18$ \x01(\x05:\x01\x31\x12\x11\n\tlr_policy\x18\x08 \x01(\t\x12\r\n\x05gamma\x18\t \x01(\x02\x12\r\n\x05power\x18\n \x01(\x02\x12\x10\n\x08momentum\x18\x0b \x01(\x02\x12\x14\n\x0cweight_decay\x18\x0c \x01(\x02\x12\x1f\n\x13regularization_type\x18\x1d \x01(\t:\x02L2\x12\x10\n\x08stepsize\x18\r \x01(\x05\x12\x11\n\tstepvalue\x18\" \x03(\x05\x12\x1a\n\x0e\x63lip_gradients\x18# \x01(\x02:\x02-1\x12\x13\n\x08snapshot\x18\x0e \x01(\x05:\x01\x30\x12\x17\n\x0fsnapshot_prefix\x18\x0f \x01(\t\x12\x1c\n\rsnapshot_diff\x18\x10 \x01(\x08:\x05\x66\x61lse\x12K\n\x0fsnapshot_format\x18% \x01(\x0e\x32%.caffe.SolverParameter.SnapshotFormat:\x0b\x42INARYPROTO\x12;\n\x0bsolver_mode\x18\x11 \x01(\x0e\x32!.caffe.SolverParameter.SolverMode:\x03GPU\x12\x14\n\tdevice_id\x18\x12 \x01(\x05:\x01\x30\x12\x17\n\x0brandom_seed\x18\x14 \x01(\x03:\x02-1\x12\x11\n\x04type\x18( \x01(\t:\x03SGD\x12\x15\n\x05\x64\x65lta\x18\x1f \x01(\x02:\x06\x31\x65-008\x12\x18\n\tmomentum2\x18\' \x01(\x02:\x05\x30.999\x12\x17\n\trms_decay\x18& \x01(\x02:\x04\x30.99\x12\x19\n\ndebug_info\x18\x17 \x01(\x08:\x05\x66\x61lse\x12\"\n\x14snapshot_after_train\x18\x1c \x01(\x08:\x04true\x12;\n\x0bsolver_type\x18\x1e \x01(\x0e\x32!.caffe.SolverParameter.SolverType:\x03SGD\"+\n\x0eSnapshotFormat\x12\x08\n\x04HDF5\x10\x00\x12\x0f\n\x0b\x42INARYPROTO\x10\x01\"\x1e\n\nSolverMode\x12\x07\n\x03\x43PU\x10\x00\x12\x07\n\x03GPU\x10\x01\"U\n\nSolverType\x12\x07\n\x03SGD\x10\x00\x12\x0c\n\x08NESTEROV\x10\x01\x12\x0b\n\x07\x41\x44\x41GRAD\x10\x02\x12\x0b\n\x07RMSPROP\x10\x03\x12\x0c\n\x08\x41\x44\x41\x44\x45LTA\x10\x04\x12\x08\n\x04\x41\x44\x41M\x10\x05\"l\n\x0bSolverState\x12\x0c\n\x04iter\x18\x01 \x01(\x05\x12\x13\n\x0blearned_net\x18\x02 \x01(\t\x12!\n\x07history\x18\x03 \x03(\x0b\x32\x10.caffe.BlobProto\x12\x17\n\x0c\x63urrent_step\x18\x04 \x01(\x05:\x01\x30\"N\n\x08NetState\x12!\n\x05phase\x18\x01 \x01(\x0e\x32\x0c.caffe.Phase:\x04TEST\x12\x10\n\x05level\x18\x02 \x01(\x05:\x01\x30\x12\r\n\x05stage\x18\x03 \x03(\t\"\x85\x01\n\x0cNetStateRule\x12\x1b\n\x05phase\x18\x01 \x01(\x0e\x32\x0c.caffe.Phase\x12\x11\n\tmin_level\x18\x02 \x01(\x05\x12\x11\n\tmax_level\x18\x03 \x01(\x05\x12\r\n\x05stage\x18\x04 \x03(\t\x12\x11\n\tnot_stage\x18\x05 \x03(\t\x12\x10\n\x08mpi_rank\x18\x06 \x03(\r\"\xa3\x01\n\tParamSpec\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x31\n\nshare_mode\x18\x02 \x01(\x0e\x32\x1d.caffe.ParamSpec.DimCheckMode\x12\x12\n\x07lr_mult\x18\x03 \x01(\x02:\x01\x31\x12\x15\n\ndecay_mult\x18\x04 \x01(\x02:\x01\x31\"*\n\x0c\x44imCheckMode\x12\n\n\x06STRICT\x10\x00\x12\x0e\n\nPERMISSIVE\x10\x01\"\xcb\x19\n\x0eLayerParameter\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\t\x12\x0e\n\x06\x62ottom\x18\x03 \x03(\t\x12\x0b\n\x03top\x18\x04 \x03(\t\x12\x1c\n\x0cmirror_stage\x18\xa2\x01 \x01(\x08:\x05\x66\x61lse\x12\x1b\n\x05phase\x18\n \x01(\x0e\x32\x0c.caffe.Phase\x12\x13\n\x0bloss_weight\x18\x05 \x03(\x02\x12\x1f\n\x05param\x18\x06 \x03(\x0b\x32\x10.caffe.ParamSpec\x12\x1f\n\x05\x62lobs\x18\x07 \x03(\x0b\x32\x10.caffe.BlobProto\x12\x16\n\x0epropagate_down\x18\x0b \x03(\x08\x12$\n\x07include\x18\x08 \x03(\x0b\x32\x13.caffe.NetStateRule\x12$\n\x07\x65xclude\x18\t \x03(\x0b\x32\x13.caffe.NetStateRule\x12\x37\n\x0ftransform_param\x18\x64 \x01(\x0b\x32\x1e.caffe.TransformationParameter\x12(\n\nloss_param\x18\x65 \x01(\x0b\x32\x14.caffe.LossParameter\x12\x30\n\x0e\x61\x63\x63uracy_param\x18\x66 \x01(\x0b\x32\x18.caffe.AccuracyParameter\x12,\n\x0c\x61rgmax_param\x18g \x01(\x0b\x32\x16.caffe.ArgMaxParameter\x12\x34\n\x10\x62\x61tch_norm_param\x18\x8b\x01 \x01(\x0b\x32\x19.caffe.BatchNormParameter\x12)\n\nbias_param\x18\x8d\x01 \x01(\x0b\x32\x14.caffe.BiasParameter\x12,\n\x0c\x63oncat_param\x18h \x01(\x0b\x32\x16.caffe.ConcatParameter\x12?\n\x16\x63ontrastive_loss_param\x18i \x01(\x0b\x32\x1f.caffe.ContrastiveLossParameter\x12\x36\n\x11\x63onvolution_param\x18j \x01(\x0b\x32\x1b.caffe.ConvolutionParameter\x12)\n\ncrop_param\x18\x90\x01 \x01(\x0b\x32\x14.caffe.CropParameter\x12(\n\ndata_param\x18k \x01(\x0b\x32\x14.caffe.DataParameter\x12.\n\rdropout_param\x18l \x01(\x0b\x32\x17.caffe.DropoutParameter\x12\x33\n\x10\x64ummy_data_param\x18m \x01(\x0b\x32\x19.caffe.DummyDataParameter\x12.\n\reltwise_param\x18n \x01(\x0b\x32\x17.caffe.EltwiseParameter\x12\'\n\telu_param\x18\x8c\x01 \x01(\x0b\x32\x13.caffe.ELUParameter\x12+\n\x0b\x65mbed_param\x18\x89\x01 \x01(\x0b\x32\x15.caffe.EmbedParameter\x12&\n\texp_param\x18o \x01(\x0b\x32\x13.caffe.ExpParameter\x12/\n\rflatten_param\x18\x87\x01 \x01(\x0b\x32\x17.caffe.FlattenParameter\x12\x31\n\x0fhdf5_data_param\x18p \x01(\x0b\x32\x18.caffe.HDF5DataParameter\x12\x35\n\x11hdf5_output_param\x18q \x01(\x0b\x32\x1a.caffe.HDF5OutputParameter\x12\x33\n\x10hinge_loss_param\x18r \x01(\x0b\x32\x19.caffe.HingeLossParameter\x12\x33\n\x10image_data_param\x18s \x01(\x0b\x32\x19.caffe.ImageDataParameter\x12\x39\n\x13infogain_loss_param\x18t \x01(\x0b\x32\x1c.caffe.InfogainLossParameter\x12\x39\n\x13inner_product_param\x18u \x01(\x0b\x32\x1c.caffe.InnerProductParameter\x12+\n\x0binput_param\x18\x8f\x01 \x01(\x0b\x32\x15.caffe.InputParameter\x12\'\n\tlog_param\x18\x86\x01 \x01(\x0b\x32\x13.caffe.LogParameter\x12&\n\tlrn_param\x18v \x01(\x0b\x32\x13.caffe.LRNParameter\x12\x35\n\x11memory_data_param\x18w \x01(\x0b\x32\x1a.caffe.MemoryDataParameter\x12&\n\tmvn_param\x18x \x01(\x0b\x32\x13.caffe.MVNParameter\x12\x33\n\x0fparameter_param\x18\x91\x01 \x01(\x0b\x32\x19.caffe.ParameterParameter\x12.\n\rpooling_param\x18y \x01(\x0b\x32\x17.caffe.PoolingParameter\x12*\n\x0bpower_param\x18z \x01(\x0b\x32\x15.caffe.PowerParameter\x12+\n\x0bprelu_param\x18\x83\x01 \x01(\x0b\x32\x15.caffe.PReLUParameter\x12-\n\x0cpython_param\x18\x82\x01 \x01(\x0b\x32\x16.caffe.PythonParameter\x12\x33\n\x0freduction_param\x18\x88\x01 \x01(\x0b\x32\x19.caffe.ReductionParameter\x12(\n\nrelu_param\x18{ \x01(\x0b\x32\x14.caffe.ReLUParameter\x12/\n\rreshape_param\x18\x85\x01 \x01(\x0b\x32\x17.caffe.ReshapeParameter\x12+\n\x0bscale_param\x18\x8e\x01 \x01(\x0b\x32\x15.caffe.ScaleParameter\x12.\n\rsigmoid_param\x18| \x01(\x0b\x32\x17.caffe.SigmoidParameter\x12.\n\rsoftmax_param\x18} \x01(\x0b\x32\x17.caffe.SoftmaxParameter\x12\'\n\tspp_param\x18\x84\x01 \x01(\x0b\x32\x13.caffe.SPPParameter\x12*\n\x0bslice_param\x18~ \x01(\x0b\x32\x15.caffe.SliceParameter\x12(\n\ntanh_param\x18\x7f \x01(\x0b\x32\x14.caffe.TanHParameter\x12\x33\n\x0fthreshold_param\x18\x80\x01 \x01(\x0b\x32\x19.caffe.ThresholdParameter\x12)\n\ntile_param\x18\x8a\x01 \x01(\x0b\x32\x14.caffe.TileParameter\x12\x36\n\x11window_data_param\x18\x81\x01 \x01(\x0b\x32\x1a.caffe.WindowDataParameter\x12\x36\n\x11roi_pooling_param\x18\x97\x01 \x01(\x0b\x32\x1a.caffe.ROIPoolingParameter\x12;\n\x14smooth_l1_loss_param\x18\x98\x01 \x01(\x0b\x32\x1c.caffe.SmoothL1LossParameter\x12\'\n\tmpi_param\x18\x99\x01 \x01(\x0b\x32\x13.caffe.MPIParameter\x12/\n\rpermute_param\x18\x9a\x01 \x01(\x0b\x32\x17.caffe.PermuteParameter\x12\x33\n\x0fnormalize_param\x18\x9b\x01 \x01(\x0b\x32\x19.caffe.NormalizeParameter\x12\x31\n\x0eparallel_param\x18\x9d\x01 \x01(\x0b\x32\x18.caffe.ParallelParameter\x12-\n\x0cresize_param\x18\x9e\x01 \x01(\x0b\x32\x16.caffe.ResizeParameter\x12\x36\n\x11\x65xpand_dims_param\x18\x9f\x01 \x01(\x0b\x32\x1a.caffe.ExpandDimsParameter\x12\x31\n\x0eproposal_param\x18\xa0\x01 \x01(\x0b\x32\x18.caffe.ProposalParameter\x12\x38\n\x12\x62\x61tch_renorm_param\x18\xa1\x01 \x01(\x0b\x32\x1b.caffe.BatchRenormParameter\x12\x38\n\x12\x64\x65nse_concat_param\x18\xa3\x01 \x01(\x0b\x32\x1b.caffe.DenseConcatParameter\x12\x34\n\x10\x66ocal_loss_param\x18\xa4\x01 \x01(\x0b\x32\x19.caffe.FocalLossParameter\x12-\n\x0cgather_param\x18\xa5\x01 \x01(\x0b\x32\x16.caffe.GatherParameter\x12\x34\n\x10group_norm_param\x18\xa6\x01 \x01(\x0b\x32\x19.caffe.GroupNormParameter\"\xa7\x02\n\x17TransformationParameter\x12\x10\n\x05scale\x18\x01 \x01(\x02:\x01\x31\x12\x15\n\x06mirror\x18\x02 \x01(\x08:\x05\x66\x61lse\x12\x14\n\tcrop_size\x18\x03 \x01(\r:\x01\x30\x12\x12\n\x07padding\x18\x0b \x01(\r:\x01\x30\x12\x11\n\tmean_file\x18\x04 \x01(\t\x12\x12\n\nmean_value\x18\x05 \x03(\x02\x12\x1a\n\x0b\x66orce_color\x18\x06 \x01(\x08:\x05\x66\x61lse\x12\x19\n\nforce_gray\x18\x07 \x01(\x08:\x05\x66\x61lse\x12!\n\x12\x63olor_augmentation\x18\x08 \x01(\x08:\x05\x66\x61lse\x12\x1b\n\x10min_random_scale\x18\t \x01(\x02:\x01\x31\x12\x1b\n\x10max_random_scale\x18\n \x01(\x02:\x01\x31\"\xf5\x01\n\rLossParameter\x12\x14\n\x0cignore_label\x18\x01 \x01(\x05\x12\x44\n\rnormalization\x18\x03 \x01(\x0e\x32&.caffe.LossParameter.NormalizationMode:\x05VALID\x12\x11\n\tnormalize\x18\x02 \x01(\x08\x1a\'\n\x13\x45xpandDimsParameter\x12\x10\n\x04\x61xis\x18\x01 \x01(\x05:\x02-1\"L\n\x11NormalizationMode\x12\x08\n\x04\x46ULL\x10\x00\x12\t\n\x05VALID\x10\x01\x12\x0e\n\nBATCH_SIZE\x10\x02\x12\x08\n\x04NONE\x10\x03\x12\x08\n\x04UNIT\x10\x04\"L\n\x11\x41\x63\x63uracyParameter\x12\x10\n\x05top_k\x18\x01 \x01(\r:\x01\x31\x12\x0f\n\x04\x61xis\x18\x02 \x01(\x05:\x01\x31\x12\x14\n\x0cignore_label\x18\x03 \x01(\x05\"M\n\x0f\x41rgMaxParameter\x12\x1a\n\x0bout_max_val\x18\x01 \x01(\x08:\x05\x66\x61lse\x12\x10\n\x05top_k\x18\x02 \x01(\r:\x01\x31\x12\x0c\n\x04\x61xis\x18\x03 \x01(\x05\"9\n\x0f\x43oncatParameter\x12\x0f\n\x04\x61xis\x18\x02 \x01(\x05:\x01\x31\x12\x15\n\nconcat_dim\x18\x01 \x01(\r:\x01\x31\"h\n\x12\x42\x61tchNormParameter\x12\x18\n\x10use_global_stats\x18\x01 \x01(\x08\x12$\n\x17moving_average_fraction\x18\x02 \x01(\x02:\x03\x30.9\x12\x12\n\x03\x65ps\x18\x03 \x01(\x02:\x05\x30.001\"]\n\rBiasParameter\x12\x0f\n\x04\x61xis\x18\x01 \x01(\x05:\x01\x31\x12\x13\n\x08num_axes\x18\x02 \x01(\x05:\x01\x31\x12&\n\x06\x66iller\x18\x03 \x01(\x0b\x32\x16.caffe.FillerParameter\"L\n\x18\x43ontrastiveLossParameter\x12\x11\n\x06margin\x18\x01 \x01(\x02:\x01\x31\x12\x1d\n\x0elegacy_version\x18\x02 \x01(\x08:\x05\x66\x61lse\"\xfc\x03\n\x14\x43onvolutionParameter\x12\x12\n\nnum_output\x18\x01 \x01(\r\x12\x17\n\tbias_term\x18\x02 \x01(\x08:\x04true\x12\x0b\n\x03pad\x18\x03 \x03(\r\x12\x13\n\x0bkernel_size\x18\x04 \x03(\r\x12\x0e\n\x06stride\x18\x06 \x03(\r\x12\x10\n\x08\x64ilation\x18\x12 \x03(\r\x12\x10\n\x05pad_h\x18\t \x01(\r:\x01\x30\x12\x10\n\x05pad_w\x18\n \x01(\r:\x01\x30\x12\x10\n\x08kernel_h\x18\x0b \x01(\r\x12\x10\n\x08kernel_w\x18\x0c \x01(\r\x12\x10\n\x08stride_h\x18\r \x01(\r\x12\x10\n\x08stride_w\x18\x0e \x01(\r\x12\x10\n\x05group\x18\x05 \x01(\r:\x01\x31\x12-\n\rweight_filler\x18\x07 \x01(\x0b\x32\x16.caffe.FillerParameter\x12+\n\x0b\x62ias_filler\x18\x08 \x01(\x0b\x32\x16.caffe.FillerParameter\x12;\n\x06\x65ngine\x18\x0f \x01(\x0e\x32\".caffe.ConvolutionParameter.Engine:\x07\x44\x45\x46\x41ULT\x12\x0f\n\x04\x61xis\x18\x10 \x01(\x05:\x01\x31\x12\x1e\n\x0f\x66orce_nd_im2col\x18\x11 \x01(\x08:\x05\x66\x61lse\"+\n\x06\x45ngine\x12\x0b\n\x07\x44\x45\x46\x41ULT\x10\x00\x12\t\n\x05\x43\x41\x46\x46\x45\x10\x01\x12\t\n\x05\x43UDNN\x10\x02\"0\n\rCropParameter\x12\x0f\n\x04\x61xis\x18\x01 \x01(\x05:\x01\x32\x12\x0e\n\x06offset\x18\x02 \x03(\r\"\xa4\x02\n\rDataParameter\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x12\n\nbatch_size\x18\x04 \x01(\r\x12\x14\n\trand_skip\x18\x07 \x01(\r:\x01\x30\x12\x31\n\x07\x62\x61\x63kend\x18\x08 \x01(\x0e\x32\x17.caffe.DataParameter.DB:\x07LEVELDB\x12\x10\n\x05scale\x18\x02 \x01(\x02:\x01\x31\x12\x11\n\tmean_file\x18\x03 \x01(\t\x12\x14\n\tcrop_size\x18\x05 \x01(\r:\x01\x30\x12\x15\n\x06mirror\x18\x06 \x01(\x08:\x05\x66\x61lse\x12\"\n\x13\x66orce_encoded_color\x18\t \x01(\x08:\x05\x66\x61lse\x12\x13\n\x08prefetch\x18\n \x01(\r:\x01\x35\"\x1b\n\x02\x44\x42\x12\x0b\n\x07LEVELDB\x10\x00\x12\x08\n\x04LMDB\x10\x01\"I\n\x10\x44ropoutParameter\x12\x1a\n\rdropout_ratio\x18\x01 \x01(\x02:\x03\x30.5\x12\x19\n\x0bscale_train\x18\x02 \x01(\x08:\x04true\"\xa0\x01\n\x12\x44ummyDataParameter\x12+\n\x0b\x64\x61ta_filler\x18\x01 \x03(\x0b\x32\x16.caffe.FillerParameter\x12\x1f\n\x05shape\x18\x06 \x03(\x0b\x32\x10.caffe.BlobShape\x12\x0b\n\x03num\x18\x02 \x03(\r\x12\x10\n\x08\x63hannels\x18\x03 \x03(\r\x12\x0e\n\x06height\x18\x04 \x03(\r\x12\r\n\x05width\x18\x05 \x03(\r\"\xa5\x01\n\x10\x45ltwiseParameter\x12\x39\n\toperation\x18\x01 \x01(\x0e\x32!.caffe.EltwiseParameter.EltwiseOp:\x03SUM\x12\r\n\x05\x63oeff\x18\x02 \x03(\x02\x12\x1e\n\x10stable_prod_grad\x18\x03 \x01(\x08:\x04true\"\'\n\tEltwiseOp\x12\x08\n\x04PROD\x10\x00\x12\x07\n\x03SUM\x10\x01\x12\x07\n\x03MAX\x10\x02\" \n\x0c\x45LUParameter\x12\x10\n\x05\x61lpha\x18\x01 \x01(\x02:\x01\x31\"\xac\x01\n\x0e\x45mbedParameter\x12\x12\n\nnum_output\x18\x01 \x01(\r\x12\x11\n\tinput_dim\x18\x02 \x01(\r\x12\x17\n\tbias_term\x18\x03 \x01(\x08:\x04true\x12-\n\rweight_filler\x18\x04 \x01(\x0b\x32\x16.caffe.FillerParameter\x12+\n\x0b\x62ias_filler\x18\x05 \x01(\x0b\x32\x16.caffe.FillerParameter\"D\n\x0c\x45xpParameter\x12\x10\n\x04\x62\x61se\x18\x01 \x01(\x02:\x02-1\x12\x10\n\x05scale\x18\x02 \x01(\x02:\x01\x31\x12\x10\n\x05shift\x18\x03 \x01(\x02:\x01\x30\"9\n\x10\x46lattenParameter\x12\x0f\n\x04\x61xis\x18\x01 \x01(\x05:\x01\x31\x12\x14\n\x08\x65nd_axis\x18\x02 \x01(\x05:\x02-1\"O\n\x11HDF5DataParameter\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x12\n\nbatch_size\x18\x02 \x01(\r\x12\x16\n\x07shuffle\x18\x03 \x01(\x08:\x05\x66\x61lse\"(\n\x13HDF5OutputParameter\x12\x11\n\tfile_name\x18\x01 \x01(\t\"^\n\x12HingeLossParameter\x12\x30\n\x04norm\x18\x01 \x01(\x0e\x32\x1e.caffe.HingeLossParameter.Norm:\x02L1\"\x16\n\x04Norm\x12\x06\n\x02L1\x10\x01\x12\x06\n\x02L2\x10\x02\"\x97\x02\n\x12ImageDataParameter\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x15\n\nbatch_size\x18\x04 \x01(\r:\x01\x31\x12\x14\n\trand_skip\x18\x07 \x01(\r:\x01\x30\x12\x16\n\x07shuffle\x18\x08 \x01(\x08:\x05\x66\x61lse\x12\x15\n\nnew_height\x18\t \x01(\r:\x01\x30\x12\x14\n\tnew_width\x18\n \x01(\r:\x01\x30\x12\x16\n\x08is_color\x18\x0b \x01(\x08:\x04true\x12\x10\n\x05scale\x18\x02 \x01(\x02:\x01\x31\x12\x11\n\tmean_file\x18\x03 \x01(\t\x12\x14\n\tcrop_size\x18\x05 \x01(\r:\x01\x30\x12\x15\n\x06mirror\x18\x06 \x01(\x08:\x05\x66\x61lse\x12\x15\n\x0broot_folder\x18\x0c \x01(\t:\x00\"\'\n\x15InfogainLossParameter\x12\x0e\n\x06source\x18\x01 \x01(\t\"\xcb\x01\n\x15InnerProductParameter\x12\x12\n\nnum_output\x18\x01 \x01(\r\x12\x17\n\tbias_term\x18\x02 \x01(\x08:\x04true\x12-\n\rweight_filler\x18\x03 \x01(\x0b\x32\x16.caffe.FillerParameter\x12+\n\x0b\x62ias_filler\x18\x04 \x01(\x0b\x32\x16.caffe.FillerParameter\x12\x0f\n\x04\x61xis\x18\x05 \x01(\x05:\x01\x31\x12\x18\n\ttranspose\x18\x06 \x01(\x08:\x05\x66\x61lse\"1\n\x0eInputParameter\x12\x1f\n\x05shape\x18\x01 \x03(\x0b\x32\x10.caffe.BlobShape\"D\n\x0cLogParameter\x12\x10\n\x04\x62\x61se\x18\x01 \x01(\x02:\x02-1\x12\x10\n\x05scale\x18\x02 \x01(\x02:\x01\x31\x12\x10\n\x05shift\x18\x03 \x01(\x02:\x01\x30\"\xb8\x02\n\x0cLRNParameter\x12\x15\n\nlocal_size\x18\x01 \x01(\r:\x01\x35\x12\x10\n\x05\x61lpha\x18\x02 \x01(\x02:\x01\x31\x12\x12\n\x04\x62\x65ta\x18\x03 \x01(\x02:\x04\x30.75\x12\x44\n\x0bnorm_region\x18\x04 \x01(\x0e\x32\x1e.caffe.LRNParameter.NormRegion:\x0f\x41\x43ROSS_CHANNELS\x12\x0c\n\x01k\x18\x05 \x01(\x02:\x01\x31\x12\x33\n\x06\x65ngine\x18\x06 \x01(\x0e\x32\x1a.caffe.LRNParameter.Engine:\x07\x44\x45\x46\x41ULT\"5\n\nNormRegion\x12\x13\n\x0f\x41\x43ROSS_CHANNELS\x10\x00\x12\x12\n\x0eWITHIN_CHANNEL\x10\x01\"+\n\x06\x45ngine\x12\x0b\n\x07\x44\x45\x46\x41ULT\x10\x00\x12\t\n\x05\x43\x41\x46\x46\x45\x10\x01\x12\t\n\x05\x43UDNN\x10\x02\"\xbd\x01\n\x13MemoryDataParameter\x12\x12\n\nbatch_size\x18\x01 \x01(\r\x12\x10\n\x08\x63hannels\x18\x02 \x01(\r\x12\x0e\n\x06height\x18\x03 \x01(\r\x12\r\n\x05width\x18\x04 \x01(\r\x12;\n\x05\x64type\x18\x05 \x01(\x0e\x32#.caffe.MemoryDataParameter.DataType:\x07\x46LOAT32\"$\n\x08\x44\x61taType\x12\x0b\n\x07\x46LOAT32\x10\x00\x12\x0b\n\x07\x46LOAT16\x10\x01\"e\n\x0cMVNParameter\x12 \n\x12normalize_variance\x18\x01 \x01(\x08:\x04true\x12\x1e\n\x0f\x61\x63ross_channels\x18\x02 \x01(\x08:\x05\x66\x61lse\x12\x13\n\x03\x65ps\x18\x03 \x01(\x02:\x06\x31\x65-009\"5\n\x12ParameterParameter\x12\x1f\n\x05shape\x18\x01 \x01(\x0b\x32\x10.caffe.BlobShape\"\xa2\x03\n\x10PoolingParameter\x12\x35\n\x04pool\x18\x01 \x01(\x0e\x32\".caffe.PoolingParameter.PoolMethod:\x03MAX\x12\x0e\n\x03pad\x18\x04 \x01(\r:\x01\x30\x12\x10\n\x05pad_h\x18\t \x01(\r:\x01\x30\x12\x10\n\x05pad_w\x18\n \x01(\r:\x01\x30\x12\x13\n\x0bkernel_size\x18\x02 \x01(\r\x12\x10\n\x08kernel_h\x18\x05 \x01(\r\x12\x10\n\x08kernel_w\x18\x06 \x01(\r\x12\x11\n\x06stride\x18\x03 \x01(\r:\x01\x31\x12\x10\n\x08stride_h\x18\x07 \x01(\r\x12\x10\n\x08stride_w\x18\x08 \x01(\r\x12\x37\n\x06\x65ngine\x18\x0b \x01(\x0e\x32\x1e.caffe.PoolingParameter.Engine:\x07\x44\x45\x46\x41ULT\x12\x1d\n\x0eglobal_pooling\x18\x0c \x01(\x08:\x05\x66\x61lse\".\n\nPoolMethod\x12\x07\n\x03MAX\x10\x00\x12\x07\n\x03\x41VE\x10\x01\x12\x0e\n\nSTOCHASTIC\x10\x02\"+\n\x06\x45ngine\x12\x0b\n\x07\x44\x45\x46\x41ULT\x10\x00\x12\t\n\x05\x43\x41\x46\x46\x45\x10\x01\x12\t\n\x05\x43UDNN\x10\x02\"Y\n\x13ROIPoolingParameter\x12\x13\n\x08pooled_h\x18\x01 \x01(\r:\x01\x30\x12\x13\n\x08pooled_w\x18\x02 \x01(\r:\x01\x30\x12\x18\n\rspatial_scale\x18\x03 \x01(\x02:\x01\x31\"F\n\x0ePowerParameter\x12\x10\n\x05power\x18\x01 \x01(\x02:\x01\x31\x12\x10\n\x05scale\x18\x02 \x01(\x02:\x01\x31\x12\x10\n\x05shift\x18\x03 \x01(\x02:\x01\x30\"g\n\x0fPythonParameter\x12\x0e\n\x06module\x18\x01 \x01(\t\x12\r\n\x05layer\x18\x02 \x01(\t\x12\x13\n\tparam_str\x18\x03 \x01(\t:\x00\x12 \n\x11share_in_parallel\x18\x04 \x01(\x08:\x05\x66\x61lse\"\xad\x01\n\x12ReductionParameter\x12=\n\toperation\x18\x01 \x01(\x0e\x32%.caffe.ReductionParameter.ReductionOp:\x03SUM\x12\x0f\n\x04\x61xis\x18\x02 \x01(\x05:\x01\x30\x12\x10\n\x05\x63oeff\x18\x03 \x01(\x02:\x01\x31\"5\n\x0bReductionOp\x12\x07\n\x03SUM\x10\x01\x12\x08\n\x04\x41SUM\x10\x02\x12\t\n\x05SUMSQ\x10\x03\x12\x08\n\x04MEAN\x10\x04\"\x8d\x01\n\rReLUParameter\x12\x19\n\x0enegative_slope\x18\x01 \x01(\x02:\x01\x30\x12\x34\n\x06\x65ngine\x18\x02 \x01(\x0e\x32\x1b.caffe.ReLUParameter.Engine:\x07\x44\x45\x46\x41ULT\"+\n\x06\x45ngine\x12\x0b\n\x07\x44\x45\x46\x41ULT\x10\x00\x12\t\n\x05\x43\x41\x46\x46\x45\x10\x01\x12\t\n\x05\x43UDNN\x10\x02\"Z\n\x10ReshapeParameter\x12\x1f\n\x05shape\x18\x01 \x01(\x0b\x32\x10.caffe.BlobShape\x12\x0f\n\x04\x61xis\x18\x02 \x01(\x05:\x01\x30\x12\x14\n\x08num_axes\x18\x03 \x01(\x05:\x02-1\"\xa5\x01\n\x0eScaleParameter\x12\x0f\n\x04\x61xis\x18\x01 \x01(\x05:\x01\x31\x12\x13\n\x08num_axes\x18\x02 \x01(\x05:\x01\x31\x12&\n\x06\x66iller\x18\x03 \x01(\x0b\x32\x16.caffe.FillerParameter\x12\x18\n\tbias_term\x18\x04 \x01(\x08:\x05\x66\x61lse\x12+\n\x0b\x62ias_filler\x18\x05 \x01(\x0b\x32\x16.caffe.FillerParameter\"x\n\x10SigmoidParameter\x12\x37\n\x06\x65ngine\x18\x01 \x01(\x0e\x32\x1e.caffe.SigmoidParameter.Engine:\x07\x44\x45\x46\x41ULT\"+\n\x06\x45ngine\x12\x0b\n\x07\x44\x45\x46\x41ULT\x10\x00\x12\t\n\x05\x43\x41\x46\x46\x45\x10\x01\x12\t\n\x05\x43UDNN\x10\x02\"L\n\x0eSliceParameter\x12\x0f\n\x04\x61xis\x18\x03 \x01(\x05:\x01\x31\x12\x13\n\x0bslice_point\x18\x02 \x03(\r\x12\x14\n\tslice_dim\x18\x01 \x01(\r:\x01\x31\"\x89\x01\n\x10SoftmaxParameter\x12\x37\n\x06\x65ngine\x18\x01 \x01(\x0e\x32\x1e.caffe.SoftmaxParameter.Engine:\x07\x44\x45\x46\x41ULT\x12\x0f\n\x04\x61xis\x18\x02 \x01(\x05:\x01\x31\"+\n\x06\x45ngine\x12\x0b\n\x07\x44\x45\x46\x41ULT\x10\x00\x12\t\n\x05\x43\x41\x46\x46\x45\x10\x01\x12\t\n\x05\x43UDNN\x10\x02\"r\n\rTanHParameter\x12\x34\n\x06\x65ngine\x18\x01 \x01(\x0e\x32\x1b.caffe.TanHParameter.Engine:\x07\x44\x45\x46\x41ULT\"+\n\x06\x45ngine\x12\x0b\n\x07\x44\x45\x46\x41ULT\x10\x00\x12\t\n\x05\x43\x41\x46\x46\x45\x10\x01\x12\t\n\x05\x43UDNN\x10\x02\"T\n\rTileParameter\x12\x0f\n\x04\x61xis\x18\x01 \x01(\x05:\x01\x31\x12\r\n\x05tiles\x18\x02 \x01(\x05\x12#\n\tmultiples\x18\x03 \x01(\x0b\x32\x10.caffe.BlobShape\"*\n\x12ThresholdParameter\x12\x14\n\tthreshold\x18\x01 \x01(\x02:\x01\x30\"\xc1\x02\n\x13WindowDataParameter\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x10\n\x05scale\x18\x02 \x01(\x02:\x01\x31\x12\x11\n\tmean_file\x18\x03 \x01(\t\x12\x12\n\nbatch_size\x18\x04 \x01(\r\x12\x14\n\tcrop_size\x18\x05 \x01(\r:\x01\x30\x12\x15\n\x06mirror\x18\x06 \x01(\x08:\x05\x66\x61lse\x12\x19\n\x0c\x66g_threshold\x18\x07 \x01(\x02:\x03\x30.5\x12\x19\n\x0c\x62g_threshold\x18\x08 \x01(\x02:\x03\x30.5\x12\x19\n\x0b\x66g_fraction\x18\t \x01(\x02:\x04\x30.25\x12\x16\n\x0b\x63ontext_pad\x18\n \x01(\r:\x01\x30\x12\x17\n\tcrop_mode\x18\x0b \x01(\t:\x04warp\x12\x1b\n\x0c\x63\x61\x63he_images\x18\x0c \x01(\x08:\x05\x66\x61lse\x12\x15\n\x0broot_folder\x18\r \x01(\t:\x00\"\xeb\x01\n\x0cSPPParameter\x12\x16\n\x0epyramid_height\x18\x01 \x01(\r\x12\x31\n\x04pool\x18\x02 \x01(\x0e\x32\x1e.caffe.SPPParameter.PoolMethod:\x03MAX\x12\x33\n\x06\x65ngine\x18\x06 \x01(\x0e\x32\x1a.caffe.SPPParameter.Engine:\x07\x44\x45\x46\x41ULT\".\n\nPoolMethod\x12\x07\n\x03MAX\x10\x00\x12\x07\n\x03\x41VE\x10\x01\x12\x0e\n\nSTOCHASTIC\x10\x02\"+\n\x06\x45ngine\x12\x0b\n\x07\x44\x45\x46\x41ULT\x10\x00\x12\t\n\x05\x43\x41\x46\x46\x45\x10\x01\x12\t\n\x05\x43UDNN\x10\x02\"\xe0\x13\n\x10V1LayerParameter\x12\x0e\n\x06\x62ottom\x18\x02 \x03(\t\x12\x0b\n\x03top\x18\x03 \x03(\t\x12\x0c\n\x04name\x18\x04 \x01(\t\x12$\n\x07include\x18 \x03(\x0b\x32\x13.caffe.NetStateRule\x12$\n\x07\x65xclude\x18! \x03(\x0b\x32\x13.caffe.NetStateRule\x12/\n\x04type\x18\x05 \x01(\x0e\x32!.caffe.V1LayerParameter.LayerType\x12\x1f\n\x05\x62lobs\x18\x06 \x03(\x0b\x32\x10.caffe.BlobProto\x12\x0e\n\x05param\x18\xe9\x07 \x03(\t\x12>\n\x0f\x62lob_share_mode\x18\xea\x07 \x03(\x0e\x32$.caffe.V1LayerParameter.DimCheckMode\x12\x10\n\x08\x62lobs_lr\x18\x07 \x03(\x02\x12\x14\n\x0cweight_decay\x18\x08 \x03(\x02\x12\x13\n\x0bloss_weight\x18# \x03(\x02\x12\x30\n\x0e\x61\x63\x63uracy_param\x18\x1b \x01(\x0b\x32\x18.caffe.AccuracyParameter\x12,\n\x0c\x61rgmax_param\x18\x17 \x01(\x0b\x32\x16.caffe.ArgMaxParameter\x12,\n\x0c\x63oncat_param\x18\t \x01(\x0b\x32\x16.caffe.ConcatParameter\x12?\n\x16\x63ontrastive_loss_param\x18( \x01(\x0b\x32\x1f.caffe.ContrastiveLossParameter\x12\x36\n\x11\x63onvolution_param\x18\n \x01(\x0b\x32\x1b.caffe.ConvolutionParameter\x12(\n\ndata_param\x18\x0b \x01(\x0b\x32\x14.caffe.DataParameter\x12.\n\rdropout_param\x18\x0c \x01(\x0b\x32\x17.caffe.DropoutParameter\x12\x33\n\x10\x64ummy_data_param\x18\x1a \x01(\x0b\x32\x19.caffe.DummyDataParameter\x12.\n\reltwise_param\x18\x18 \x01(\x0b\x32\x17.caffe.EltwiseParameter\x12&\n\texp_param\x18) \x01(\x0b\x32\x13.caffe.ExpParameter\x12\x31\n\x0fhdf5_data_param\x18\r \x01(\x0b\x32\x18.caffe.HDF5DataParameter\x12\x35\n\x11hdf5_output_param\x18\x0e \x01(\x0b\x32\x1a.caffe.HDF5OutputParameter\x12\x33\n\x10hinge_loss_param\x18\x1d \x01(\x0b\x32\x19.caffe.HingeLossParameter\x12\x33\n\x10image_data_param\x18\x0f \x01(\x0b\x32\x19.caffe.ImageDataParameter\x12\x39\n\x13infogain_loss_param\x18\x10 \x01(\x0b\x32\x1c.caffe.InfogainLossParameter\x12\x39\n\x13inner_product_param\x18\x11 \x01(\x0b\x32\x1c.caffe.InnerProductParameter\x12&\n\tlrn_param\x18\x12 \x01(\x0b\x32\x13.caffe.LRNParameter\x12\x35\n\x11memory_data_param\x18\x16 \x01(\x0b\x32\x1a.caffe.MemoryDataParameter\x12&\n\tmvn_param\x18\" \x01(\x0b\x32\x13.caffe.MVNParameter\x12.\n\rpooling_param\x18\x13 \x01(\x0b\x32\x17.caffe.PoolingParameter\x12*\n\x0bpower_param\x18\x15 \x01(\x0b\x32\x15.caffe.PowerParameter\x12(\n\nrelu_param\x18\x1e \x01(\x0b\x32\x14.caffe.ReLUParameter\x12.\n\rsigmoid_param\x18& \x01(\x0b\x32\x17.caffe.SigmoidParameter\x12.\n\rsoftmax_param\x18\' \x01(\x0b\x32\x17.caffe.SoftmaxParameter\x12*\n\x0bslice_param\x18\x1f \x01(\x0b\x32\x15.caffe.SliceParameter\x12(\n\ntanh_param\x18% \x01(\x0b\x32\x14.caffe.TanHParameter\x12\x32\n\x0fthreshold_param\x18\x19 \x01(\x0b\x32\x19.caffe.ThresholdParameter\x12\x35\n\x11window_data_param\x18\x14 \x01(\x0b\x32\x1a.caffe.WindowDataParameter\x12\x37\n\x0ftransform_param\x18$ \x01(\x0b\x32\x1e.caffe.TransformationParameter\x12(\n\nloss_param\x18* \x01(\x0b\x32\x14.caffe.LossParameter\x12&\n\x05layer\x18\x01 \x01(\x0b\x32\x17.caffe.V0LayerParameter\"\xd8\x04\n\tLayerType\x12\x08\n\x04NONE\x10\x00\x12\n\n\x06\x41\x42SVAL\x10#\x12\x0c\n\x08\x41\x43\x43URACY\x10\x01\x12\n\n\x06\x41RGMAX\x10\x1e\x12\x08\n\x04\x42NLL\x10\x02\x12\n\n\x06\x43ONCAT\x10\x03\x12\x14\n\x10\x43ONTRASTIVE_LOSS\x10%\x12\x0f\n\x0b\x43ONVOLUTION\x10\x04\x12\x08\n\x04\x44\x41TA\x10\x05\x12\x11\n\rDECONVOLUTION\x10\'\x12\x0b\n\x07\x44ROPOUT\x10\x06\x12\x0e\n\nDUMMY_DATA\x10 \x12\x12\n\x0e\x45UCLIDEAN_LOSS\x10\x07\x12\x0b\n\x07\x45LTWISE\x10\x19\x12\x07\n\x03\x45XP\x10&\x12\x0b\n\x07\x46LATTEN\x10\x08\x12\r\n\tHDF5_DATA\x10\t\x12\x0f\n\x0bHDF5_OUTPUT\x10\n\x12\x0e\n\nHINGE_LOSS\x10\x1c\x12\n\n\x06IM2COL\x10\x0b\x12\x0e\n\nIMAGE_DATA\x10\x0c\x12\x11\n\rINFOGAIN_LOSS\x10\r\x12\x11\n\rINNER_PRODUCT\x10\x0e\x12\x07\n\x03LRN\x10\x0f\x12\x0f\n\x0bMEMORY_DATA\x10\x1d\x12\x1d\n\x19MULTINOMIAL_LOGISTIC_LOSS\x10\x10\x12\x07\n\x03MVN\x10\"\x12\x0b\n\x07POOLING\x10\x11\x12\t\n\x05POWER\x10\x1a\x12\x08\n\x04RELU\x10\x12\x12\x0b\n\x07SIGMOID\x10\x13\x12\x1e\n\x1aSIGMOID_CROSS_ENTROPY_LOSS\x10\x1b\x12\x0b\n\x07SILENCE\x10$\x12\x0b\n\x07SOFTMAX\x10\x14\x12\x10\n\x0cSOFTMAX_LOSS\x10\x15\x12\t\n\x05SPLIT\x10\x16\x12\t\n\x05SLICE\x10!\x12\x08\n\x04TANH\x10\x17\x12\x0f\n\x0bWINDOW_DATA\x10\x18\x12\r\n\tTHRESHOLD\x10\x1f\"*\n\x0c\x44imCheckMode\x12\n\n\x06STRICT\x10\x00\x12\x0e\n\nPERMISSIVE\x10\x01\"\xfd\x07\n\x10V0LayerParameter\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\t\x12\x12\n\nnum_output\x18\x03 \x01(\r\x12\x16\n\x08\x62iasterm\x18\x04 \x01(\x08:\x04true\x12-\n\rweight_filler\x18\x05 \x01(\x0b\x32\x16.caffe.FillerParameter\x12+\n\x0b\x62ias_filler\x18\x06 \x01(\x0b\x32\x16.caffe.FillerParameter\x12\x0e\n\x03pad\x18\x07 \x01(\r:\x01\x30\x12\x12\n\nkernelsize\x18\x08 \x01(\r\x12\x10\n\x05group\x18\t \x01(\r:\x01\x31\x12\x11\n\x06stride\x18\n \x01(\r:\x01\x31\x12\x35\n\x04pool\x18\x0b \x01(\x0e\x32\".caffe.V0LayerParameter.PoolMethod:\x03MAX\x12\x1a\n\rdropout_ratio\x18\x0c \x01(\x02:\x03\x30.5\x12\x15\n\nlocal_size\x18\r \x01(\r:\x01\x35\x12\x10\n\x05\x61lpha\x18\x0e \x01(\x02:\x01\x31\x12\x12\n\x04\x62\x65ta\x18\x0f \x01(\x02:\x04\x30.75\x12\x0c\n\x01k\x18\x16 \x01(\x02:\x01\x31\x12\x0e\n\x06source\x18\x10 \x01(\t\x12\x10\n\x05scale\x18\x11 \x01(\x02:\x01\x31\x12\x10\n\x08meanfile\x18\x12 \x01(\t\x12\x11\n\tbatchsize\x18\x13 \x01(\r\x12\x13\n\x08\x63ropsize\x18\x14 \x01(\r:\x01\x30\x12\x15\n\x06mirror\x18\x15 \x01(\x08:\x05\x66\x61lse\x12\x1f\n\x05\x62lobs\x18\x32 \x03(\x0b\x32\x10.caffe.BlobProto\x12\x10\n\x08\x62lobs_lr\x18\x33 \x03(\x02\x12\x14\n\x0cweight_decay\x18\x34 \x03(\x02\x12\x14\n\trand_skip\x18\x35 \x01(\r:\x01\x30\x12\x1d\n\x10\x64\x65t_fg_threshold\x18\x36 \x01(\x02:\x03\x30.5\x12\x1d\n\x10\x64\x65t_bg_threshold\x18\x37 \x01(\x02:\x03\x30.5\x12\x1d\n\x0f\x64\x65t_fg_fraction\x18\x38 \x01(\x02:\x04\x30.25\x12\x1a\n\x0f\x64\x65t_context_pad\x18: \x01(\r:\x01\x30\x12\x1b\n\rdet_crop_mode\x18; \x01(\t:\x04warp\x12\x12\n\x07new_num\x18< \x01(\x05:\x01\x30\x12\x17\n\x0cnew_channels\x18= \x01(\x05:\x01\x30\x12\x15\n\nnew_height\x18> \x01(\x05:\x01\x30\x12\x14\n\tnew_width\x18? \x01(\x05:\x01\x30\x12\x1d\n\x0eshuffle_images\x18@ \x01(\x08:\x05\x66\x61lse\x12\x15\n\nconcat_dim\x18\x41 \x01(\r:\x01\x31\x12\x36\n\x11hdf5_output_param\x18\xe9\x07 \x01(\x0b\x32\x1a.caffe.HDF5OutputParameter\".\n\nPoolMethod\x12\x07\n\x03MAX\x10\x00\x12\x07\n\x03\x41VE\x10\x01\x12\x0e\n\nSTOCHASTIC\x10\x02\"W\n\x0ePReLUParameter\x12&\n\x06\x66iller\x18\x01 \x01(\x0b\x32\x16.caffe.FillerParameter\x12\x1d\n\x0e\x63hannel_shared\x18\x02 \x01(\x08:\x05\x66\x61lse\")\n\x15SmoothL1LossParameter\x12\x10\n\x05sigma\x18\x01 \x01(\x02:\x01\x31\"H\n\x0cMPIParameter\x12\x0f\n\x04root\x18\x01 \x01(\r:\x01\x30\x12\x12\n\x07\x63omm_id\x18\x02 \x01(\x04:\x01\x30\x12\x13\n\x08group_id\x18\x03 \x01(\x04:\x01\x30\"!\n\x10PermuteParameter\x12\r\n\x05order\x18\x01 \x03(\r\"\x93\x01\n\x12NormalizeParameter\x12\x1c\n\x0e\x61\x63ross_spatial\x18\x01 \x01(\x08:\x04true\x12,\n\x0cscale_filler\x18\x02 \x01(\x0b\x32\x16.caffe.FillerParameter\x12\x1c\n\x0e\x63hannel_shared\x18\x03 \x01(\x08:\x04true\x12\x13\n\x03\x65ps\x18\x04 \x01(\x02:\x06\x31\x65-010\"_\n\x11ParallelParameter\x12\x16\n\x07shuffle\x18\x01 \x01(\x08:\x05\x66\x61lse\x12\x18\n\tnode_step\x18\x02 \x01(\x08:\x05\x66\x61lse\x12\x18\n\tpartition\x18\x03 \x01(\x08:\x05\x66\x61lse\"R\n\x0fResizeParameter\x12\x1f\n\x05shape\x18\x01 \x01(\x0b\x32\x10.caffe.BlobShape\x12\x0e\n\x02\x66x\x18\x02 \x01(\x02:\x02-1\x12\x0e\n\x02\x66y\x18\x03 \x01(\x02:\x02-1\"\'\n\x13\x45xpandDimsParameter\x12\x10\n\x04\x61xis\x18\x01 \x01(\x05:\x02-1\"\x90\x02\n\x11ProposalParameter\x12\x0e\n\x06stride\x18\x01 \x03(\x05\x12\r\n\x05ratio\x18\x02 \x03(\x02\x12\r\n\x05scale\x18\x03 \x03(\x02\x12\x1b\n\rpre_nms_top_n\x18\x04 \x01(\r:\x04\x36\x30\x30\x30\x12\x1b\n\x0epost_nms_top_n\x18\x05 \x01(\r:\x03\x33\x30\x30\x12\x17\n\nnms_thresh\x18\x06 \x01(\x02:\x03\x30.7\x12\x14\n\x08min_size\x18\x07 \x01(\r:\x02\x31\x36\x12\x14\n\tmin_level\x18\x08 \x01(\x05:\x01\x32\x12\x14\n\tmax_level\x18\t \x01(\x05:\x01\x35\x12\x1c\n\x0f\x63\x61nonical_scale\x18\n \x01(\x05:\x03\x32\x32\x34\x12\x1a\n\x0f\x63\x61nonical_level\x18\x0b \x01(\x05:\x01\x34\"\xa6\x01\n\x14\x42\x61tchRenormParameter\x12\x18\n\x10use_global_stats\x18\x01 \x01(\x08\x12$\n\x17moving_average_fraction\x18\x02 \x01(\x02:\x03\x30.9\x12\x12\n\x03\x65ps\x18\x03 \x01(\x02:\x05\x30.001\x12\x10\n\x05r_max\x18\x04 \x01(\x02:\x01\x33\x12\x10\n\x05\x64_max\x18\x05 \x01(\x02:\x01\x35\x12\x16\n\x07t_delta\x18\x06 \x01(\x02:\x05\x30.001\"?\n\x14\x44\x65nseConcatParameter\x12\x0f\n\x04\x61xis\x18\x01 \x01(\x05:\x01\x31\x12\x16\n\x0bgrowth_rate\x18\x02 \x01(\x05:\x01\x30\"c\n\x12\x46ocalLossParameter\x12\x12\n\x05\x61lpha\x18\x01 \x01(\x02:\x03\x30.5\x12\x10\n\x05gamma\x18\x02 \x01(\x02:\x01\x30\x12\x13\n\x03\x65ps\x18\x03 \x01(\x02:\x06\x31\x65-010\x12\x12\n\x06neg_id\x18\x04 \x01(\x05:\x02-1\"\"\n\x0fGatherParameter\x12\x0f\n\x04\x61xis\x18\x01 \x01(\x05:\x01\x30\"{\n\x12GroupNormParameter\x12\x18\n\x10use_global_stats\x18\x01 \x01(\x08\x12$\n\x17moving_average_fraction\x18\x02 \x01(\x02:\x03\x30.9\x12\x12\n\x03\x65ps\x18\x03 \x01(\x02:\x05\x30.001\x12\x11\n\x05group\x18\x05 \x01(\r:\x02\x33\x32*\x1c\n\x05Phase\x12\t\n\x05TRAIN\x10\x00\x12\x08\n\x04TEST\x10\x01')
) )
_sym_db.RegisterFileDescriptor(DESCRIPTOR) _sym_db.RegisterFileDescriptor(DESCRIPTOR)
...@@ -40,8 +40,8 @@ _PHASE = _descriptor.EnumDescriptor( ...@@ -40,8 +40,8 @@ _PHASE = _descriptor.EnumDescriptor(
], ],
containing_type=None, containing_type=None,
options=None, options=None,
serialized_start=17463, serialized_start=17642,
serialized_end=17491, serialized_end=17670,
) )
_sym_db.RegisterEnumDescriptor(_PHASE) _sym_db.RegisterEnumDescriptor(_PHASE)
...@@ -209,8 +209,8 @@ _LOSSPARAMETER_NORMALIZATIONMODE = _descriptor.EnumDescriptor( ...@@ -209,8 +209,8 @@ _LOSSPARAMETER_NORMALIZATIONMODE = _descriptor.EnumDescriptor(
], ],
containing_type=None, containing_type=None,
options=None, options=None,
serialized_start=6525, serialized_start=6579,
serialized_end=6601, serialized_end=6655,
) )
_sym_db.RegisterEnumDescriptor(_LOSSPARAMETER_NORMALIZATIONMODE) _sym_db.RegisterEnumDescriptor(_LOSSPARAMETER_NORMALIZATIONMODE)
...@@ -235,8 +235,8 @@ _CONVOLUTIONPARAMETER_ENGINE = _descriptor.EnumDescriptor( ...@@ -235,8 +235,8 @@ _CONVOLUTIONPARAMETER_ENGINE = _descriptor.EnumDescriptor(
], ],
containing_type=None, containing_type=None,
options=None, options=None,
serialized_start=7564, serialized_start=7618,
serialized_end=7607, serialized_end=7661,
) )
_sym_db.RegisterEnumDescriptor(_CONVOLUTIONPARAMETER_ENGINE) _sym_db.RegisterEnumDescriptor(_CONVOLUTIONPARAMETER_ENGINE)
...@@ -257,8 +257,8 @@ _DATAPARAMETER_DB = _descriptor.EnumDescriptor( ...@@ -257,8 +257,8 @@ _DATAPARAMETER_DB = _descriptor.EnumDescriptor(
], ],
containing_type=None, containing_type=None,
options=None, options=None,
serialized_start=7925, serialized_start=7979,
serialized_end=7952, serialized_end=8006,
) )
_sym_db.RegisterEnumDescriptor(_DATAPARAMETER_DB) _sym_db.RegisterEnumDescriptor(_DATAPARAMETER_DB)
...@@ -283,8 +283,8 @@ _ELTWISEPARAMETER_ELTWISEOP = _descriptor.EnumDescriptor( ...@@ -283,8 +283,8 @@ _ELTWISEPARAMETER_ELTWISEOP = _descriptor.EnumDescriptor(
], ],
containing_type=None, containing_type=None,
options=None, options=None,
serialized_start=8319, serialized_start=8373,
serialized_end=8358, serialized_end=8412,
) )
_sym_db.RegisterEnumDescriptor(_ELTWISEPARAMETER_ELTWISEOP) _sym_db.RegisterEnumDescriptor(_ELTWISEPARAMETER_ELTWISEOP)
...@@ -305,8 +305,8 @@ _HINGELOSSPARAMETER_NORM = _descriptor.EnumDescriptor( ...@@ -305,8 +305,8 @@ _HINGELOSSPARAMETER_NORM = _descriptor.EnumDescriptor(
], ],
containing_type=None, containing_type=None,
options=None, options=None,
serialized_start=8893, serialized_start=8947,
serialized_end=8915, serialized_end=8969,
) )
_sym_db.RegisterEnumDescriptor(_HINGELOSSPARAMETER_NORM) _sym_db.RegisterEnumDescriptor(_HINGELOSSPARAMETER_NORM)
...@@ -327,8 +327,8 @@ _LRNPARAMETER_NORMREGION = _descriptor.EnumDescriptor( ...@@ -327,8 +327,8 @@ _LRNPARAMETER_NORMREGION = _descriptor.EnumDescriptor(
], ],
containing_type=None, containing_type=None,
options=None, options=None,
serialized_start=9782, serialized_start=9836,
serialized_end=9835, serialized_end=9889,
) )
_sym_db.RegisterEnumDescriptor(_LRNPARAMETER_NORMREGION) _sym_db.RegisterEnumDescriptor(_LRNPARAMETER_NORMREGION)
...@@ -353,8 +353,8 @@ _LRNPARAMETER_ENGINE = _descriptor.EnumDescriptor( ...@@ -353,8 +353,8 @@ _LRNPARAMETER_ENGINE = _descriptor.EnumDescriptor(
], ],
containing_type=None, containing_type=None,
options=None, options=None,
serialized_start=7564, serialized_start=7618,
serialized_end=7607, serialized_end=7661,
) )
_sym_db.RegisterEnumDescriptor(_LRNPARAMETER_ENGINE) _sym_db.RegisterEnumDescriptor(_LRNPARAMETER_ENGINE)
...@@ -375,8 +375,8 @@ _MEMORYDATAPARAMETER_DATATYPE = _descriptor.EnumDescriptor( ...@@ -375,8 +375,8 @@ _MEMORYDATAPARAMETER_DATATYPE = _descriptor.EnumDescriptor(
], ],
containing_type=None, containing_type=None,
options=None, options=None,
serialized_start=10036, serialized_start=10090,
serialized_end=10072, serialized_end=10126,
) )
_sym_db.RegisterEnumDescriptor(_MEMORYDATAPARAMETER_DATATYPE) _sym_db.RegisterEnumDescriptor(_MEMORYDATAPARAMETER_DATATYPE)
...@@ -401,8 +401,8 @@ _POOLINGPARAMETER_POOLMETHOD = _descriptor.EnumDescriptor( ...@@ -401,8 +401,8 @@ _POOLINGPARAMETER_POOLMETHOD = _descriptor.EnumDescriptor(
], ],
containing_type=None, containing_type=None,
options=None, options=None,
serialized_start=10560, serialized_start=10614,
serialized_end=10606, serialized_end=10660,
) )
_sym_db.RegisterEnumDescriptor(_POOLINGPARAMETER_POOLMETHOD) _sym_db.RegisterEnumDescriptor(_POOLINGPARAMETER_POOLMETHOD)
...@@ -427,8 +427,8 @@ _POOLINGPARAMETER_ENGINE = _descriptor.EnumDescriptor( ...@@ -427,8 +427,8 @@ _POOLINGPARAMETER_ENGINE = _descriptor.EnumDescriptor(
], ],
containing_type=None, containing_type=None,
options=None, options=None,
serialized_start=7564, serialized_start=7618,
serialized_end=7607, serialized_end=7661,
) )
_sym_db.RegisterEnumDescriptor(_POOLINGPARAMETER_ENGINE) _sym_db.RegisterEnumDescriptor(_POOLINGPARAMETER_ENGINE)
...@@ -457,8 +457,8 @@ _REDUCTIONPARAMETER_REDUCTIONOP = _descriptor.EnumDescriptor( ...@@ -457,8 +457,8 @@ _REDUCTIONPARAMETER_REDUCTIONOP = _descriptor.EnumDescriptor(
], ],
containing_type=None, containing_type=None,
options=None, options=None,
serialized_start=11042, serialized_start=11096,
serialized_end=11095, serialized_end=11149,
) )
_sym_db.RegisterEnumDescriptor(_REDUCTIONPARAMETER_REDUCTIONOP) _sym_db.RegisterEnumDescriptor(_REDUCTIONPARAMETER_REDUCTIONOP)
...@@ -483,8 +483,8 @@ _RELUPARAMETER_ENGINE = _descriptor.EnumDescriptor( ...@@ -483,8 +483,8 @@ _RELUPARAMETER_ENGINE = _descriptor.EnumDescriptor(
], ],
containing_type=None, containing_type=None,
options=None, options=None,
serialized_start=7564, serialized_start=7618,
serialized_end=7607, serialized_end=7661,
) )
_sym_db.RegisterEnumDescriptor(_RELUPARAMETER_ENGINE) _sym_db.RegisterEnumDescriptor(_RELUPARAMETER_ENGINE)
...@@ -509,8 +509,8 @@ _SIGMOIDPARAMETER_ENGINE = _descriptor.EnumDescriptor( ...@@ -509,8 +509,8 @@ _SIGMOIDPARAMETER_ENGINE = _descriptor.EnumDescriptor(
], ],
containing_type=None, containing_type=None,
options=None, options=None,
serialized_start=7564, serialized_start=7618,
serialized_end=7607, serialized_end=7661,
) )
_sym_db.RegisterEnumDescriptor(_SIGMOIDPARAMETER_ENGINE) _sym_db.RegisterEnumDescriptor(_SIGMOIDPARAMETER_ENGINE)
...@@ -535,8 +535,8 @@ _SOFTMAXPARAMETER_ENGINE = _descriptor.EnumDescriptor( ...@@ -535,8 +535,8 @@ _SOFTMAXPARAMETER_ENGINE = _descriptor.EnumDescriptor(
], ],
containing_type=None, containing_type=None,
options=None, options=None,
serialized_start=7564, serialized_start=7618,
serialized_end=7607, serialized_end=7661,
) )
_sym_db.RegisterEnumDescriptor(_SOFTMAXPARAMETER_ENGINE) _sym_db.RegisterEnumDescriptor(_SOFTMAXPARAMETER_ENGINE)
...@@ -561,8 +561,8 @@ _TANHPARAMETER_ENGINE = _descriptor.EnumDescriptor( ...@@ -561,8 +561,8 @@ _TANHPARAMETER_ENGINE = _descriptor.EnumDescriptor(
], ],
containing_type=None, containing_type=None,
options=None, options=None,
serialized_start=7564, serialized_start=7618,
serialized_end=7607, serialized_end=7661,
) )
_sym_db.RegisterEnumDescriptor(_TANHPARAMETER_ENGINE) _sym_db.RegisterEnumDescriptor(_TANHPARAMETER_ENGINE)
...@@ -587,8 +587,8 @@ _SPPPARAMETER_POOLMETHOD = _descriptor.EnumDescriptor( ...@@ -587,8 +587,8 @@ _SPPPARAMETER_POOLMETHOD = _descriptor.EnumDescriptor(
], ],
containing_type=None, containing_type=None,
options=None, options=None,
serialized_start=10560, serialized_start=10614,
serialized_end=10606, serialized_end=10660,
) )
_sym_db.RegisterEnumDescriptor(_SPPPARAMETER_POOLMETHOD) _sym_db.RegisterEnumDescriptor(_SPPPARAMETER_POOLMETHOD)
...@@ -613,8 +613,8 @@ _SPPPARAMETER_ENGINE = _descriptor.EnumDescriptor( ...@@ -613,8 +613,8 @@ _SPPPARAMETER_ENGINE = _descriptor.EnumDescriptor(
], ],
containing_type=None, containing_type=None,
options=None, options=None,
serialized_start=7564, serialized_start=7618,
serialized_end=7607, serialized_end=7661,
) )
_sym_db.RegisterEnumDescriptor(_SPPPARAMETER_ENGINE) _sym_db.RegisterEnumDescriptor(_SPPPARAMETER_ENGINE)
...@@ -787,8 +787,8 @@ _V1LAYERPARAMETER_LAYERTYPE = _descriptor.EnumDescriptor( ...@@ -787,8 +787,8 @@ _V1LAYERPARAMETER_LAYERTYPE = _descriptor.EnumDescriptor(
], ],
containing_type=None, containing_type=None,
options=None, options=None,
serialized_start=14534, serialized_start=14588,
serialized_end=15134, serialized_end=15188,
) )
_sym_db.RegisterEnumDescriptor(_V1LAYERPARAMETER_LAYERTYPE) _sym_db.RegisterEnumDescriptor(_V1LAYERPARAMETER_LAYERTYPE)
...@@ -835,8 +835,8 @@ _V0LAYERPARAMETER_POOLMETHOD = _descriptor.EnumDescriptor( ...@@ -835,8 +835,8 @@ _V0LAYERPARAMETER_POOLMETHOD = _descriptor.EnumDescriptor(
], ],
containing_type=None, containing_type=None,
options=None, options=None,
serialized_start=10560, serialized_start=10614,
serialized_end=10606, serialized_end=10660,
) )
_sym_db.RegisterEnumDescriptor(_V0LAYERPARAMETER_POOLMETHOD) _sym_db.RegisterEnumDescriptor(_V0LAYERPARAMETER_POOLMETHOD)
...@@ -2261,6 +2261,13 @@ _LAYERPARAMETER = _descriptor.Descriptor( ...@@ -2261,6 +2261,13 @@ _LAYERPARAMETER = _descriptor.Descriptor(
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
options=None), options=None),
_descriptor.FieldDescriptor(
name='group_norm_param', full_name='caffe.LayerParameter.group_norm_param', index=71,
number=166, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
], ],
extensions=[ extensions=[
], ],
...@@ -2273,7 +2280,7 @@ _LAYERPARAMETER = _descriptor.Descriptor( ...@@ -2273,7 +2280,7 @@ _LAYERPARAMETER = _descriptor.Descriptor(
oneofs=[ oneofs=[
], ],
serialized_start=2834, serialized_start=2834,
serialized_end=6055, serialized_end=6109,
) )
...@@ -2372,8 +2379,8 @@ _TRANSFORMATIONPARAMETER = _descriptor.Descriptor( ...@@ -2372,8 +2379,8 @@ _TRANSFORMATIONPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=6058, serialized_start=6112,
serialized_end=6353, serialized_end=6407,
) )
...@@ -2402,8 +2409,8 @@ _LOSSPARAMETER_EXPANDDIMSPARAMETER = _descriptor.Descriptor( ...@@ -2402,8 +2409,8 @@ _LOSSPARAMETER_EXPANDDIMSPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=6484, serialized_start=6538,
serialized_end=6523, serialized_end=6577,
) )
_LOSSPARAMETER = _descriptor.Descriptor( _LOSSPARAMETER = _descriptor.Descriptor(
...@@ -2446,8 +2453,8 @@ _LOSSPARAMETER = _descriptor.Descriptor( ...@@ -2446,8 +2453,8 @@ _LOSSPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=6356, serialized_start=6410,
serialized_end=6601, serialized_end=6655,
) )
...@@ -2490,8 +2497,8 @@ _ACCURACYPARAMETER = _descriptor.Descriptor( ...@@ -2490,8 +2497,8 @@ _ACCURACYPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=6603, serialized_start=6657,
serialized_end=6679, serialized_end=6733,
) )
...@@ -2534,8 +2541,8 @@ _ARGMAXPARAMETER = _descriptor.Descriptor( ...@@ -2534,8 +2541,8 @@ _ARGMAXPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=6681, serialized_start=6735,
serialized_end=6758, serialized_end=6812,
) )
...@@ -2571,8 +2578,8 @@ _CONCATPARAMETER = _descriptor.Descriptor( ...@@ -2571,8 +2578,8 @@ _CONCATPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=6760, serialized_start=6814,
serialized_end=6817, serialized_end=6871,
) )
...@@ -2615,8 +2622,8 @@ _BATCHNORMPARAMETER = _descriptor.Descriptor( ...@@ -2615,8 +2622,8 @@ _BATCHNORMPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=6819, serialized_start=6873,
serialized_end=6923, serialized_end=6977,
) )
...@@ -2659,8 +2666,8 @@ _BIASPARAMETER = _descriptor.Descriptor( ...@@ -2659,8 +2666,8 @@ _BIASPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=6925, serialized_start=6979,
serialized_end=7018, serialized_end=7072,
) )
...@@ -2696,8 +2703,8 @@ _CONTRASTIVELOSSPARAMETER = _descriptor.Descriptor( ...@@ -2696,8 +2703,8 @@ _CONTRASTIVELOSSPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=7020, serialized_start=7074,
serialized_end=7096, serialized_end=7150,
) )
...@@ -2846,8 +2853,8 @@ _CONVOLUTIONPARAMETER = _descriptor.Descriptor( ...@@ -2846,8 +2853,8 @@ _CONVOLUTIONPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=7099, serialized_start=7153,
serialized_end=7607, serialized_end=7661,
) )
...@@ -2883,8 +2890,8 @@ _CROPPARAMETER = _descriptor.Descriptor( ...@@ -2883,8 +2890,8 @@ _CROPPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=7609, serialized_start=7663,
serialized_end=7657, serialized_end=7711,
) )
...@@ -2977,8 +2984,8 @@ _DATAPARAMETER = _descriptor.Descriptor( ...@@ -2977,8 +2984,8 @@ _DATAPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=7660, serialized_start=7714,
serialized_end=7952, serialized_end=8006,
) )
...@@ -3014,8 +3021,8 @@ _DROPOUTPARAMETER = _descriptor.Descriptor( ...@@ -3014,8 +3021,8 @@ _DROPOUTPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=7954, serialized_start=8008,
serialized_end=8027, serialized_end=8081,
) )
...@@ -3079,8 +3086,8 @@ _DUMMYDATAPARAMETER = _descriptor.Descriptor( ...@@ -3079,8 +3086,8 @@ _DUMMYDATAPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=8030, serialized_start=8084,
serialized_end=8190, serialized_end=8244,
) )
...@@ -3124,8 +3131,8 @@ _ELTWISEPARAMETER = _descriptor.Descriptor( ...@@ -3124,8 +3131,8 @@ _ELTWISEPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=8193, serialized_start=8247,
serialized_end=8358, serialized_end=8412,
) )
...@@ -3154,8 +3161,8 @@ _ELUPARAMETER = _descriptor.Descriptor( ...@@ -3154,8 +3161,8 @@ _ELUPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=8360, serialized_start=8414,
serialized_end=8392, serialized_end=8446,
) )
...@@ -3212,8 +3219,8 @@ _EMBEDPARAMETER = _descriptor.Descriptor( ...@@ -3212,8 +3219,8 @@ _EMBEDPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=8395, serialized_start=8449,
serialized_end=8567, serialized_end=8621,
) )
...@@ -3256,8 +3263,8 @@ _EXPPARAMETER = _descriptor.Descriptor( ...@@ -3256,8 +3263,8 @@ _EXPPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=8569, serialized_start=8623,
serialized_end=8637, serialized_end=8691,
) )
...@@ -3293,8 +3300,8 @@ _FLATTENPARAMETER = _descriptor.Descriptor( ...@@ -3293,8 +3300,8 @@ _FLATTENPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=8639, serialized_start=8693,
serialized_end=8696, serialized_end=8750,
) )
...@@ -3337,8 +3344,8 @@ _HDF5DATAPARAMETER = _descriptor.Descriptor( ...@@ -3337,8 +3344,8 @@ _HDF5DATAPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=8698, serialized_start=8752,
serialized_end=8777, serialized_end=8831,
) )
...@@ -3367,8 +3374,8 @@ _HDF5OUTPUTPARAMETER = _descriptor.Descriptor( ...@@ -3367,8 +3374,8 @@ _HDF5OUTPUTPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=8779, serialized_start=8833,
serialized_end=8819, serialized_end=8873,
) )
...@@ -3398,8 +3405,8 @@ _HINGELOSSPARAMETER = _descriptor.Descriptor( ...@@ -3398,8 +3405,8 @@ _HINGELOSSPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=8821, serialized_start=8875,
serialized_end=8915, serialized_end=8969,
) )
...@@ -3505,8 +3512,8 @@ _IMAGEDATAPARAMETER = _descriptor.Descriptor( ...@@ -3505,8 +3512,8 @@ _IMAGEDATAPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=8918, serialized_start=8972,
serialized_end=9197, serialized_end=9251,
) )
...@@ -3535,8 +3542,8 @@ _INFOGAINLOSSPARAMETER = _descriptor.Descriptor( ...@@ -3535,8 +3542,8 @@ _INFOGAINLOSSPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=9199, serialized_start=9253,
serialized_end=9238, serialized_end=9292,
) )
...@@ -3600,8 +3607,8 @@ _INNERPRODUCTPARAMETER = _descriptor.Descriptor( ...@@ -3600,8 +3607,8 @@ _INNERPRODUCTPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=9241, serialized_start=9295,
serialized_end=9444, serialized_end=9498,
) )
...@@ -3630,8 +3637,8 @@ _INPUTPARAMETER = _descriptor.Descriptor( ...@@ -3630,8 +3637,8 @@ _INPUTPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=9446, serialized_start=9500,
serialized_end=9495, serialized_end=9549,
) )
...@@ -3674,8 +3681,8 @@ _LOGPARAMETER = _descriptor.Descriptor( ...@@ -3674,8 +3681,8 @@ _LOGPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=9497, serialized_start=9551,
serialized_end=9565, serialized_end=9619,
) )
...@@ -3741,8 +3748,8 @@ _LRNPARAMETER = _descriptor.Descriptor( ...@@ -3741,8 +3748,8 @@ _LRNPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=9568, serialized_start=9622,
serialized_end=9880, serialized_end=9934,
) )
...@@ -3800,8 +3807,8 @@ _MEMORYDATAPARAMETER = _descriptor.Descriptor( ...@@ -3800,8 +3807,8 @@ _MEMORYDATAPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=9883, serialized_start=9937,
serialized_end=10072, serialized_end=10126,
) )
...@@ -3844,8 +3851,8 @@ _MVNPARAMETER = _descriptor.Descriptor( ...@@ -3844,8 +3851,8 @@ _MVNPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=10074, serialized_start=10128,
serialized_end=10175, serialized_end=10229,
) )
...@@ -3874,8 +3881,8 @@ _PARAMETERPARAMETER = _descriptor.Descriptor( ...@@ -3874,8 +3881,8 @@ _PARAMETERPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=10177, serialized_start=10231,
serialized_end=10230, serialized_end=10284,
) )
...@@ -3983,8 +3990,8 @@ _POOLINGPARAMETER = _descriptor.Descriptor( ...@@ -3983,8 +3990,8 @@ _POOLINGPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=10233, serialized_start=10287,
serialized_end=10651, serialized_end=10705,
) )
...@@ -4027,8 +4034,8 @@ _ROIPOOLINGPARAMETER = _descriptor.Descriptor( ...@@ -4027,8 +4034,8 @@ _ROIPOOLINGPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=10653, serialized_start=10707,
serialized_end=10742, serialized_end=10796,
) )
...@@ -4071,8 +4078,8 @@ _POWERPARAMETER = _descriptor.Descriptor( ...@@ -4071,8 +4078,8 @@ _POWERPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=10744, serialized_start=10798,
serialized_end=10814, serialized_end=10868,
) )
...@@ -4122,8 +4129,8 @@ _PYTHONPARAMETER = _descriptor.Descriptor( ...@@ -4122,8 +4129,8 @@ _PYTHONPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=10816, serialized_start=10870,
serialized_end=10919, serialized_end=10973,
) )
...@@ -4167,8 +4174,8 @@ _REDUCTIONPARAMETER = _descriptor.Descriptor( ...@@ -4167,8 +4174,8 @@ _REDUCTIONPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=10922, serialized_start=10976,
serialized_end=11095, serialized_end=11149,
) )
...@@ -4205,8 +4212,8 @@ _RELUPARAMETER = _descriptor.Descriptor( ...@@ -4205,8 +4212,8 @@ _RELUPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=11098, serialized_start=11152,
serialized_end=11239, serialized_end=11293,
) )
...@@ -4249,8 +4256,8 @@ _RESHAPEPARAMETER = _descriptor.Descriptor( ...@@ -4249,8 +4256,8 @@ _RESHAPEPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=11241, serialized_start=11295,
serialized_end=11331, serialized_end=11385,
) )
...@@ -4307,8 +4314,8 @@ _SCALEPARAMETER = _descriptor.Descriptor( ...@@ -4307,8 +4314,8 @@ _SCALEPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=11334, serialized_start=11388,
serialized_end=11499, serialized_end=11553,
) )
...@@ -4338,8 +4345,8 @@ _SIGMOIDPARAMETER = _descriptor.Descriptor( ...@@ -4338,8 +4345,8 @@ _SIGMOIDPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=11501, serialized_start=11555,
serialized_end=11621, serialized_end=11675,
) )
...@@ -4382,8 +4389,8 @@ _SLICEPARAMETER = _descriptor.Descriptor( ...@@ -4382,8 +4389,8 @@ _SLICEPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=11623, serialized_start=11677,
serialized_end=11699, serialized_end=11753,
) )
...@@ -4420,8 +4427,8 @@ _SOFTMAXPARAMETER = _descriptor.Descriptor( ...@@ -4420,8 +4427,8 @@ _SOFTMAXPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=11702, serialized_start=11756,
serialized_end=11839, serialized_end=11893,
) )
...@@ -4451,8 +4458,8 @@ _TANHPARAMETER = _descriptor.Descriptor( ...@@ -4451,8 +4458,8 @@ _TANHPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=11841, serialized_start=11895,
serialized_end=11955, serialized_end=12009,
) )
...@@ -4495,8 +4502,8 @@ _TILEPARAMETER = _descriptor.Descriptor( ...@@ -4495,8 +4502,8 @@ _TILEPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=11957, serialized_start=12011,
serialized_end=12041, serialized_end=12095,
) )
...@@ -4525,8 +4532,8 @@ _THRESHOLDPARAMETER = _descriptor.Descriptor( ...@@ -4525,8 +4532,8 @@ _THRESHOLDPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=12043, serialized_start=12097,
serialized_end=12085, serialized_end=12139,
) )
...@@ -4639,8 +4646,8 @@ _WINDOWDATAPARAMETER = _descriptor.Descriptor( ...@@ -4639,8 +4646,8 @@ _WINDOWDATAPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=12088, serialized_start=12142,
serialized_end=12409, serialized_end=12463,
) )
...@@ -4685,8 +4692,8 @@ _SPPPARAMETER = _descriptor.Descriptor( ...@@ -4685,8 +4692,8 @@ _SPPPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=12412, serialized_start=12466,
serialized_end=12647, serialized_end=12701,
) )
...@@ -5011,8 +5018,8 @@ _V1LAYERPARAMETER = _descriptor.Descriptor( ...@@ -5011,8 +5018,8 @@ _V1LAYERPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=12650, serialized_start=12704,
serialized_end=15178, serialized_end=15232,
) )
...@@ -5301,8 +5308,8 @@ _V0LAYERPARAMETER = _descriptor.Descriptor( ...@@ -5301,8 +5308,8 @@ _V0LAYERPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=15181, serialized_start=15235,
serialized_end=16202, serialized_end=16256,
) )
...@@ -5338,8 +5345,8 @@ _PRELUPARAMETER = _descriptor.Descriptor( ...@@ -5338,8 +5345,8 @@ _PRELUPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=16204, serialized_start=16258,
serialized_end=16291, serialized_end=16345,
) )
...@@ -5368,8 +5375,8 @@ _SMOOTHL1LOSSPARAMETER = _descriptor.Descriptor( ...@@ -5368,8 +5375,8 @@ _SMOOTHL1LOSSPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=16293, serialized_start=16347,
serialized_end=16334, serialized_end=16388,
) )
...@@ -5412,8 +5419,8 @@ _MPIPARAMETER = _descriptor.Descriptor( ...@@ -5412,8 +5419,8 @@ _MPIPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=16336, serialized_start=16390,
serialized_end=16408, serialized_end=16462,
) )
...@@ -5442,8 +5449,8 @@ _PERMUTEPARAMETER = _descriptor.Descriptor( ...@@ -5442,8 +5449,8 @@ _PERMUTEPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=16410, serialized_start=16464,
serialized_end=16443, serialized_end=16497,
) )
...@@ -5493,8 +5500,8 @@ _NORMALIZEPARAMETER = _descriptor.Descriptor( ...@@ -5493,8 +5500,8 @@ _NORMALIZEPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=16446, serialized_start=16500,
serialized_end=16593, serialized_end=16647,
) )
...@@ -5537,8 +5544,8 @@ _PARALLELPARAMETER = _descriptor.Descriptor( ...@@ -5537,8 +5544,8 @@ _PARALLELPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=16595, serialized_start=16649,
serialized_end=16690, serialized_end=16744,
) )
...@@ -5581,8 +5588,8 @@ _RESIZEPARAMETER = _descriptor.Descriptor( ...@@ -5581,8 +5588,8 @@ _RESIZEPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=16692, serialized_start=16746,
serialized_end=16774, serialized_end=16828,
) )
...@@ -5611,8 +5618,8 @@ _EXPANDDIMSPARAMETER = _descriptor.Descriptor( ...@@ -5611,8 +5618,8 @@ _EXPANDDIMSPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=6484, serialized_start=6538,
serialized_end=6523, serialized_end=6577,
) )
...@@ -5711,8 +5718,8 @@ _PROPOSALPARAMETER = _descriptor.Descriptor( ...@@ -5711,8 +5718,8 @@ _PROPOSALPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=16818, serialized_start=16872,
serialized_end=17090, serialized_end=17144,
) )
...@@ -5776,8 +5783,8 @@ _BATCHRENORMPARAMETER = _descriptor.Descriptor( ...@@ -5776,8 +5783,8 @@ _BATCHRENORMPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=17093, serialized_start=17147,
serialized_end=17259, serialized_end=17313,
) )
...@@ -5813,8 +5820,8 @@ _DENSECONCATPARAMETER = _descriptor.Descriptor( ...@@ -5813,8 +5820,8 @@ _DENSECONCATPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=17261, serialized_start=17315,
serialized_end=17324, serialized_end=17378,
) )
...@@ -5864,8 +5871,8 @@ _FOCALLOSSPARAMETER = _descriptor.Descriptor( ...@@ -5864,8 +5871,8 @@ _FOCALLOSSPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=17326, serialized_start=17380,
serialized_end=17425, serialized_end=17479,
) )
...@@ -5894,8 +5901,59 @@ _GATHERPARAMETER = _descriptor.Descriptor( ...@@ -5894,8 +5901,59 @@ _GATHERPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=17427, serialized_start=17481,
serialized_end=17461, serialized_end=17515,
)
_GROUPNORMPARAMETER = _descriptor.Descriptor(
name='GroupNormParameter',
full_name='caffe.GroupNormParameter',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='use_global_stats', full_name='caffe.GroupNormParameter.use_global_stats', index=0,
number=1, type=8, cpp_type=7, label=1,
has_default_value=False, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='moving_average_fraction', full_name='caffe.GroupNormParameter.moving_average_fraction', index=1,
number=2, type=2, cpp_type=6, label=1,
has_default_value=True, default_value=0.9,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='eps', full_name='caffe.GroupNormParameter.eps', index=2,
number=3, type=2, cpp_type=6, label=1,
has_default_value=True, default_value=0.001,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='group', full_name='caffe.GroupNormParameter.group', index=3,
number=5, type=13, cpp_type=3, label=1,
has_default_value=True, default_value=32,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
extension_ranges=[],
oneofs=[
],
serialized_start=17517,
serialized_end=17640,
) )
_BLOBPROTO.fields_by_name['shape'].message_type = _BLOBSHAPE _BLOBPROTO.fields_by_name['shape'].message_type = _BLOBSHAPE
...@@ -5986,6 +6044,7 @@ _LAYERPARAMETER.fields_by_name['batch_renorm_param'].message_type = _BATCHRENORM ...@@ -5986,6 +6044,7 @@ _LAYERPARAMETER.fields_by_name['batch_renorm_param'].message_type = _BATCHRENORM
_LAYERPARAMETER.fields_by_name['dense_concat_param'].message_type = _DENSECONCATPARAMETER _LAYERPARAMETER.fields_by_name['dense_concat_param'].message_type = _DENSECONCATPARAMETER
_LAYERPARAMETER.fields_by_name['focal_loss_param'].message_type = _FOCALLOSSPARAMETER _LAYERPARAMETER.fields_by_name['focal_loss_param'].message_type = _FOCALLOSSPARAMETER
_LAYERPARAMETER.fields_by_name['gather_param'].message_type = _GATHERPARAMETER _LAYERPARAMETER.fields_by_name['gather_param'].message_type = _GATHERPARAMETER
_LAYERPARAMETER.fields_by_name['group_norm_param'].message_type = _GROUPNORMPARAMETER
_LOSSPARAMETER_EXPANDDIMSPARAMETER.containing_type = _LOSSPARAMETER _LOSSPARAMETER_EXPANDDIMSPARAMETER.containing_type = _LOSSPARAMETER
_LOSSPARAMETER.fields_by_name['normalization'].enum_type = _LOSSPARAMETER_NORMALIZATIONMODE _LOSSPARAMETER.fields_by_name['normalization'].enum_type = _LOSSPARAMETER_NORMALIZATIONMODE
_LOSSPARAMETER_NORMALIZATIONMODE.containing_type = _LOSSPARAMETER _LOSSPARAMETER_NORMALIZATIONMODE.containing_type = _LOSSPARAMETER
...@@ -6156,6 +6215,7 @@ DESCRIPTOR.message_types_by_name['BatchRenormParameter'] = _BATCHRENORMPARAMETER ...@@ -6156,6 +6215,7 @@ DESCRIPTOR.message_types_by_name['BatchRenormParameter'] = _BATCHRENORMPARAMETER
DESCRIPTOR.message_types_by_name['DenseConcatParameter'] = _DENSECONCATPARAMETER DESCRIPTOR.message_types_by_name['DenseConcatParameter'] = _DENSECONCATPARAMETER
DESCRIPTOR.message_types_by_name['FocalLossParameter'] = _FOCALLOSSPARAMETER DESCRIPTOR.message_types_by_name['FocalLossParameter'] = _FOCALLOSSPARAMETER
DESCRIPTOR.message_types_by_name['GatherParameter'] = _GATHERPARAMETER DESCRIPTOR.message_types_by_name['GatherParameter'] = _GATHERPARAMETER
DESCRIPTOR.message_types_by_name['GroupNormParameter'] = _GROUPNORMPARAMETER
DESCRIPTOR.enum_types_by_name['Phase'] = _PHASE DESCRIPTOR.enum_types_by_name['Phase'] = _PHASE
BlobShape = _reflection.GeneratedProtocolMessageType('BlobShape', (_message.Message,), dict( BlobShape = _reflection.GeneratedProtocolMessageType('BlobShape', (_message.Message,), dict(
...@@ -6677,6 +6737,13 @@ GatherParameter = _reflection.GeneratedProtocolMessageType('GatherParameter', (_ ...@@ -6677,6 +6737,13 @@ GatherParameter = _reflection.GeneratedProtocolMessageType('GatherParameter', (_
)) ))
_sym_db.RegisterMessage(GatherParameter) _sym_db.RegisterMessage(GatherParameter)
GroupNormParameter = _reflection.GeneratedProtocolMessageType('GroupNormParameter', (_message.Message,), dict(
DESCRIPTOR = _GROUPNORMPARAMETER,
__module__ = 'caffe_pb2'
# @@protoc_insertion_point(class_scope:caffe.GroupNormParameter)
))
_sym_db.RegisterMessage(GroupNormParameter)
_BLOBSHAPE.fields_by_name['dim'].has_options = True _BLOBSHAPE.fields_by_name['dim'].has_options = True
_BLOBSHAPE.fields_by_name['dim']._options = _descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001')) _BLOBSHAPE.fields_by_name['dim']._options = _descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001'))
......
...@@ -36,7 +36,7 @@ find_packages('dragon') ...@@ -36,7 +36,7 @@ find_packages('dragon')
find_modules() find_modules()
setup(name = 'dragon', setup(name = 'dragon',
version='0.2.1.12', version='0.2.1.13',
description = 'Dragon: A Computation Graph Virtual Machine Based Deep Learning Framework', description = 'Dragon: A Computation Graph Virtual Machine Based Deep Learning Framework',
url='https://github.com/neopenx/Dragon', url='https://github.com/neopenx/Dragon',
author='Ting Pan', author='Ting Pan',
......
#include "operators/ndarray/reshape_op.h" #include "operators/ndarray/reshape_op.h"
#include "core/workspace.h"
namespace dragon { namespace dragon {
string dim_string(const vector<TIndex>& shape) {
std::stringstream ss;
ss << "(";
for (int i = 0; i < shape.size() - 1; i++) ss << shape[i] << ",";
ss << shape[shape.size() - 1] << ")";
return ss.str();
}
template <class Context> template <class Context>
void ReshapeOp<Context>::RunOnDevice() { void ReshapeOp<Context>::RunOnDevice() {
if (shape_desc.size() > 0 || shape_value.size() > 0) {
require_shape.resize(std::max(shape_desc.size(), shape_value.size()));
for (int i = 0; i < require_shape.size(); i++)
require_shape[i] = shape(i);
} else if (shape_like_desc.size() > 0) {
Tensor* shape_like_tensor = ws()->GetTensor(shape_like_desc);
require_shape.resize(shape_like_tensor->ndim());
for (int i = 0; i < require_shape.size(); i++)
require_shape[i] = shape_like_tensor->dim(i);
} else { LOG(FATAL) << "Missing the require shape."; }
vector<TIndex> Xdims = input(0).dims(); vector<TIndex> Xdims = input(0).dims();
new_shape.resize(require_shape.size());
int infer_dim = -1; int infer_dim = -1;
TIndex total_count = 1; TIndex total_count = 1;
for (int i = 0; i < shape.size(); i++) { for (int i = 0; i < require_shape.size(); i++) {
// handle unchanged dim if (require_shape[i] == 0) {
if (shape[i] == 0) { // handle unchanged dim
CHECK_LT(i, (int)Xdims.size()) CHECK_LT(i, (int)Xdims.size())
<< "\nDim(" << i << ") is out of the Xdims range of (0, " << "\nDim(" << i << ") is out of the Xdims range of (0, "
<< Xdims.size() << ")."; << Xdims.size() << ").";
new_shape[i] = Xdims[i]; new_shape[i] = Xdims[i];
} } else if (require_shape[i] > 0) {
// handle reseted dim // handle reseted dim
else if (shape[i] > 0) { new_shape[i] = require_shape[i];
new_shape[i] = shape[i]; } else {
} // handle inferred dim
// handle inferred dim
else {
CHECK_EQ(infer_dim, -1) CHECK_EQ(infer_dim, -1)
<< "\nDim(" << infer_dim << ") required infer before" << "\nDim(" << infer_dim << ") required infer before"
<< "\ncould not infer for dim(" << i << ") both."; << "\ncould not infer for dim(" << i << ") both.";
...@@ -35,7 +54,8 @@ void ReshapeOp<Context>::RunOnDevice() { ...@@ -35,7 +54,8 @@ void ReshapeOp<Context>::RunOnDevice() {
for (int i = 0; i < new_shape.size(); i++) { for (int i = 0; i < new_shape.size(); i++) {
if (new_shape[i] == -1) { if (new_shape[i] == -1) {
CHECK_EQ(input(0).count() % total_count, 0) CHECK_EQ(input(0).count() % total_count, 0)
<< "\nCan not change the total size."; << "\nCan not change the total size: "
<< input(0).dim_string() << " -> " << dim_string(new_shape);
new_shape[i] = input(0).count() / total_count; new_shape[i] = input(0).count() / total_count;
total_count *= new_shape[i]; total_count *= new_shape[i];
break; break;
...@@ -43,7 +63,8 @@ void ReshapeOp<Context>::RunOnDevice() { ...@@ -43,7 +63,8 @@ void ReshapeOp<Context>::RunOnDevice() {
} }
} }
CHECK_EQ(total_count, input(0).count()) CHECK_EQ(total_count, input(0).count())
<< "\nCan not change the total size."; << "\nCan not change the total size."
<< input(0).dim_string() << " -> " << dim_string(new_shape);
output(0)->Reshape(new_shape); output(0)->Reshape(new_shape);
output(0)->Share(input(0)); output(0)->Share(input(0));
} }
......
...@@ -19,9 +19,9 @@ void BatchNormOp<Context>::TrainingRunWithType() { ...@@ -19,9 +19,9 @@ void BatchNormOp<Context>::TrainingRunWithType() {
auto* tVar_data = var->template mutable_data<T, Context>(); auto* tVar_data = var->template mutable_data<T, Context>();
auto* Xdata = input(0).template data<T, Context>(); auto* Xdata = input(0).template data<T, Context>();
auto* Ydata = output(0)->template mutable_data<T, Context>(); auto* Ydata = output(0)->template mutable_data<T, Context>();
auto* NSMul_data = multiplier->template data<T, Context>();
auto* SMul_data = spatial_multiplier->template data<T, Context>();
auto* NMul_data = num_multiplier->template data<T, Context>(); auto* NMul_data = num_multiplier->template data<T, Context>();
auto* SMul_data = spatial_multiplier->template data<T, Context>();
auto* NSMul_data = multiplier->template data<T, Context>();
auto* NC_data = num_by_chans.template mutable_data<T, Context>(); auto* NC_data = num_by_chans.template mutable_data<T, Context>();
auto* Std_data = stddev->template mutable_data<T, Context>(); auto* Std_data = stddev->template mutable_data<T, Context>();
ctx().template Copy<T, Context, Context>(output(0)->count(), Ydata, Xdata); ctx().template Copy<T, Context, Context>(output(0)->count(), Ydata, Xdata);
...@@ -127,9 +127,9 @@ void BatchNormOp<Context>::InferenceRunWithType() { ...@@ -127,9 +127,9 @@ void BatchNormOp<Context>::InferenceRunWithType() {
auto* tVar_data = var->template mutable_data<T, Context>(); auto* tVar_data = var->template mutable_data<T, Context>();
auto* Xdata = input(0).template data<T, Context>(); auto* Xdata = input(0).template data<T, Context>();
auto* Ydata = output(0)->template mutable_data<T, Context>(); auto* Ydata = output(0)->template mutable_data<T, Context>();
auto* NSMul_data = multiplier->template data<T, Context>();
auto* SMul_data = spatial_multiplier->template data<T, Context>();
auto* NMul_data = num_multiplier->template data<T, Context>(); auto* NMul_data = num_multiplier->template data<T, Context>();
auto* SMul_data = spatial_multiplier->template data<T, Context>();
auto* NSMul_data = multiplier->template data<T, Context>();
auto* NC_data = num_by_chans.template mutable_data<T, Context>(); auto* NC_data = num_by_chans.template mutable_data<T, Context>();
auto* Std_data = stddev->template mutable_data<T, Context>(); auto* Std_data = stddev->template mutable_data<T, Context>();
ctx().template Copy<T, Context, Context>(input(0).count(), Ydata, Xdata); ctx().template Copy<T, Context, Context>(input(0).count(), Ydata, Xdata);
...@@ -169,12 +169,12 @@ void BatchNormOp<Context>::InferenceRunWithType() { ...@@ -169,12 +169,12 @@ void BatchNormOp<Context>::InferenceRunWithType() {
// divide by stddev // divide by stddev
if (data_format == "NCHW") { if (data_format == "NCHW") {
math::Gemm<T, Context>(CblasNoTrans, CblasNoTrans, N, C, 1, math::Gemm<T, Context>(CblasNoTrans, CblasNoTrans, N, C, 1,
1.0, NMul_data, tVar_data, 1.0, NMul_data, tVar_data,
0.0, NC_data); 0.0, NC_data);
math::Gemm<T, Context>(CblasNoTrans, CblasNoTrans, NC, S, 1, math::Gemm<T, Context>(CblasNoTrans, CblasNoTrans, NC, S, 1,
1.0, NC_data, SMul_data, 1.0, NC_data, SMul_data,
0.0, Std_data); 0.0, Std_data);
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
math::Gemm<T, Context>(CblasNoTrans, CblasNoTrans, NS, C, 1, math::Gemm<T, Context>(CblasNoTrans, CblasNoTrans, NS, C, 1,
1.0, NSMul_data, tVar_data, 1.0, NSMul_data, tVar_data,
...@@ -248,9 +248,9 @@ void BatchNormGradientOp<Context>::TrainingRunWithType() { ...@@ -248,9 +248,9 @@ void BatchNormGradientOp<Context>::TrainingRunWithType() {
auto* dXdata = output(0)->template mutable_data<T, Context>(); auto* dXdata = output(0)->template mutable_data<T, Context>();
auto* Std_data = stddev->template mutable_data<T, Context>(); auto* Std_data = stddev->template mutable_data<T, Context>();
auto* tVar_data = var->template mutable_data<T, Context>(); auto* tVar_data = var->template mutable_data<T, Context>();
auto* NSMul_data = multiplier->template data<T, Context>();
auto* SMul_data = spatial_multiplier->template data<T, Context>();
auto* NMul_data = num_multiplier->template data<T, Context>(); auto* NMul_data = num_multiplier->template data<T, Context>();
auto* SMul_data = spatial_multiplier->template data<T, Context>();
auto* NSMul_data = multiplier->template data<T, Context>();
auto* NC_data = num_by_chans.template mutable_data<T, Context>(); auto* NC_data = num_by_chans.template mutable_data<T, Context>();
if (data_format == "NCHW") { if (data_format == "NCHW") {
...@@ -337,9 +337,9 @@ void BatchNormGradientOp<Context>::InferenceRunWithType() { ...@@ -337,9 +337,9 @@ void BatchNormGradientOp<Context>::InferenceRunWithType() {
auto* dXdata = output(0)->template mutable_data<T, Context>(); auto* dXdata = output(0)->template mutable_data<T, Context>();
auto* Std_data = stddev->template mutable_data<T, Context>(); auto* Std_data = stddev->template mutable_data<T, Context>();
auto* tVar_data = var->template mutable_data<T, Context>(); auto* tVar_data = var->template mutable_data<T, Context>();
auto* NSMul_data = multiplier->template data<T, Context>();
auto* SMul_data = spatial_multiplier->template data<T, Context>();
auto* NMul_data = num_multiplier->template data<T, Context>(); auto* NMul_data = num_multiplier->template data<T, Context>();
auto* SMul_data = spatial_multiplier->template data<T, Context>();
auto* NSMul_data = multiplier->template data<T, Context>();
auto* NC_data = num_by_chans.template mutable_data<T, Context>(); auto* NC_data = num_by_chans.template mutable_data<T, Context>();
if (data_format == "NCHW") { if (data_format == "NCHW") {
......
...@@ -23,9 +23,9 @@ void FusedBatchNormOp<Context>::TrainingRunWithType() { ...@@ -23,9 +23,9 @@ void FusedBatchNormOp<Context>::TrainingRunWithType() {
auto* tVar_data = var->template mutable_data<T, Context>(); auto* tVar_data = var->template mutable_data<T, Context>();
auto* Xdata = input(0).template data<T, Context>(); auto* Xdata = input(0).template data<T, Context>();
auto* Ydata = output(0)->template mutable_data<T, Context>(); auto* Ydata = output(0)->template mutable_data<T, Context>();
auto* NSMul_data = multiplier->template data<T, Context>();
auto* SMul_data = spatial_multiplier->template data<T, Context>();
auto* NMul_data = num_multiplier->template data<T, Context>(); auto* NMul_data = num_multiplier->template data<T, Context>();
auto* SMul_data = spatial_multiplier->template data<T, Context>();
auto* NSMul_data = multiplier->template data<T, Context>();
auto* NC_data = num_by_chans.template mutable_data<T, Context>(); auto* NC_data = num_by_chans.template mutable_data<T, Context>();
auto* Std_data = stddev->template mutable_data<T, Context>(); auto* Std_data = stddev->template mutable_data<T, Context>();
ctx().template Copy<T, Context, Context>(output(0)->count(), Ydata, Xdata); ctx().template Copy<T, Context, Context>(output(0)->count(), Ydata, Xdata);
...@@ -153,9 +153,9 @@ void FusedBatchNormOp<Context>::InferenceRunWithType() { ...@@ -153,9 +153,9 @@ void FusedBatchNormOp<Context>::InferenceRunWithType() {
auto* tVar_data = var->template mutable_data<T, Context>(); auto* tVar_data = var->template mutable_data<T, Context>();
auto* Xdata = input(0).template data<T, Context>(); auto* Xdata = input(0).template data<T, Context>();
auto* Ydata = output(0)->template mutable_data<T, Context>(); auto* Ydata = output(0)->template mutable_data<T, Context>();
auto* NSMul_data = multiplier->template data<T, Context>();
auto* SMul_data = spatial_multiplier->template data<T, Context>();
auto* NMul_data = num_multiplier->template data<T, Context>(); auto* NMul_data = num_multiplier->template data<T, Context>();
auto* SMul_data = spatial_multiplier->template data<T, Context>();
auto* NSMul_data = multiplier->template data<T, Context>();
auto* NC_data = num_by_chans.template mutable_data<T, Context>(); auto* NC_data = num_by_chans.template mutable_data<T, Context>();
auto* Std_data = stddev->template mutable_data<T, Context>(); auto* Std_data = stddev->template mutable_data<T, Context>();
ctx().template Copy<T, Context, Context>(input(0).count(), Ydata, Xdata); ctx().template Copy<T, Context, Context>(input(0).count(), Ydata, Xdata);
...@@ -296,9 +296,9 @@ void FusedBatchNormGradientOp<Context>::TrainingRunWithType() { ...@@ -296,9 +296,9 @@ void FusedBatchNormGradientOp<Context>::TrainingRunWithType() {
auto* Std_data = stddev->template mutable_data<T, Context>(); auto* Std_data = stddev->template mutable_data<T, Context>();
auto* tMean_data = mean->template mutable_data<T, Context>(); auto* tMean_data = mean->template mutable_data<T, Context>();
auto* tVar_data = var->template mutable_data<T, Context>(); auto* tVar_data = var->template mutable_data<T, Context>();
auto* NSMul_data = multiplier->template data<T, Context>();
auto* SMul_data = spatial_multiplier->template data<T, Context>();
auto* NMul_data = num_multiplier->template data<T, Context>(); auto* NMul_data = num_multiplier->template data<T, Context>();
auto* SMul_data = spatial_multiplier->template data<T, Context>();
auto* NSMul_data = multiplier->template data<T, Context>();
auto* NC_data = num_by_chans.template mutable_data<T, Context>(); auto* NC_data = num_by_chans.template mutable_data<T, Context>();
auto* XNorm_data = x_norm->template data<T, Context>(); auto* XNorm_data = x_norm->template data<T, Context>();
...@@ -436,9 +436,9 @@ void FusedBatchNormGradientOp<Context>::InferenceRunWithType() { ...@@ -436,9 +436,9 @@ void FusedBatchNormGradientOp<Context>::InferenceRunWithType() {
auto* dYdata = input(-1).template data<T, Context>(); auto* dYdata = input(-1).template data<T, Context>();
auto* Sdata = input(3).template data<T, Context>(); auto* Sdata = input(3).template data<T, Context>();
auto* tVar_data = var->template mutable_data<T, Context>(); auto* tVar_data = var->template mutable_data<T, Context>();
auto* NSMul_data = multiplier->template data<T, Context>();
auto* SMul_data = spatial_multiplier->template data<T, Context>();
auto* NMul_data = num_multiplier->template data<T, Context>(); auto* NMul_data = num_multiplier->template data<T, Context>();
auto* SMul_data = spatial_multiplier->template data<T, Context>();
auto* NSMul_data = multiplier->template data<T, Context>();
auto* NC_data = num_by_chans.template mutable_data<T, Context>(); auto* NC_data = num_by_chans.template mutable_data<T, Context>();
// gradient w.r.t. scale // gradient w.r.t. scale
......
#include "operators/norm/group_norm_op.h"
#include "core/workspace.h"
#include "utils/math_functions.h"
#include "utils/filler.h"
namespace dragon {
template <class Context> template <typename T>
void FusedGroupNormOp<Context>::TrainingRunWithType() {
INIT_MULTIPLIER(multiplier, NS);
INIT_MULTIPLIER(num_multiplier, N);
INIT_MULTIPLIER(spatial_multiplier, S);
INIT_MULTIPLIER(cgs_multiplier, CGS);
TENSOR_FILL(input(1), vector<TIndex>(1, NG)); // history_mean
TENSOR_FILL(input(2), vector<TIndex>(1, NG)); // history_var
TENSOR_FILL(input(3), vector<TIndex>(1, C)); // scale
TENSOR_FILL(input(4), vector<TIndex>(1, C)); // bias
auto* hMean_data = input(1).template mutable_data<T, Context>();
auto* hVar_data = input(2).template mutable_data<T, Context>();
auto* Sdata = input(3).template data<T, Context>();
auto* Bdata = input(4).template data<T, Context>();
auto* tMean_data = mean->template mutable_data<T, Context>();
auto* tVar_data = var->template mutable_data<T, Context>();
auto* Xdata = input(0).template data<T, Context>();
auto* Ydata = output(0)->template mutable_data<T, Context>();
auto* NMul_data = num_multiplier->template data<T, Context>();
auto* SMul_data = spatial_multiplier->template data<T, Context>();
auto* NSMul_data = multiplier->template data<T, Context>();
auto* CGSMul_data = cgs_multiplier->template data<T, Context>();
auto* NC_data = num_by_chans.template mutable_data<T, Context>();
auto* Std_data = stddev->template mutable_data<T, Context>();
ctx().template Copy<T, Context, Context>(output(0)->count(), Ydata, Xdata);
// compute mean
if (data_format == "NCHW") {
math::Gemv<T, Context>(CblasNoTrans, NG, CGS,
1.0 / CGS, Xdata, CGSMul_data,
0, tMean_data);
} else if (data_format == "NHWC") {
NOT_IMPLEMENTED;
}
// subtract mean
if (data_format == "NCHW") {
math::Gemm<T, Context>(CblasNoTrans, CblasNoTrans, NG, CGS, 1,
-1.0, tMean_data, CGSMul_data,
1.0, Ydata);
} else if (data_format == "NHWC") {
NOT_IMPLEMENTED;
}
// compute variance
// note that we use VAR(X) = E((X - EX) ^ 2)
math::Square<T, Context>(output(0)->count(), Ydata, Std_data);
if (data_format == "NCHW") {
math::Gemv<T, Context>(CblasNoTrans, NG, CGS,
1.0 / CGS, Std_data, CGSMul_data,
0.0, tVar_data);
} else if (data_format == "NHWC") {
NOT_IMPLEMENTED;
}
// compute moving average
if (!is_recomputing) {
// History(X) = (1 - momentum) * Cur(X) + momentum * History(X)
math::Axpby<T, Context>(mean->count(), 1.0 - momentum, tMean_data, momentum, hMean_data);
math::Axpby<T, Context>(var->count(), 1.0 - momentum, tVar_data, momentum, hVar_data);
}
// compute stddev
math::AddScalar<T, Context>(var->count(), eps, tVar_data);
math::Sqrt<T, Context>(var->count(), tVar_data, tVar_data);
// divide by stddev
if (data_format == "NCHW") {
math::Gemm<T, Context>(CblasNoTrans, CblasNoTrans, NG, CGS, 1,
1.0, tVar_data, CGSMul_data,
0.0, Std_data);
} else if (data_format == "NHWC") {
NOT_IMPLEMENTED;
}
math::Div<T, Context>(output(0)->count(), Ydata, Std_data, Ydata);
// store x_norm for backward
auto* XNorm_data = x_norm->template mutable_data<T, Context>();
ctx().template Copy<T, Context, Context>(output(0)->count(), XNorm_data, Ydata);
// scale
if (data_format == "NCHW") {
math::Gemm<T, Context>(CblasNoTrans, CblasNoTrans, N, C, 1,
1.0, NMul_data, Sdata,
0.0, NC_data);
math::Gemm<T, Context>(CblasNoTrans, CblasNoTrans, NC, S, 1,
1.0, NC_data, SMul_data,
0.0, Std_data);
} else if (data_format == "NHWC") {
math::Gemm<T, Context>(CblasNoTrans, CblasNoTrans, NS, C, 1,
1.0, NSMul_data, Sdata,
0.0, Std_data);
}
math::Mul<T, Context>(output(0)->count(), Ydata, Std_data, Ydata);
// shift
if (data_format == "NCHW") {
math::Gemm<T, Context>(CblasNoTrans, CblasNoTrans, N, C, 1,
1.0, NMul_data, Bdata,
0.0, NC_data);
math::Gemm<T, Context>(CblasNoTrans, CblasNoTrans, NC, S, 1,
1.0, NC_data, SMul_data,
1.0, Ydata);
} else if (data_format == "NHWC") {
math::Gemm<T, Context>(CblasNoTrans, CblasNoTrans, NS, C, 1,
1.0, NSMul_data, Bdata,
1.0, Ydata);
}
ws()->ReleaseBuffer(stddev);
}
template <class Context> template <typename T>
void FusedGroupNormOp<Context>::InferenceRunWithType() {
INIT_MULTIPLIER(multiplier, NS);
INIT_MULTIPLIER(num_multiplier, N);
INIT_MULTIPLIER(spatial_multiplier, S);
INIT_MULTIPLIER(cgs_multiplier, CGS);
TENSOR_FILL(input(1), vector<TIndex>(1, NG)); // history_mean
TENSOR_FILL(input(2), vector<TIndex>(1, NG)); // history_var
TENSOR_FILL(input(3), vector<TIndex>(1, C)); // scale
TENSOR_FILL(input(4), vector<TIndex>(1, C)); // bias
auto* hMean_data = input(1).template mutable_data<T, Context>();
auto* hVar_data = input(2).template mutable_data<T, Context>();
auto* Sdata = input(3).template data<T, Context>();
auto* Bdata = input(4).template data<T, Context>();
auto* tMean_data = mean->template mutable_data<T, Context>();
auto* tVar_data = var->template mutable_data<T, Context>();
auto* Xdata = input(0).template data<T, Context>();
auto* Ydata = output(0)->template mutable_data<T, Context>();
auto* NMul_data = num_multiplier->template data<T, Context>();
auto* SMul_data = spatial_multiplier->template data<T, Context>();
auto* NSMul_data = multiplier->template data<T, Context>();
auto* CGSMul_data = cgs_multiplier->template data<T, Context>();
auto* NC_data = num_by_chans.template mutable_data<T, Context>();
auto* Std_data = stddev->template mutable_data<T, Context>();
ctx().template Copy<T, Context, Context>(input(0).count(), Ydata, Xdata);
ctx().template Copy<T, Context, Context>(mean->count(), tMean_data, hMean_data);
ctx().template Copy<T, Context, Context>(var->count(), tVar_data, hVar_data);
// subtract mean
if (data_format == "NCHW") {
math::Gemv<T, Context>(CblasNoTrans, NG, CGS,
1.0 / CGS, Xdata, CGSMul_data,
0, tMean_data);
} else if (data_format == "NHWC") {
NOT_IMPLEMENTED;
}
// compute stddev
math::AddScalar<T, Context>(var->count(), eps, tVar_data);
math::Sqrt<T, Context>(var->count(), tVar_data, tVar_data);
// divide by stddev
if (data_format == "NCHW") {
math::Gemm<T, Context>(CblasNoTrans, CblasNoTrans, NG, CGS, 1,
1.0, tVar_data, CGSMul_data,
0.0, Std_data);
} else if (data_format == "NHWC") {
NOT_IMPLEMENTED;
}
math::Div<T, Context>(output(0)->count(), Ydata, Std_data, Ydata);
// scale
if (data_format == "NCHW") {
math::Gemm<T, Context>(CblasNoTrans, CblasNoTrans, N, C, 1,
1.0, NMul_data, Sdata,
0.0, NC_data);
math::Gemm<T, Context>(CblasNoTrans, CblasNoTrans, NC, S, 1,
1.0, NC_data, SMul_data,
0.0, Std_data);
} else if (data_format == "NHWC") {
math::Gemm<T, Context>(CblasNoTrans, CblasNoTrans, NS, C, 1,
1.0, NSMul_data, Sdata,
0.0, Std_data);
}
math::Mul<T, Context>(output(0)->count(), Ydata, Std_data, Ydata);
// shift
if (data_format == "NCHW") {
math::Gemm<T, Context>(CblasNoTrans, CblasNoTrans, N, C, 1,
1.0, NMul_data, Bdata,
0.0, NC_data);
math::Gemm<T, Context>(CblasNoTrans, CblasNoTrans, NC, S, 1,
1.0, NC_data, SMul_data,
1.0, Ydata);
} else if (data_format == "NHWC") {
math::Gemm<T, Context>(CblasNoTrans, CblasNoTrans, NS, C, 1,
1.0, NSMul_data, Bdata,
1.0, Ydata);
}
ws()->ReleaseBuffer(stddev);
}
template <class Context>
void FusedGroupNormOp<Context>::Setup() {
// determine the mode
if (use_stats == -1) use_global_stats = phase() == "TEST" ? true : false;
else use_global_stats = use_stats == 1 ? true : false;
is_recomputing = ws()->GetTensor("/opt/mirror_stage/recompute_flag")
->template data<bool, CPUContext>()[0];
// determine the data format
TIndex channel_axis = axis;
data_format = "NCHW";
if (channel_axis == -1) channel_axis += (int)input(0).ndim();
if (channel_axis + 1 == (int)input(0).ndim()) data_format = "NHWC";
if (input(0).ndim() == 2) data_format = "NCHW";
N = input(0).dim(0);
C = input(0).dim(channel_axis);
CHECK_EQ(C % group, 0) << "\nThe " << C << " channels "
<< "can not be split into " << group << " groups.";
if (group == C && input(0).ndim() == 2) // InstanceNorm
LOG(WARNING) << "The 2d input will output all zeros.";
NC = N * C;
NG = N * group;
S = input(0).count() / NC;
CGS = (C / group) * S;
NS = N * S;
// make resource
mean = ws()->CreateTensor("/mnt/" + anchor() + "/gn_mean");
var = ws()->CreateTensor("/mnt/" + anchor() + "/gn_var");
x_norm = ws()->CreateTensor("/mnt/" + anchor() + "/gn_x_norm");
stddev = ws()->GetBuffer();
stddev->ReshapeLike(input(0));
// reshape
mean->Reshape(vector<TIndex>(1, NG));
var->Reshape(vector<TIndex>(1, NG));
num_by_chans.Reshape(vector<TIndex>(1, NC));
x_norm->ReshapeLike(input(0));
output(0)->ReshapeLike(input(0));
}
template <class Context>
void FusedGroupNormOp<Context>::RunOnDevice() {
Setup();
if (input(0).template IsType<float>()) {
if (use_global_stats) InferenceRunWithType<float>();
else TrainingRunWithType<float>();
}
#ifdef WITH_CUDA_FP16
else if (input(0).template IsType<float16>()) {
if (use_global_stats) InferenceRunWithType<float16>();
else TrainingRunWithType<float16>();
}
#endif
else LOG(FATAL) << "Unsupported input types.";
}
DEPLOY_CPU(FusedGroupNorm);
#ifdef WITH_CUDA
DEPLOY_CUDA(FusedGroupNorm);
#endif
OPERATOR_SCHEMA(FusedGroupNorm).NumInputs(5).NumOutputs(1);
template <class Context> template <typename T>
void FusedGroupNormGradientOp<Context>::TrainingRunWithType() {
INIT_MULTIPLIER(multiplier, NS);
INIT_MULTIPLIER(num_multiplier, N);
INIT_MULTIPLIER(spatial_multiplier, S);
INIT_MULTIPLIER(cgs_multiplier, CGS);
auto* dYdata = input(-1).template data<T, Context>();
auto* dXdata = output(0)->template mutable_data<T, Context>();
auto* Sdata = input(3).template data<T, Context>();
auto* Std_data = stddev->template mutable_data<T, Context>();
auto* tMean_data = mean->template mutable_data<T, Context>();
auto* tVar_data = var->template mutable_data<T, Context>();
auto* NMul_data = num_multiplier->template data<T, Context>();
auto* SMul_data = spatial_multiplier->template data<T, Context>();
auto* NSMul_data = multiplier->template data<T, Context>();
auto* CGSMul_data = cgs_multiplier->template data<T, Context>();
auto* NC_data = num_by_chans.template mutable_data<T, Context>();
auto* XNorm_data = x_norm->template data<T, Context>();
// gradient w.r.t. scale
if (output(1)->name() != "ignore") {
auto* dSdata = output(1)->template mutable_data<T, Context>();
math::Mul<T, Context>(stddev->count(), XNorm_data, dYdata, Std_data);
if (data_format == "NCHW") {
math::Gemv<T, Context>(CblasNoTrans, NC, S,
1.0, Std_data, SMul_data,
0.0, NC_data);
math::Gemv<T, Context>(CblasTrans, N, C,
1.0, NC_data, NMul_data,
1.0, dSdata);
} else if (data_format == "NHWC") {
math::Gemv<T, Context>(CblasTrans, NS, C,
1.0, Std_data, NSMul_data,
1.0, dSdata);
}
}
// gradient w.r.t. bias
if (output(2)->name() != "ignore") {
auto* dBdata = output(2)->template mutable_data<T, Context>();
if (data_format == "NCHW") {
math::Gemv<T, Context>(CblasNoTrans, NC, S,
1.0, dYdata, SMul_data,
0.0, NC_data);
math::Gemv<T, Context>(CblasTrans, N, C,
1.0, NC_data, NMul_data,
1.0, dBdata);
} else if (data_format == "NHWC") {
math::Gemv<T, Context>(CblasTrans, NS, C,
1.0, dYdata, NSMul_data,
1.0, dBdata);
}
}
// gradient w.r.t. x
if (output(0)->name() != "ignore") {
// scale * dY
if (data_format == "NCHW") {
math::Gemm<T, Context>(CblasNoTrans, CblasNoTrans, N, C, 1,
1.0, NMul_data, Sdata,
0.0, NC_data);
math::Gemm<T, Context>(CblasNoTrans, CblasNoTrans, NC, S, 1,
1.0, NC_data, SMul_data,
0.0, Std_data);
} else if (data_format == "NHWC") {
math::Gemm<T, Context>(CblasNoTrans, CblasNoTrans, NS, C, 1,
1.0, NSMul_data, Sdata,
0.0, Std_data);
}
math::Mul<T, Context>(stddev->count(), Std_data, dYdata, Std_data);
// sum of x_hat * (dl / dx_hat)
math::Mul<T, Context>(stddev->count(), XNorm_data, Std_data, dXdata);
if (data_format == "NCHW") {
math::Gemv<T, Context>(CblasNoTrans, NG, CGS,
1.0, dXdata, CGSMul_data,
0.0, tMean_data);
} else if (data_format == "NHWC") {
NOT_IMPLEMENTED;
}
// x_hat times the sum
if (data_format == "NCHW") {
math::Gemm<T, Context>(CblasNoTrans, CblasNoTrans, NG, CGS, 1,
1.0, tMean_data, CGSMul_data,
0.0, dXdata);
} else if (data_format == "NHWC") {
NOT_IMPLEMENTED;
}
math::Mul<T, Context>(stddev->count(), XNorm_data, dXdata, dXdata);
// subtract the average of x_hat times the sum
if (data_format == "NCHW") {
math::Gemv<T, Context>(CblasNoTrans, NG, CGS,
1.0, Std_data, CGSMul_data,
0.0, tMean_data);
math::Gemm<T, Context>(CblasNoTrans, CblasNoTrans, NG, CGS, 1,
1.0, tMean_data, CGSMul_data,
1.0, dXdata);
} else if (data_format == "NHWC") {
NOT_IMPLEMENTED;
}
math::Axpby<T, Context>(stddev->count(), 1.0, Std_data, -1.0 / CGS, dXdata);
// multiply with the inverse std
if (data_format == "NCHW") {
math::Gemm<T, Context>(CblasNoTrans, CblasNoTrans, NG, CGS, 1,
1.0, tVar_data, CGSMul_data,
0.0, Std_data);
} else if (data_format == "NHWC") {
NOT_IMPLEMENTED;
}
// divide by stddev
math::Div<T, Context>(output(0)->count(), dXdata, Std_data, dXdata);
}
ws()->ReleaseBuffer(stddev);
}
template <class Context> template <typename T>
void FusedGroupNormGradientOp<Context>::InferenceRunWithType() {
INIT_MULTIPLIER(multiplier, NS);
INIT_MULTIPLIER(num_multiplier, N);
INIT_MULTIPLIER(spatial_multiplier, S);
INIT_MULTIPLIER(cgs_multiplier, CGS);
auto* dYdata = input(-1).template data<T, Context>();
auto* Sdata = input(3).template data<T, Context>();
auto* tVar_data = var->template mutable_data<T, Context>();
auto* NMul_data = num_multiplier->template data<T, Context>();
auto* SMul_data = spatial_multiplier->template data<T, Context>();
auto* NSMul_data = multiplier->template data<T, Context>();
auto* CGSMul_data = cgs_multiplier->template data<T, Context>();
auto* NC_data = num_by_chans.template mutable_data<T, Context>();
// gradient w.r.t. scale
if (output(1)->name() != "ignore")
LOG(FATAL) << "The gamma should be fixed if using global stats.";
// gradient w.r.t. bias
if (output(2)->name() != "ignore") {
auto* dBdata = output(2)->template mutable_data<T, Context>();
if (data_format == "NCHW") {
math::Gemv<T, Context>(CblasNoTrans, NC, S,
1.0, dYdata, SMul_data,
0.0, NC_data);
math::Gemv<T, Context>(CblasTrans, N, C,
1.0, NC_data, NMul_data,
1.0, dBdata);
} else if (data_format == "NHWC") {
math::Gemv<T, Context>(CblasTrans, NS, C,
1.0, dYdata, NSMul_data,
1.0, dBdata);
}
}
// gradient w.r.t. x
if (output(0)->name() != "ignore") {
auto* dXdata = output(0)->template mutable_data<T, Context>();
auto* Std_data = stddev->template mutable_data<T, Context>();
// divide scale by stddev
math::Div<T, Context>(var->count(), Sdata, tVar_data, tVar_data);
// compute dE/dY \cot (scale / std(X))
if (data_format == "NCHW") {
math::Gemm<T, Context>(CblasNoTrans, CblasNoTrans, NG, CGS, 1,
1.0, tVar_data, CGSMul_data,
0.0, Std_data);
} else if (data_format == "NHWC") {
NOT_IMPLEMENTED;
}
math::Mul<T, Context>(output(0)->count(), dYdata, Std_data, dXdata);
}
ws()->ReleaseBuffer(stddev);
}
template <class Context>
void FusedGroupNormGradientOp<Context>::Setup() {
// determine the mode
if (use_stats == -1) use_global_stats = phase() == "TEST" ? true : false;
else use_global_stats = use_stats == 1 ? true : false;
// determine the data format
TIndex channel_axis = axis;
data_format = "NCHW";
if (channel_axis == -1) channel_axis += (int)input(0).ndim();
if (channel_axis + 1 == (int)input(0).ndim()) data_format = "NHWC";
if (input(0).ndim() == 2) data_format = "NCHW";
N = input(0).dim(0);
C = input(0).dim(channel_axis);
CHECK_EQ(C % group, 0) << "\nThe " << C << " channels "
<< "can not be split into " << group << " groups.";
if (group == C && input(0).ndim() == 2) // InstanceNorm
LOG(WARNING) << "The 2d input will output all zeros.";
NC = N * C;
NG = N * group;
S = input(0).count() / NC;
CGS = (C / group) * S;
NS = N * S;
// make resource
mean = ws()->GetTensor("/mnt/" + anchor() + "/gn_mean");
var = ws()->GetTensor("/mnt/" + anchor() + "/gn_var");
x_norm = ws()->GetTensor("/mnt/" + anchor() + "/gn_x_norm");
stddev = ws()->GetBuffer();
stddev->ReshapeLike(input(0));
// reshape
num_by_chans.Reshape(vector<TIndex>(1, NC));
output(0)->ReshapeLike(input(0)); // dX
output(1)->ReshapeLike(input(3)); // dScale
output(2)->ReshapeLike(input(3)); // dBias
}
template <class Context>
void FusedGroupNormGradientOp<Context>::RunOnDevice() {
Setup();
if (input(0).template IsType<float>()) {
if (use_global_stats) InferenceRunWithType<float>();
else TrainingRunWithType<float>();
}
#ifdef WITH_CUDA_FP16
else if (input(0).template IsType<float16>()) {
if (use_global_stats) InferenceRunWithType<float16>();
else TrainingRunWithType<float16>();
}
#endif
else LOG(FATAL) << "Unsupported input types.";
}
template <class Context>
void FusedGroupNormGradientOp<Context>::ShareGradient() {
if (use_global_stats) {
if (output(0)->name() != "ignore") {
Tensor* dX = ws()->GetBuffer("Grad");
ws()->CreateAvatar(output(0), dX);
}
} else {
if (output(0)->name() != "ignore" ||
output(1)->name() != "ignore" ||
output(2)->name() != "ignore") {
Tensor* dX = ws()->GetBuffer("Grad");
ws()->CreateAvatar(output(0), dX);
}
}
}
DEPLOY_CPU(FusedGroupNormGradient);
#ifdef WITH_CUDA
DEPLOY_CUDA(FusedGroupNormGradient);
#endif
OPERATOR_SCHEMA(FusedGroupNormGradient).NumInputs(5).NumOutputs(3);
class GetFusedGroupNormGradient final : public GradientMakerBase {
public:
GRADIENT_MAKER_CTOR(GetFusedGroupNormGradient);
vector<OperatorDef> MakeDefs() override {
return SingleDef(def.type() + "Gradient", "",
vector<string> {I(0), I(1), I(2), I(3), GO(0)},
vector<string> {GI(0), GI(3), GI(4)});
}
};
REGISTER_GRADIENT(FusedGroupNorm, GetFusedGroupNormGradient);
} // namespace dragon
\ No newline at end of file
#include "operators/norm/group_norm_op.h"
#include "core/workspace.h"
#include "utils/math_functions.h"
#include "utils/filler.h"
namespace dragon {
template <class Context> template <typename T>
void GroupNormOp<Context>::TrainingRunWithType() {
INIT_MULTIPLIER(multiplier, NS);
INIT_MULTIPLIER(num_multiplier, N);
INIT_MULTIPLIER(spatial_multiplier, S);
INIT_MULTIPLIER(cgs_multiplier, CGS);
TENSOR_FILL(input(1), vector<TIndex>(1, NG)); // history_mean
TENSOR_FILL(input(2), vector<TIndex>(1, NG)); // history_var
auto* hMean_data = input(1).template mutable_data<T, Context>();
auto* hVar_data = input(2).template mutable_data<T, Context>();
auto* tMean_data = mean.template mutable_data<T, Context>();
auto* tVar_data = var->template mutable_data<T, Context>();
auto* Xdata = input(0).template data<T, Context>();
auto* Ydata = output(0)->template mutable_data<T, Context>();
auto* NMul_data = num_multiplier->template data<T, Context>();
auto* SMul_data = spatial_multiplier->template data<T, Context>();
auto* NSMul_data = multiplier->template data<T, Context>();
auto* CGSMul_data = cgs_multiplier->template data<T, Context>();
auto* NC_data = num_by_chans.template mutable_data<T, Context>();
auto* Std_data = stddev->template mutable_data<T, Context>();
ctx().template Copy<T, Context, Context>(output(0)->count(), Ydata, Xdata);
// compute mean
if (data_format == "NCHW") {
math::Gemv<T, Context>(CblasNoTrans, NG, CGS,
1.0 / CGS, Xdata, CGSMul_data,
0, tMean_data);
} else if (data_format == "NHWC") {
NOT_IMPLEMENTED;
}
// subtract mean
if (data_format == "NCHW") {
math::Gemm<T, Context>(CblasNoTrans, CblasNoTrans, NG, CGS, 1,
-1.0, tMean_data, CGSMul_data,
1.0, Ydata);
} else if (data_format == "NHWC") {
NOT_IMPLEMENTED;
}
// compute variance
// note that we use VAR(X) = E((X - EX) ^ 2)
math::Square<T, Context>(output(0)->count(), Ydata, Std_data);
if (data_format == "NCHW") {
math::Gemv<T, Context>(CblasNoTrans, NG, CGS,
1.0 / CGS, Std_data, CGSMul_data,
0.0, tVar_data);
} else if (data_format == "NHWC") {
NOT_IMPLEMENTED;
}
// compute moving average
if (!is_recomputing) {
if (mode == "CAFFE") {
CHECK_EQ(InputSize(), 4)
<< "\nThe number of inputs should be 4 if use CAFFE mode.";
TENSOR_FILL(input(3), vector<TIndex>(1, 1));
auto* hFact_data = input(3).template mutable_data<T, CPUContext>();
float factor = dragon_cast<float, T>(hFact_data[0]);
factor *= momentum; factor += 1;
hFact_data[0] = dragon_cast<T, float>(factor);
int m = input(0).count() / C;
float coeff = m > 1 ? float(m) / (m - 1) : 1;
// History(X) = Cur(X) + momentum * History(X)
math::Axpby<T, Context>(mean.count(), 1.0, tMean_data, momentum, hMean_data);
math::Axpby<T, Context>(var->count(), coeff, tVar_data, momentum, hVar_data);
} else {
// History(X) = (1 - momentum) * Cur(X) + momentum * History(X)
math::Axpby<T, Context>(mean.count(), 1.0 - momentum, tMean_data, momentum, hMean_data);
math::Axpby<T, Context>(var->count(), 1.0 - momentum, tVar_data, momentum, hVar_data);
}
}
// compute stddev
math::AddScalar<T, Context>(var->count(), eps, tVar_data);
math::Sqrt<T, Context>(var->count(), tVar_data, tVar_data);
// divide by stddev
if (data_format == "NCHW") {
math::Gemm<T, Context>(CblasNoTrans, CblasNoTrans, NG, CGS, 1,
1.0, tVar_data, CGSMul_data,
0.0, Std_data);
} else if (data_format == "NHWC") {
NOT_IMPLEMENTED;
}
math::Div<T, Context>(output(0)->count(), Ydata, Std_data, Ydata);
ws()->ReleaseBuffer(stddev);
}
template <class Context> template <typename T>
void GroupNormOp<Context>::InferenceRunWithType() {
INIT_MULTIPLIER(multiplier, NS);
INIT_MULTIPLIER(num_multiplier, N);
INIT_MULTIPLIER(spatial_multiplier, S);
INIT_MULTIPLIER(cgs_multiplier, CGS);
TENSOR_FILL(input(1), vector<TIndex>(1, NG)); // history_mean
TENSOR_FILL(input(2), vector<TIndex>(1, NG)); // history_var
auto* hMean_data = input(1).template mutable_data<T, Context>();
auto* hVar_data = input(2).template mutable_data<T, Context>();
auto* tMean_data = mean.template mutable_data<T, Context>();
auto* tVar_data = var->template mutable_data<T, Context>();
auto* Xdata = input(0).template data<T, Context>();
auto* Ydata = output(0)->template mutable_data<T, Context>();
auto* NMul_data = num_multiplier->template data<T, Context>();
auto* SMul_data = spatial_multiplier->template data<T, Context>();
auto* NSMul_data = multiplier->template data<T, Context>();
auto* CGSMul_data = cgs_multiplier->template data<T, Context>();
auto* NC_data = num_by_chans.template mutable_data<T, Context>();
auto* Std_data = stddev->template mutable_data<T, Context>();
ctx().template Copy<T, Context, Context>(input(0).count(), Ydata, Xdata);
// scale the mean and variance if necessary
if (mode == "CAFFE") {
CHECK_EQ(InputSize(), 4)
<< "\nThe number of inputs should be 4 if use CAFFE mode.";
TENSOR_FILL(input(3), vector<TIndex>(1, 1));
auto* hFact_data = input(3).template mutable_data<T, CPUContext>();
const float factor = dragon_cast<float, T>(hFact_data[0]);
const float scale = factor == 0 ? 0 : 1.0 / factor;
math::Scale<T, Context>(mean.count(), scale, hMean_data, tMean_data);
math::Scale<T, Context>(var->count(), scale, hVar_data, tVar_data);
} else {
ctx().template Copy<T, Context, Context>(mean.count(), tMean_data, hMean_data);
ctx().template Copy<T, Context, Context>(var->count(), tVar_data, hVar_data);
}
// subtract mean
if (data_format == "NCHW") {
math::Gemv<T, Context>(CblasNoTrans, NG, CGS,
1.0 / CGS, Xdata, CGSMul_data,
0, tMean_data);
} else if (data_format == "NHWC") {
NOT_IMPLEMENTED;
}
// compute stddev
math::AddScalar<T, Context>(var->count(), eps, tVar_data);
math::Sqrt<T, Context>(var->count(), tVar_data, tVar_data);
// divide by stddev
if (data_format == "NCHW") {
math::Gemm<T, Context>(CblasNoTrans, CblasNoTrans, NG, CGS, 1,
1.0, tVar_data, CGSMul_data,
0.0, Std_data);
} else if (data_format == "NHWC") {
NOT_IMPLEMENTED;
}
math::Div<T, Context>(output(0)->count(), Ydata, Std_data, Ydata);
ws()->ReleaseBuffer(stddev);
}
template <class Context>
void GroupNormOp<Context>::Setup() {
// determine the mode
if (use_stats == -1) use_global_stats = phase() == "TEST" ? true : false;
else use_global_stats = use_stats == 1 ? true : false;
is_recomputing = ws()->GetTensor("/opt/mirror_stage/recompute_flag")
->template data<bool, CPUContext>()[0];
// determine the data format
TIndex channel_axis = axis;
data_format = "NCHW";
if (channel_axis == -1) channel_axis += (int)input(0).ndim();
if (channel_axis + 1 == (int)input(0).ndim()) data_format = "NHWC";
if (input(0).ndim() == 2) data_format = "NCHW";
N = input(0).dim(0);
C = input(0).dim(channel_axis);
CHECK_EQ(C % group, 0) << "\nThe " << C << " channels "
<< "can not be split into " << group << " groups.";
if (group == C && input(0).ndim() == 2) // InstanceNorm
LOG(WARNING) << "The 2d input will output all zeros.";
NC = N * C;
NG = N * group;
S = input(0).count() / NC;
CGS = (C / group) * S;
NS = N * S;
// make resource
var = ws()->CreateTensor("/mnt/" + anchor() + "/gn_var");
stddev = ws()->GetBuffer();
stddev->ReshapeLike(input(0));
// reshape
mean.Reshape(vector<TIndex>(1, NG));
var->Reshape(vector<TIndex>(1, NG));
num_by_chans.Reshape(vector<TIndex>(1, NC));
output(0)->ReshapeLike(input(0));
}
template <class Context>
void GroupNormOp<Context>::RunOnDevice() {
Setup();
if (input(0).template IsType<float>()) {
if (use_global_stats) InferenceRunWithType<float>();
else TrainingRunWithType<float>();
}
#ifdef WITH_CUDA_FP16
else if (input(0).template IsType<float16>()) {
if (use_global_stats) InferenceRunWithType<float16>();
else TrainingRunWithType<float16>();
}
#endif
else LOG(FATAL) << "Unsupported input types.";
}
DEPLOY_CPU(GroupNorm);
#ifdef WITH_CUDA
DEPLOY_CUDA(GroupNorm);
#endif
OPERATOR_SCHEMA(GroupNorm).NumInputs(3, 4).NumOutputs(1);
template <class Context> template <typename T>
void GroupNormGradientOp<Context>::TrainingRunWithType() {
INIT_MULTIPLIER(multiplier, NS);
INIT_MULTIPLIER(num_multiplier, N);
INIT_MULTIPLIER(spatial_multiplier, S);
INIT_MULTIPLIER(cgs_multiplier, CGS);
auto* dYdata = input(-1).template data<T, Context>();
auto* dXdata = output(0)->template mutable_data<T, Context>();
auto* Std_data = stddev->template mutable_data<T, Context>();
auto* tVar_data = var->template mutable_data<T, Context>();
auto* NMul_data = num_multiplier->template data<T, Context>();
auto* SMul_data = spatial_multiplier->template data<T, Context>();
auto* NSMul_data = multiplier->template data<T, Context>();
auto* CGSMul_data = cgs_multiplier->template data<T, Context>();
auto* NC_data = num_by_chans.template mutable_data<T, Context>();
if (data_format == "NCHW") {
math::Gemm<T, Context>(CblasNoTrans, CblasNoTrans, NG, CGS, 1,
1.0, tVar_data, CGSMul_data,
0.0, Std_data);
} else if (data_format == "NHWC") {
NOT_IMPLEMENTED;
}
auto* Ydata = input(1).template data<T, Context>();
math::Mul<T, Context>(output(0)->count(), Ydata, dYdata, dXdata);
// sum(dE/dY \cdot Y)
if (data_format == "NCHW") {
math::Gemv<T, Context>(CblasNoTrans, NG, CGS,
1.0, dXdata, CGSMul_data,
0.0, tVar_data);
math::Gemm<T, Context>(CblasNoTrans, CblasNoTrans, NG, CGS, 1,
1.0, tVar_data, CGSMul_data,
0.0, dXdata);
} else if (data_format == "NHWC") {
NOT_IMPLEMENTED;
}
// sum(dE/dY \cdot Y) \cdot Y
math::Mul<T, Context>(output(0)->count(), Ydata, dXdata, dXdata);
// sum(dE/dY) + sum(dE/dY \cdot Y) \cdot Y
if (data_format == "NCHW") {
math::Gemv<T, Context>(CblasNoTrans, NG, CGS,
1.0, dYdata, CGSMul_data,
0.0, tVar_data);
math::Gemm<T, Context>(CblasNoTrans, CblasNoTrans, NG, CGS, 1,
1.0, tVar_data, CGSMul_data,
1.0, dXdata);
} else if (data_format == "NHWC") {
NOT_IMPLEMENTED;
}
// dE/dY - mean(dE/dY)- mean(dE/dY \cdot Y) \cdot Y
// = dE/dY - mean(sum(dE/dY) + sum(dE/dY \cdot Y) \cdot Y)
math::Axpby<T, Context>(output(0)->count(), 1.0, dYdata, -1.0 / CGS, dXdata);
// divide by stddev
math::Div<T, Context>(output(0)->count(), dXdata, Std_data, dXdata);
ws()->ReleaseBuffer(stddev);
}
template <class Context> template <typename T>
void GroupNormGradientOp<Context>::InferenceRunWithType() {
INIT_MULTIPLIER(multiplier, NS);
INIT_MULTIPLIER(num_multiplier, N);
INIT_MULTIPLIER(spatial_multiplier, S);
INIT_MULTIPLIER(cgs_multiplier, CGS);
auto* dYdata = input(-1).template data<T, Context>();
auto* dXdata = output(0)->template mutable_data<T, Context>();
auto* Std_data = stddev->template mutable_data<T, Context>();
auto* tVar_data = var->template mutable_data<T, Context>();
auto* NMul_data = num_multiplier->template data<T, Context>();
auto* SMul_data = spatial_multiplier->template data<T, Context>();
auto* NSMul_data = multiplier->template data<T, Context>();
auto* CGSMul_data = cgs_multiplier->template data<T, Context>();
auto* NC_data = num_by_chans.template mutable_data<T, Context>();
if (data_format == "NCHW") {
math::Gemm<T, Context>(CblasNoTrans, CblasNoTrans, NG, CGS, 1,
1.0, tVar_data, NSMul_data,
0.0, Std_data);
} else if (data_format == "NHWC") {
NOT_IMPLEMENTED;
}
math::Div<T, Context>(output(0)->count(), dYdata, Std_data, dXdata);
ws()->ReleaseBuffer(stddev);
}
template <class Context>
void GroupNormGradientOp<Context>::Setup() {
// determine the mode
if (use_stats == -1) use_global_stats = phase() == "TEST" ? true : false;
else use_global_stats = use_stats == 1 ? true : false;
// determine the data format
TIndex channel_axis = axis;
data_format = "NCHW";
if (channel_axis == -1) channel_axis += (int)input(0).ndim();
if (channel_axis + 1 == (int)input(0).ndim()) data_format = "NHWC";
if (input(0).ndim() == 2) data_format = "NCHW";
N = input(0).dim(0);
C = input(0).dim(channel_axis);
CHECK_EQ(C % group, 0) << "\nThe " << C << " channels "
<< "can not be split into " << group << " groups.";
if (group == C && input(0).ndim() == 2) // InstanceNorm
LOG(WARNING) << "The 2d input will output all zeros.";
NC = N * C;
NG = N * group;
S = input(0).count() / NC;
CGS = (C / group) * S;
NS = N * S;
// make resource
var = ws()->GetTensor("/mnt/" + anchor() + "/gn_var");
stddev = ws()->GetBuffer();
stddev->ReshapeLike(input(0));
// reshape
num_by_chans.Reshape(vector<TIndex>(1, NC));
output(0)->ReshapeLike(input(0));
}
template <class Context>
void GroupNormGradientOp<Context>::RunOnDevice() {
Setup();
if (input(0).template IsType<float>()) {
if (use_global_stats) InferenceRunWithType<float>();
else TrainingRunWithType<float>();
}
#ifdef WITH_CUDA_FP16
else if (input(0).template IsType<float16>()) {
if (use_global_stats) InferenceRunWithType<float16>();
else TrainingRunWithType<float16>();
}
#endif
else LOG(FATAL) << "Unsupported input types.";
}
DEPLOY_CPU(GroupNormGradient);
#ifdef WITH_CUDA
DEPLOY_CUDA(GroupNormGradient);
#endif
OPERATOR_SCHEMA(GroupNormGradient).NumInputs(3).NumOutputs(1);
class GetGroupNormGradient final : public GradientMakerBase {
public:
GRADIENT_MAKER_CTOR(GetGroupNormGradient);
vector<OperatorDef> MakeDefs() override {
return SingleDef(def.type() + "Gradient", "",
vector<string> {I(0), O(0), GO(0)},
vector<string> {GI(0)});
}
};
REGISTER_GRADIENT(GroupNorm, GetGroupNormGradient);
} // namespace dragon
\ No newline at end of file
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!