Commit 36f27485 by Ting PAN

Refer the RoIAlign@Caffe2

1 parent 51056e19
Showing with 858 additions and 540 deletions
......@@ -16,14 +16,13 @@ class SigmoidCrossEntropyOp final : public Operator<Context> {
public:
SigmoidCrossEntropyOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
normalization(OperatorBase::GetSingleArg<string>("normalization", "FULL")) {}
normalization(OperatorBase::GetSingleArg<string>("normalization", "VALID")) {}
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
Tensor losses;
Tensor* prob;
Tensor valid, losses;
string normalization;
};
......@@ -32,13 +31,13 @@ class SigmoidCrossEntropyGradientOp final : public Operator<Context> {
public:
SigmoidCrossEntropyGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
normalization(OperatorBase::GetSingleArg<string>("normalization", "FULL")) {}
normalization(OperatorBase::GetSingleArg<string>("normalization", "VALID")) {}
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
Tensor* prob;
Tensor valid;
string normalization;
};
......
......@@ -4,17 +4,17 @@
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_NDARRAY_AT_OP_H_
#define DRAGON_OPERATORS_NDARRAY_AT_OP_H_
#ifndef DRAGON_OPERATORS_NDARRAY_GATHER_OP_H_
#define DRAGON_OPERATORS_NDARRAY_GATHER_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class AtOp final : public Operator<Context> {
class GatherOp final : public Operator<Context> {
public:
AtOp(const OperatorDef& op_def, Workspace* ws)
GatherOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 0)) {}
......@@ -27,9 +27,9 @@ class AtOp final : public Operator<Context> {
};
template <class Context>
class AtGradientOp final : public Operator<Context> {
class GatherGradientOp final : public Operator<Context> {
public:
AtGradientOp(const OperatorDef& op_def, Workspace* ws)
GatherGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 0)),
acc_grad(OperatorBase::GetSingleArg<bool>("acc_gradient", false)) {}
......@@ -44,4 +44,4 @@ class AtGradientOp final : public Operator<Context> {
} // namespace dragon
#endif // DRAGON_OPERATORS_NDARRAY_AT_OP_H_
\ No newline at end of file
#endif // DRAGON_OPERATORS_NDARRAY_GATHER_OP_H_
\ No newline at end of file
......@@ -18,7 +18,8 @@ class ROIAlignOp : public Operator<Context> {
: Operator<Context>(op_def, ws),
pool_h(OperatorBase::GetSingleArg<int>("pool_h", 0)),
pool_w(OperatorBase::GetSingleArg<int>("pool_w", 0)),
spatial_scale(OperatorBase::GetSingleArg<float>("spatial_scale", 1.0)) {
spatial_scale(OperatorBase::GetSingleArg<float>("spatial_scale", 1.0)),
sampling_ratio(OperatorBase::GetSingleArg<int>("sampling_ratio", 2)) {
CHECK_GT(pool_h, 0) << "\npool_h must > 0";
CHECK_GT(pool_w, 0) << "\npool_w must > 0";
}
......@@ -27,9 +28,8 @@ class ROIAlignOp : public Operator<Context> {
template <typename T> void RunWithType();
protected:
int pool_h, pool_w;
int pool_h, pool_w, sampling_ratio;
float spatial_scale;
Tensor* mask_h, *mask_w;
};
template <class Context>
......@@ -39,7 +39,8 @@ class ROIAlignGradientOp : public Operator<Context> {
: Operator<Context>(op_def, ws),
pool_h(OperatorBase::GetSingleArg<int>("pool_h", 0)),
pool_w(OperatorBase::GetSingleArg<int>("pool_w", 0)),
spatial_scale(OperatorBase::GetSingleArg<float>("spatial_scale", 1.0)) {
spatial_scale(OperatorBase::GetSingleArg<float>("spatial_scale", 1.0)),
sampling_ratio(OperatorBase::GetSingleArg<int>("sampling_ratio", 2)) {
CHECK_GT(pool_h, 0) << "\npool_h must > 0";
CHECK_GT(pool_w, 0) << "\npool_w must > 0";
}
......@@ -48,9 +49,8 @@ class ROIAlignGradientOp : public Operator<Context> {
template <typename T> void RunWithType();
protected:
int pool_h, pool_w;
int pool_h, pool_w, sampling_ratio;
float spatial_scale;
Tensor* mask_h, *mask_w;
};
} // namespace dragon
......
......@@ -198,7 +198,18 @@ void AbsGrad(const int count, const T* dy, T* dx);
/******************** loss.sigmoid_cross_entropy ********************/
template <typename T, class Context>
void SigmoidCrossEntropy(const int count, const T* x, const T* target, T* loss);
void SigmoidCrossEntropy(const int count,
const T* x,
const T* target,
T* loss,
T* valid);
template <typename T, class Context>
void SigmoidCrossEntropyGrad(const int count,
const T* x,
const T* target,
T* dx,
T* valid);
/******************** loss.smooth_l1_loss ********************/
......@@ -312,13 +323,13 @@ void Argmin(const int count,
const T* x,
T* y);
/******************** ndarray.at ********************/
/******************** ndarray.gather ********************/
template <typename T, class Context>
void CanonicalAxis(const int count, const int dim, T* y);
template <typename T, class Context>
void At(const int count,
void Gather(const int count,
const int outer_dim,
const int inner_dim,
const int x_slice_dim,
......@@ -329,7 +340,7 @@ void At(const int count,
Context* ctx);
template <typename T, class Context>
void AtGrad(const int count,
void GatherGrad(const int count,
const int outer_dim,
const int inner_dim,
const int x_slice_dim,
......@@ -791,7 +802,7 @@ void ROIPooling(const float spatial_scale,
const int pool_h,
const int pool_w,
Tensor* x,
Tensor* roi,
Tensor* rois,
Tensor* mask,
Tensor* y);
......@@ -800,7 +811,7 @@ void ROIPoolingGrad(const float spatial_scale,
const int pool_h,
const int pool_w,
Tensor* dy,
Tensor* roi,
Tensor* rois,
Tensor* mask,
Tensor* dx);
......@@ -810,20 +821,18 @@ template <typename T, class Context>
void ROIAlign(const float spatial_scale,
const int pool_h,
const int pool_w,
const int sampling_ratio,
Tensor* x,
Tensor* roi,
Tensor* mask_h,
Tensor* mask_w,
Tensor* rois,
Tensor* y);
template <typename T, class Context>
void ROIAlignGrad(const float spatial_scale,
const int pool_h,
const int pool_w,
const int sampling_ratio,
Tensor* dy,
Tensor* roi,
Tensor* mask_h,
Tensor* mask_w,
Tensor* rois,
Tensor* dx);
} // namespace kernel
......
......@@ -4,16 +4,22 @@
# Written by Ting Pan
# --------------------------------------------------------
import logging
import sys
# core
from dragon.core.tensor import Tensor
import dragon.core.workspace as workspace
try:
from dragon.libdragon import *
except ImportError as e:
logging.critical(
'cannot load dragon. Error: {0}'.format(str(e)))
sys.exit(1)
# ops
from dragon.ops import *
# updaters
from dragon.updaters import *
# theano utilities
from dragon.vm.theano.compile.function import function as function
from dragon.vm.theano.tensor import grad as grad
# scope
from dragon.core.scope import TensorScope as name_scope
from dragon.core.scope import PhaseScope as phase_scope
from dragon.core.scope import DeviceScope as device_scope
......@@ -4,13 +4,18 @@
# Written by Ting Pan
# --------------------------------------------------------
from dragon.__init__ import *
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import logging
logger = logging.getLogger('dragon')
logger.setLevel(logging.INFO)
logger.addHandler(logging.StreamHandler(sys.stdout))
from dragon.import_c_apis import *
option = {}
REGISTERED_OPERATORS = set(s for s in RegisteredOperatorsCC())
......
......@@ -4,14 +4,20 @@
# Written by Ting Pan
# --------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import defaultdict
from dragon.import_c_apis import *
import dragon.config as config
import dragon.protos.dragon_pb2 as pb
from collections import defaultdict
from dragon.core.utils import MakeOperatorDef
from dragon.__init__ import *
from .scope import GetOperatorName
class GraphGradientMaker(object):
"""
GraphGradientMaker is deigned to generate gradient operators automatically.
......
......@@ -4,11 +4,14 @@
# Written by Ting Pan
# --------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from six.moves import range as xrange
from dragon import MPIInitCC, MPIRankCC, MPISizeCC, \
MPICreateGroupCC, MPIFinalizeCC
from dragon.import_c_apis import *
_is_init = False
_snapshot_ranks = []
......
......@@ -4,7 +4,9 @@
# Written by Ting Pan
# --------------------------------------------------------
from collections import defaultdict
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
_TENSOR_SCOPE = ''
_PHASE_SCOPE = ''
......
......@@ -4,13 +4,18 @@
# Written by Ting Pan
# --------------------------------------------------------
import dragon.core.workspace as ws
import dragon.protos.dragon_pb2 as pb
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from collections import OrderedDict
from six.moves import range as xrange
import dragon.core.workspace as ws
import dragon.protos.dragon_pb2 as pb
from dragon.core.utils import MakeOperatorDef
from dragon.core.scope import GetOperatorName, GetTensorName
from six.moves import range as xrange
class Tensor(object):
......@@ -416,7 +421,7 @@ class Tensor(object):
if not isinstance(item, tuple):
# 1D At
if isinstance(item, int):
output = self.CreateOperator(inputs=[self, wrapper_indices([item])], nout=1, op_type='At')
output = self.CreateOperator(inputs=[self, wrapper_indices([item])], nout=1, op_type='Gather')
if self.shape is not None:
output.shape = self.shape[:]
output.shape[0] = 1
......
......@@ -4,11 +4,16 @@
# Written by Ting Pan
# --------------------------------------------------------
import sys
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import numpy as np
from google.protobuf.message import Message
from dragon.protos import dragon_pb2 as pb
import numpy as np
if sys.version_info >= (3,0):
def MakeArgument(key, value):
......
......@@ -4,19 +4,25 @@
# Written by Ting Pan
# --------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
try:
import cPickle
except:
import pickle as cPickle
import dragon.core.utils as utils
import dragon.core.mpi as mpi
import dragon.protos.dragon_pb2 as pb
import numpy as np
import os
from dragon import *
import numpy as np
from google.protobuf.message import Message
from six.moves import range as xrange
from dragon.import_c_apis import *
import dragon.core.utils as utils
import dragon.core.mpi as mpi
import dragon.protos.dragon_pb2 as pb
CURRENT_GRAPH_IDX = 0
__all__ = [
......@@ -44,6 +50,7 @@ _DATA_TYPES = {
'float64': np.float64,
}
def _stringify_proto(obj):
"""
Stringify a protobuf structure.
......
......@@ -32,9 +32,9 @@ Vision
=================== ======================================================================
List Brief
=================== ======================================================================
`Conv2D`_ 2D Convolution.
`Deconv2D`_ 2D Deconvolution.
`Pool2D`_ 2D Pooling, MAX or AVG.
`Conv2d`_ 2d Convolution.
`Conv2dTranspose`_ 2d Deconvolution.
`Pool2d`_ 2d Pooling, MAX or AVG.
`ROIPooling`_ ROIPoolin(MAX), introduced by `[Girshick, 2015] <https://arxiv.org/abs/1504.08083>`_.
`ROIAlign`_ ROIAlign(MAX), introduced by `[He et.al, 2017] <https://arxiv.org/abs/1703.06870>`_.
`LRN`_ Local Response Normalization, introduced by `[Krizhevsky et.al, 2012] <http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks>`_.
......@@ -122,8 +122,8 @@ NDArray
=============== ======================================================================
List Brief
=============== ======================================================================
`At`_ 1D At interface of NDArray.
`RandomPick`_ 1D RandomPick interface of NDArray.
`Gather`_ Gather the input according to the indices along the given axis.
`RandomPick`_ Randomly pick the input along the given axis.
`Reduce`_ The general reduce operator.
`Sum`_ Compute the sum along the given axis.
`Mean`_ Compute the mean along the given axis.
......@@ -195,9 +195,9 @@ List Brief
.. _GlorotUniform: operators/initializer.html#dragon.operators.initializer.GlorotUniform
.. _GlorotNormal: operators/initializer.html#dragon.operators.initializer.GlorotNormal
.. _Conv2D: operators/vision.html#dragon.operators.vision.Conv2D
.. _Deconv2D: operators/vision.html#dragon.operators.vision.Deconv2D
.. _Pool2D: operators/vision.html#dragon.operators.vision.Pool2D
.. _Conv2d: operators/vision.html#dragon.operators.vision.Conv2d
.. _Conv2dTranspose: operators/vision.html#dragon.operators.vision.Conv2dTranspose
.. _Pool2d: operators/vision.html#dragon.operators.vision.Pool2d
.. _ROIPooling: operators/vision.html#dragon.operators.vision.ROIPooling
.. _ROIAlign: operators/vision.html#dragon.operators.vision.ROIAlign
.. _LRN: operators/vision.html#dragon.operators.vision.LRN
......@@ -249,7 +249,7 @@ List Brief
.. _InstanceNorm: operators/norm.html#dragon.operators.norm.InstanceNorm
.. _L2Norm: operators/norm.html#dragon.operators.norm.L2Norm
.. _At: operators/ndarray.html#dragon.operators.ndarray.At
.. _Gather: operators/ndarray.html#dragon.operators.ndarray.Gather
.. _RandomPick: operators/ndarray.html#dragon.operators.ndarray.RandomPick
.. _Crop: operators/ndarray.html#dragon.operators.ndarray.Crop
.. _Reduce: operators/ndarray.html#dragon.operators.ndarray.Reduce
......
......@@ -68,6 +68,7 @@ List Brief
`ReshapeLayer`_ The implementation of ``ReshapeLayer``.
`PermuteLayer`_ The implementation of ``PermuteLayer``.
`FlattenLayer`_ The implementation of ``FlattenLayer``.
`GatherLayer`_ The extended implementation for ``GatherOp``.
`SoftmaxLayer`_ The implementation of ``SoftmaxLayer``.
`ArgMaxLayer`_ The implementation of ``ArgMaxLayer``.
`BatchNormLayer`_ The implementation of ``BatchNormLayer``.
......@@ -174,6 +175,7 @@ API Reference
.. _ReshapeLayer: #dragon.vm.caffe.layers.common.ReshapeLayer
.. _PermuteLayer: #dragon.vm.caffe.layers.common.PermuteLayer
.. _FlattenLayer: #dragon.vm.caffe.layers.common.FlattenLayer
.. _GatherLayer: #dragon.vm.caffe.layers.common.GatherLayer
.. _SoftmaxLayer: #dragon.vm.caffe.layers.common.SoftmaxLayer
.. _ArgMaxLayer: #dragon.vm.caffe.layers.common.ArgMaxLayer
.. _BatchNormLayer: #dragon.vm.caffe.layers.common.BatchNormLayer
......
# --------------------------------------------------------
# Dragon
# Copyright(c) 2017 SeetaTech
# Written by Ting Pan
# --------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import logging
try:
from dragon.libdragon import *
except ImportError as e:
logging.critical(
'cannot load dragon. Error: {0}'.format(str(e)))
sys.exit(1)
\ No newline at end of file
......@@ -4,8 +4,13 @@
# Written by Ting Pan
# --------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from . import *
def Relu(inputs, **kwargs):
"""Rectified Linear Unit function, introduces by `[Nair & Hinton, 2010] <http://www.csri.utoronto.ca/~hinton/absps/reluICML.pdf>`_.
......
......@@ -4,8 +4,13 @@
# Written by Ting Pan
# --------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from . import *
def Add(inputs, **kwargs):
"""Calculate A + B.
......
......@@ -4,8 +4,13 @@
# Written by Ting Pan
# --------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from . import *
def FloatToHalf(inputs, **kwargs):
"""Cast the type of tensor from ``float32`` to ``float16``.
......
......@@ -4,8 +4,13 @@
# Written by Ting Pan
# --------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from . import *
def Copy(inputs, **kwargs):
"""Copy A to B.
......
......@@ -4,11 +4,15 @@
# Written by Ting Pan
# --------------------------------------------------------
import numpy as np
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from dragon.operators.misc import Run
from . import *
def LMDBData(**kwargs):
"""Prefetch Image data with `LMDB`_ database.
......
......@@ -4,6 +4,10 @@
# Written by Ting Pan
# --------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from . import *
......
......@@ -4,10 +4,15 @@
# Written by Ting Pan
# --------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from . import *
def SparseSoftmaxCrossEntropy(inputs, axis=1, normalization='VALID', ignore_labels=(), **kwargs):
"""SoftmaxCrossEntropy with sparse labels.
......@@ -48,7 +53,7 @@ def SparseSoftmaxCrossEntropy(inputs, axis=1, normalization='VALID', ignore_labe
return output
def SigmoidCrossEntropy(inputs, normalization='FULL', **kwargs):
def SigmoidCrossEntropy(inputs, normalization='VALID', **kwargs):
"""SigmoidCrossEntropy with binary labels.
Parameters
......@@ -56,7 +61,7 @@ def SigmoidCrossEntropy(inputs, normalization='FULL', **kwargs):
inputs : list of Tensor
The inputs, represent [input, labels].
normalization : str
The normalization, ``UNIT``, ``FULL``, ``BATCH_SIZE`` or ``NONE``.
The normalization, ``UNIT``, ``FULL``, ``VALID``, ``BATCH_SIZE`` or ``NONE``.
Returns
-------
......
......@@ -4,8 +4,13 @@
# Written by Ting Pan
# --------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from . import *
def Run(inputs, module, op, param_str='', nout=1, **kwargs):
"""Run a custom operator. (Without GradientFlow)
......
......@@ -4,11 +4,17 @@
# Written by Ting Pan
# --------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from six.moves import range as xrange
import dragon.core.mpi as mpi
from . import *
def MPIBroadcast(inputs, root, mpi_ranks=None, **kwargs):
"""Broadcast a tensor to all nodes in the ``MPIGroup``.
......
......@@ -4,21 +4,24 @@
# Written by Ting Pan
# --------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from six.moves import range as xrange
from dragon.core.tensor import GetTensorName
import dragon.core.workspace as ws
from . import *
def At(inputs, indices, axis=0, acc_gradient=False, **kwargs):
"""1D At interface of NDArray.
def Gather(inputs, indices, axis=0, acc_gradient=False, **kwargs):
"""Gather the input according to the indices along the given axis.
Parameters
----------
inputs : Tensor
The input tensor.
indices : list or Tensor
indices : int, list or Tensor
The indices to form output tensor.
axis : int
The start axis.
......@@ -31,29 +34,28 @@ def At(inputs, indices, axis=0, acc_gradient=False, **kwargs):
The output tensor.
"""
CheckInputs(inputs, 1)
arguments = ParseArguments(locals())
arguments['inputs'] = [arguments['inputs'],
Tensor.Convert(indices, dtype='int32')]
arguments['indices'] = None
if isinstance(inputs, list): CheckInputs(inputs, 2)
elif isinstance(inputs, Tensor):
if not isinstance(indices, list):
raise ValueError('The type of indices should be list.')
indices = np.array(indices, dtype=np.float32)
tensor = GetTensorName()
ws.FeedTensor(tensor, indices)
arguments['inputs'] = [arguments['inputs'], Tensor(tensor)]
output = Tensor.CreateOperator(op_type='At', nout=1, **arguments)
output = Tensor.CreateOperator(op_type='Gather', nout=1, **arguments)
if isinstance(inputs, Tensor):
if inputs.shape is not None:
output.shape = inputs.shape[:]
if not isinstance(indices, Tensor):
if not isinstance(indices, (list, tuple)):
indices = [indices]
output.shape[axis] = len(indices)
else:
output.shape[axis] = None
return output
def RandomPick(inputs, max_samples=1, axis=0, **kwargs):
"""1D RandomPick interface of NDArray.
"""Randomly pick the input along the given axis.
Parameters
----------
......@@ -541,8 +543,6 @@ def Pad(inputs, paddings, mode='CONSTANT', value=0, **kwargs):
output = Tensor.CreateOperator(nout=1, op_type='Pad', **arguments)
return output
......
......@@ -4,8 +4,13 @@
# Written by Ting Pan
# --------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from . import *
def BatchNorm(inputs, axis=-1, momentum=0.9, eps=1e-3,
use_stats=-1, mode='DEFAULT', **kwargs):
"""Batch Normalization, introduced by `[Ioffe & Szegedy, 2015] <https://arxiv.org/abs/1502.03167>`_.
......
......@@ -4,8 +4,13 @@
# Written by Ting Pan
# --------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from . import *
def LSTMUnit(c_t_1, gate_input, cont_t=None, **kwargs):
"""Simple LSTMCell module.
......@@ -31,13 +36,3 @@ def LSTMUnit(c_t_1, gate_input, cont_t=None, **kwargs):
arguments['cont_t'] = cont_t.name
return Tensor.CreateOperator(inputs=[c_t_1, gate_input], nout=2,
op_type='LSTMUnit', **arguments)
\ No newline at end of file
......@@ -13,6 +13,7 @@ from six.moves import range as xrange
from . import *
def Conv2d(inputs, num_output, kernel_size,
stride=1, pad=0, dilation=1, group=1,
padding='VALID', data_format='NCHW', **kwargs):
......@@ -327,7 +328,7 @@ def ROIPooling(inputs, pool_h, pool_w, spatial_scale, **kwargs):
return Tensor.CreateOperator(nout=1, op_type='ROIPooling', **arguments)
def ROIAlign(inputs, pool_h=0, pool_w=0, spatial_scale=1.0, **kwargs):
def ROIAlign(inputs, pool_h=0, pool_w=0, spatial_scale=1.0, sampling_ratio=2, **kwargs):
"""Max ROIAlign, introduced by `[He et.al, 2017] <https://arxiv.org/abs/1703.06870>`_.
The first dimension of input must be ``1``.
......@@ -342,6 +343,8 @@ def ROIAlign(inputs, pool_h=0, pool_w=0, spatial_scale=1.0, **kwargs):
The width of pooled tensor.
spatial_scale : float
The ``inverse`` of total down-sampling multiples on input tensor.
sampling_ratio : int
The number of sampling grids for each RoI bin.
Returns
-------
......
......@@ -4,6 +4,10 @@
# Written by Ting Pan
# --------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from .operators import initializer as init
from .operators import vision
from .operators import loss
......@@ -92,7 +96,7 @@ InstanceNorm = norm.InstanceNorm
L2Norm = norm.L2Norm
# ndarray
At = ndarray.At
Gather = ndarray.Gather
RandomPick = ndarray.RandomPick
Crop = ndarray.Crop
Reduce = ndarray.Reduce
......
......@@ -4,11 +4,17 @@
# Written by Ting Pan
# --------------------------------------------------------
import numpy as np
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import pprint
import numpy as np
import dragon.core.workspace as ws
from dragon.core.tensor import Tensor
class BaseUpdater(object):
"""
BaseUpdater is designed to preprocess the gradients.
......
......@@ -50,6 +50,7 @@ from .common import InnerProductLayer, \
ArgMaxLayer, \
PermuteLayer, \
FlattenLayer, \
GatherLayer, \
ConcatLayer, \
NormalizeLayer, \
InstanceNormLayer, \
......
......@@ -266,6 +266,25 @@ class FlattenLayer(Layer):
return ops.Flatten(input, **self._param)
class GatherLayer(Layer):
"""The extended implementation of ``GatherOp``.
Parameters
----------
axis : int
The axis for gathering. Refer `GatherParameter.axis`_.
"""
def __init__(self, LayerParameter):
super(GatherLayer, self).__init__(LayerParameter)
param = LayerParameter.gather_param
self._param = {'axis': param.axis}
def Setup(self, bottom):
super(GatherLayer, self).Setup(bottom)
return ops.Gather(bottom[0], indices=bottom[1], **self._param)
class SoftmaxLayer(Layer):
"""The implementation of ``SoftmaxLayer``.
......
......@@ -57,12 +57,12 @@ class SigmoidCrossEntropyLossLayer(Layer):
def __init__(self, LayerParameter):
super(SigmoidCrossEntropyLossLayer, self).__init__(LayerParameter)
param = LayerParameter.loss_param
norm_mode = {0: 'FULL', 1: 'BATCH_SIZE', 2: 'BATCH_SIZE', 3: 'NONE', 4: 'UNIT'}
normalization = 'BATCH_SIZE'
norm_mode = {0: 'FULL', 1: 'VALID', 2: 'BATCH_SIZE', 3: 'NONE', 4: 'UNIT'}
normalization = 'VALID'
if param.HasField('normalize'):
if param.normalize: normalization = 'FULL'
if not param.normalize: normalization = 'BATCH_SIZE'
else: normalization = norm_mode[param.normalization]
self._param = { 'normalization': normalization }
self._param = {'normalization': normalization}
def Setup(self, bottom):
super(SigmoidCrossEntropyLossLayer, self).Setup(bottom)
......
......@@ -422,6 +422,7 @@ message LayerParameter {
optional BatchRenormParameter batch_renorm_param = 161;
optional DenseConcatParameter dense_concat_param = 163;
optional FocalLossParameter focal_loss_param = 164;
optional GatherParameter gather_param = 165;
}
// Message that stores parameters used to apply transformation
......@@ -1504,3 +1505,7 @@ message FocalLossParameter {
optional int32 neg_id = 4 [default = -1];
}
message GatherParameter {
optional int32 axis = 1 [default = 0];
}
......@@ -19,7 +19,7 @@ _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor.FileDescriptor(
name='caffe.proto',
package='caffe',
serialized_pb=_b('\n\x0b\x63\x61\x66\x66\x65.proto\x12\x05\x63\x61\x66\x66\x65\"\x1c\n\tBlobShape\x12\x0f\n\x03\x64im\x18\x01 \x03(\x03\x42\x02\x10\x01\"\xcc\x01\n\tBlobProto\x12\x1f\n\x05shape\x18\x07 \x01(\x0b\x32\x10.caffe.BlobShape\x12\x10\n\x04\x64\x61ta\x18\x05 \x03(\x02\x42\x02\x10\x01\x12\x10\n\x04\x64iff\x18\x06 \x03(\x02\x42\x02\x10\x01\x12\x17\n\x0b\x64ouble_data\x18\x08 \x03(\x01\x42\x02\x10\x01\x12\x17\n\x0b\x64ouble_diff\x18\t \x03(\x01\x42\x02\x10\x01\x12\x0e\n\x03num\x18\x01 \x01(\x05:\x01\x30\x12\x13\n\x08\x63hannels\x18\x02 \x01(\x05:\x01\x30\x12\x11\n\x06height\x18\x03 \x01(\x05:\x01\x30\x12\x10\n\x05width\x18\x04 \x01(\x05:\x01\x30\"2\n\x0f\x42lobProtoVector\x12\x1f\n\x05\x62lobs\x18\x01 \x03(\x0b\x32\x10.caffe.BlobProto\"\x81\x01\n\x05\x44\x61tum\x12\x10\n\x08\x63hannels\x18\x01 \x01(\x05\x12\x0e\n\x06height\x18\x02 \x01(\x05\x12\r\n\x05width\x18\x03 \x01(\x05\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\x0c\x12\r\n\x05label\x18\x05 \x01(\x05\x12\x12\n\nfloat_data\x18\x06 \x03(\x02\x12\x16\n\x07\x65ncoded\x18\x07 \x01(\x08:\x05\x66\x61lse\"\x8a\x02\n\x0f\x46illerParameter\x12\x16\n\x04type\x18\x01 \x01(\t:\x08\x63onstant\x12\x10\n\x05value\x18\x02 \x01(\x02:\x01\x30\x12\x0e\n\x03min\x18\x03 \x01(\x02:\x01\x30\x12\x0e\n\x03max\x18\x04 \x01(\x02:\x01\x31\x12\x0f\n\x04mean\x18\x05 \x01(\x02:\x01\x30\x12\x0e\n\x03std\x18\x06 \x01(\x02:\x01\x31\x12\x12\n\x06sparse\x18\x07 \x01(\x05:\x02-1\x12\x42\n\rvariance_norm\x18\x08 \x01(\x0e\x32#.caffe.FillerParameter.VarianceNorm:\x06\x46\x41N_IN\"4\n\x0cVarianceNorm\x12\n\n\x06\x46\x41N_IN\x10\x00\x12\x0b\n\x07\x46\x41N_OUT\x10\x01\x12\x0b\n\x07\x41VERAGE\x10\x02\"\x8e\x02\n\x0cNetParameter\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\r\n\x05input\x18\x03 \x03(\t\x12%\n\x0binput_shape\x18\x08 \x03(\x0b\x32\x10.caffe.BlobShape\x12\x11\n\tinput_dim\x18\x04 \x03(\x05\x12\x1d\n\x0e\x66orce_backward\x18\x05 \x01(\x08:\x05\x66\x61lse\x12\x1e\n\x05state\x18\x06 \x01(\x0b\x32\x0f.caffe.NetState\x12\x19\n\ndebug_info\x18\x07 \x01(\x08:\x05\x66\x61lse\x12$\n\x05layer\x18\x64 \x03(\x0b\x32\x15.caffe.LayerParameter\x12\'\n\x06layers\x18\x02 \x03(\x0b\x32\x17.caffe.V1LayerParameter\"\xc9\n\n\x0fSolverParameter\x12\x0b\n\x03net\x18\x18 \x01(\t\x12&\n\tnet_param\x18\x19 \x01(\x0b\x32\x13.caffe.NetParameter\x12\x11\n\ttrain_net\x18\x01 \x01(\t\x12\x10\n\x08test_net\x18\x02 \x03(\t\x12,\n\x0ftrain_net_param\x18\x15 \x01(\x0b\x32\x13.caffe.NetParameter\x12+\n\x0etest_net_param\x18\x16 \x03(\x0b\x32\x13.caffe.NetParameter\x12$\n\x0btrain_state\x18\x1a \x01(\x0b\x32\x0f.caffe.NetState\x12#\n\ntest_state\x18\x1b \x03(\x0b\x32\x0f.caffe.NetState\x12\x11\n\ttest_iter\x18\x03 \x03(\x05\x12\x18\n\rtest_interval\x18\x04 \x01(\x05:\x01\x30\x12 \n\x11test_compute_loss\x18\x13 \x01(\x08:\x05\x66\x61lse\x12!\n\x13test_initialization\x18 \x01(\x08:\x04true\x12\x0f\n\x07\x62\x61se_lr\x18\x05 \x01(\x02\x12\x10\n\x08stage_lr\x18\x32 \x03(\x02\x12\x12\n\nstage_iter\x18\x33 \x03(\x05\x12\x0f\n\x07\x64isplay\x18\x06 \x01(\x05\x12\x17\n\x0c\x61verage_loss\x18! \x01(\x05:\x01\x31\x12\x10\n\x08max_iter\x18\x07 \x01(\x05\x12\x14\n\titer_size\x18$ \x01(\x05:\x01\x31\x12\x11\n\tlr_policy\x18\x08 \x01(\t\x12\r\n\x05gamma\x18\t \x01(\x02\x12\r\n\x05power\x18\n \x01(\x02\x12\x10\n\x08momentum\x18\x0b \x01(\x02\x12\x14\n\x0cweight_decay\x18\x0c \x01(\x02\x12\x1f\n\x13regularization_type\x18\x1d \x01(\t:\x02L2\x12\x10\n\x08stepsize\x18\r \x01(\x05\x12\x11\n\tstepvalue\x18\" \x03(\x05\x12\x1a\n\x0e\x63lip_gradients\x18# \x01(\x02:\x02-1\x12\x13\n\x08snapshot\x18\x0e \x01(\x05:\x01\x30\x12\x17\n\x0fsnapshot_prefix\x18\x0f \x01(\t\x12\x1c\n\rsnapshot_diff\x18\x10 \x01(\x08:\x05\x66\x61lse\x12K\n\x0fsnapshot_format\x18% \x01(\x0e\x32%.caffe.SolverParameter.SnapshotFormat:\x0b\x42INARYPROTO\x12;\n\x0bsolver_mode\x18\x11 \x01(\x0e\x32!.caffe.SolverParameter.SolverMode:\x03GPU\x12\x14\n\tdevice_id\x18\x12 \x01(\x05:\x01\x30\x12\x17\n\x0brandom_seed\x18\x14 \x01(\x03:\x02-1\x12\x11\n\x04type\x18( \x01(\t:\x03SGD\x12\x15\n\x05\x64\x65lta\x18\x1f \x01(\x02:\x06\x31\x65-008\x12\x18\n\tmomentum2\x18\' \x01(\x02:\x05\x30.999\x12\x17\n\trms_decay\x18& \x01(\x02:\x04\x30.99\x12\x19\n\ndebug_info\x18\x17 \x01(\x08:\x05\x66\x61lse\x12\"\n\x14snapshot_after_train\x18\x1c \x01(\x08:\x04true\x12;\n\x0bsolver_type\x18\x1e \x01(\x0e\x32!.caffe.SolverParameter.SolverType:\x03SGD\"+\n\x0eSnapshotFormat\x12\x08\n\x04HDF5\x10\x00\x12\x0f\n\x0b\x42INARYPROTO\x10\x01\"\x1e\n\nSolverMode\x12\x07\n\x03\x43PU\x10\x00\x12\x07\n\x03GPU\x10\x01\"U\n\nSolverType\x12\x07\n\x03SGD\x10\x00\x12\x0c\n\x08NESTEROV\x10\x01\x12\x0b\n\x07\x41\x44\x41GRAD\x10\x02\x12\x0b\n\x07RMSPROP\x10\x03\x12\x0c\n\x08\x41\x44\x41\x44\x45LTA\x10\x04\x12\x08\n\x04\x41\x44\x41M\x10\x05\"l\n\x0bSolverState\x12\x0c\n\x04iter\x18\x01 \x01(\x05\x12\x13\n\x0blearned_net\x18\x02 \x01(\t\x12!\n\x07history\x18\x03 \x03(\x0b\x32\x10.caffe.BlobProto\x12\x17\n\x0c\x63urrent_step\x18\x04 \x01(\x05:\x01\x30\"N\n\x08NetState\x12!\n\x05phase\x18\x01 \x01(\x0e\x32\x0c.caffe.Phase:\x04TEST\x12\x10\n\x05level\x18\x02 \x01(\x05:\x01\x30\x12\r\n\x05stage\x18\x03 \x03(\t\"\x85\x01\n\x0cNetStateRule\x12\x1b\n\x05phase\x18\x01 \x01(\x0e\x32\x0c.caffe.Phase\x12\x11\n\tmin_level\x18\x02 \x01(\x05\x12\x11\n\tmax_level\x18\x03 \x01(\x05\x12\r\n\x05stage\x18\x04 \x03(\t\x12\x11\n\tnot_stage\x18\x05 \x03(\t\x12\x10\n\x08mpi_rank\x18\x06 \x03(\r\"\xa3\x01\n\tParamSpec\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x31\n\nshare_mode\x18\x02 \x01(\x0e\x32\x1d.caffe.ParamSpec.DimCheckMode\x12\x12\n\x07lr_mult\x18\x03 \x01(\x02:\x01\x31\x12\x15\n\ndecay_mult\x18\x04 \x01(\x02:\x01\x31\"*\n\x0c\x44imCheckMode\x12\n\n\x06STRICT\x10\x00\x12\x0e\n\nPERMISSIVE\x10\x01\"\xe6\x18\n\x0eLayerParameter\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\t\x12\x0e\n\x06\x62ottom\x18\x03 \x03(\t\x12\x0b\n\x03top\x18\x04 \x03(\t\x12\x1c\n\x0cmirror_stage\x18\xa2\x01 \x01(\x08:\x05\x66\x61lse\x12\x1b\n\x05phase\x18\n \x01(\x0e\x32\x0c.caffe.Phase\x12\x13\n\x0bloss_weight\x18\x05 \x03(\x02\x12\x1f\n\x05param\x18\x06 \x03(\x0b\x32\x10.caffe.ParamSpec\x12\x1f\n\x05\x62lobs\x18\x07 \x03(\x0b\x32\x10.caffe.BlobProto\x12\x16\n\x0epropagate_down\x18\x0b \x03(\x08\x12$\n\x07include\x18\x08 \x03(\x0b\x32\x13.caffe.NetStateRule\x12$\n\x07\x65xclude\x18\t \x03(\x0b\x32\x13.caffe.NetStateRule\x12\x37\n\x0ftransform_param\x18\x64 \x01(\x0b\x32\x1e.caffe.TransformationParameter\x12(\n\nloss_param\x18\x65 \x01(\x0b\x32\x14.caffe.LossParameter\x12\x30\n\x0e\x61\x63\x63uracy_param\x18\x66 \x01(\x0b\x32\x18.caffe.AccuracyParameter\x12,\n\x0c\x61rgmax_param\x18g \x01(\x0b\x32\x16.caffe.ArgMaxParameter\x12\x34\n\x10\x62\x61tch_norm_param\x18\x8b\x01 \x01(\x0b\x32\x19.caffe.BatchNormParameter\x12)\n\nbias_param\x18\x8d\x01 \x01(\x0b\x32\x14.caffe.BiasParameter\x12,\n\x0c\x63oncat_param\x18h \x01(\x0b\x32\x16.caffe.ConcatParameter\x12?\n\x16\x63ontrastive_loss_param\x18i \x01(\x0b\x32\x1f.caffe.ContrastiveLossParameter\x12\x36\n\x11\x63onvolution_param\x18j \x01(\x0b\x32\x1b.caffe.ConvolutionParameter\x12)\n\ncrop_param\x18\x90\x01 \x01(\x0b\x32\x14.caffe.CropParameter\x12(\n\ndata_param\x18k \x01(\x0b\x32\x14.caffe.DataParameter\x12.\n\rdropout_param\x18l \x01(\x0b\x32\x17.caffe.DropoutParameter\x12\x33\n\x10\x64ummy_data_param\x18m \x01(\x0b\x32\x19.caffe.DummyDataParameter\x12.\n\reltwise_param\x18n \x01(\x0b\x32\x17.caffe.EltwiseParameter\x12\'\n\telu_param\x18\x8c\x01 \x01(\x0b\x32\x13.caffe.ELUParameter\x12+\n\x0b\x65mbed_param\x18\x89\x01 \x01(\x0b\x32\x15.caffe.EmbedParameter\x12&\n\texp_param\x18o \x01(\x0b\x32\x13.caffe.ExpParameter\x12/\n\rflatten_param\x18\x87\x01 \x01(\x0b\x32\x17.caffe.FlattenParameter\x12\x31\n\x0fhdf5_data_param\x18p \x01(\x0b\x32\x18.caffe.HDF5DataParameter\x12\x35\n\x11hdf5_output_param\x18q \x01(\x0b\x32\x1a.caffe.HDF5OutputParameter\x12\x33\n\x10hinge_loss_param\x18r \x01(\x0b\x32\x19.caffe.HingeLossParameter\x12\x33\n\x10image_data_param\x18s \x01(\x0b\x32\x19.caffe.ImageDataParameter\x12\x39\n\x13infogain_loss_param\x18t \x01(\x0b\x32\x1c.caffe.InfogainLossParameter\x12\x39\n\x13inner_product_param\x18u \x01(\x0b\x32\x1c.caffe.InnerProductParameter\x12+\n\x0binput_param\x18\x8f\x01 \x01(\x0b\x32\x15.caffe.InputParameter\x12\'\n\tlog_param\x18\x86\x01 \x01(\x0b\x32\x13.caffe.LogParameter\x12&\n\tlrn_param\x18v \x01(\x0b\x32\x13.caffe.LRNParameter\x12\x35\n\x11memory_data_param\x18w \x01(\x0b\x32\x1a.caffe.MemoryDataParameter\x12&\n\tmvn_param\x18x \x01(\x0b\x32\x13.caffe.MVNParameter\x12\x33\n\x0fparameter_param\x18\x91\x01 \x01(\x0b\x32\x19.caffe.ParameterParameter\x12.\n\rpooling_param\x18y \x01(\x0b\x32\x17.caffe.PoolingParameter\x12*\n\x0bpower_param\x18z \x01(\x0b\x32\x15.caffe.PowerParameter\x12+\n\x0bprelu_param\x18\x83\x01 \x01(\x0b\x32\x15.caffe.PReLUParameter\x12-\n\x0cpython_param\x18\x82\x01 \x01(\x0b\x32\x16.caffe.PythonParameter\x12\x33\n\x0freduction_param\x18\x88\x01 \x01(\x0b\x32\x19.caffe.ReductionParameter\x12(\n\nrelu_param\x18{ \x01(\x0b\x32\x14.caffe.ReLUParameter\x12/\n\rreshape_param\x18\x85\x01 \x01(\x0b\x32\x17.caffe.ReshapeParameter\x12+\n\x0bscale_param\x18\x8e\x01 \x01(\x0b\x32\x15.caffe.ScaleParameter\x12.\n\rsigmoid_param\x18| \x01(\x0b\x32\x17.caffe.SigmoidParameter\x12.\n\rsoftmax_param\x18} \x01(\x0b\x32\x17.caffe.SoftmaxParameter\x12\'\n\tspp_param\x18\x84\x01 \x01(\x0b\x32\x13.caffe.SPPParameter\x12*\n\x0bslice_param\x18~ \x01(\x0b\x32\x15.caffe.SliceParameter\x12(\n\ntanh_param\x18\x7f \x01(\x0b\x32\x14.caffe.TanHParameter\x12\x33\n\x0fthreshold_param\x18\x80\x01 \x01(\x0b\x32\x19.caffe.ThresholdParameter\x12)\n\ntile_param\x18\x8a\x01 \x01(\x0b\x32\x14.caffe.TileParameter\x12\x36\n\x11window_data_param\x18\x81\x01 \x01(\x0b\x32\x1a.caffe.WindowDataParameter\x12\x36\n\x11roi_pooling_param\x18\x97\x01 \x01(\x0b\x32\x1a.caffe.ROIPoolingParameter\x12;\n\x14smooth_l1_loss_param\x18\x98\x01 \x01(\x0b\x32\x1c.caffe.SmoothL1LossParameter\x12\'\n\tmpi_param\x18\x99\x01 \x01(\x0b\x32\x13.caffe.MPIParameter\x12/\n\rpermute_param\x18\x9a\x01 \x01(\x0b\x32\x17.caffe.PermuteParameter\x12\x33\n\x0fnormalize_param\x18\x9b\x01 \x01(\x0b\x32\x19.caffe.NormalizeParameter\x12\x31\n\x0eparallel_param\x18\x9d\x01 \x01(\x0b\x32\x18.caffe.ParallelParameter\x12-\n\x0cresize_param\x18\x9e\x01 \x01(\x0b\x32\x16.caffe.ResizeParameter\x12\x36\n\x11\x65xpand_dims_param\x18\x9f\x01 \x01(\x0b\x32\x1a.caffe.ExpandDimsParameter\x12\x31\n\x0eproposal_param\x18\xa0\x01 \x01(\x0b\x32\x18.caffe.ProposalParameter\x12\x38\n\x12\x62\x61tch_renorm_param\x18\xa1\x01 \x01(\x0b\x32\x1b.caffe.BatchRenormParameter\x12\x38\n\x12\x64\x65nse_concat_param\x18\xa3\x01 \x01(\x0b\x32\x1b.caffe.DenseConcatParameter\x12\x34\n\x10\x66ocal_loss_param\x18\xa4\x01 \x01(\x0b\x32\x19.caffe.FocalLossParameter\"\xa7\x02\n\x17TransformationParameter\x12\x10\n\x05scale\x18\x01 \x01(\x02:\x01\x31\x12\x15\n\x06mirror\x18\x02 \x01(\x08:\x05\x66\x61lse\x12\x14\n\tcrop_size\x18\x03 \x01(\r:\x01\x30\x12\x12\n\x07padding\x18\x0b \x01(\r:\x01\x30\x12\x11\n\tmean_file\x18\x04 \x01(\t\x12\x12\n\nmean_value\x18\x05 \x03(\x02\x12\x1a\n\x0b\x66orce_color\x18\x06 \x01(\x08:\x05\x66\x61lse\x12\x19\n\nforce_gray\x18\x07 \x01(\x08:\x05\x66\x61lse\x12!\n\x12\x63olor_augmentation\x18\x08 \x01(\x08:\x05\x66\x61lse\x12\x1b\n\x10min_random_scale\x18\t \x01(\x02:\x01\x31\x12\x1b\n\x10max_random_scale\x18\n \x01(\x02:\x01\x31\"\xf5\x01\n\rLossParameter\x12\x14\n\x0cignore_label\x18\x01 \x01(\x05\x12\x44\n\rnormalization\x18\x03 \x01(\x0e\x32&.caffe.LossParameter.NormalizationMode:\x05VALID\x12\x11\n\tnormalize\x18\x02 \x01(\x08\x1a\'\n\x13\x45xpandDimsParameter\x12\x10\n\x04\x61xis\x18\x01 \x01(\x05:\x02-1\"L\n\x11NormalizationMode\x12\x08\n\x04\x46ULL\x10\x00\x12\t\n\x05VALID\x10\x01\x12\x0e\n\nBATCH_SIZE\x10\x02\x12\x08\n\x04NONE\x10\x03\x12\x08\n\x04UNIT\x10\x04\"L\n\x11\x41\x63\x63uracyParameter\x12\x10\n\x05top_k\x18\x01 \x01(\r:\x01\x31\x12\x0f\n\x04\x61xis\x18\x02 \x01(\x05:\x01\x31\x12\x14\n\x0cignore_label\x18\x03 \x01(\x05\"M\n\x0f\x41rgMaxParameter\x12\x1a\n\x0bout_max_val\x18\x01 \x01(\x08:\x05\x66\x61lse\x12\x10\n\x05top_k\x18\x02 \x01(\r:\x01\x31\x12\x0c\n\x04\x61xis\x18\x03 \x01(\x05\"9\n\x0f\x43oncatParameter\x12\x0f\n\x04\x61xis\x18\x02 \x01(\x05:\x01\x31\x12\x15\n\nconcat_dim\x18\x01 \x01(\r:\x01\x31\"h\n\x12\x42\x61tchNormParameter\x12\x18\n\x10use_global_stats\x18\x01 \x01(\x08\x12$\n\x17moving_average_fraction\x18\x02 \x01(\x02:\x03\x30.9\x12\x12\n\x03\x65ps\x18\x03 \x01(\x02:\x05\x30.001\"]\n\rBiasParameter\x12\x0f\n\x04\x61xis\x18\x01 \x01(\x05:\x01\x31\x12\x13\n\x08num_axes\x18\x02 \x01(\x05:\x01\x31\x12&\n\x06\x66iller\x18\x03 \x01(\x0b\x32\x16.caffe.FillerParameter\"L\n\x18\x43ontrastiveLossParameter\x12\x11\n\x06margin\x18\x01 \x01(\x02:\x01\x31\x12\x1d\n\x0elegacy_version\x18\x02 \x01(\x08:\x05\x66\x61lse\"\xfc\x03\n\x14\x43onvolutionParameter\x12\x12\n\nnum_output\x18\x01 \x01(\r\x12\x17\n\tbias_term\x18\x02 \x01(\x08:\x04true\x12\x0b\n\x03pad\x18\x03 \x03(\r\x12\x13\n\x0bkernel_size\x18\x04 \x03(\r\x12\x0e\n\x06stride\x18\x06 \x03(\r\x12\x10\n\x08\x64ilation\x18\x12 \x03(\r\x12\x10\n\x05pad_h\x18\t \x01(\r:\x01\x30\x12\x10\n\x05pad_w\x18\n \x01(\r:\x01\x30\x12\x10\n\x08kernel_h\x18\x0b \x01(\r\x12\x10\n\x08kernel_w\x18\x0c \x01(\r\x12\x10\n\x08stride_h\x18\r \x01(\r\x12\x10\n\x08stride_w\x18\x0e \x01(\r\x12\x10\n\x05group\x18\x05 \x01(\r:\x01\x31\x12-\n\rweight_filler\x18\x07 \x01(\x0b\x32\x16.caffe.FillerParameter\x12+\n\x0b\x62ias_filler\x18\x08 \x01(\x0b\x32\x16.caffe.FillerParameter\x12;\n\x06\x65ngine\x18\x0f \x01(\x0e\x32\".caffe.ConvolutionParameter.Engine:\x07\x44\x45\x46\x41ULT\x12\x0f\n\x04\x61xis\x18\x10 \x01(\x05:\x01\x31\x12\x1e\n\x0f\x66orce_nd_im2col\x18\x11 \x01(\x08:\x05\x66\x61lse\"+\n\x06\x45ngine\x12\x0b\n\x07\x44\x45\x46\x41ULT\x10\x00\x12\t\n\x05\x43\x41\x46\x46\x45\x10\x01\x12\t\n\x05\x43UDNN\x10\x02\"0\n\rCropParameter\x12\x0f\n\x04\x61xis\x18\x01 \x01(\x05:\x01\x32\x12\x0e\n\x06offset\x18\x02 \x03(\r\"\xa4\x02\n\rDataParameter\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x12\n\nbatch_size\x18\x04 \x01(\r\x12\x14\n\trand_skip\x18\x07 \x01(\r:\x01\x30\x12\x31\n\x07\x62\x61\x63kend\x18\x08 \x01(\x0e\x32\x17.caffe.DataParameter.DB:\x07LEVELDB\x12\x10\n\x05scale\x18\x02 \x01(\x02:\x01\x31\x12\x11\n\tmean_file\x18\x03 \x01(\t\x12\x14\n\tcrop_size\x18\x05 \x01(\r:\x01\x30\x12\x15\n\x06mirror\x18\x06 \x01(\x08:\x05\x66\x61lse\x12\"\n\x13\x66orce_encoded_color\x18\t \x01(\x08:\x05\x66\x61lse\x12\x13\n\x08prefetch\x18\n \x01(\r:\x01\x35\"\x1b\n\x02\x44\x42\x12\x0b\n\x07LEVELDB\x10\x00\x12\x08\n\x04LMDB\x10\x01\"I\n\x10\x44ropoutParameter\x12\x1a\n\rdropout_ratio\x18\x01 \x01(\x02:\x03\x30.5\x12\x19\n\x0bscale_train\x18\x02 \x01(\x08:\x04true\"\xa0\x01\n\x12\x44ummyDataParameter\x12+\n\x0b\x64\x61ta_filler\x18\x01 \x03(\x0b\x32\x16.caffe.FillerParameter\x12\x1f\n\x05shape\x18\x06 \x03(\x0b\x32\x10.caffe.BlobShape\x12\x0b\n\x03num\x18\x02 \x03(\r\x12\x10\n\x08\x63hannels\x18\x03 \x03(\r\x12\x0e\n\x06height\x18\x04 \x03(\r\x12\r\n\x05width\x18\x05 \x03(\r\"\xa5\x01\n\x10\x45ltwiseParameter\x12\x39\n\toperation\x18\x01 \x01(\x0e\x32!.caffe.EltwiseParameter.EltwiseOp:\x03SUM\x12\r\n\x05\x63oeff\x18\x02 \x03(\x02\x12\x1e\n\x10stable_prod_grad\x18\x03 \x01(\x08:\x04true\"\'\n\tEltwiseOp\x12\x08\n\x04PROD\x10\x00\x12\x07\n\x03SUM\x10\x01\x12\x07\n\x03MAX\x10\x02\" \n\x0c\x45LUParameter\x12\x10\n\x05\x61lpha\x18\x01 \x01(\x02:\x01\x31\"\xac\x01\n\x0e\x45mbedParameter\x12\x12\n\nnum_output\x18\x01 \x01(\r\x12\x11\n\tinput_dim\x18\x02 \x01(\r\x12\x17\n\tbias_term\x18\x03 \x01(\x08:\x04true\x12-\n\rweight_filler\x18\x04 \x01(\x0b\x32\x16.caffe.FillerParameter\x12+\n\x0b\x62ias_filler\x18\x05 \x01(\x0b\x32\x16.caffe.FillerParameter\"D\n\x0c\x45xpParameter\x12\x10\n\x04\x62\x61se\x18\x01 \x01(\x02:\x02-1\x12\x10\n\x05scale\x18\x02 \x01(\x02:\x01\x31\x12\x10\n\x05shift\x18\x03 \x01(\x02:\x01\x30\"9\n\x10\x46lattenParameter\x12\x0f\n\x04\x61xis\x18\x01 \x01(\x05:\x01\x31\x12\x14\n\x08\x65nd_axis\x18\x02 \x01(\x05:\x02-1\"O\n\x11HDF5DataParameter\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x12\n\nbatch_size\x18\x02 \x01(\r\x12\x16\n\x07shuffle\x18\x03 \x01(\x08:\x05\x66\x61lse\"(\n\x13HDF5OutputParameter\x12\x11\n\tfile_name\x18\x01 \x01(\t\"^\n\x12HingeLossParameter\x12\x30\n\x04norm\x18\x01 \x01(\x0e\x32\x1e.caffe.HingeLossParameter.Norm:\x02L1\"\x16\n\x04Norm\x12\x06\n\x02L1\x10\x01\x12\x06\n\x02L2\x10\x02\"\x97\x02\n\x12ImageDataParameter\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x15\n\nbatch_size\x18\x04 \x01(\r:\x01\x31\x12\x14\n\trand_skip\x18\x07 \x01(\r:\x01\x30\x12\x16\n\x07shuffle\x18\x08 \x01(\x08:\x05\x66\x61lse\x12\x15\n\nnew_height\x18\t \x01(\r:\x01\x30\x12\x14\n\tnew_width\x18\n \x01(\r:\x01\x30\x12\x16\n\x08is_color\x18\x0b \x01(\x08:\x04true\x12\x10\n\x05scale\x18\x02 \x01(\x02:\x01\x31\x12\x11\n\tmean_file\x18\x03 \x01(\t\x12\x14\n\tcrop_size\x18\x05 \x01(\r:\x01\x30\x12\x15\n\x06mirror\x18\x06 \x01(\x08:\x05\x66\x61lse\x12\x15\n\x0broot_folder\x18\x0c \x01(\t:\x00\"\'\n\x15InfogainLossParameter\x12\x0e\n\x06source\x18\x01 \x01(\t\"\xcb\x01\n\x15InnerProductParameter\x12\x12\n\nnum_output\x18\x01 \x01(\r\x12\x17\n\tbias_term\x18\x02 \x01(\x08:\x04true\x12-\n\rweight_filler\x18\x03 \x01(\x0b\x32\x16.caffe.FillerParameter\x12+\n\x0b\x62ias_filler\x18\x04 \x01(\x0b\x32\x16.caffe.FillerParameter\x12\x0f\n\x04\x61xis\x18\x05 \x01(\x05:\x01\x31\x12\x18\n\ttranspose\x18\x06 \x01(\x08:\x05\x66\x61lse\"1\n\x0eInputParameter\x12\x1f\n\x05shape\x18\x01 \x03(\x0b\x32\x10.caffe.BlobShape\"D\n\x0cLogParameter\x12\x10\n\x04\x62\x61se\x18\x01 \x01(\x02:\x02-1\x12\x10\n\x05scale\x18\x02 \x01(\x02:\x01\x31\x12\x10\n\x05shift\x18\x03 \x01(\x02:\x01\x30\"\xb8\x02\n\x0cLRNParameter\x12\x15\n\nlocal_size\x18\x01 \x01(\r:\x01\x35\x12\x10\n\x05\x61lpha\x18\x02 \x01(\x02:\x01\x31\x12\x12\n\x04\x62\x65ta\x18\x03 \x01(\x02:\x04\x30.75\x12\x44\n\x0bnorm_region\x18\x04 \x01(\x0e\x32\x1e.caffe.LRNParameter.NormRegion:\x0f\x41\x43ROSS_CHANNELS\x12\x0c\n\x01k\x18\x05 \x01(\x02:\x01\x31\x12\x33\n\x06\x65ngine\x18\x06 \x01(\x0e\x32\x1a.caffe.LRNParameter.Engine:\x07\x44\x45\x46\x41ULT\"5\n\nNormRegion\x12\x13\n\x0f\x41\x43ROSS_CHANNELS\x10\x00\x12\x12\n\x0eWITHIN_CHANNEL\x10\x01\"+\n\x06\x45ngine\x12\x0b\n\x07\x44\x45\x46\x41ULT\x10\x00\x12\t\n\x05\x43\x41\x46\x46\x45\x10\x01\x12\t\n\x05\x43UDNN\x10\x02\"\xbd\x01\n\x13MemoryDataParameter\x12\x12\n\nbatch_size\x18\x01 \x01(\r\x12\x10\n\x08\x63hannels\x18\x02 \x01(\r\x12\x0e\n\x06height\x18\x03 \x01(\r\x12\r\n\x05width\x18\x04 \x01(\r\x12;\n\x05\x64type\x18\x05 \x01(\x0e\x32#.caffe.MemoryDataParameter.DataType:\x07\x46LOAT32\"$\n\x08\x44\x61taType\x12\x0b\n\x07\x46LOAT32\x10\x00\x12\x0b\n\x07\x46LOAT16\x10\x01\"e\n\x0cMVNParameter\x12 \n\x12normalize_variance\x18\x01 \x01(\x08:\x04true\x12\x1e\n\x0f\x61\x63ross_channels\x18\x02 \x01(\x08:\x05\x66\x61lse\x12\x13\n\x03\x65ps\x18\x03 \x01(\x02:\x06\x31\x65-009\"5\n\x12ParameterParameter\x12\x1f\n\x05shape\x18\x01 \x01(\x0b\x32\x10.caffe.BlobShape\"\xa2\x03\n\x10PoolingParameter\x12\x35\n\x04pool\x18\x01 \x01(\x0e\x32\".caffe.PoolingParameter.PoolMethod:\x03MAX\x12\x0e\n\x03pad\x18\x04 \x01(\r:\x01\x30\x12\x10\n\x05pad_h\x18\t \x01(\r:\x01\x30\x12\x10\n\x05pad_w\x18\n \x01(\r:\x01\x30\x12\x13\n\x0bkernel_size\x18\x02 \x01(\r\x12\x10\n\x08kernel_h\x18\x05 \x01(\r\x12\x10\n\x08kernel_w\x18\x06 \x01(\r\x12\x11\n\x06stride\x18\x03 \x01(\r:\x01\x31\x12\x10\n\x08stride_h\x18\x07 \x01(\r\x12\x10\n\x08stride_w\x18\x08 \x01(\r\x12\x37\n\x06\x65ngine\x18\x0b \x01(\x0e\x32\x1e.caffe.PoolingParameter.Engine:\x07\x44\x45\x46\x41ULT\x12\x1d\n\x0eglobal_pooling\x18\x0c \x01(\x08:\x05\x66\x61lse\".\n\nPoolMethod\x12\x07\n\x03MAX\x10\x00\x12\x07\n\x03\x41VE\x10\x01\x12\x0e\n\nSTOCHASTIC\x10\x02\"+\n\x06\x45ngine\x12\x0b\n\x07\x44\x45\x46\x41ULT\x10\x00\x12\t\n\x05\x43\x41\x46\x46\x45\x10\x01\x12\t\n\x05\x43UDNN\x10\x02\"Y\n\x13ROIPoolingParameter\x12\x13\n\x08pooled_h\x18\x01 \x01(\r:\x01\x30\x12\x13\n\x08pooled_w\x18\x02 \x01(\r:\x01\x30\x12\x18\n\rspatial_scale\x18\x03 \x01(\x02:\x01\x31\"F\n\x0ePowerParameter\x12\x10\n\x05power\x18\x01 \x01(\x02:\x01\x31\x12\x10\n\x05scale\x18\x02 \x01(\x02:\x01\x31\x12\x10\n\x05shift\x18\x03 \x01(\x02:\x01\x30\"g\n\x0fPythonParameter\x12\x0e\n\x06module\x18\x01 \x01(\t\x12\r\n\x05layer\x18\x02 \x01(\t\x12\x13\n\tparam_str\x18\x03 \x01(\t:\x00\x12 \n\x11share_in_parallel\x18\x04 \x01(\x08:\x05\x66\x61lse\"\xad\x01\n\x12ReductionParameter\x12=\n\toperation\x18\x01 \x01(\x0e\x32%.caffe.ReductionParameter.ReductionOp:\x03SUM\x12\x0f\n\x04\x61xis\x18\x02 \x01(\x05:\x01\x30\x12\x10\n\x05\x63oeff\x18\x03 \x01(\x02:\x01\x31\"5\n\x0bReductionOp\x12\x07\n\x03SUM\x10\x01\x12\x08\n\x04\x41SUM\x10\x02\x12\t\n\x05SUMSQ\x10\x03\x12\x08\n\x04MEAN\x10\x04\"\x8d\x01\n\rReLUParameter\x12\x19\n\x0enegative_slope\x18\x01 \x01(\x02:\x01\x30\x12\x34\n\x06\x65ngine\x18\x02 \x01(\x0e\x32\x1b.caffe.ReLUParameter.Engine:\x07\x44\x45\x46\x41ULT\"+\n\x06\x45ngine\x12\x0b\n\x07\x44\x45\x46\x41ULT\x10\x00\x12\t\n\x05\x43\x41\x46\x46\x45\x10\x01\x12\t\n\x05\x43UDNN\x10\x02\"Z\n\x10ReshapeParameter\x12\x1f\n\x05shape\x18\x01 \x01(\x0b\x32\x10.caffe.BlobShape\x12\x0f\n\x04\x61xis\x18\x02 \x01(\x05:\x01\x30\x12\x14\n\x08num_axes\x18\x03 \x01(\x05:\x02-1\"\xa5\x01\n\x0eScaleParameter\x12\x0f\n\x04\x61xis\x18\x01 \x01(\x05:\x01\x31\x12\x13\n\x08num_axes\x18\x02 \x01(\x05:\x01\x31\x12&\n\x06\x66iller\x18\x03 \x01(\x0b\x32\x16.caffe.FillerParameter\x12\x18\n\tbias_term\x18\x04 \x01(\x08:\x05\x66\x61lse\x12+\n\x0b\x62ias_filler\x18\x05 \x01(\x0b\x32\x16.caffe.FillerParameter\"x\n\x10SigmoidParameter\x12\x37\n\x06\x65ngine\x18\x01 \x01(\x0e\x32\x1e.caffe.SigmoidParameter.Engine:\x07\x44\x45\x46\x41ULT\"+\n\x06\x45ngine\x12\x0b\n\x07\x44\x45\x46\x41ULT\x10\x00\x12\t\n\x05\x43\x41\x46\x46\x45\x10\x01\x12\t\n\x05\x43UDNN\x10\x02\"L\n\x0eSliceParameter\x12\x0f\n\x04\x61xis\x18\x03 \x01(\x05:\x01\x31\x12\x13\n\x0bslice_point\x18\x02 \x03(\r\x12\x14\n\tslice_dim\x18\x01 \x01(\r:\x01\x31\"\x89\x01\n\x10SoftmaxParameter\x12\x37\n\x06\x65ngine\x18\x01 \x01(\x0e\x32\x1e.caffe.SoftmaxParameter.Engine:\x07\x44\x45\x46\x41ULT\x12\x0f\n\x04\x61xis\x18\x02 \x01(\x05:\x01\x31\"+\n\x06\x45ngine\x12\x0b\n\x07\x44\x45\x46\x41ULT\x10\x00\x12\t\n\x05\x43\x41\x46\x46\x45\x10\x01\x12\t\n\x05\x43UDNN\x10\x02\"r\n\rTanHParameter\x12\x34\n\x06\x65ngine\x18\x01 \x01(\x0e\x32\x1b.caffe.TanHParameter.Engine:\x07\x44\x45\x46\x41ULT\"+\n\x06\x45ngine\x12\x0b\n\x07\x44\x45\x46\x41ULT\x10\x00\x12\t\n\x05\x43\x41\x46\x46\x45\x10\x01\x12\t\n\x05\x43UDNN\x10\x02\"T\n\rTileParameter\x12\x0f\n\x04\x61xis\x18\x01 \x01(\x05:\x01\x31\x12\r\n\x05tiles\x18\x02 \x01(\x05\x12#\n\tmultiples\x18\x03 \x01(\x0b\x32\x10.caffe.BlobShape\"*\n\x12ThresholdParameter\x12\x14\n\tthreshold\x18\x01 \x01(\x02:\x01\x30\"\xc1\x02\n\x13WindowDataParameter\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x10\n\x05scale\x18\x02 \x01(\x02:\x01\x31\x12\x11\n\tmean_file\x18\x03 \x01(\t\x12\x12\n\nbatch_size\x18\x04 \x01(\r\x12\x14\n\tcrop_size\x18\x05 \x01(\r:\x01\x30\x12\x15\n\x06mirror\x18\x06 \x01(\x08:\x05\x66\x61lse\x12\x19\n\x0c\x66g_threshold\x18\x07 \x01(\x02:\x03\x30.5\x12\x19\n\x0c\x62g_threshold\x18\x08 \x01(\x02:\x03\x30.5\x12\x19\n\x0b\x66g_fraction\x18\t \x01(\x02:\x04\x30.25\x12\x16\n\x0b\x63ontext_pad\x18\n \x01(\r:\x01\x30\x12\x17\n\tcrop_mode\x18\x0b \x01(\t:\x04warp\x12\x1b\n\x0c\x63\x61\x63he_images\x18\x0c \x01(\x08:\x05\x66\x61lse\x12\x15\n\x0broot_folder\x18\r \x01(\t:\x00\"\xeb\x01\n\x0cSPPParameter\x12\x16\n\x0epyramid_height\x18\x01 \x01(\r\x12\x31\n\x04pool\x18\x02 \x01(\x0e\x32\x1e.caffe.SPPParameter.PoolMethod:\x03MAX\x12\x33\n\x06\x65ngine\x18\x06 \x01(\x0e\x32\x1a.caffe.SPPParameter.Engine:\x07\x44\x45\x46\x41ULT\".\n\nPoolMethod\x12\x07\n\x03MAX\x10\x00\x12\x07\n\x03\x41VE\x10\x01\x12\x0e\n\nSTOCHASTIC\x10\x02\"+\n\x06\x45ngine\x12\x0b\n\x07\x44\x45\x46\x41ULT\x10\x00\x12\t\n\x05\x43\x41\x46\x46\x45\x10\x01\x12\t\n\x05\x43UDNN\x10\x02\"\xe0\x13\n\x10V1LayerParameter\x12\x0e\n\x06\x62ottom\x18\x02 \x03(\t\x12\x0b\n\x03top\x18\x03 \x03(\t\x12\x0c\n\x04name\x18\x04 \x01(\t\x12$\n\x07include\x18 \x03(\x0b\x32\x13.caffe.NetStateRule\x12$\n\x07\x65xclude\x18! \x03(\x0b\x32\x13.caffe.NetStateRule\x12/\n\x04type\x18\x05 \x01(\x0e\x32!.caffe.V1LayerParameter.LayerType\x12\x1f\n\x05\x62lobs\x18\x06 \x03(\x0b\x32\x10.caffe.BlobProto\x12\x0e\n\x05param\x18\xe9\x07 \x03(\t\x12>\n\x0f\x62lob_share_mode\x18\xea\x07 \x03(\x0e\x32$.caffe.V1LayerParameter.DimCheckMode\x12\x10\n\x08\x62lobs_lr\x18\x07 \x03(\x02\x12\x14\n\x0cweight_decay\x18\x08 \x03(\x02\x12\x13\n\x0bloss_weight\x18# \x03(\x02\x12\x30\n\x0e\x61\x63\x63uracy_param\x18\x1b \x01(\x0b\x32\x18.caffe.AccuracyParameter\x12,\n\x0c\x61rgmax_param\x18\x17 \x01(\x0b\x32\x16.caffe.ArgMaxParameter\x12,\n\x0c\x63oncat_param\x18\t \x01(\x0b\x32\x16.caffe.ConcatParameter\x12?\n\x16\x63ontrastive_loss_param\x18( \x01(\x0b\x32\x1f.caffe.ContrastiveLossParameter\x12\x36\n\x11\x63onvolution_param\x18\n \x01(\x0b\x32\x1b.caffe.ConvolutionParameter\x12(\n\ndata_param\x18\x0b \x01(\x0b\x32\x14.caffe.DataParameter\x12.\n\rdropout_param\x18\x0c \x01(\x0b\x32\x17.caffe.DropoutParameter\x12\x33\n\x10\x64ummy_data_param\x18\x1a \x01(\x0b\x32\x19.caffe.DummyDataParameter\x12.\n\reltwise_param\x18\x18 \x01(\x0b\x32\x17.caffe.EltwiseParameter\x12&\n\texp_param\x18) \x01(\x0b\x32\x13.caffe.ExpParameter\x12\x31\n\x0fhdf5_data_param\x18\r \x01(\x0b\x32\x18.caffe.HDF5DataParameter\x12\x35\n\x11hdf5_output_param\x18\x0e \x01(\x0b\x32\x1a.caffe.HDF5OutputParameter\x12\x33\n\x10hinge_loss_param\x18\x1d \x01(\x0b\x32\x19.caffe.HingeLossParameter\x12\x33\n\x10image_data_param\x18\x0f \x01(\x0b\x32\x19.caffe.ImageDataParameter\x12\x39\n\x13infogain_loss_param\x18\x10 \x01(\x0b\x32\x1c.caffe.InfogainLossParameter\x12\x39\n\x13inner_product_param\x18\x11 \x01(\x0b\x32\x1c.caffe.InnerProductParameter\x12&\n\tlrn_param\x18\x12 \x01(\x0b\x32\x13.caffe.LRNParameter\x12\x35\n\x11memory_data_param\x18\x16 \x01(\x0b\x32\x1a.caffe.MemoryDataParameter\x12&\n\tmvn_param\x18\" \x01(\x0b\x32\x13.caffe.MVNParameter\x12.\n\rpooling_param\x18\x13 \x01(\x0b\x32\x17.caffe.PoolingParameter\x12*\n\x0bpower_param\x18\x15 \x01(\x0b\x32\x15.caffe.PowerParameter\x12(\n\nrelu_param\x18\x1e \x01(\x0b\x32\x14.caffe.ReLUParameter\x12.\n\rsigmoid_param\x18& \x01(\x0b\x32\x17.caffe.SigmoidParameter\x12.\n\rsoftmax_param\x18\' \x01(\x0b\x32\x17.caffe.SoftmaxParameter\x12*\n\x0bslice_param\x18\x1f \x01(\x0b\x32\x15.caffe.SliceParameter\x12(\n\ntanh_param\x18% \x01(\x0b\x32\x14.caffe.TanHParameter\x12\x32\n\x0fthreshold_param\x18\x19 \x01(\x0b\x32\x19.caffe.ThresholdParameter\x12\x35\n\x11window_data_param\x18\x14 \x01(\x0b\x32\x1a.caffe.WindowDataParameter\x12\x37\n\x0ftransform_param\x18$ \x01(\x0b\x32\x1e.caffe.TransformationParameter\x12(\n\nloss_param\x18* \x01(\x0b\x32\x14.caffe.LossParameter\x12&\n\x05layer\x18\x01 \x01(\x0b\x32\x17.caffe.V0LayerParameter\"\xd8\x04\n\tLayerType\x12\x08\n\x04NONE\x10\x00\x12\n\n\x06\x41\x42SVAL\x10#\x12\x0c\n\x08\x41\x43\x43URACY\x10\x01\x12\n\n\x06\x41RGMAX\x10\x1e\x12\x08\n\x04\x42NLL\x10\x02\x12\n\n\x06\x43ONCAT\x10\x03\x12\x14\n\x10\x43ONTRASTIVE_LOSS\x10%\x12\x0f\n\x0b\x43ONVOLUTION\x10\x04\x12\x08\n\x04\x44\x41TA\x10\x05\x12\x11\n\rDECONVOLUTION\x10\'\x12\x0b\n\x07\x44ROPOUT\x10\x06\x12\x0e\n\nDUMMY_DATA\x10 \x12\x12\n\x0e\x45UCLIDEAN_LOSS\x10\x07\x12\x0b\n\x07\x45LTWISE\x10\x19\x12\x07\n\x03\x45XP\x10&\x12\x0b\n\x07\x46LATTEN\x10\x08\x12\r\n\tHDF5_DATA\x10\t\x12\x0f\n\x0bHDF5_OUTPUT\x10\n\x12\x0e\n\nHINGE_LOSS\x10\x1c\x12\n\n\x06IM2COL\x10\x0b\x12\x0e\n\nIMAGE_DATA\x10\x0c\x12\x11\n\rINFOGAIN_LOSS\x10\r\x12\x11\n\rINNER_PRODUCT\x10\x0e\x12\x07\n\x03LRN\x10\x0f\x12\x0f\n\x0bMEMORY_DATA\x10\x1d\x12\x1d\n\x19MULTINOMIAL_LOGISTIC_LOSS\x10\x10\x12\x07\n\x03MVN\x10\"\x12\x0b\n\x07POOLING\x10\x11\x12\t\n\x05POWER\x10\x1a\x12\x08\n\x04RELU\x10\x12\x12\x0b\n\x07SIGMOID\x10\x13\x12\x1e\n\x1aSIGMOID_CROSS_ENTROPY_LOSS\x10\x1b\x12\x0b\n\x07SILENCE\x10$\x12\x0b\n\x07SOFTMAX\x10\x14\x12\x10\n\x0cSOFTMAX_LOSS\x10\x15\x12\t\n\x05SPLIT\x10\x16\x12\t\n\x05SLICE\x10!\x12\x08\n\x04TANH\x10\x17\x12\x0f\n\x0bWINDOW_DATA\x10\x18\x12\r\n\tTHRESHOLD\x10\x1f\"*\n\x0c\x44imCheckMode\x12\n\n\x06STRICT\x10\x00\x12\x0e\n\nPERMISSIVE\x10\x01\"\xfd\x07\n\x10V0LayerParameter\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\t\x12\x12\n\nnum_output\x18\x03 \x01(\r\x12\x16\n\x08\x62iasterm\x18\x04 \x01(\x08:\x04true\x12-\n\rweight_filler\x18\x05 \x01(\x0b\x32\x16.caffe.FillerParameter\x12+\n\x0b\x62ias_filler\x18\x06 \x01(\x0b\x32\x16.caffe.FillerParameter\x12\x0e\n\x03pad\x18\x07 \x01(\r:\x01\x30\x12\x12\n\nkernelsize\x18\x08 \x01(\r\x12\x10\n\x05group\x18\t \x01(\r:\x01\x31\x12\x11\n\x06stride\x18\n \x01(\r:\x01\x31\x12\x35\n\x04pool\x18\x0b \x01(\x0e\x32\".caffe.V0LayerParameter.PoolMethod:\x03MAX\x12\x1a\n\rdropout_ratio\x18\x0c \x01(\x02:\x03\x30.5\x12\x15\n\nlocal_size\x18\r \x01(\r:\x01\x35\x12\x10\n\x05\x61lpha\x18\x0e \x01(\x02:\x01\x31\x12\x12\n\x04\x62\x65ta\x18\x0f \x01(\x02:\x04\x30.75\x12\x0c\n\x01k\x18\x16 \x01(\x02:\x01\x31\x12\x0e\n\x06source\x18\x10 \x01(\t\x12\x10\n\x05scale\x18\x11 \x01(\x02:\x01\x31\x12\x10\n\x08meanfile\x18\x12 \x01(\t\x12\x11\n\tbatchsize\x18\x13 \x01(\r\x12\x13\n\x08\x63ropsize\x18\x14 \x01(\r:\x01\x30\x12\x15\n\x06mirror\x18\x15 \x01(\x08:\x05\x66\x61lse\x12\x1f\n\x05\x62lobs\x18\x32 \x03(\x0b\x32\x10.caffe.BlobProto\x12\x10\n\x08\x62lobs_lr\x18\x33 \x03(\x02\x12\x14\n\x0cweight_decay\x18\x34 \x03(\x02\x12\x14\n\trand_skip\x18\x35 \x01(\r:\x01\x30\x12\x1d\n\x10\x64\x65t_fg_threshold\x18\x36 \x01(\x02:\x03\x30.5\x12\x1d\n\x10\x64\x65t_bg_threshold\x18\x37 \x01(\x02:\x03\x30.5\x12\x1d\n\x0f\x64\x65t_fg_fraction\x18\x38 \x01(\x02:\x04\x30.25\x12\x1a\n\x0f\x64\x65t_context_pad\x18: \x01(\r:\x01\x30\x12\x1b\n\rdet_crop_mode\x18; \x01(\t:\x04warp\x12\x12\n\x07new_num\x18< \x01(\x05:\x01\x30\x12\x17\n\x0cnew_channels\x18= \x01(\x05:\x01\x30\x12\x15\n\nnew_height\x18> \x01(\x05:\x01\x30\x12\x14\n\tnew_width\x18? \x01(\x05:\x01\x30\x12\x1d\n\x0eshuffle_images\x18@ \x01(\x08:\x05\x66\x61lse\x12\x15\n\nconcat_dim\x18\x41 \x01(\r:\x01\x31\x12\x36\n\x11hdf5_output_param\x18\xe9\x07 \x01(\x0b\x32\x1a.caffe.HDF5OutputParameter\".\n\nPoolMethod\x12\x07\n\x03MAX\x10\x00\x12\x07\n\x03\x41VE\x10\x01\x12\x0e\n\nSTOCHASTIC\x10\x02\"W\n\x0ePReLUParameter\x12&\n\x06\x66iller\x18\x01 \x01(\x0b\x32\x16.caffe.FillerParameter\x12\x1d\n\x0e\x63hannel_shared\x18\x02 \x01(\x08:\x05\x66\x61lse\")\n\x15SmoothL1LossParameter\x12\x10\n\x05sigma\x18\x01 \x01(\x02:\x01\x31\"H\n\x0cMPIParameter\x12\x0f\n\x04root\x18\x01 \x01(\r:\x01\x30\x12\x12\n\x07\x63omm_id\x18\x02 \x01(\x04:\x01\x30\x12\x13\n\x08group_id\x18\x03 \x01(\x04:\x01\x30\"!\n\x10PermuteParameter\x12\r\n\x05order\x18\x01 \x03(\r\"\x93\x01\n\x12NormalizeParameter\x12\x1c\n\x0e\x61\x63ross_spatial\x18\x01 \x01(\x08:\x04true\x12,\n\x0cscale_filler\x18\x02 \x01(\x0b\x32\x16.caffe.FillerParameter\x12\x1c\n\x0e\x63hannel_shared\x18\x03 \x01(\x08:\x04true\x12\x13\n\x03\x65ps\x18\x04 \x01(\x02:\x06\x31\x65-010\"_\n\x11ParallelParameter\x12\x16\n\x07shuffle\x18\x01 \x01(\x08:\x05\x66\x61lse\x12\x18\n\tnode_step\x18\x02 \x01(\x08:\x05\x66\x61lse\x12\x18\n\tpartition\x18\x03 \x01(\x08:\x05\x66\x61lse\"R\n\x0fResizeParameter\x12\x1f\n\x05shape\x18\x01 \x01(\x0b\x32\x10.caffe.BlobShape\x12\x0e\n\x02\x66x\x18\x02 \x01(\x02:\x02-1\x12\x0e\n\x02\x66y\x18\x03 \x01(\x02:\x02-1\"\'\n\x13\x45xpandDimsParameter\x12\x10\n\x04\x61xis\x18\x01 \x01(\x05:\x02-1\"\xc8\x01\n\x11ProposalParameter\x12\x17\n\x0b\x66\x65\x61t_stride\x18\x01 \x01(\r:\x02\x31\x36\x12\x15\n\tbase_size\x18\x02 \x01(\r:\x02\x31\x36\x12\x14\n\x08min_size\x18\x03 \x01(\r:\x02\x31\x36\x12\r\n\x05ratio\x18\x04 \x03(\x02\x12\r\n\x05scale\x18\x05 \x03(\x02\x12\x1a\n\x0cpre_nms_topn\x18\x06 \x01(\r:\x04\x36\x30\x30\x30\x12\x1a\n\rpost_nms_topn\x18\x07 \x01(\r:\x03\x33\x30\x30\x12\x17\n\nnms_thresh\x18\x08 \x01(\x02:\x03\x30.7\"\xa6\x01\n\x14\x42\x61tchRenormParameter\x12\x18\n\x10use_global_stats\x18\x01 \x01(\x08\x12$\n\x17moving_average_fraction\x18\x02 \x01(\x02:\x03\x30.9\x12\x12\n\x03\x65ps\x18\x03 \x01(\x02:\x05\x30.001\x12\x10\n\x05r_max\x18\x04 \x01(\x02:\x01\x33\x12\x10\n\x05\x64_max\x18\x05 \x01(\x02:\x01\x35\x12\x16\n\x07t_delta\x18\x06 \x01(\x02:\x05\x30.001\"?\n\x14\x44\x65nseConcatParameter\x12\x0f\n\x04\x61xis\x18\x01 \x01(\x05:\x01\x31\x12\x16\n\x0bgrowth_rate\x18\x02 \x01(\x05:\x01\x30\"c\n\x12\x46ocalLossParameter\x12\x12\n\x05\x61lpha\x18\x01 \x01(\x02:\x03\x30.5\x12\x10\n\x05gamma\x18\x02 \x01(\x02:\x01\x30\x12\x13\n\x03\x65ps\x18\x03 \x01(\x02:\x06\x31\x65-010\x12\x12\n\x06neg_id\x18\x04 \x01(\x05:\x02-1*\x1c\n\x05Phase\x12\t\n\x05TRAIN\x10\x00\x12\x08\n\x04TEST\x10\x01')
serialized_pb=_b('\n\x0b\x63\x61\x66\x66\x65.proto\x12\x05\x63\x61\x66\x66\x65\"\x1c\n\tBlobShape\x12\x0f\n\x03\x64im\x18\x01 \x03(\x03\x42\x02\x10\x01\"\xcc\x01\n\tBlobProto\x12\x1f\n\x05shape\x18\x07 \x01(\x0b\x32\x10.caffe.BlobShape\x12\x10\n\x04\x64\x61ta\x18\x05 \x03(\x02\x42\x02\x10\x01\x12\x10\n\x04\x64iff\x18\x06 \x03(\x02\x42\x02\x10\x01\x12\x17\n\x0b\x64ouble_data\x18\x08 \x03(\x01\x42\x02\x10\x01\x12\x17\n\x0b\x64ouble_diff\x18\t \x03(\x01\x42\x02\x10\x01\x12\x0e\n\x03num\x18\x01 \x01(\x05:\x01\x30\x12\x13\n\x08\x63hannels\x18\x02 \x01(\x05:\x01\x30\x12\x11\n\x06height\x18\x03 \x01(\x05:\x01\x30\x12\x10\n\x05width\x18\x04 \x01(\x05:\x01\x30\"2\n\x0f\x42lobProtoVector\x12\x1f\n\x05\x62lobs\x18\x01 \x03(\x0b\x32\x10.caffe.BlobProto\"\x81\x01\n\x05\x44\x61tum\x12\x10\n\x08\x63hannels\x18\x01 \x01(\x05\x12\x0e\n\x06height\x18\x02 \x01(\x05\x12\r\n\x05width\x18\x03 \x01(\x05\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\x0c\x12\r\n\x05label\x18\x05 \x01(\x05\x12\x12\n\nfloat_data\x18\x06 \x03(\x02\x12\x16\n\x07\x65ncoded\x18\x07 \x01(\x08:\x05\x66\x61lse\"\x8a\x02\n\x0f\x46illerParameter\x12\x16\n\x04type\x18\x01 \x01(\t:\x08\x63onstant\x12\x10\n\x05value\x18\x02 \x01(\x02:\x01\x30\x12\x0e\n\x03min\x18\x03 \x01(\x02:\x01\x30\x12\x0e\n\x03max\x18\x04 \x01(\x02:\x01\x31\x12\x0f\n\x04mean\x18\x05 \x01(\x02:\x01\x30\x12\x0e\n\x03std\x18\x06 \x01(\x02:\x01\x31\x12\x12\n\x06sparse\x18\x07 \x01(\x05:\x02-1\x12\x42\n\rvariance_norm\x18\x08 \x01(\x0e\x32#.caffe.FillerParameter.VarianceNorm:\x06\x46\x41N_IN\"4\n\x0cVarianceNorm\x12\n\n\x06\x46\x41N_IN\x10\x00\x12\x0b\n\x07\x46\x41N_OUT\x10\x01\x12\x0b\n\x07\x41VERAGE\x10\x02\"\x8e\x02\n\x0cNetParameter\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\r\n\x05input\x18\x03 \x03(\t\x12%\n\x0binput_shape\x18\x08 \x03(\x0b\x32\x10.caffe.BlobShape\x12\x11\n\tinput_dim\x18\x04 \x03(\x05\x12\x1d\n\x0e\x66orce_backward\x18\x05 \x01(\x08:\x05\x66\x61lse\x12\x1e\n\x05state\x18\x06 \x01(\x0b\x32\x0f.caffe.NetState\x12\x19\n\ndebug_info\x18\x07 \x01(\x08:\x05\x66\x61lse\x12$\n\x05layer\x18\x64 \x03(\x0b\x32\x15.caffe.LayerParameter\x12\'\n\x06layers\x18\x02 \x03(\x0b\x32\x17.caffe.V1LayerParameter\"\xc9\n\n\x0fSolverParameter\x12\x0b\n\x03net\x18\x18 \x01(\t\x12&\n\tnet_param\x18\x19 \x01(\x0b\x32\x13.caffe.NetParameter\x12\x11\n\ttrain_net\x18\x01 \x01(\t\x12\x10\n\x08test_net\x18\x02 \x03(\t\x12,\n\x0ftrain_net_param\x18\x15 \x01(\x0b\x32\x13.caffe.NetParameter\x12+\n\x0etest_net_param\x18\x16 \x03(\x0b\x32\x13.caffe.NetParameter\x12$\n\x0btrain_state\x18\x1a \x01(\x0b\x32\x0f.caffe.NetState\x12#\n\ntest_state\x18\x1b \x03(\x0b\x32\x0f.caffe.NetState\x12\x11\n\ttest_iter\x18\x03 \x03(\x05\x12\x18\n\rtest_interval\x18\x04 \x01(\x05:\x01\x30\x12 \n\x11test_compute_loss\x18\x13 \x01(\x08:\x05\x66\x61lse\x12!\n\x13test_initialization\x18 \x01(\x08:\x04true\x12\x0f\n\x07\x62\x61se_lr\x18\x05 \x01(\x02\x12\x10\n\x08stage_lr\x18\x32 \x03(\x02\x12\x12\n\nstage_iter\x18\x33 \x03(\x05\x12\x0f\n\x07\x64isplay\x18\x06 \x01(\x05\x12\x17\n\x0c\x61verage_loss\x18! \x01(\x05:\x01\x31\x12\x10\n\x08max_iter\x18\x07 \x01(\x05\x12\x14\n\titer_size\x18$ \x01(\x05:\x01\x31\x12\x11\n\tlr_policy\x18\x08 \x01(\t\x12\r\n\x05gamma\x18\t \x01(\x02\x12\r\n\x05power\x18\n \x01(\x02\x12\x10\n\x08momentum\x18\x0b \x01(\x02\x12\x14\n\x0cweight_decay\x18\x0c \x01(\x02\x12\x1f\n\x13regularization_type\x18\x1d \x01(\t:\x02L2\x12\x10\n\x08stepsize\x18\r \x01(\x05\x12\x11\n\tstepvalue\x18\" \x03(\x05\x12\x1a\n\x0e\x63lip_gradients\x18# \x01(\x02:\x02-1\x12\x13\n\x08snapshot\x18\x0e \x01(\x05:\x01\x30\x12\x17\n\x0fsnapshot_prefix\x18\x0f \x01(\t\x12\x1c\n\rsnapshot_diff\x18\x10 \x01(\x08:\x05\x66\x61lse\x12K\n\x0fsnapshot_format\x18% \x01(\x0e\x32%.caffe.SolverParameter.SnapshotFormat:\x0b\x42INARYPROTO\x12;\n\x0bsolver_mode\x18\x11 \x01(\x0e\x32!.caffe.SolverParameter.SolverMode:\x03GPU\x12\x14\n\tdevice_id\x18\x12 \x01(\x05:\x01\x30\x12\x17\n\x0brandom_seed\x18\x14 \x01(\x03:\x02-1\x12\x11\n\x04type\x18( \x01(\t:\x03SGD\x12\x15\n\x05\x64\x65lta\x18\x1f \x01(\x02:\x06\x31\x65-008\x12\x18\n\tmomentum2\x18\' \x01(\x02:\x05\x30.999\x12\x17\n\trms_decay\x18& \x01(\x02:\x04\x30.99\x12\x19\n\ndebug_info\x18\x17 \x01(\x08:\x05\x66\x61lse\x12\"\n\x14snapshot_after_train\x18\x1c \x01(\x08:\x04true\x12;\n\x0bsolver_type\x18\x1e \x01(\x0e\x32!.caffe.SolverParameter.SolverType:\x03SGD\"+\n\x0eSnapshotFormat\x12\x08\n\x04HDF5\x10\x00\x12\x0f\n\x0b\x42INARYPROTO\x10\x01\"\x1e\n\nSolverMode\x12\x07\n\x03\x43PU\x10\x00\x12\x07\n\x03GPU\x10\x01\"U\n\nSolverType\x12\x07\n\x03SGD\x10\x00\x12\x0c\n\x08NESTEROV\x10\x01\x12\x0b\n\x07\x41\x44\x41GRAD\x10\x02\x12\x0b\n\x07RMSPROP\x10\x03\x12\x0c\n\x08\x41\x44\x41\x44\x45LTA\x10\x04\x12\x08\n\x04\x41\x44\x41M\x10\x05\"l\n\x0bSolverState\x12\x0c\n\x04iter\x18\x01 \x01(\x05\x12\x13\n\x0blearned_net\x18\x02 \x01(\t\x12!\n\x07history\x18\x03 \x03(\x0b\x32\x10.caffe.BlobProto\x12\x17\n\x0c\x63urrent_step\x18\x04 \x01(\x05:\x01\x30\"N\n\x08NetState\x12!\n\x05phase\x18\x01 \x01(\x0e\x32\x0c.caffe.Phase:\x04TEST\x12\x10\n\x05level\x18\x02 \x01(\x05:\x01\x30\x12\r\n\x05stage\x18\x03 \x03(\t\"\x85\x01\n\x0cNetStateRule\x12\x1b\n\x05phase\x18\x01 \x01(\x0e\x32\x0c.caffe.Phase\x12\x11\n\tmin_level\x18\x02 \x01(\x05\x12\x11\n\tmax_level\x18\x03 \x01(\x05\x12\r\n\x05stage\x18\x04 \x03(\t\x12\x11\n\tnot_stage\x18\x05 \x03(\t\x12\x10\n\x08mpi_rank\x18\x06 \x03(\r\"\xa3\x01\n\tParamSpec\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x31\n\nshare_mode\x18\x02 \x01(\x0e\x32\x1d.caffe.ParamSpec.DimCheckMode\x12\x12\n\x07lr_mult\x18\x03 \x01(\x02:\x01\x31\x12\x15\n\ndecay_mult\x18\x04 \x01(\x02:\x01\x31\"*\n\x0c\x44imCheckMode\x12\n\n\x06STRICT\x10\x00\x12\x0e\n\nPERMISSIVE\x10\x01\"\x95\x19\n\x0eLayerParameter\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\t\x12\x0e\n\x06\x62ottom\x18\x03 \x03(\t\x12\x0b\n\x03top\x18\x04 \x03(\t\x12\x1c\n\x0cmirror_stage\x18\xa2\x01 \x01(\x08:\x05\x66\x61lse\x12\x1b\n\x05phase\x18\n \x01(\x0e\x32\x0c.caffe.Phase\x12\x13\n\x0bloss_weight\x18\x05 \x03(\x02\x12\x1f\n\x05param\x18\x06 \x03(\x0b\x32\x10.caffe.ParamSpec\x12\x1f\n\x05\x62lobs\x18\x07 \x03(\x0b\x32\x10.caffe.BlobProto\x12\x16\n\x0epropagate_down\x18\x0b \x03(\x08\x12$\n\x07include\x18\x08 \x03(\x0b\x32\x13.caffe.NetStateRule\x12$\n\x07\x65xclude\x18\t \x03(\x0b\x32\x13.caffe.NetStateRule\x12\x37\n\x0ftransform_param\x18\x64 \x01(\x0b\x32\x1e.caffe.TransformationParameter\x12(\n\nloss_param\x18\x65 \x01(\x0b\x32\x14.caffe.LossParameter\x12\x30\n\x0e\x61\x63\x63uracy_param\x18\x66 \x01(\x0b\x32\x18.caffe.AccuracyParameter\x12,\n\x0c\x61rgmax_param\x18g \x01(\x0b\x32\x16.caffe.ArgMaxParameter\x12\x34\n\x10\x62\x61tch_norm_param\x18\x8b\x01 \x01(\x0b\x32\x19.caffe.BatchNormParameter\x12)\n\nbias_param\x18\x8d\x01 \x01(\x0b\x32\x14.caffe.BiasParameter\x12,\n\x0c\x63oncat_param\x18h \x01(\x0b\x32\x16.caffe.ConcatParameter\x12?\n\x16\x63ontrastive_loss_param\x18i \x01(\x0b\x32\x1f.caffe.ContrastiveLossParameter\x12\x36\n\x11\x63onvolution_param\x18j \x01(\x0b\x32\x1b.caffe.ConvolutionParameter\x12)\n\ncrop_param\x18\x90\x01 \x01(\x0b\x32\x14.caffe.CropParameter\x12(\n\ndata_param\x18k \x01(\x0b\x32\x14.caffe.DataParameter\x12.\n\rdropout_param\x18l \x01(\x0b\x32\x17.caffe.DropoutParameter\x12\x33\n\x10\x64ummy_data_param\x18m \x01(\x0b\x32\x19.caffe.DummyDataParameter\x12.\n\reltwise_param\x18n \x01(\x0b\x32\x17.caffe.EltwiseParameter\x12\'\n\telu_param\x18\x8c\x01 \x01(\x0b\x32\x13.caffe.ELUParameter\x12+\n\x0b\x65mbed_param\x18\x89\x01 \x01(\x0b\x32\x15.caffe.EmbedParameter\x12&\n\texp_param\x18o \x01(\x0b\x32\x13.caffe.ExpParameter\x12/\n\rflatten_param\x18\x87\x01 \x01(\x0b\x32\x17.caffe.FlattenParameter\x12\x31\n\x0fhdf5_data_param\x18p \x01(\x0b\x32\x18.caffe.HDF5DataParameter\x12\x35\n\x11hdf5_output_param\x18q \x01(\x0b\x32\x1a.caffe.HDF5OutputParameter\x12\x33\n\x10hinge_loss_param\x18r \x01(\x0b\x32\x19.caffe.HingeLossParameter\x12\x33\n\x10image_data_param\x18s \x01(\x0b\x32\x19.caffe.ImageDataParameter\x12\x39\n\x13infogain_loss_param\x18t \x01(\x0b\x32\x1c.caffe.InfogainLossParameter\x12\x39\n\x13inner_product_param\x18u \x01(\x0b\x32\x1c.caffe.InnerProductParameter\x12+\n\x0binput_param\x18\x8f\x01 \x01(\x0b\x32\x15.caffe.InputParameter\x12\'\n\tlog_param\x18\x86\x01 \x01(\x0b\x32\x13.caffe.LogParameter\x12&\n\tlrn_param\x18v \x01(\x0b\x32\x13.caffe.LRNParameter\x12\x35\n\x11memory_data_param\x18w \x01(\x0b\x32\x1a.caffe.MemoryDataParameter\x12&\n\tmvn_param\x18x \x01(\x0b\x32\x13.caffe.MVNParameter\x12\x33\n\x0fparameter_param\x18\x91\x01 \x01(\x0b\x32\x19.caffe.ParameterParameter\x12.\n\rpooling_param\x18y \x01(\x0b\x32\x17.caffe.PoolingParameter\x12*\n\x0bpower_param\x18z \x01(\x0b\x32\x15.caffe.PowerParameter\x12+\n\x0bprelu_param\x18\x83\x01 \x01(\x0b\x32\x15.caffe.PReLUParameter\x12-\n\x0cpython_param\x18\x82\x01 \x01(\x0b\x32\x16.caffe.PythonParameter\x12\x33\n\x0freduction_param\x18\x88\x01 \x01(\x0b\x32\x19.caffe.ReductionParameter\x12(\n\nrelu_param\x18{ \x01(\x0b\x32\x14.caffe.ReLUParameter\x12/\n\rreshape_param\x18\x85\x01 \x01(\x0b\x32\x17.caffe.ReshapeParameter\x12+\n\x0bscale_param\x18\x8e\x01 \x01(\x0b\x32\x15.caffe.ScaleParameter\x12.\n\rsigmoid_param\x18| \x01(\x0b\x32\x17.caffe.SigmoidParameter\x12.\n\rsoftmax_param\x18} \x01(\x0b\x32\x17.caffe.SoftmaxParameter\x12\'\n\tspp_param\x18\x84\x01 \x01(\x0b\x32\x13.caffe.SPPParameter\x12*\n\x0bslice_param\x18~ \x01(\x0b\x32\x15.caffe.SliceParameter\x12(\n\ntanh_param\x18\x7f \x01(\x0b\x32\x14.caffe.TanHParameter\x12\x33\n\x0fthreshold_param\x18\x80\x01 \x01(\x0b\x32\x19.caffe.ThresholdParameter\x12)\n\ntile_param\x18\x8a\x01 \x01(\x0b\x32\x14.caffe.TileParameter\x12\x36\n\x11window_data_param\x18\x81\x01 \x01(\x0b\x32\x1a.caffe.WindowDataParameter\x12\x36\n\x11roi_pooling_param\x18\x97\x01 \x01(\x0b\x32\x1a.caffe.ROIPoolingParameter\x12;\n\x14smooth_l1_loss_param\x18\x98\x01 \x01(\x0b\x32\x1c.caffe.SmoothL1LossParameter\x12\'\n\tmpi_param\x18\x99\x01 \x01(\x0b\x32\x13.caffe.MPIParameter\x12/\n\rpermute_param\x18\x9a\x01 \x01(\x0b\x32\x17.caffe.PermuteParameter\x12\x33\n\x0fnormalize_param\x18\x9b\x01 \x01(\x0b\x32\x19.caffe.NormalizeParameter\x12\x31\n\x0eparallel_param\x18\x9d\x01 \x01(\x0b\x32\x18.caffe.ParallelParameter\x12-\n\x0cresize_param\x18\x9e\x01 \x01(\x0b\x32\x16.caffe.ResizeParameter\x12\x36\n\x11\x65xpand_dims_param\x18\x9f\x01 \x01(\x0b\x32\x1a.caffe.ExpandDimsParameter\x12\x31\n\x0eproposal_param\x18\xa0\x01 \x01(\x0b\x32\x18.caffe.ProposalParameter\x12\x38\n\x12\x62\x61tch_renorm_param\x18\xa1\x01 \x01(\x0b\x32\x1b.caffe.BatchRenormParameter\x12\x38\n\x12\x64\x65nse_concat_param\x18\xa3\x01 \x01(\x0b\x32\x1b.caffe.DenseConcatParameter\x12\x34\n\x10\x66ocal_loss_param\x18\xa4\x01 \x01(\x0b\x32\x19.caffe.FocalLossParameter\x12-\n\x0cgather_param\x18\xa5\x01 \x01(\x0b\x32\x16.caffe.GatherParameter\"\xa7\x02\n\x17TransformationParameter\x12\x10\n\x05scale\x18\x01 \x01(\x02:\x01\x31\x12\x15\n\x06mirror\x18\x02 \x01(\x08:\x05\x66\x61lse\x12\x14\n\tcrop_size\x18\x03 \x01(\r:\x01\x30\x12\x12\n\x07padding\x18\x0b \x01(\r:\x01\x30\x12\x11\n\tmean_file\x18\x04 \x01(\t\x12\x12\n\nmean_value\x18\x05 \x03(\x02\x12\x1a\n\x0b\x66orce_color\x18\x06 \x01(\x08:\x05\x66\x61lse\x12\x19\n\nforce_gray\x18\x07 \x01(\x08:\x05\x66\x61lse\x12!\n\x12\x63olor_augmentation\x18\x08 \x01(\x08:\x05\x66\x61lse\x12\x1b\n\x10min_random_scale\x18\t \x01(\x02:\x01\x31\x12\x1b\n\x10max_random_scale\x18\n \x01(\x02:\x01\x31\"\xf5\x01\n\rLossParameter\x12\x14\n\x0cignore_label\x18\x01 \x01(\x05\x12\x44\n\rnormalization\x18\x03 \x01(\x0e\x32&.caffe.LossParameter.NormalizationMode:\x05VALID\x12\x11\n\tnormalize\x18\x02 \x01(\x08\x1a\'\n\x13\x45xpandDimsParameter\x12\x10\n\x04\x61xis\x18\x01 \x01(\x05:\x02-1\"L\n\x11NormalizationMode\x12\x08\n\x04\x46ULL\x10\x00\x12\t\n\x05VALID\x10\x01\x12\x0e\n\nBATCH_SIZE\x10\x02\x12\x08\n\x04NONE\x10\x03\x12\x08\n\x04UNIT\x10\x04\"L\n\x11\x41\x63\x63uracyParameter\x12\x10\n\x05top_k\x18\x01 \x01(\r:\x01\x31\x12\x0f\n\x04\x61xis\x18\x02 \x01(\x05:\x01\x31\x12\x14\n\x0cignore_label\x18\x03 \x01(\x05\"M\n\x0f\x41rgMaxParameter\x12\x1a\n\x0bout_max_val\x18\x01 \x01(\x08:\x05\x66\x61lse\x12\x10\n\x05top_k\x18\x02 \x01(\r:\x01\x31\x12\x0c\n\x04\x61xis\x18\x03 \x01(\x05\"9\n\x0f\x43oncatParameter\x12\x0f\n\x04\x61xis\x18\x02 \x01(\x05:\x01\x31\x12\x15\n\nconcat_dim\x18\x01 \x01(\r:\x01\x31\"h\n\x12\x42\x61tchNormParameter\x12\x18\n\x10use_global_stats\x18\x01 \x01(\x08\x12$\n\x17moving_average_fraction\x18\x02 \x01(\x02:\x03\x30.9\x12\x12\n\x03\x65ps\x18\x03 \x01(\x02:\x05\x30.001\"]\n\rBiasParameter\x12\x0f\n\x04\x61xis\x18\x01 \x01(\x05:\x01\x31\x12\x13\n\x08num_axes\x18\x02 \x01(\x05:\x01\x31\x12&\n\x06\x66iller\x18\x03 \x01(\x0b\x32\x16.caffe.FillerParameter\"L\n\x18\x43ontrastiveLossParameter\x12\x11\n\x06margin\x18\x01 \x01(\x02:\x01\x31\x12\x1d\n\x0elegacy_version\x18\x02 \x01(\x08:\x05\x66\x61lse\"\xfc\x03\n\x14\x43onvolutionParameter\x12\x12\n\nnum_output\x18\x01 \x01(\r\x12\x17\n\tbias_term\x18\x02 \x01(\x08:\x04true\x12\x0b\n\x03pad\x18\x03 \x03(\r\x12\x13\n\x0bkernel_size\x18\x04 \x03(\r\x12\x0e\n\x06stride\x18\x06 \x03(\r\x12\x10\n\x08\x64ilation\x18\x12 \x03(\r\x12\x10\n\x05pad_h\x18\t \x01(\r:\x01\x30\x12\x10\n\x05pad_w\x18\n \x01(\r:\x01\x30\x12\x10\n\x08kernel_h\x18\x0b \x01(\r\x12\x10\n\x08kernel_w\x18\x0c \x01(\r\x12\x10\n\x08stride_h\x18\r \x01(\r\x12\x10\n\x08stride_w\x18\x0e \x01(\r\x12\x10\n\x05group\x18\x05 \x01(\r:\x01\x31\x12-\n\rweight_filler\x18\x07 \x01(\x0b\x32\x16.caffe.FillerParameter\x12+\n\x0b\x62ias_filler\x18\x08 \x01(\x0b\x32\x16.caffe.FillerParameter\x12;\n\x06\x65ngine\x18\x0f \x01(\x0e\x32\".caffe.ConvolutionParameter.Engine:\x07\x44\x45\x46\x41ULT\x12\x0f\n\x04\x61xis\x18\x10 \x01(\x05:\x01\x31\x12\x1e\n\x0f\x66orce_nd_im2col\x18\x11 \x01(\x08:\x05\x66\x61lse\"+\n\x06\x45ngine\x12\x0b\n\x07\x44\x45\x46\x41ULT\x10\x00\x12\t\n\x05\x43\x41\x46\x46\x45\x10\x01\x12\t\n\x05\x43UDNN\x10\x02\"0\n\rCropParameter\x12\x0f\n\x04\x61xis\x18\x01 \x01(\x05:\x01\x32\x12\x0e\n\x06offset\x18\x02 \x03(\r\"\xa4\x02\n\rDataParameter\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x12\n\nbatch_size\x18\x04 \x01(\r\x12\x14\n\trand_skip\x18\x07 \x01(\r:\x01\x30\x12\x31\n\x07\x62\x61\x63kend\x18\x08 \x01(\x0e\x32\x17.caffe.DataParameter.DB:\x07LEVELDB\x12\x10\n\x05scale\x18\x02 \x01(\x02:\x01\x31\x12\x11\n\tmean_file\x18\x03 \x01(\t\x12\x14\n\tcrop_size\x18\x05 \x01(\r:\x01\x30\x12\x15\n\x06mirror\x18\x06 \x01(\x08:\x05\x66\x61lse\x12\"\n\x13\x66orce_encoded_color\x18\t \x01(\x08:\x05\x66\x61lse\x12\x13\n\x08prefetch\x18\n \x01(\r:\x01\x35\"\x1b\n\x02\x44\x42\x12\x0b\n\x07LEVELDB\x10\x00\x12\x08\n\x04LMDB\x10\x01\"I\n\x10\x44ropoutParameter\x12\x1a\n\rdropout_ratio\x18\x01 \x01(\x02:\x03\x30.5\x12\x19\n\x0bscale_train\x18\x02 \x01(\x08:\x04true\"\xa0\x01\n\x12\x44ummyDataParameter\x12+\n\x0b\x64\x61ta_filler\x18\x01 \x03(\x0b\x32\x16.caffe.FillerParameter\x12\x1f\n\x05shape\x18\x06 \x03(\x0b\x32\x10.caffe.BlobShape\x12\x0b\n\x03num\x18\x02 \x03(\r\x12\x10\n\x08\x63hannels\x18\x03 \x03(\r\x12\x0e\n\x06height\x18\x04 \x03(\r\x12\r\n\x05width\x18\x05 \x03(\r\"\xa5\x01\n\x10\x45ltwiseParameter\x12\x39\n\toperation\x18\x01 \x01(\x0e\x32!.caffe.EltwiseParameter.EltwiseOp:\x03SUM\x12\r\n\x05\x63oeff\x18\x02 \x03(\x02\x12\x1e\n\x10stable_prod_grad\x18\x03 \x01(\x08:\x04true\"\'\n\tEltwiseOp\x12\x08\n\x04PROD\x10\x00\x12\x07\n\x03SUM\x10\x01\x12\x07\n\x03MAX\x10\x02\" \n\x0c\x45LUParameter\x12\x10\n\x05\x61lpha\x18\x01 \x01(\x02:\x01\x31\"\xac\x01\n\x0e\x45mbedParameter\x12\x12\n\nnum_output\x18\x01 \x01(\r\x12\x11\n\tinput_dim\x18\x02 \x01(\r\x12\x17\n\tbias_term\x18\x03 \x01(\x08:\x04true\x12-\n\rweight_filler\x18\x04 \x01(\x0b\x32\x16.caffe.FillerParameter\x12+\n\x0b\x62ias_filler\x18\x05 \x01(\x0b\x32\x16.caffe.FillerParameter\"D\n\x0c\x45xpParameter\x12\x10\n\x04\x62\x61se\x18\x01 \x01(\x02:\x02-1\x12\x10\n\x05scale\x18\x02 \x01(\x02:\x01\x31\x12\x10\n\x05shift\x18\x03 \x01(\x02:\x01\x30\"9\n\x10\x46lattenParameter\x12\x0f\n\x04\x61xis\x18\x01 \x01(\x05:\x01\x31\x12\x14\n\x08\x65nd_axis\x18\x02 \x01(\x05:\x02-1\"O\n\x11HDF5DataParameter\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x12\n\nbatch_size\x18\x02 \x01(\r\x12\x16\n\x07shuffle\x18\x03 \x01(\x08:\x05\x66\x61lse\"(\n\x13HDF5OutputParameter\x12\x11\n\tfile_name\x18\x01 \x01(\t\"^\n\x12HingeLossParameter\x12\x30\n\x04norm\x18\x01 \x01(\x0e\x32\x1e.caffe.HingeLossParameter.Norm:\x02L1\"\x16\n\x04Norm\x12\x06\n\x02L1\x10\x01\x12\x06\n\x02L2\x10\x02\"\x97\x02\n\x12ImageDataParameter\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x15\n\nbatch_size\x18\x04 \x01(\r:\x01\x31\x12\x14\n\trand_skip\x18\x07 \x01(\r:\x01\x30\x12\x16\n\x07shuffle\x18\x08 \x01(\x08:\x05\x66\x61lse\x12\x15\n\nnew_height\x18\t \x01(\r:\x01\x30\x12\x14\n\tnew_width\x18\n \x01(\r:\x01\x30\x12\x16\n\x08is_color\x18\x0b \x01(\x08:\x04true\x12\x10\n\x05scale\x18\x02 \x01(\x02:\x01\x31\x12\x11\n\tmean_file\x18\x03 \x01(\t\x12\x14\n\tcrop_size\x18\x05 \x01(\r:\x01\x30\x12\x15\n\x06mirror\x18\x06 \x01(\x08:\x05\x66\x61lse\x12\x15\n\x0broot_folder\x18\x0c \x01(\t:\x00\"\'\n\x15InfogainLossParameter\x12\x0e\n\x06source\x18\x01 \x01(\t\"\xcb\x01\n\x15InnerProductParameter\x12\x12\n\nnum_output\x18\x01 \x01(\r\x12\x17\n\tbias_term\x18\x02 \x01(\x08:\x04true\x12-\n\rweight_filler\x18\x03 \x01(\x0b\x32\x16.caffe.FillerParameter\x12+\n\x0b\x62ias_filler\x18\x04 \x01(\x0b\x32\x16.caffe.FillerParameter\x12\x0f\n\x04\x61xis\x18\x05 \x01(\x05:\x01\x31\x12\x18\n\ttranspose\x18\x06 \x01(\x08:\x05\x66\x61lse\"1\n\x0eInputParameter\x12\x1f\n\x05shape\x18\x01 \x03(\x0b\x32\x10.caffe.BlobShape\"D\n\x0cLogParameter\x12\x10\n\x04\x62\x61se\x18\x01 \x01(\x02:\x02-1\x12\x10\n\x05scale\x18\x02 \x01(\x02:\x01\x31\x12\x10\n\x05shift\x18\x03 \x01(\x02:\x01\x30\"\xb8\x02\n\x0cLRNParameter\x12\x15\n\nlocal_size\x18\x01 \x01(\r:\x01\x35\x12\x10\n\x05\x61lpha\x18\x02 \x01(\x02:\x01\x31\x12\x12\n\x04\x62\x65ta\x18\x03 \x01(\x02:\x04\x30.75\x12\x44\n\x0bnorm_region\x18\x04 \x01(\x0e\x32\x1e.caffe.LRNParameter.NormRegion:\x0f\x41\x43ROSS_CHANNELS\x12\x0c\n\x01k\x18\x05 \x01(\x02:\x01\x31\x12\x33\n\x06\x65ngine\x18\x06 \x01(\x0e\x32\x1a.caffe.LRNParameter.Engine:\x07\x44\x45\x46\x41ULT\"5\n\nNormRegion\x12\x13\n\x0f\x41\x43ROSS_CHANNELS\x10\x00\x12\x12\n\x0eWITHIN_CHANNEL\x10\x01\"+\n\x06\x45ngine\x12\x0b\n\x07\x44\x45\x46\x41ULT\x10\x00\x12\t\n\x05\x43\x41\x46\x46\x45\x10\x01\x12\t\n\x05\x43UDNN\x10\x02\"\xbd\x01\n\x13MemoryDataParameter\x12\x12\n\nbatch_size\x18\x01 \x01(\r\x12\x10\n\x08\x63hannels\x18\x02 \x01(\r\x12\x0e\n\x06height\x18\x03 \x01(\r\x12\r\n\x05width\x18\x04 \x01(\r\x12;\n\x05\x64type\x18\x05 \x01(\x0e\x32#.caffe.MemoryDataParameter.DataType:\x07\x46LOAT32\"$\n\x08\x44\x61taType\x12\x0b\n\x07\x46LOAT32\x10\x00\x12\x0b\n\x07\x46LOAT16\x10\x01\"e\n\x0cMVNParameter\x12 \n\x12normalize_variance\x18\x01 \x01(\x08:\x04true\x12\x1e\n\x0f\x61\x63ross_channels\x18\x02 \x01(\x08:\x05\x66\x61lse\x12\x13\n\x03\x65ps\x18\x03 \x01(\x02:\x06\x31\x65-009\"5\n\x12ParameterParameter\x12\x1f\n\x05shape\x18\x01 \x01(\x0b\x32\x10.caffe.BlobShape\"\xa2\x03\n\x10PoolingParameter\x12\x35\n\x04pool\x18\x01 \x01(\x0e\x32\".caffe.PoolingParameter.PoolMethod:\x03MAX\x12\x0e\n\x03pad\x18\x04 \x01(\r:\x01\x30\x12\x10\n\x05pad_h\x18\t \x01(\r:\x01\x30\x12\x10\n\x05pad_w\x18\n \x01(\r:\x01\x30\x12\x13\n\x0bkernel_size\x18\x02 \x01(\r\x12\x10\n\x08kernel_h\x18\x05 \x01(\r\x12\x10\n\x08kernel_w\x18\x06 \x01(\r\x12\x11\n\x06stride\x18\x03 \x01(\r:\x01\x31\x12\x10\n\x08stride_h\x18\x07 \x01(\r\x12\x10\n\x08stride_w\x18\x08 \x01(\r\x12\x37\n\x06\x65ngine\x18\x0b \x01(\x0e\x32\x1e.caffe.PoolingParameter.Engine:\x07\x44\x45\x46\x41ULT\x12\x1d\n\x0eglobal_pooling\x18\x0c \x01(\x08:\x05\x66\x61lse\".\n\nPoolMethod\x12\x07\n\x03MAX\x10\x00\x12\x07\n\x03\x41VE\x10\x01\x12\x0e\n\nSTOCHASTIC\x10\x02\"+\n\x06\x45ngine\x12\x0b\n\x07\x44\x45\x46\x41ULT\x10\x00\x12\t\n\x05\x43\x41\x46\x46\x45\x10\x01\x12\t\n\x05\x43UDNN\x10\x02\"Y\n\x13ROIPoolingParameter\x12\x13\n\x08pooled_h\x18\x01 \x01(\r:\x01\x30\x12\x13\n\x08pooled_w\x18\x02 \x01(\r:\x01\x30\x12\x18\n\rspatial_scale\x18\x03 \x01(\x02:\x01\x31\"F\n\x0ePowerParameter\x12\x10\n\x05power\x18\x01 \x01(\x02:\x01\x31\x12\x10\n\x05scale\x18\x02 \x01(\x02:\x01\x31\x12\x10\n\x05shift\x18\x03 \x01(\x02:\x01\x30\"g\n\x0fPythonParameter\x12\x0e\n\x06module\x18\x01 \x01(\t\x12\r\n\x05layer\x18\x02 \x01(\t\x12\x13\n\tparam_str\x18\x03 \x01(\t:\x00\x12 \n\x11share_in_parallel\x18\x04 \x01(\x08:\x05\x66\x61lse\"\xad\x01\n\x12ReductionParameter\x12=\n\toperation\x18\x01 \x01(\x0e\x32%.caffe.ReductionParameter.ReductionOp:\x03SUM\x12\x0f\n\x04\x61xis\x18\x02 \x01(\x05:\x01\x30\x12\x10\n\x05\x63oeff\x18\x03 \x01(\x02:\x01\x31\"5\n\x0bReductionOp\x12\x07\n\x03SUM\x10\x01\x12\x08\n\x04\x41SUM\x10\x02\x12\t\n\x05SUMSQ\x10\x03\x12\x08\n\x04MEAN\x10\x04\"\x8d\x01\n\rReLUParameter\x12\x19\n\x0enegative_slope\x18\x01 \x01(\x02:\x01\x30\x12\x34\n\x06\x65ngine\x18\x02 \x01(\x0e\x32\x1b.caffe.ReLUParameter.Engine:\x07\x44\x45\x46\x41ULT\"+\n\x06\x45ngine\x12\x0b\n\x07\x44\x45\x46\x41ULT\x10\x00\x12\t\n\x05\x43\x41\x46\x46\x45\x10\x01\x12\t\n\x05\x43UDNN\x10\x02\"Z\n\x10ReshapeParameter\x12\x1f\n\x05shape\x18\x01 \x01(\x0b\x32\x10.caffe.BlobShape\x12\x0f\n\x04\x61xis\x18\x02 \x01(\x05:\x01\x30\x12\x14\n\x08num_axes\x18\x03 \x01(\x05:\x02-1\"\xa5\x01\n\x0eScaleParameter\x12\x0f\n\x04\x61xis\x18\x01 \x01(\x05:\x01\x31\x12\x13\n\x08num_axes\x18\x02 \x01(\x05:\x01\x31\x12&\n\x06\x66iller\x18\x03 \x01(\x0b\x32\x16.caffe.FillerParameter\x12\x18\n\tbias_term\x18\x04 \x01(\x08:\x05\x66\x61lse\x12+\n\x0b\x62ias_filler\x18\x05 \x01(\x0b\x32\x16.caffe.FillerParameter\"x\n\x10SigmoidParameter\x12\x37\n\x06\x65ngine\x18\x01 \x01(\x0e\x32\x1e.caffe.SigmoidParameter.Engine:\x07\x44\x45\x46\x41ULT\"+\n\x06\x45ngine\x12\x0b\n\x07\x44\x45\x46\x41ULT\x10\x00\x12\t\n\x05\x43\x41\x46\x46\x45\x10\x01\x12\t\n\x05\x43UDNN\x10\x02\"L\n\x0eSliceParameter\x12\x0f\n\x04\x61xis\x18\x03 \x01(\x05:\x01\x31\x12\x13\n\x0bslice_point\x18\x02 \x03(\r\x12\x14\n\tslice_dim\x18\x01 \x01(\r:\x01\x31\"\x89\x01\n\x10SoftmaxParameter\x12\x37\n\x06\x65ngine\x18\x01 \x01(\x0e\x32\x1e.caffe.SoftmaxParameter.Engine:\x07\x44\x45\x46\x41ULT\x12\x0f\n\x04\x61xis\x18\x02 \x01(\x05:\x01\x31\"+\n\x06\x45ngine\x12\x0b\n\x07\x44\x45\x46\x41ULT\x10\x00\x12\t\n\x05\x43\x41\x46\x46\x45\x10\x01\x12\t\n\x05\x43UDNN\x10\x02\"r\n\rTanHParameter\x12\x34\n\x06\x65ngine\x18\x01 \x01(\x0e\x32\x1b.caffe.TanHParameter.Engine:\x07\x44\x45\x46\x41ULT\"+\n\x06\x45ngine\x12\x0b\n\x07\x44\x45\x46\x41ULT\x10\x00\x12\t\n\x05\x43\x41\x46\x46\x45\x10\x01\x12\t\n\x05\x43UDNN\x10\x02\"T\n\rTileParameter\x12\x0f\n\x04\x61xis\x18\x01 \x01(\x05:\x01\x31\x12\r\n\x05tiles\x18\x02 \x01(\x05\x12#\n\tmultiples\x18\x03 \x01(\x0b\x32\x10.caffe.BlobShape\"*\n\x12ThresholdParameter\x12\x14\n\tthreshold\x18\x01 \x01(\x02:\x01\x30\"\xc1\x02\n\x13WindowDataParameter\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x10\n\x05scale\x18\x02 \x01(\x02:\x01\x31\x12\x11\n\tmean_file\x18\x03 \x01(\t\x12\x12\n\nbatch_size\x18\x04 \x01(\r\x12\x14\n\tcrop_size\x18\x05 \x01(\r:\x01\x30\x12\x15\n\x06mirror\x18\x06 \x01(\x08:\x05\x66\x61lse\x12\x19\n\x0c\x66g_threshold\x18\x07 \x01(\x02:\x03\x30.5\x12\x19\n\x0c\x62g_threshold\x18\x08 \x01(\x02:\x03\x30.5\x12\x19\n\x0b\x66g_fraction\x18\t \x01(\x02:\x04\x30.25\x12\x16\n\x0b\x63ontext_pad\x18\n \x01(\r:\x01\x30\x12\x17\n\tcrop_mode\x18\x0b \x01(\t:\x04warp\x12\x1b\n\x0c\x63\x61\x63he_images\x18\x0c \x01(\x08:\x05\x66\x61lse\x12\x15\n\x0broot_folder\x18\r \x01(\t:\x00\"\xeb\x01\n\x0cSPPParameter\x12\x16\n\x0epyramid_height\x18\x01 \x01(\r\x12\x31\n\x04pool\x18\x02 \x01(\x0e\x32\x1e.caffe.SPPParameter.PoolMethod:\x03MAX\x12\x33\n\x06\x65ngine\x18\x06 \x01(\x0e\x32\x1a.caffe.SPPParameter.Engine:\x07\x44\x45\x46\x41ULT\".\n\nPoolMethod\x12\x07\n\x03MAX\x10\x00\x12\x07\n\x03\x41VE\x10\x01\x12\x0e\n\nSTOCHASTIC\x10\x02\"+\n\x06\x45ngine\x12\x0b\n\x07\x44\x45\x46\x41ULT\x10\x00\x12\t\n\x05\x43\x41\x46\x46\x45\x10\x01\x12\t\n\x05\x43UDNN\x10\x02\"\xe0\x13\n\x10V1LayerParameter\x12\x0e\n\x06\x62ottom\x18\x02 \x03(\t\x12\x0b\n\x03top\x18\x03 \x03(\t\x12\x0c\n\x04name\x18\x04 \x01(\t\x12$\n\x07include\x18 \x03(\x0b\x32\x13.caffe.NetStateRule\x12$\n\x07\x65xclude\x18! \x03(\x0b\x32\x13.caffe.NetStateRule\x12/\n\x04type\x18\x05 \x01(\x0e\x32!.caffe.V1LayerParameter.LayerType\x12\x1f\n\x05\x62lobs\x18\x06 \x03(\x0b\x32\x10.caffe.BlobProto\x12\x0e\n\x05param\x18\xe9\x07 \x03(\t\x12>\n\x0f\x62lob_share_mode\x18\xea\x07 \x03(\x0e\x32$.caffe.V1LayerParameter.DimCheckMode\x12\x10\n\x08\x62lobs_lr\x18\x07 \x03(\x02\x12\x14\n\x0cweight_decay\x18\x08 \x03(\x02\x12\x13\n\x0bloss_weight\x18# \x03(\x02\x12\x30\n\x0e\x61\x63\x63uracy_param\x18\x1b \x01(\x0b\x32\x18.caffe.AccuracyParameter\x12,\n\x0c\x61rgmax_param\x18\x17 \x01(\x0b\x32\x16.caffe.ArgMaxParameter\x12,\n\x0c\x63oncat_param\x18\t \x01(\x0b\x32\x16.caffe.ConcatParameter\x12?\n\x16\x63ontrastive_loss_param\x18( \x01(\x0b\x32\x1f.caffe.ContrastiveLossParameter\x12\x36\n\x11\x63onvolution_param\x18\n \x01(\x0b\x32\x1b.caffe.ConvolutionParameter\x12(\n\ndata_param\x18\x0b \x01(\x0b\x32\x14.caffe.DataParameter\x12.\n\rdropout_param\x18\x0c \x01(\x0b\x32\x17.caffe.DropoutParameter\x12\x33\n\x10\x64ummy_data_param\x18\x1a \x01(\x0b\x32\x19.caffe.DummyDataParameter\x12.\n\reltwise_param\x18\x18 \x01(\x0b\x32\x17.caffe.EltwiseParameter\x12&\n\texp_param\x18) \x01(\x0b\x32\x13.caffe.ExpParameter\x12\x31\n\x0fhdf5_data_param\x18\r \x01(\x0b\x32\x18.caffe.HDF5DataParameter\x12\x35\n\x11hdf5_output_param\x18\x0e \x01(\x0b\x32\x1a.caffe.HDF5OutputParameter\x12\x33\n\x10hinge_loss_param\x18\x1d \x01(\x0b\x32\x19.caffe.HingeLossParameter\x12\x33\n\x10image_data_param\x18\x0f \x01(\x0b\x32\x19.caffe.ImageDataParameter\x12\x39\n\x13infogain_loss_param\x18\x10 \x01(\x0b\x32\x1c.caffe.InfogainLossParameter\x12\x39\n\x13inner_product_param\x18\x11 \x01(\x0b\x32\x1c.caffe.InnerProductParameter\x12&\n\tlrn_param\x18\x12 \x01(\x0b\x32\x13.caffe.LRNParameter\x12\x35\n\x11memory_data_param\x18\x16 \x01(\x0b\x32\x1a.caffe.MemoryDataParameter\x12&\n\tmvn_param\x18\" \x01(\x0b\x32\x13.caffe.MVNParameter\x12.\n\rpooling_param\x18\x13 \x01(\x0b\x32\x17.caffe.PoolingParameter\x12*\n\x0bpower_param\x18\x15 \x01(\x0b\x32\x15.caffe.PowerParameter\x12(\n\nrelu_param\x18\x1e \x01(\x0b\x32\x14.caffe.ReLUParameter\x12.\n\rsigmoid_param\x18& \x01(\x0b\x32\x17.caffe.SigmoidParameter\x12.\n\rsoftmax_param\x18\' \x01(\x0b\x32\x17.caffe.SoftmaxParameter\x12*\n\x0bslice_param\x18\x1f \x01(\x0b\x32\x15.caffe.SliceParameter\x12(\n\ntanh_param\x18% \x01(\x0b\x32\x14.caffe.TanHParameter\x12\x32\n\x0fthreshold_param\x18\x19 \x01(\x0b\x32\x19.caffe.ThresholdParameter\x12\x35\n\x11window_data_param\x18\x14 \x01(\x0b\x32\x1a.caffe.WindowDataParameter\x12\x37\n\x0ftransform_param\x18$ \x01(\x0b\x32\x1e.caffe.TransformationParameter\x12(\n\nloss_param\x18* \x01(\x0b\x32\x14.caffe.LossParameter\x12&\n\x05layer\x18\x01 \x01(\x0b\x32\x17.caffe.V0LayerParameter\"\xd8\x04\n\tLayerType\x12\x08\n\x04NONE\x10\x00\x12\n\n\x06\x41\x42SVAL\x10#\x12\x0c\n\x08\x41\x43\x43URACY\x10\x01\x12\n\n\x06\x41RGMAX\x10\x1e\x12\x08\n\x04\x42NLL\x10\x02\x12\n\n\x06\x43ONCAT\x10\x03\x12\x14\n\x10\x43ONTRASTIVE_LOSS\x10%\x12\x0f\n\x0b\x43ONVOLUTION\x10\x04\x12\x08\n\x04\x44\x41TA\x10\x05\x12\x11\n\rDECONVOLUTION\x10\'\x12\x0b\n\x07\x44ROPOUT\x10\x06\x12\x0e\n\nDUMMY_DATA\x10 \x12\x12\n\x0e\x45UCLIDEAN_LOSS\x10\x07\x12\x0b\n\x07\x45LTWISE\x10\x19\x12\x07\n\x03\x45XP\x10&\x12\x0b\n\x07\x46LATTEN\x10\x08\x12\r\n\tHDF5_DATA\x10\t\x12\x0f\n\x0bHDF5_OUTPUT\x10\n\x12\x0e\n\nHINGE_LOSS\x10\x1c\x12\n\n\x06IM2COL\x10\x0b\x12\x0e\n\nIMAGE_DATA\x10\x0c\x12\x11\n\rINFOGAIN_LOSS\x10\r\x12\x11\n\rINNER_PRODUCT\x10\x0e\x12\x07\n\x03LRN\x10\x0f\x12\x0f\n\x0bMEMORY_DATA\x10\x1d\x12\x1d\n\x19MULTINOMIAL_LOGISTIC_LOSS\x10\x10\x12\x07\n\x03MVN\x10\"\x12\x0b\n\x07POOLING\x10\x11\x12\t\n\x05POWER\x10\x1a\x12\x08\n\x04RELU\x10\x12\x12\x0b\n\x07SIGMOID\x10\x13\x12\x1e\n\x1aSIGMOID_CROSS_ENTROPY_LOSS\x10\x1b\x12\x0b\n\x07SILENCE\x10$\x12\x0b\n\x07SOFTMAX\x10\x14\x12\x10\n\x0cSOFTMAX_LOSS\x10\x15\x12\t\n\x05SPLIT\x10\x16\x12\t\n\x05SLICE\x10!\x12\x08\n\x04TANH\x10\x17\x12\x0f\n\x0bWINDOW_DATA\x10\x18\x12\r\n\tTHRESHOLD\x10\x1f\"*\n\x0c\x44imCheckMode\x12\n\n\x06STRICT\x10\x00\x12\x0e\n\nPERMISSIVE\x10\x01\"\xfd\x07\n\x10V0LayerParameter\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\t\x12\x12\n\nnum_output\x18\x03 \x01(\r\x12\x16\n\x08\x62iasterm\x18\x04 \x01(\x08:\x04true\x12-\n\rweight_filler\x18\x05 \x01(\x0b\x32\x16.caffe.FillerParameter\x12+\n\x0b\x62ias_filler\x18\x06 \x01(\x0b\x32\x16.caffe.FillerParameter\x12\x0e\n\x03pad\x18\x07 \x01(\r:\x01\x30\x12\x12\n\nkernelsize\x18\x08 \x01(\r\x12\x10\n\x05group\x18\t \x01(\r:\x01\x31\x12\x11\n\x06stride\x18\n \x01(\r:\x01\x31\x12\x35\n\x04pool\x18\x0b \x01(\x0e\x32\".caffe.V0LayerParameter.PoolMethod:\x03MAX\x12\x1a\n\rdropout_ratio\x18\x0c \x01(\x02:\x03\x30.5\x12\x15\n\nlocal_size\x18\r \x01(\r:\x01\x35\x12\x10\n\x05\x61lpha\x18\x0e \x01(\x02:\x01\x31\x12\x12\n\x04\x62\x65ta\x18\x0f \x01(\x02:\x04\x30.75\x12\x0c\n\x01k\x18\x16 \x01(\x02:\x01\x31\x12\x0e\n\x06source\x18\x10 \x01(\t\x12\x10\n\x05scale\x18\x11 \x01(\x02:\x01\x31\x12\x10\n\x08meanfile\x18\x12 \x01(\t\x12\x11\n\tbatchsize\x18\x13 \x01(\r\x12\x13\n\x08\x63ropsize\x18\x14 \x01(\r:\x01\x30\x12\x15\n\x06mirror\x18\x15 \x01(\x08:\x05\x66\x61lse\x12\x1f\n\x05\x62lobs\x18\x32 \x03(\x0b\x32\x10.caffe.BlobProto\x12\x10\n\x08\x62lobs_lr\x18\x33 \x03(\x02\x12\x14\n\x0cweight_decay\x18\x34 \x03(\x02\x12\x14\n\trand_skip\x18\x35 \x01(\r:\x01\x30\x12\x1d\n\x10\x64\x65t_fg_threshold\x18\x36 \x01(\x02:\x03\x30.5\x12\x1d\n\x10\x64\x65t_bg_threshold\x18\x37 \x01(\x02:\x03\x30.5\x12\x1d\n\x0f\x64\x65t_fg_fraction\x18\x38 \x01(\x02:\x04\x30.25\x12\x1a\n\x0f\x64\x65t_context_pad\x18: \x01(\r:\x01\x30\x12\x1b\n\rdet_crop_mode\x18; \x01(\t:\x04warp\x12\x12\n\x07new_num\x18< \x01(\x05:\x01\x30\x12\x17\n\x0cnew_channels\x18= \x01(\x05:\x01\x30\x12\x15\n\nnew_height\x18> \x01(\x05:\x01\x30\x12\x14\n\tnew_width\x18? \x01(\x05:\x01\x30\x12\x1d\n\x0eshuffle_images\x18@ \x01(\x08:\x05\x66\x61lse\x12\x15\n\nconcat_dim\x18\x41 \x01(\r:\x01\x31\x12\x36\n\x11hdf5_output_param\x18\xe9\x07 \x01(\x0b\x32\x1a.caffe.HDF5OutputParameter\".\n\nPoolMethod\x12\x07\n\x03MAX\x10\x00\x12\x07\n\x03\x41VE\x10\x01\x12\x0e\n\nSTOCHASTIC\x10\x02\"W\n\x0ePReLUParameter\x12&\n\x06\x66iller\x18\x01 \x01(\x0b\x32\x16.caffe.FillerParameter\x12\x1d\n\x0e\x63hannel_shared\x18\x02 \x01(\x08:\x05\x66\x61lse\")\n\x15SmoothL1LossParameter\x12\x10\n\x05sigma\x18\x01 \x01(\x02:\x01\x31\"H\n\x0cMPIParameter\x12\x0f\n\x04root\x18\x01 \x01(\r:\x01\x30\x12\x12\n\x07\x63omm_id\x18\x02 \x01(\x04:\x01\x30\x12\x13\n\x08group_id\x18\x03 \x01(\x04:\x01\x30\"!\n\x10PermuteParameter\x12\r\n\x05order\x18\x01 \x03(\r\"\x93\x01\n\x12NormalizeParameter\x12\x1c\n\x0e\x61\x63ross_spatial\x18\x01 \x01(\x08:\x04true\x12,\n\x0cscale_filler\x18\x02 \x01(\x0b\x32\x16.caffe.FillerParameter\x12\x1c\n\x0e\x63hannel_shared\x18\x03 \x01(\x08:\x04true\x12\x13\n\x03\x65ps\x18\x04 \x01(\x02:\x06\x31\x65-010\"_\n\x11ParallelParameter\x12\x16\n\x07shuffle\x18\x01 \x01(\x08:\x05\x66\x61lse\x12\x18\n\tnode_step\x18\x02 \x01(\x08:\x05\x66\x61lse\x12\x18\n\tpartition\x18\x03 \x01(\x08:\x05\x66\x61lse\"R\n\x0fResizeParameter\x12\x1f\n\x05shape\x18\x01 \x01(\x0b\x32\x10.caffe.BlobShape\x12\x0e\n\x02\x66x\x18\x02 \x01(\x02:\x02-1\x12\x0e\n\x02\x66y\x18\x03 \x01(\x02:\x02-1\"\'\n\x13\x45xpandDimsParameter\x12\x10\n\x04\x61xis\x18\x01 \x01(\x05:\x02-1\"\xc8\x01\n\x11ProposalParameter\x12\x17\n\x0b\x66\x65\x61t_stride\x18\x01 \x01(\r:\x02\x31\x36\x12\x15\n\tbase_size\x18\x02 \x01(\r:\x02\x31\x36\x12\x14\n\x08min_size\x18\x03 \x01(\r:\x02\x31\x36\x12\r\n\x05ratio\x18\x04 \x03(\x02\x12\r\n\x05scale\x18\x05 \x03(\x02\x12\x1a\n\x0cpre_nms_topn\x18\x06 \x01(\r:\x04\x36\x30\x30\x30\x12\x1a\n\rpost_nms_topn\x18\x07 \x01(\r:\x03\x33\x30\x30\x12\x17\n\nnms_thresh\x18\x08 \x01(\x02:\x03\x30.7\"\xa6\x01\n\x14\x42\x61tchRenormParameter\x12\x18\n\x10use_global_stats\x18\x01 \x01(\x08\x12$\n\x17moving_average_fraction\x18\x02 \x01(\x02:\x03\x30.9\x12\x12\n\x03\x65ps\x18\x03 \x01(\x02:\x05\x30.001\x12\x10\n\x05r_max\x18\x04 \x01(\x02:\x01\x33\x12\x10\n\x05\x64_max\x18\x05 \x01(\x02:\x01\x35\x12\x16\n\x07t_delta\x18\x06 \x01(\x02:\x05\x30.001\"?\n\x14\x44\x65nseConcatParameter\x12\x0f\n\x04\x61xis\x18\x01 \x01(\x05:\x01\x31\x12\x16\n\x0bgrowth_rate\x18\x02 \x01(\x05:\x01\x30\"c\n\x12\x46ocalLossParameter\x12\x12\n\x05\x61lpha\x18\x01 \x01(\x02:\x03\x30.5\x12\x10\n\x05gamma\x18\x02 \x01(\x02:\x01\x30\x12\x13\n\x03\x65ps\x18\x03 \x01(\x02:\x06\x31\x65-010\x12\x12\n\x06neg_id\x18\x04 \x01(\x05:\x02-1\"\"\n\x0fGatherParameter\x12\x0f\n\x04\x61xis\x18\x01 \x01(\x05:\x01\x30*\x1c\n\x05Phase\x12\t\n\x05TRAIN\x10\x00\x12\x08\n\x04TEST\x10\x01')
)
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
......@@ -40,8 +40,8 @@ _PHASE = _descriptor.EnumDescriptor(
],
containing_type=None,
options=None,
serialized_start=17308,
serialized_end=17336,
serialized_start=17391,
serialized_end=17419,
)
_sym_db.RegisterEnumDescriptor(_PHASE)
......@@ -209,8 +209,8 @@ _LOSSPARAMETER_NORMALIZATIONMODE = _descriptor.EnumDescriptor(
],
containing_type=None,
options=None,
serialized_start=6478,
serialized_end=6554,
serialized_start=6525,
serialized_end=6601,
)
_sym_db.RegisterEnumDescriptor(_LOSSPARAMETER_NORMALIZATIONMODE)
......@@ -235,8 +235,8 @@ _CONVOLUTIONPARAMETER_ENGINE = _descriptor.EnumDescriptor(
],
containing_type=None,
options=None,
serialized_start=7517,
serialized_end=7560,
serialized_start=7564,
serialized_end=7607,
)
_sym_db.RegisterEnumDescriptor(_CONVOLUTIONPARAMETER_ENGINE)
......@@ -257,8 +257,8 @@ _DATAPARAMETER_DB = _descriptor.EnumDescriptor(
],
containing_type=None,
options=None,
serialized_start=7878,
serialized_end=7905,
serialized_start=7925,
serialized_end=7952,
)
_sym_db.RegisterEnumDescriptor(_DATAPARAMETER_DB)
......@@ -283,8 +283,8 @@ _ELTWISEPARAMETER_ELTWISEOP = _descriptor.EnumDescriptor(
],
containing_type=None,
options=None,
serialized_start=8272,
serialized_end=8311,
serialized_start=8319,
serialized_end=8358,
)
_sym_db.RegisterEnumDescriptor(_ELTWISEPARAMETER_ELTWISEOP)
......@@ -305,8 +305,8 @@ _HINGELOSSPARAMETER_NORM = _descriptor.EnumDescriptor(
],
containing_type=None,
options=None,
serialized_start=8846,
serialized_end=8868,
serialized_start=8893,
serialized_end=8915,
)
_sym_db.RegisterEnumDescriptor(_HINGELOSSPARAMETER_NORM)
......@@ -327,8 +327,8 @@ _LRNPARAMETER_NORMREGION = _descriptor.EnumDescriptor(
],
containing_type=None,
options=None,
serialized_start=9735,
serialized_end=9788,
serialized_start=9782,
serialized_end=9835,
)
_sym_db.RegisterEnumDescriptor(_LRNPARAMETER_NORMREGION)
......@@ -353,8 +353,8 @@ _LRNPARAMETER_ENGINE = _descriptor.EnumDescriptor(
],
containing_type=None,
options=None,
serialized_start=7517,
serialized_end=7560,
serialized_start=7564,
serialized_end=7607,
)
_sym_db.RegisterEnumDescriptor(_LRNPARAMETER_ENGINE)
......@@ -375,8 +375,8 @@ _MEMORYDATAPARAMETER_DATATYPE = _descriptor.EnumDescriptor(
],
containing_type=None,
options=None,
serialized_start=9989,
serialized_end=10025,
serialized_start=10036,
serialized_end=10072,
)
_sym_db.RegisterEnumDescriptor(_MEMORYDATAPARAMETER_DATATYPE)
......@@ -401,8 +401,8 @@ _POOLINGPARAMETER_POOLMETHOD = _descriptor.EnumDescriptor(
],
containing_type=None,
options=None,
serialized_start=10513,
serialized_end=10559,
serialized_start=10560,
serialized_end=10606,
)
_sym_db.RegisterEnumDescriptor(_POOLINGPARAMETER_POOLMETHOD)
......@@ -427,8 +427,8 @@ _POOLINGPARAMETER_ENGINE = _descriptor.EnumDescriptor(
],
containing_type=None,
options=None,
serialized_start=7517,
serialized_end=7560,
serialized_start=7564,
serialized_end=7607,
)
_sym_db.RegisterEnumDescriptor(_POOLINGPARAMETER_ENGINE)
......@@ -457,8 +457,8 @@ _REDUCTIONPARAMETER_REDUCTIONOP = _descriptor.EnumDescriptor(
],
containing_type=None,
options=None,
serialized_start=10995,
serialized_end=11048,
serialized_start=11042,
serialized_end=11095,
)
_sym_db.RegisterEnumDescriptor(_REDUCTIONPARAMETER_REDUCTIONOP)
......@@ -483,8 +483,8 @@ _RELUPARAMETER_ENGINE = _descriptor.EnumDescriptor(
],
containing_type=None,
options=None,
serialized_start=7517,
serialized_end=7560,
serialized_start=7564,
serialized_end=7607,
)
_sym_db.RegisterEnumDescriptor(_RELUPARAMETER_ENGINE)
......@@ -509,8 +509,8 @@ _SIGMOIDPARAMETER_ENGINE = _descriptor.EnumDescriptor(
],
containing_type=None,
options=None,
serialized_start=7517,
serialized_end=7560,
serialized_start=7564,
serialized_end=7607,
)
_sym_db.RegisterEnumDescriptor(_SIGMOIDPARAMETER_ENGINE)
......@@ -535,8 +535,8 @@ _SOFTMAXPARAMETER_ENGINE = _descriptor.EnumDescriptor(
],
containing_type=None,
options=None,
serialized_start=7517,
serialized_end=7560,
serialized_start=7564,
serialized_end=7607,
)
_sym_db.RegisterEnumDescriptor(_SOFTMAXPARAMETER_ENGINE)
......@@ -561,8 +561,8 @@ _TANHPARAMETER_ENGINE = _descriptor.EnumDescriptor(
],
containing_type=None,
options=None,
serialized_start=7517,
serialized_end=7560,
serialized_start=7564,
serialized_end=7607,
)
_sym_db.RegisterEnumDescriptor(_TANHPARAMETER_ENGINE)
......@@ -587,8 +587,8 @@ _SPPPARAMETER_POOLMETHOD = _descriptor.EnumDescriptor(
],
containing_type=None,
options=None,
serialized_start=10513,
serialized_end=10559,
serialized_start=10560,
serialized_end=10606,
)
_sym_db.RegisterEnumDescriptor(_SPPPARAMETER_POOLMETHOD)
......@@ -613,8 +613,8 @@ _SPPPARAMETER_ENGINE = _descriptor.EnumDescriptor(
],
containing_type=None,
options=None,
serialized_start=7517,
serialized_end=7560,
serialized_start=7564,
serialized_end=7607,
)
_sym_db.RegisterEnumDescriptor(_SPPPARAMETER_ENGINE)
......@@ -787,8 +787,8 @@ _V1LAYERPARAMETER_LAYERTYPE = _descriptor.EnumDescriptor(
],
containing_type=None,
options=None,
serialized_start=14487,
serialized_end=15087,
serialized_start=14534,
serialized_end=15134,
)
_sym_db.RegisterEnumDescriptor(_V1LAYERPARAMETER_LAYERTYPE)
......@@ -835,8 +835,8 @@ _V0LAYERPARAMETER_POOLMETHOD = _descriptor.EnumDescriptor(
],
containing_type=None,
options=None,
serialized_start=10513,
serialized_end=10559,
serialized_start=10560,
serialized_end=10606,
)
_sym_db.RegisterEnumDescriptor(_V0LAYERPARAMETER_POOLMETHOD)
......@@ -2254,6 +2254,13 @@ _LAYERPARAMETER = _descriptor.Descriptor(
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='gather_param', full_name='caffe.LayerParameter.gather_param', index=70,
number=165, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
......@@ -2266,7 +2273,7 @@ _LAYERPARAMETER = _descriptor.Descriptor(
oneofs=[
],
serialized_start=2834,
serialized_end=6008,
serialized_end=6055,
)
......@@ -2365,8 +2372,8 @@ _TRANSFORMATIONPARAMETER = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=6011,
serialized_end=6306,
serialized_start=6058,
serialized_end=6353,
)
......@@ -2395,8 +2402,8 @@ _LOSSPARAMETER_EXPANDDIMSPARAMETER = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=6437,
serialized_end=6476,
serialized_start=6484,
serialized_end=6523,
)
_LOSSPARAMETER = _descriptor.Descriptor(
......@@ -2439,8 +2446,8 @@ _LOSSPARAMETER = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=6309,
serialized_end=6554,
serialized_start=6356,
serialized_end=6601,
)
......@@ -2483,8 +2490,8 @@ _ACCURACYPARAMETER = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=6556,
serialized_end=6632,
serialized_start=6603,
serialized_end=6679,
)
......@@ -2527,8 +2534,8 @@ _ARGMAXPARAMETER = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=6634,
serialized_end=6711,
serialized_start=6681,
serialized_end=6758,
)
......@@ -2564,8 +2571,8 @@ _CONCATPARAMETER = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=6713,
serialized_end=6770,
serialized_start=6760,
serialized_end=6817,
)
......@@ -2608,8 +2615,8 @@ _BATCHNORMPARAMETER = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=6772,
serialized_end=6876,
serialized_start=6819,
serialized_end=6923,
)
......@@ -2652,8 +2659,8 @@ _BIASPARAMETER = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=6878,
serialized_end=6971,
serialized_start=6925,
serialized_end=7018,
)
......@@ -2689,8 +2696,8 @@ _CONTRASTIVELOSSPARAMETER = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=6973,
serialized_end=7049,
serialized_start=7020,
serialized_end=7096,
)
......@@ -2839,8 +2846,8 @@ _CONVOLUTIONPARAMETER = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=7052,
serialized_end=7560,
serialized_start=7099,
serialized_end=7607,
)
......@@ -2876,8 +2883,8 @@ _CROPPARAMETER = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=7562,
serialized_end=7610,
serialized_start=7609,
serialized_end=7657,
)
......@@ -2970,8 +2977,8 @@ _DATAPARAMETER = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=7613,
serialized_end=7905,
serialized_start=7660,
serialized_end=7952,
)
......@@ -3007,8 +3014,8 @@ _DROPOUTPARAMETER = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=7907,
serialized_end=7980,
serialized_start=7954,
serialized_end=8027,
)
......@@ -3072,8 +3079,8 @@ _DUMMYDATAPARAMETER = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=7983,
serialized_end=8143,
serialized_start=8030,
serialized_end=8190,
)
......@@ -3117,8 +3124,8 @@ _ELTWISEPARAMETER = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=8146,
serialized_end=8311,
serialized_start=8193,
serialized_end=8358,
)
......@@ -3147,8 +3154,8 @@ _ELUPARAMETER = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=8313,
serialized_end=8345,
serialized_start=8360,
serialized_end=8392,
)
......@@ -3205,8 +3212,8 @@ _EMBEDPARAMETER = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=8348,
serialized_end=8520,
serialized_start=8395,
serialized_end=8567,
)
......@@ -3249,8 +3256,8 @@ _EXPPARAMETER = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=8522,
serialized_end=8590,
serialized_start=8569,
serialized_end=8637,
)
......@@ -3286,8 +3293,8 @@ _FLATTENPARAMETER = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=8592,
serialized_end=8649,
serialized_start=8639,
serialized_end=8696,
)
......@@ -3330,8 +3337,8 @@ _HDF5DATAPARAMETER = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=8651,
serialized_end=8730,
serialized_start=8698,
serialized_end=8777,
)
......@@ -3360,8 +3367,8 @@ _HDF5OUTPUTPARAMETER = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=8732,
serialized_end=8772,
serialized_start=8779,
serialized_end=8819,
)
......@@ -3391,8 +3398,8 @@ _HINGELOSSPARAMETER = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=8774,
serialized_end=8868,
serialized_start=8821,
serialized_end=8915,
)
......@@ -3498,8 +3505,8 @@ _IMAGEDATAPARAMETER = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=8871,
serialized_end=9150,
serialized_start=8918,
serialized_end=9197,
)
......@@ -3528,8 +3535,8 @@ _INFOGAINLOSSPARAMETER = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=9152,
serialized_end=9191,
serialized_start=9199,
serialized_end=9238,
)
......@@ -3593,8 +3600,8 @@ _INNERPRODUCTPARAMETER = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=9194,
serialized_end=9397,
serialized_start=9241,
serialized_end=9444,
)
......@@ -3623,8 +3630,8 @@ _INPUTPARAMETER = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=9399,
serialized_end=9448,
serialized_start=9446,
serialized_end=9495,
)
......@@ -3667,8 +3674,8 @@ _LOGPARAMETER = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=9450,
serialized_end=9518,
serialized_start=9497,
serialized_end=9565,
)
......@@ -3734,8 +3741,8 @@ _LRNPARAMETER = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=9521,
serialized_end=9833,
serialized_start=9568,
serialized_end=9880,
)
......@@ -3793,8 +3800,8 @@ _MEMORYDATAPARAMETER = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=9836,
serialized_end=10025,
serialized_start=9883,
serialized_end=10072,
)
......@@ -3837,8 +3844,8 @@ _MVNPARAMETER = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=10027,
serialized_end=10128,
serialized_start=10074,
serialized_end=10175,
)
......@@ -3867,8 +3874,8 @@ _PARAMETERPARAMETER = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=10130,
serialized_end=10183,
serialized_start=10177,
serialized_end=10230,
)
......@@ -3976,8 +3983,8 @@ _POOLINGPARAMETER = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=10186,
serialized_end=10604,
serialized_start=10233,
serialized_end=10651,
)
......@@ -4020,8 +4027,8 @@ _ROIPOOLINGPARAMETER = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=10606,
serialized_end=10695,
serialized_start=10653,
serialized_end=10742,
)
......@@ -4064,8 +4071,8 @@ _POWERPARAMETER = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=10697,
serialized_end=10767,
serialized_start=10744,
serialized_end=10814,
)
......@@ -4115,8 +4122,8 @@ _PYTHONPARAMETER = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=10769,
serialized_end=10872,
serialized_start=10816,
serialized_end=10919,
)
......@@ -4160,8 +4167,8 @@ _REDUCTIONPARAMETER = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=10875,
serialized_end=11048,
serialized_start=10922,
serialized_end=11095,
)
......@@ -4198,8 +4205,8 @@ _RELUPARAMETER = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=11051,
serialized_end=11192,
serialized_start=11098,
serialized_end=11239,
)
......@@ -4242,8 +4249,8 @@ _RESHAPEPARAMETER = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=11194,
serialized_end=11284,
serialized_start=11241,
serialized_end=11331,
)
......@@ -4300,8 +4307,8 @@ _SCALEPARAMETER = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=11287,
serialized_end=11452,
serialized_start=11334,
serialized_end=11499,
)
......@@ -4331,8 +4338,8 @@ _SIGMOIDPARAMETER = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=11454,
serialized_end=11574,
serialized_start=11501,
serialized_end=11621,
)
......@@ -4375,8 +4382,8 @@ _SLICEPARAMETER = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=11576,
serialized_end=11652,
serialized_start=11623,
serialized_end=11699,
)
......@@ -4413,8 +4420,8 @@ _SOFTMAXPARAMETER = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=11655,
serialized_end=11792,
serialized_start=11702,
serialized_end=11839,
)
......@@ -4444,8 +4451,8 @@ _TANHPARAMETER = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=11794,
serialized_end=11908,
serialized_start=11841,
serialized_end=11955,
)
......@@ -4488,8 +4495,8 @@ _TILEPARAMETER = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=11910,
serialized_end=11994,
serialized_start=11957,
serialized_end=12041,
)
......@@ -4518,8 +4525,8 @@ _THRESHOLDPARAMETER = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=11996,
serialized_end=12038,
serialized_start=12043,
serialized_end=12085,
)
......@@ -4632,8 +4639,8 @@ _WINDOWDATAPARAMETER = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=12041,
serialized_end=12362,
serialized_start=12088,
serialized_end=12409,
)
......@@ -4678,8 +4685,8 @@ _SPPPARAMETER = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=12365,
serialized_end=12600,
serialized_start=12412,
serialized_end=12647,
)
......@@ -5004,8 +5011,8 @@ _V1LAYERPARAMETER = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=12603,
serialized_end=15131,
serialized_start=12650,
serialized_end=15178,
)
......@@ -5294,8 +5301,8 @@ _V0LAYERPARAMETER = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=15134,
serialized_end=16155,
serialized_start=15181,
serialized_end=16202,
)
......@@ -5331,8 +5338,8 @@ _PRELUPARAMETER = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=16157,
serialized_end=16244,
serialized_start=16204,
serialized_end=16291,
)
......@@ -5361,8 +5368,8 @@ _SMOOTHL1LOSSPARAMETER = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=16246,
serialized_end=16287,
serialized_start=16293,
serialized_end=16334,
)
......@@ -5405,8 +5412,8 @@ _MPIPARAMETER = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=16289,
serialized_end=16361,
serialized_start=16336,
serialized_end=16408,
)
......@@ -5435,8 +5442,8 @@ _PERMUTEPARAMETER = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=16363,
serialized_end=16396,
serialized_start=16410,
serialized_end=16443,
)
......@@ -5486,8 +5493,8 @@ _NORMALIZEPARAMETER = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=16399,
serialized_end=16546,
serialized_start=16446,
serialized_end=16593,
)
......@@ -5530,8 +5537,8 @@ _PARALLELPARAMETER = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=16548,
serialized_end=16643,
serialized_start=16595,
serialized_end=16690,
)
......@@ -5574,8 +5581,8 @@ _RESIZEPARAMETER = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=16645,
serialized_end=16727,
serialized_start=16692,
serialized_end=16774,
)
......@@ -5604,8 +5611,8 @@ _EXPANDDIMSPARAMETER = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=6437,
serialized_end=6476,
serialized_start=6484,
serialized_end=6523,
)
......@@ -5683,8 +5690,8 @@ _PROPOSALPARAMETER = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=16771,
serialized_end=16971,
serialized_start=16818,
serialized_end=17018,
)
......@@ -5748,8 +5755,8 @@ _BATCHRENORMPARAMETER = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=16974,
serialized_end=17140,
serialized_start=17021,
serialized_end=17187,
)
......@@ -5785,8 +5792,8 @@ _DENSECONCATPARAMETER = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=17142,
serialized_end=17205,
serialized_start=17189,
serialized_end=17252,
)
......@@ -5836,8 +5843,38 @@ _FOCALLOSSPARAMETER = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=17207,
serialized_end=17306,
serialized_start=17254,
serialized_end=17353,
)
_GATHERPARAMETER = _descriptor.Descriptor(
name='GatherParameter',
full_name='caffe.GatherParameter',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='axis', full_name='caffe.GatherParameter.axis', index=0,
number=1, type=5, cpp_type=1, label=1,
has_default_value=True, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
extension_ranges=[],
oneofs=[
],
serialized_start=17355,
serialized_end=17389,
)
_BLOBPROTO.fields_by_name['shape'].message_type = _BLOBSHAPE
......@@ -5927,6 +5964,7 @@ _LAYERPARAMETER.fields_by_name['proposal_param'].message_type = _PROPOSALPARAMET
_LAYERPARAMETER.fields_by_name['batch_renorm_param'].message_type = _BATCHRENORMPARAMETER
_LAYERPARAMETER.fields_by_name['dense_concat_param'].message_type = _DENSECONCATPARAMETER
_LAYERPARAMETER.fields_by_name['focal_loss_param'].message_type = _FOCALLOSSPARAMETER
_LAYERPARAMETER.fields_by_name['gather_param'].message_type = _GATHERPARAMETER
_LOSSPARAMETER_EXPANDDIMSPARAMETER.containing_type = _LOSSPARAMETER
_LOSSPARAMETER.fields_by_name['normalization'].enum_type = _LOSSPARAMETER_NORMALIZATIONMODE
_LOSSPARAMETER_NORMALIZATIONMODE.containing_type = _LOSSPARAMETER
......@@ -6096,6 +6134,7 @@ DESCRIPTOR.message_types_by_name['ProposalParameter'] = _PROPOSALPARAMETER
DESCRIPTOR.message_types_by_name['BatchRenormParameter'] = _BATCHRENORMPARAMETER
DESCRIPTOR.message_types_by_name['DenseConcatParameter'] = _DENSECONCATPARAMETER
DESCRIPTOR.message_types_by_name['FocalLossParameter'] = _FOCALLOSSPARAMETER
DESCRIPTOR.message_types_by_name['GatherParameter'] = _GATHERPARAMETER
DESCRIPTOR.enum_types_by_name['Phase'] = _PHASE
BlobShape = _reflection.GeneratedProtocolMessageType('BlobShape', (_message.Message,), dict(
......@@ -6610,6 +6649,13 @@ FocalLossParameter = _reflection.GeneratedProtocolMessageType('FocalLossParamete
))
_sym_db.RegisterMessage(FocalLossParameter)
GatherParameter = _reflection.GeneratedProtocolMessageType('GatherParameter', (_message.Message,), dict(
DESCRIPTOR = _GATHERPARAMETER,
__module__ = 'caffe_pb2'
# @@protoc_insertion_point(class_scope:caffe.GatherParameter)
))
_sym_db.RegisterMessage(GatherParameter)
_BLOBSHAPE.fields_by_name['dim'].has_options = True
_BLOBSHAPE.fields_by_name['dim']._options = _descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001'))
......
......@@ -4,10 +4,11 @@
# Written by Ting Pan
# --------------------------------------------------------
import sys
import copy
from collections import OrderedDict
import numpy as np
import sys
from collections import OrderedDict
from six.moves import xrange
import dragon.core.mpi as mpi
import dragon.core.workspace as ws
......
......@@ -36,7 +36,7 @@ find_packages('dragon')
find_modules()
setup(name = 'dragon',
version='0.2.1.2',
version='0.2.1.3',
description = 'Dragon: A Computation Graph Virtual Machine Based Deep Learning Framework',
url='https://github.com/neopenx/Dragon',
author='Ting Pan',
......
......@@ -8,12 +8,15 @@ namespace dragon {
template <class Context> template <typename T>
void SigmoidCrossEntropyOp<Context>::RunWithType() {
auto* Xdata = input(0).template data<T, Context>();
auto* Pdata = prob->template mutable_data<T, Context>();
kernel::Sigmoid<T, Context>(prob->count(), Xdata, Pdata);
auto* Tdata = input(1).template data<T, Context>();
auto* Ldata = losses.template mutable_data<T, Context>();
kernel::SigmoidCrossEntropy<T, Context>(input(0).count(), Xdata, Tdata, Ldata);
auto* Vdata = valid.template mutable_data<T, Context>();
kernel::SigmoidCrossEntropy<T, Context>(input(0).count(),
Xdata,
Tdata,
Ldata,
Vdata);
if (normalization == "UNIT") {
output(0)->ReshapeLike(losses);
......@@ -22,7 +25,9 @@ void SigmoidCrossEntropyOp<Context>::RunWithType() {
}
T normalizer;
if (normalization == "BATCH_SIZE") normalizer = input(0).dim(0);
if (normalization == "VALID")
normalizer = math::ASum<T, Context>(valid.count(), Vdata);
else if (normalization == "BATCH_SIZE") normalizer = input(0).dim(0);
else if (normalization == "FULL") normalizer = input(0).count();
else if (normalization == "NONE") normalizer = 1;
T loss = math::ASum<T, Context>(losses.count(), Ldata);
......@@ -35,9 +40,8 @@ template <class Context>
void SigmoidCrossEntropyOp<Context>::RunOnDevice() {
CHECK_EQ(input(0).count(), input(1).count())
<< "\nNumber of predictions must match the number of labels.";
prob = ws()->CreateTensor("/mnt/" + anchor() + "/sigmoid_prob");
prob->ReshapeLike(input(0));
losses.ReshapeLike(input(0));
valid.ReshapeLike(input(0));
if (input(0).template IsType<float>()) RunWithType<float>();
else LOG(FATAL) << "Unsupported input types.";
......@@ -51,11 +55,16 @@ OPERATOR_SCHEMA(SigmoidCrossEntropy).NumInputs(2).NumOutputs(1);
template <class Context> template <typename T>
void SigmoidCrossEntropyGradientOp<Context>::RunWithType() {
auto* Pdata = prob->template data<T, Context>();
auto* Xdata = input(0).template data<T, Context>();
auto* Tdata = input(1).template data<T, Context>();
auto* Vdata = valid.template mutable_data<T, Context>();
auto* dXdata = output(0)->template mutable_data<T, Context>();
ctx().template Copy<T, Context, Context>(prob->count(), dXdata, Pdata);
math::Axpy<T, Context>(output(0)->count(), -1.0, Tdata, dXdata);
kernel::SigmoidCrossEntropyGrad<T, Context>(input(0).count(),
Xdata,
Tdata,
dXdata,
Vdata);
if (normalization == "UNIT") {
auto* dYdata = input(-1).template data<T, Context>();
......@@ -64,7 +73,8 @@ void SigmoidCrossEntropyGradientOp<Context>::RunWithType() {
}
T normalizer;
if (normalization == "BATCH_SIZE") normalizer = input(0).dim(0);
if (normalization == "VALID") normalizer = math::ASum<T, Context>(valid.count(), Vdata);
else if (normalization == "BATCH_SIZE") normalizer = input(0).dim(0);
else if (normalization == "FULL") normalizer = input(0).count();
else if (normalization == "NONE") normalizer = 1;
auto* dYdata = input(-1).template data<T, CPUContext>();
......@@ -73,8 +83,8 @@ void SigmoidCrossEntropyGradientOp<Context>::RunWithType() {
template <class Context>
void SigmoidCrossEntropyGradientOp<Context>::RunOnDevice() {
prob = ws()->GetTensor("/mnt/" + anchor() + "/sigmoid_prob");
output(0)->ReshapeLike(input(0));
valid.ReshapeLike(input(0));
if (input(0).template IsType<float>()) RunWithType<float>();
else LOG(FATAL) << "Unsupported input types.";
......
#include "operators/ndarray/at_op.h"
#include "operators/ndarray/gather_op.h"
#include "core/workspace.h"
#include "utils/math_functions.h"
#include "utils/op_kernel.h"
......@@ -6,12 +6,12 @@
namespace dragon {
template <class Context> template <typename T>
void AtOp<Context>::RunWithType() {
void GatherOp<Context>::RunWithType() {
auto* Xdata = input(0).template data<T, Context>();
auto* indices = input(1).template mutable_data<int, Context>();
auto* Ydata = output(0)->template mutable_data<T, Context>();
kernel::CanonicalAxis<int, Context>(input(1).count(), x_slice_dim, indices);
kernel::At<T, Context>(output(0)->count(), outer_dim, inner_dim,
kernel::Gather<T, Context>(output(0)->count(), outer_dim, inner_dim,
x_slice_dim, y_slice_dim,
indices,
Xdata,
......@@ -20,7 +20,7 @@ void AtOp<Context>::RunWithType() {
}
template <class Context>
void AtOp<Context>::RunOnDevice() {
void GatherOp<Context>::RunOnDevice() {
output_dims = input(0).dims();
x_slice_dim = input(0).dim(axis);
output_dims[axis] = y_slice_dim = input(1).count();
......@@ -35,19 +35,19 @@ void AtOp<Context>::RunOnDevice() {
else LOG(FATAL) << "Unsupported input types.";
}
DEPLOY_CPU(At);
DEPLOY_CPU(Gather);
#ifdef WITH_CUDA
DEPLOY_CUDA(At);
DEPLOY_CUDA(Gather);
#endif
OPERATOR_SCHEMA(At).NumInputs(2).NumOutputs(1);
OPERATOR_SCHEMA(Gather).NumInputs(2).NumOutputs(1);
template <class Context> template <typename T>
void AtGradientOp<Context>::RunWithType() {
void GatherGradientOp<Context>::RunWithType() {
auto* indices = input(1).template data<int, Context>();
auto* dYdata = input(-1).template data<T, Context>();
auto* dXdata = output(0)->template mutable_data<T, Context>();
if (!acc_grad) math::Set<T, Context>(output(0)->count(), 0, dXdata);
kernel::AtGrad<T, Context>(input(-1).count(), outer_dim, inner_dim,
kernel::GatherGrad<T, Context>(input(-1).count(), outer_dim, inner_dim,
x_slice_dim, y_slice_dim,
indices,
dYdata,
......@@ -55,7 +55,7 @@ void AtGradientOp<Context>::RunWithType() {
}
template <class Context>
void AtGradientOp<Context>::RunOnDevice() {
void GatherGradientOp<Context>::RunOnDevice() {
x_slice_dim = input(0).dim(axis);
y_slice_dim = input(1).count();
outer_dim = input(0).count(0, axis);
......@@ -68,21 +68,21 @@ void AtGradientOp<Context>::RunOnDevice() {
else LOG(FATAL) << "Unsupported input types.";
}
DEPLOY_CPU(AtGradient);
DEPLOY_CPU(GatherGradient);
#ifdef WITH_CUDA
DEPLOY_CUDA(AtGradient);
DEPLOY_CUDA(GatherGradient);
#endif
OPERATOR_SCHEMA(AtGradient).NumInputs(3).NumOutputs(1);
OPERATOR_SCHEMA(GatherGradient).NumInputs(3).NumOutputs(1);
class GetAtGradient final : public GradientMakerBase {
class GetGatherGradient final : public GradientMakerBase {
public:
GRADIENT_MAKER_CTOR(GetAtGradient);
GRADIENT_MAKER_CTOR(GetGatherGradient);
vector<OperatorDef> MakeDefs() override {
return SingleDef(def.type() + "Gradient", "",
vector<string> {I(0), I(1), GO(0)},
vector<string> {GI(0)});
}
};
REGISTER_GRADIENT(At, GetAtGradient);
REGISTER_GRADIENT(Gather, GetGatherGradient);
} // namespace dragon
\ No newline at end of file
......@@ -14,9 +14,8 @@ void RandomPickOp<Context>::RunWithType() {
auto* Xdata = input(0).template data<T, Context>();
indices = pick_indices->template mutable_data<int, Context>();
auto* Ydata = output(0)->template mutable_data<T, Context>();
kernel::At<T, Context>(output(0)->count(), outer_dim, inner_dim,
x_slice_dim,
y_slice_dim,
kernel::Gather<T, Context>(output(0)->count(), outer_dim, inner_dim,
x_slice_dim, y_slice_dim,
indices,
Xdata,
Ydata,
......@@ -57,9 +56,8 @@ void RandomPickGradientOp<Context>::RunWithType() {
auto* dYdata = input(-1).template data<T, Context>();
auto* dXdata = output(0)->template mutable_data<T, Context>();
math::Set<T, Context>(output(0)->count(), 0, dXdata);
kernel::AtGrad<T, Context>(input(-1).count(), outer_dim, inner_dim,
x_slice_dim,
y_slice_dim,
kernel::GatherGrad<T, Context>(input(-1).count(), outer_dim, inner_dim,
x_slice_dim, y_slice_dim,
indices,
dYdata,
dXdata);
......
......@@ -9,22 +9,17 @@ template <class Context> template <typename T>
void ROIAlignOp<Context>::RunWithType() {
kernel::ROIAlign<T, Context>(spatial_scale,
pool_h, pool_w,
sampling_ratio,
&input(0),
&input(1),
mask_h,
mask_w,
output(0));
}
template <class Context>
void ROIAlignOp<Context>::RunOnDevice() {
mask_h = ws()->CreateTensor("/mnt/" + anchor() + "/roi_align_mask_h");
mask_w = ws()->CreateTensor("/mnt/" + anchor() + "/roi_align_mask_w");
vector<TIndex> dims({input(1).dim(0), input(0).dim(1), pool_h, pool_w});
output(0)->Reshape(dims);
mask_h->Reshape(dims);
mask_w->Reshape(dims);
output(0)->Reshape(vector<TIndex>({ input(1).dim(0),
input(0).dim(1),
pool_h, pool_w }));
if (input(0).template IsType<float>()) return RunWithType<float>();
else LOG(FATAL) << "Unsupported input types.";
......@@ -40,18 +35,14 @@ template <class Context> template <typename T>
void ROIAlignGradientOp<Context>::RunWithType() {
kernel::ROIAlignGrad<T, Context>(spatial_scale,
pool_h, pool_w,
sampling_ratio,
&input(-1),
&input(1),
mask_h,
mask_w,
output(0));
}
template <class Context>
void ROIAlignGradientOp<Context>::RunOnDevice() {
mask_h = ws()->GetTensor("/mnt/" + anchor() + "/roi_align_mask_h");
mask_w = ws()->GetTensor("/mnt/" + anchor() + "/roi_align_mask_w");
output(0)->ReshapeLike(input(0));
if (input(0).template IsType<float>()) return RunWithType<float>();
......
......@@ -503,13 +503,39 @@ template<> void AbsGrad<float, CPUContext>(const int count, const float* dy, flo
template <> void SigmoidCrossEntropy<float, CPUContext>(const int count,
const float* x,
const float* target,
float* loss) {
float* loss,
float* valid) {
#ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count))
#endif
for (int i = 0; i < count; ++i) {
if (target[i] < 0) {
loss[i] = 0.;
valid[i] = 0.;
} else {
loss[i] = std::log(1 + std::exp(x[i] - 2 * x[i] * (x[i] >= 0)))
+ x[i] * ((x[i] >= 0) - target[i]);
valid[i] = 1.;
}
}
}
template <> void SigmoidCrossEntropyGrad<float, CPUContext>(const int count,
const float* x,
const float* target,
float* dx,
float* valid) {
#ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count))
#endif
for (int i = 0; i < count; ++i) {
if (target[i] < 0) {
dx[i] = 0.;
valid[i] = 0.;
} else {
dx[i] = 1. / (1. + expf(-x[i])) - target[i];
valid[i] = 1.;
}
}
}
......@@ -902,7 +928,7 @@ template<> void Argmin<float, CPUContext>(const int count,
}
}
/******************** ndarray.at ********************/
/******************** ndarray.gather ********************/
template <> void CanonicalAxis<int, CPUContext>(const int count, const int dim, int* y) {
#ifdef WITH_OMP
......@@ -912,7 +938,7 @@ template <> void CanonicalAxis<int, CPUContext>(const int count, const int dim,
}
template <typename T>
void _At(const int count,
void _Gather(const int count,
const int outer_dim,
const int inner_dim,
const int x_slice_dim,
......@@ -935,7 +961,7 @@ void _At(const int count,
}
}
template <> void At<float, CPUContext>(const int count,
template <> void Gather<float, CPUContext>(const int count,
const int outer_dim,
const int inner_dim,
const int x_slice_dim,
......@@ -944,13 +970,13 @@ template <> void At<float, CPUContext>(const int count,
const float* x,
float* y,
CPUContext* ctx) {
_At<float>(count, outer_dim, inner_dim,
_Gather<float>(count, outer_dim, inner_dim,
x_slice_dim, y_slice_dim,
indices, x, y, ctx);
}
template <> void At<int, CPUContext>(const int count,
template <> void Gather<int, CPUContext>(const int count,
const int outer_dim,
const int inner_dim,
const int x_slice_dim,
......@@ -959,13 +985,13 @@ template <> void At<int, CPUContext>(const int count,
const int* x,
int* y,
CPUContext* ctx) {
_At<int>(count, outer_dim, inner_dim,
_Gather<int>(count, outer_dim, inner_dim,
x_slice_dim, y_slice_dim,
indices, x, y, ctx);
}
template <typename T>
void _AtGrad(const int count,
void _GatherGrad(const int count,
const int outer_dim,
const int inner_dim,
const int x_slice_dim,
......@@ -988,7 +1014,7 @@ void _AtGrad(const int count,
}
}
template <> void AtGrad<float, CPUContext>(const int count,
template <> void GatherGrad<float, CPUContext>(const int count,
const int outer_dim,
const int inner_dim,
const int x_slice_dim,
......@@ -996,12 +1022,12 @@ template <> void AtGrad<float, CPUContext>(const int count,
const int* indices,
const float* dy,
float* dx) {
_AtGrad<float>(count, outer_dim, inner_dim,
_GatherGrad<float>(count, outer_dim, inner_dim,
x_slice_dim, y_slice_dim,
indices, dy, dx);
}
template <> void AtGrad<int, CPUContext>(const int count,
template <> void GatherGrad<int, CPUContext>(const int count,
const int outer_dim,
const int inner_dim,
const int x_slice_dim,
......@@ -1009,7 +1035,7 @@ template <> void AtGrad<int, CPUContext>(const int count,
const int* indices,
const int* dy,
int* dx) {
_AtGrad<int>(count, outer_dim, inner_dim,
_GatherGrad<int>(count, outer_dim, inner_dim,
x_slice_dim, y_slice_dim,
indices, dy, dx);
}
......@@ -2694,20 +2720,18 @@ template<> void ROIPoolingGrad<float, CPUContext>(const float spatial_scale,
template<> void ROIAlign<float, CPUContext>(const float spatial_scale,
const int pool_h, const int pool_w,
const int sampling_ratio,
Tensor* x,
Tensor* roi,
Tensor* mask_h,
Tensor* mask_w,
Tensor* rois,
Tensor* y) {
NOT_IMPLEMENTED;
}
template<> void ROIAlignGrad<float, CPUContext>(const float spatial_scale,
const int pool_h, const int pool_w,
const int sampling_ratio,
Tensor* dy,
Tensor* roi,
Tensor* mask_h,
Tensor* mask_w,
Tensor* rois,
Tensor* dx) {
NOT_IMPLEMENTED;
}
......
......@@ -927,22 +927,61 @@ template<> void AbsGrad<float, CUDAContext>(const int count, const float* dy, fl
template <typename T>
__global__ void _SigmoidCrossEntropy(const int count,
const T* x,
const T* targets,
T* loss) {
const T* target,
T* loss,
T* valid) {
CUDA_KERNEL_LOOP(idx, count) {
if (target[idx] < 0) {
loss[idx] = 0.;
valid[idx] = 0.;
} else {
loss[idx] = std::log(1 + std::exp(x[idx] - 2 * x[idx] * (x[idx] >= 0)))
+ x[idx] * ((x[idx] >= 0) - targets[idx]);
+ x[idx] * ((x[idx] >= 0) - target[idx]);
valid[idx] = 1.;
}
}
}
template <> void SigmoidCrossEntropy<float, CUDAContext>(const int count,
const float* x,
const float* targets,
float* loss) {
const float* target,
float* loss,
float* valid) {
_SigmoidCrossEntropy<float> << <GET_BLOCKS(count), CUDA_NUM_THREADS >> >(count,
x,
targets,
loss);
target,
loss,
valid);
CUDA_POST_KERNEL_CHECK;
}
template <typename T>
__global__ void _SigmoidCrossEntropyGrad(const int count,
const T* x,
const T* target,
T* dx,
T* valid) {
CUDA_KERNEL_LOOP(idx, count) {
if (target[idx] < 0) {
dx[idx] = 0.;
valid[idx] = 0.;
} else {
dx[idx] = 1. / (1. + expf(-x[idx])) - target[idx];
valid[idx] = 1.;
}
}
}
template <> void SigmoidCrossEntropyGrad<float, CUDAContext>(const int count,
const float* x,
const float* target,
float* dx,
float* valid) {
_SigmoidCrossEntropyGrad<float> << <GET_BLOCKS(count), CUDA_NUM_THREADS >> >(count,
x,
target,
dx,
valid);
CUDA_POST_KERNEL_CHECK;
}
......@@ -1565,7 +1604,7 @@ template<> void Argmin<float, CUDAContext>(const int count,
CUDA_POST_KERNEL_CHECK;
}
/******************** ndarray.at ********************/
/******************** ndarray.gather ********************/
template <typename T>
__global__ void _CanonicalAxis(const int count, const int dim, T* y) {
......@@ -1580,7 +1619,7 @@ template <> void CanonicalAxis<int, CUDAContext>(const int count, const int dim,
}
template <typename T>
__global__ void _At(const int count,
__global__ void _Gather(const int count,
const int outer_dim,
const int inner_dim,
const int x_slice_dim,
......@@ -1599,7 +1638,7 @@ __global__ void _At(const int count,
}
}
template <> void At<float, CUDAContext>(const int count,
template <> void Gather<float, CUDAContext>(const int count,
const int outer_dim,
const int inner_dim,
const int x_slice_dim,
......@@ -1608,14 +1647,14 @@ template <> void At<float, CUDAContext>(const int count,
const float* x,
float* y,
CUDAContext* context) {
_At<float> << <GET_BLOCKS(count), CUDA_NUM_THREADS >> >(count,
_Gather<float> << <GET_BLOCKS(count), CUDA_NUM_THREADS >> >(count,
outer_dim, inner_dim,
x_slice_dim, y_slice_dim,
indices, x, y);
CUDA_POST_KERNEL_CHECK;
}
template <> void At<int, CUDAContext>(const int count,
template <> void Gather<int, CUDAContext>(const int count,
const int outer_dim,
const int inner_dim,
const int x_slice_dim,
......@@ -1624,7 +1663,7 @@ template <> void At<int, CUDAContext>(const int count,
const int* x,
int* y,
CUDAContext* context) {
_At<int> << <GET_BLOCKS(count), CUDA_NUM_THREADS >> >(count,
_Gather<int> << <GET_BLOCKS(count), CUDA_NUM_THREADS >> >(count,
outer_dim, inner_dim,
x_slice_dim, y_slice_dim,
indices, x, y);
......@@ -1632,7 +1671,7 @@ template <> void At<int, CUDAContext>(const int count,
}
template <typename T>
__global__ void _AtGrad(const int count,
__global__ void _GatherGrad(const int count,
const int outer_dim,
const int inner_dim,
const int x_slice_dim,
......@@ -1651,7 +1690,7 @@ __global__ void _AtGrad(const int count,
}
}
template <> void AtGrad<float, CUDAContext>(const int count,
template <> void GatherGrad<float, CUDAContext>(const int count,
const int outer_dim,
const int inner_dim,
const int x_slice_dim,
......@@ -1659,14 +1698,14 @@ template <> void AtGrad<float, CUDAContext>(const int count,
const int* indices,
const float* dy,
float* dx) {
_AtGrad<float> << <GET_BLOCKS(count), CUDA_NUM_THREADS >> >(count,
_GatherGrad<float> << <GET_BLOCKS(count), CUDA_NUM_THREADS >> >(count,
outer_dim, inner_dim,
x_slice_dim, y_slice_dim,
indices, dy, dx);
CUDA_POST_KERNEL_CHECK;
}
template <> void AtGrad<int, CUDAContext>(const int count,
template <> void GatherGrad<int, CUDAContext>(const int count,
const int outer_dim,
const int inner_dim,
const int x_slice_dim,
......@@ -1674,7 +1713,7 @@ template <> void AtGrad<int, CUDAContext>(const int count,
const int* indices,
const int* dy,
int* dx) {
_AtGrad<int> << <GET_BLOCKS(count), CUDA_NUM_THREADS >> >(count,
_GatherGrad<int> << <GET_BLOCKS(count), CUDA_NUM_THREADS >> >(count,
outer_dim, inner_dim,
x_slice_dim, y_slice_dim,
indices, dy, dx);
......@@ -3779,7 +3818,7 @@ __global__ void _ROIPooling(const int count,
const int height, const int width,
const int pool_h, const int pool_w,
const T* x,
const T* roi,
const T* rois,
int* mask,
T* y) {
CUDA_KERNEL_LOOP(idx, count) {
......@@ -3788,43 +3827,41 @@ __global__ void _ROIPooling(const int count,
int c = (idx / pool_w / pool_h) % channels;
int n = idx / pool_w / pool_h / channels;
roi += n * 5;
int im_idx = roi[0];
const T* offset_rois = rois + n * 5;
int roi_batch_ind = offset_rois[0];
if (im_idx < 0) {
if (roi_batch_ind < 0) {
y[idx] = 0;
mask[idx] = 0;
continue;
}
int x1 = round(roi[1] * spatial_scale);
int y1 = round(roi[2] * spatial_scale);
int x2 = round(roi[3] * spatial_scale);
int y2 = round(roi[4] * spatial_scale);
int roi_height = max(y2 - y1 + 1, 1);
int roi_width = max(x2 - x1 + 1, 1);
int roi_start_w = round(offset_rois[1] * spatial_scale);
int roi_start_h = round(offset_rois[2] * spatial_scale);
int roi_end_w = round(offset_rois[3] * spatial_scale);
int roi_end_h = round(offset_rois[4] * spatial_scale);
const float bin_size_h = (float)roi_height / (float)pool_h;
const float bin_size_w = (float)roi_width / (float)pool_w;
int roi_width = max(roi_end_w - roi_start_w + 1, 1);
int roi_height = max(roi_end_h - roi_start_h + 1, 1);
const T bin_size_h = (T)roi_height / (T)pool_h;
const T bin_size_w = (T)roi_width / (T)pool_w;
int start_h = floor(bin_size_h * ph);
int start_w = floor(bin_size_w * pw);
int end_h = ceil(bin_size_h * (ph + 1));
int end_w = ceil(bin_size_w * (pw + 1));
int hstart = floor(bin_size_h * ph);
int wstart = floor(bin_size_w * pw);
int hend = ceil(bin_size_h * (ph + 1));
int wend = ceil(bin_size_w * (pw + 1));
start_h = min(max(start_h + y1, 0), height);
start_w = min(max(start_w + x1, 0), width);
end_h = min(max(end_h + y1, 0), height);
end_w = min(max(end_w + x1, 0), width);
hstart = min(max(hstart + roi_start_h, 0), height);
hend = min(max(hend + roi_start_h, 0), height);
wstart = min(max(wstart + roi_start_w, 0), width);
wend = min(max(wend + roi_start_w, 0), width);
bool is_empty = (end_h <= start_h) || (end_w <= start_w);
bool is_empty = (hend <= hstart) || (wend <= wstart);
float max_val = is_empty ? 0 : -FLT_MAX;
int max_idx = -1;
x += ((im_idx * channels + c) * height * width);
for (int h = start_h; h < end_h; ++h) {
for (int w = start_w; w < end_w; ++w) {
x += ((roi_batch_ind * channels + c) * height * width);
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
const int x_idx = h * width + w;
if (x[x_idx] > max_val) {
max_val = x[x_idx];
......@@ -3832,7 +3869,6 @@ __global__ void _ROIPooling(const int count,
}
}
}
y[idx] = max_val;
mask[idx] = max_idx;
}
......@@ -3841,11 +3877,11 @@ __global__ void _ROIPooling(const int count,
template<> void ROIPooling<float, CUDAContext>(const float spatial_scale,
const int pool_h, const int pool_w,
Tensor* x,
Tensor* roi,
Tensor* rois,
Tensor* mask,
Tensor* y) {
auto* Xdata = x->data<float, CUDAContext>();
auto* Rdata = roi->data<float, CUDAContext>();
auto* Rdata = rois->data<float, CUDAContext>();
auto* Ydata = y->mutable_data<float, CUDAContext>();
auto* Mdata = mask->mutable_data<int, CUDAContext>();
TIndex channels = x->dim(1), count = y->count();
......@@ -3870,79 +3906,82 @@ __global__ void _ROIPoolingGrad(const int count,
const int height, const int width,
const int pool_h, const int pool_w,
const T* dy,
const T* roi,
const T* rois,
const int* mask,
T* dx) {
CUDA_KERNEL_LOOP(idx, count) {
int w = idx % width;
int h = (idx / width) % height;
int c = (idx / width / height) % channels;
int im_idx = idx / width / height / channels;
int n = idx / width / height / channels;
T diff = 0;
T gradient = 0;
for (int n = 0; n < num_rois; ++n) {
const T* cur_roi = roi + n * 5;
const int im_idx_spec = cur_roi[0];
for (int roi_n = 0; roi_n < num_rois; ++roi_n) {
const T* offset_rois = rois + roi_n * 5;
int roi_batch_ind = offset_rois[0];
if (im_idx != im_idx_spec) continue;
if (n != roi_batch_ind) continue;
int x1 = round(cur_roi[1] * spatial_scale);
int y1 = round(cur_roi[2] * spatial_scale);
int x2 = round(cur_roi[3] * spatial_scale);
int y2 = round(cur_roi[4] * spatial_scale);
int roi_start_w = round(offset_rois[1] * spatial_scale);
int roi_start_h = round(offset_rois[2] * spatial_scale);
int roi_end_w = round(offset_rois[3] * spatial_scale);
int roi_end_h = round(offset_rois[4] * spatial_scale);
const bool is_in = (w >= x1 && w <= x2 && h >= y1 && h <= y2);
const bool in_roi = (w >= roi_start_w &&
w <= roi_end_w &&
h >= roi_start_h &&
h <= roi_end_h);
if (!is_in) continue;
if (!in_roi) continue;
int roi_height = max(y2 - y1 + 1, 1);
int roi_width = max(x2 - x1 + 1, 1);
int y_offset = (n * channels + c) * pool_h * pool_w;
const T* offset_dy = dy + y_offset;
const int* offset_mask = mask + y_offset;
const float bin_size_h = (float)roi_height / (float)pool_h;
const float bin_size_w = (float)roi_width / (float)pool_w;
int roi_width = max(roi_end_w - roi_start_w + 1, 1);
int roi_height = max(roi_end_h - roi_start_h + 1, 1);
int start_ph = floor((h - y1) / bin_size_h);
int start_pw = floor((w - x1) / bin_size_w);
int end_ph = ceil((h + 1 - y1) / bin_size_h);
int end_pw = ceil((w + 1 - x1) / bin_size_w);
const T bin_size_h = (T)roi_height / (T)pool_h;
const T bin_size_w = (T)roi_width / (T)pool_w;
start_ph = min(max(start_ph, 0), pool_h);
start_pw = min(max(start_pw, 0), pool_w);
end_ph = min(max(end_ph, 0), pool_h);
end_pw = min(max(end_pw, 0), pool_w);
int phstart = floor(static_cast<T>(h - roi_start_h) / bin_size_h);
int phend = ceil(static_cast<T>(h - roi_start_h + 1) / bin_size_h);
int pwstart = floor(static_cast<T>(w - roi_start_w) / bin_size_w);
int pwend = ceil(static_cast<T>(w - roi_start_w + 1) / bin_size_w);
int y_offset = (n * channels + c) * pool_h * pool_w;
const T* dy_off = dy + y_offset;
const int* mask_off = mask + y_offset;
phstart = min(max(phstart, 0), pool_h);
phend = min(max(phend, 0), pool_h);
pwstart = min(max(pwstart, 0), pool_w);
pwend = min(max(pwend, 0), pool_w);
for (int ph = start_ph; ph < end_ph; ++ph) {
for (int pw = start_pw; pw < end_pw; ++pw) {
for (int ph = phstart; ph < phend; ++ph) {
for (int pw = pwstart; pw < pwend; ++pw) {
int pool_idx = ph * pool_w + pw;
if (mask_off[pool_idx] == (h * width + w)) {
diff += dy_off[pool_idx];
if (offset_mask[pool_idx] == (h * width + w)) {
gradient += offset_dy[pool_idx];
}
}
}
}
dx[idx] = diff;
dx[idx] = gradient;
}
}
template<> void ROIPoolingGrad<float, CUDAContext>(const float spatial_scale,
const int pool_h, const int pool_w,
Tensor* dy,
Tensor* roi,
Tensor* rois,
Tensor* mask,
Tensor* dx) {
auto* dYdata = dy->data<float, CUDAContext>();
auto* Rdata = roi->data<float, CUDAContext>();
auto* Rdata = rois->data<float, CUDAContext>();
auto* Mdata = mask->data<int, CUDAContext>();
auto* dXdata = dx->mutable_data<float, CUDAContext>();
TIndex channels = dx->dim(1), count = dx->count();
TIndex height = dx->dim(2), width = dx->dim(3);
_ROIPoolingGrad<float> << <GET_BLOCKS(count), CUDA_NUM_THREADS >> >(count,
roi->dim(0),
rois->dim(0),
spatial_scale,
channels,
height, width,
......@@ -3957,100 +3996,112 @@ template<> void ROIPoolingGrad<float, CUDAContext>(const float spatial_scale,
/******************** vision.roi_align ********************/
template <typename T>
__device__ T _ROIAlignInterpolate(const T* Xdata,
const int height,
const int width,
T y,
T x) {
if (y < -1.0 || y > height || x < -1.0 || x > width) return 0;
if (y <= 0) y = 0;
if (x <= 0) x = 0;
int y_low = (int)y;
int x_low = (int)x;
int y_high;
int x_high;
if (y_low >= height - 1) {
y_high = y_low = height - 1;
y = (T)y_low;
} else {
y_high = y_low + 1;
}
if (x_low >= width - 1) {
x_high = x_low = width - 1;
x = (T)x_low;
} else {
x_high = x_low + 1;
}
T ly = y - y_low;
T lx = x - x_low;
T hy = 1. - ly, hx = 1. - lx;
T v1 = Xdata[y_low * width + x_low];
T v2 = Xdata[y_low * width + x_high];
T v3 = Xdata[y_high * width + x_low];
T v4 = Xdata[y_high * width + x_high];
T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
}
template <typename T>
__global__ void _ROIAlign(const int count,
const float spatial_scale,
const int channels,
const int height, const int width,
const int pool_h, const int pool_w,
const T* x,
const T* roi,
T* mask_h,
T* mask_w,
T* y) {
const int sampling_ratio,
const T* Xdata,
const T* rois,
T* Ydata) {
CUDA_KERNEL_LOOP(idx, count) {
int pw = idx % pool_w;
int ph = (idx / pool_w) % pool_h;
int c = (idx / pool_w / pool_h) % channels;
int n = idx / pool_w / pool_h / channels;
roi += n * 5;
int roi_batch_ind = roi[0];
const T* offset_rois = rois + n * 5;
int roi_batch_ind = offset_rois[0];
if (roi_batch_ind < 0) {
y[idx] = 0;
mask_h[idx] = 0;
mask_w[idx] = 0;
Ydata[idx] = 0;
continue;
}
T roi_start_w = (roi[1]) * spatial_scale;
T roi_start_h = (roi[2]) * spatial_scale;
T roi_end_w = (roi[3]) * spatial_scale;
T roi_end_h = (roi[4]) * spatial_scale;
T roi_start_w = offset_rois[1] * spatial_scale;
T roi_start_h = offset_rois[2] * spatial_scale;
T roi_end_w = offset_rois[3] * spatial_scale;
T roi_end_h = offset_rois[4] * spatial_scale;
T roi_width = max(roi_end_w - roi_start_w, static_cast<T>(1));
T roi_height = max(roi_end_h - roi_start_h, static_cast<T>(1));
T roi_width = max(roi_end_w - roi_start_w, (T)1.);
T roi_height = max(roi_end_h - roi_start_h, (T)1.);
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pool_h);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pool_w);
T hstart = static_cast<T>((ph)* bin_size_h);
T wstart = static_cast<T>((pw)* bin_size_w);
T hend = static_cast<T>((ph + 1) * bin_size_h);
T wend = static_cast<T>((pw + 1) * bin_size_w);
hstart = min(max(hstart + roi_start_h, static_cast<T>(0)), static_cast<T>(height));
hend = min(max(hend + roi_start_h, static_cast<T>(0)), static_cast<T>(height));
wstart = min(max(wstart + roi_start_w, static_cast<T>(0)), static_cast<T>(width));
wend = min(max(wend + roi_start_w, static_cast<T>(0)), static_cast<T>(width));
bool is_empty = (hend <= hstart) || (wend <= wstart);
const T* offset_Xdata = Xdata + (roi_batch_ind * channels + c) * height * width;
T maxval = is_empty ? 0 : -FLT_MAX;
T max_h_idx = -1;
T max_w_idx = -1;
x += (roi_batch_ind * channels + c) * height * width;
T h_stride = (hend - hstart) / 3.0;
T w_stride = (wend - wstart) / 3.0;
for (T h = hstart + h_stride; h <= hend - h_stride + 0.01; h += max(h_stride, 0.01)) {
for (T w = wstart + w_stride; w <= wend - w_stride + 0.01; w += max(w_stride, 0.01)) {
int hlow = min(max(static_cast<int>(floor(h)), 0), height - 1);
int hhigh = min(max(static_cast<int>(ceil(h)), 0), height - 1);
int wleft = min(max(static_cast<int>(floor(w)), 0), width - 1);
int wright = min(max(static_cast<int>(ceil(w)), 0), width - 1);
int topleft = hlow * width + wleft;
int topright = hlow * width + wright;
int bottomleft = hhigh * width + wleft;
int bottomright = hhigh * width + wright;
int roi_bin_grid_h = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_height / pool_h);
int roi_bin_grid_w = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pool_w);
T alpha = (hlow == hhigh) ? static_cast<T>(0.5) : (h - hlow) / (hhigh - hlow);
T beta = (wleft == wright) ? static_cast<T>(0.5) : (w - wleft) / (wright - wleft);
T value = (1 - alpha) * (1 - beta) * x[topleft] + alpha * (1 - beta) * x[bottomleft]
+ (1 - alpha) * beta * x[topright] + alpha * beta * x[bottomright];
const T num_bin_grids = roi_bin_grid_h * roi_bin_grid_w;
if (value > maxval) {
maxval = value;
max_h_idx = h;
max_w_idx = w;
T output_val = 0.;
for (int iy = 0; iy < roi_bin_grid_h; iy++) {
const T y = roi_start_h + ph * bin_size_h +
static_cast<T>(iy + .5f) * bin_size_h / static_cast<T>(roi_bin_grid_h);
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
const T x = roi_start_w + pw * bin_size_w +
static_cast<T>(ix + .5f) * bin_size_w / static_cast<T>(roi_bin_grid_w);
T val = _ROIAlignInterpolate(offset_Xdata, height, width, y, x);
output_val += val;
}
}
}
y[idx] = maxval;
mask_h[idx] = max_h_idx;
mask_w[idx] = max_w_idx;
output_val /= num_bin_grids;
Ydata[idx] = output_val;
}
}
template<> void ROIAlign<float, CUDAContext>(const float spatial_scale,
const int pool_h, const int pool_w,
const int sampling_ratio,
Tensor* x,
Tensor* roi,
Tensor* mask_h,
Tensor* mask_w,
Tensor* rois,
Tensor* y) {
auto* Xdata = x->data<float, CUDAContext>();
auto* Rdata = roi->data<float, CUDAContext>();
auto* Rdata = rois->data<float, CUDAContext>();
auto* Ydata = y->mutable_data<float, CUDAContext>();
auto* MHdata = mask_h->mutable_data<float, CUDAContext>();
auto* MWdata = mask_w->mutable_data<float, CUDAContext>();
TIndex channels = x->dim(1), count = y->count();
TIndex height = x->dim(2), width = x->dim(3);
_ROIAlign<float> << <GET_BLOCKS(count), CUDA_NUM_THREADS >> >(count,
......@@ -4058,104 +4109,148 @@ template<> void ROIAlign<float, CUDAContext>(const float spatial_scale,
channels,
height, width,
pool_h, pool_w,
sampling_ratio,
Xdata,
Rdata,
MHdata,
MWdata,
Ydata);
CUDA_POST_KERNEL_CHECK;
}
template <typename T>
__device__ void _ROIAlignInterpolateGrad(const int height,
const int width,
T y, T x,
T & w1, T & w2, T & w3, T & w4,
int & x_low, int & x_high,
int & y_low, int & y_high) {
if (y < -1.0 || y > height || x < -1.0 || x > width) {
w1 = w2 = w3 = w4 = 0.;
x_low = x_high = y_low = y_high = -1;
return;
}
if (y <= 0) y = 0;
if (x <= 0) x = 0;
y_low = (int)y;
x_low = (int)x;
if (y_low >= height - 1) {
y_high = y_low = height - 1;
y = (T)y_low;
} else {
y_high = y_low + 1;
}
if (x_low >= width - 1) {
x_high = x_low = width - 1;
x = (T)x_low;
} else {
x_high = x_low + 1;
}
T ly = y - y_low;
T lx = x - x_low;
T hy = 1. - ly, hx = 1. - lx;
w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
return;
}
template <typename T>
__global__ void _ROIAlignGrad(const int count,
const int num_rois,
const T spatial_scale,
const int channels,
const int height, const int width,
const int pool_h, const int pool_w,
const T* dy,
const T* roi,
const T* mask_h,
const T* mask_w,
T* dx) {
const int sampling_ratio,
const T* dYdata,
const T* rois,
T* dXdata) {
CUDA_KERNEL_LOOP(idx, count) {
int w = idx % width;
int h = (idx / width) % height;
int c = (idx / width / height) % channels;
int n = idx / width / height / channels;
int pw = idx % pool_w;
int ph = (idx / pool_w) % pool_h;
int c = (idx / pool_w / pool_h) % channels;
int n = idx / pool_w / pool_h / channels;
T gradient = 0;
for (int roi_n = 0; roi_n < num_rois; ++roi_n) {
const T* offset_roi = roi + roi_n * 5;
int roi_batch_ind = offset_roi[0];
if (n != roi_batch_ind) continue;
const T* offset_rois = rois + n * 5;
int roi_batch_ind = offset_rois[0];
T roi_start_w = (offset_roi[1]) * spatial_scale;
T roi_start_h = (offset_roi[2]) * spatial_scale;
T roi_end_w = (offset_roi[3]) * spatial_scale;
T roi_end_h = (offset_roi[4]) * spatial_scale;
if (roi_batch_ind < 0) continue;
const bool in_roi = (w > roi_start_w - 1.0 &&
w < roi_end_w + 1.0 &&
h > roi_start_h - 1.0
&& h < roi_end_h + 1.0);
if (!in_roi) continue;
T roi_start_w = offset_rois[1] * spatial_scale;
T roi_start_h = offset_rois[2] * spatial_scale;
T roi_end_w = offset_rois[3] * spatial_scale;
T roi_end_h = offset_rois[4] * spatial_scale;
int offset = (roi_n * channels + c) * pool_h * pool_w;
const T* offset_dy = dy + offset;
const T* offset_mask_h = mask_h + offset;
const T* offset_mask_w = mask_w + offset;
T roi_width = max(roi_end_w - roi_start_w, static_cast<T>(1));
T roi_height = max(roi_end_h - roi_start_h, static_cast<T>(1));
for (int ph = 0; ph < pool_h; ++ph) {
for (int pw = 0; pw < pool_w; ++pw) {
const int pool_idx = ph * pool_w + pw;
T a_h = offset_mask_h[pool_idx];
T a_w = offset_mask_w[pool_idx];
int hlow = min(max(static_cast<int>(floor(a_h)), 0), height - 1);
int hhigh = min(max(static_cast<int>(ceil(a_h)), 0), height - 1);
int wleft = min(max(static_cast<int>(floor(a_w)), 0), width - 1);
int wright = min(max(static_cast<int>(ceil(a_w)), 0), width - 1);
if (h != hlow && h != hhigh && w != wleft && w != wright) continue;
T alpha = (hlow == hhigh) ? static_cast<T>(0.5) : (a_h - hlow) / (hhigh - hlow);
T beta = (wleft == wright) ? static_cast<T>(0.5) : (a_w - wleft) / (wright - wleft);
if (h == hlow && w == wleft) gradient += offset_dy[pool_idx] * (1 - alpha) * (1 - beta);
else if (h == hlow && w == wright) gradient += offset_dy[pool_idx] * (1 - alpha) * beta;
else if (h == hhigh && w == wleft) gradient += offset_dy[pool_idx] * alpha * (1 - beta);
else if (h == hhigh && w == wright) gradient += offset_dy[pool_idx] * alpha * beta;
T roi_width = max(roi_end_w - roi_start_w, (T)1.);
T roi_height = max(roi_end_h - roi_start_h, (T)1.);
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pool_h);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pool_w);
T* offset_dXdata = dXdata + (roi_batch_ind * channels + c) * height * width;
int y_offset = (n * channels + c) * pool_h * pool_w;
const T* offset_dYdata = dYdata + y_offset;
const T dYdata_this_bin = offset_dYdata[ph * pool_w + pw];
int roi_bin_grid_h = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_height / pool_h);
int roi_bin_grid_w = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pool_w);
const T num_bin_grids = roi_bin_grid_h * roi_bin_grid_w;
for (int iy = 0; iy < roi_bin_grid_h; iy++) {
const T y = roi_start_h + ph * bin_size_h +
static_cast<T>(iy + .5f) * bin_size_h / static_cast<T>(roi_bin_grid_h);
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
const T x = roi_start_w + pw * bin_size_w +
static_cast<T>(ix + .5f) * bin_size_w / static_cast<T>(roi_bin_grid_w);
T w1, w2, w3, w4;
int x_low, x_high, y_low, y_high;
_ROIAlignInterpolateGrad(height, width,
y, x,
w1, w2, w3, w4,
x_low, x_high, y_low, y_high);
T g1 = dYdata_this_bin * w1 / num_bin_grids;
T g2 = dYdata_this_bin * w2 / num_bin_grids;
T g3 = dYdata_this_bin * w3 / num_bin_grids;
T g4 = dYdata_this_bin * w4 / num_bin_grids;
if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
atomicAdd(offset_dXdata + y_low * width + x_low, static_cast<T>(g1));
atomicAdd(offset_dXdata + y_low * width + x_high, static_cast<T>(g2));
atomicAdd(offset_dXdata + y_high * width + x_low, static_cast<T>(g3));
atomicAdd(offset_dXdata + y_high * width + x_high, static_cast<T>(g4));
}
}
}
dx[idx] = gradient;
}
}
template<> void ROIAlignGrad<float, CUDAContext>(const float spatial_scale,
const int pool_h, const int pool_w,
const int sampling_ratio,
Tensor* dy,
Tensor* roi,
Tensor* mask_h,
Tensor* mask_w,
Tensor* rois,
Tensor* dx) {
auto* dYdata = dy->data<float, CUDAContext>();
auto* Rdata = roi->data<float, CUDAContext>();
auto* MHdata = mask_h->data<float, CUDAContext>();
auto* MWdata = mask_w->data<float, CUDAContext>();
auto* Rdata = rois->data<float, CUDAContext>();
auto* dXdata = dx->mutable_data<float, CUDAContext>();
TIndex channels = dx->dim(1), count = dx->count();
TIndex channels = dx->dim(1), count = dy->count();
TIndex height = dx->dim(2), width = dx->dim(3);
math::Set<float, CUDAContext>(dx->count(), 0, dXdata);
_ROIAlignGrad<float> << <GET_BLOCKS(count), CUDA_NUM_THREADS >> >(count,
roi->dim(0),
rois->dim(0),
spatial_scale,
channels,
height, width,
pool_h, pool_w,
sampling_ratio,
dYdata,
Rdata,
MHdata,
MWdata,
dXdata);
CUDA_POST_KERNEL_CHECK;
}
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!