Commit 36f27485 by Ting PAN

Refer the RoIAlign@Caffe2

1 parent 51056e19
Showing with 350 additions and 173 deletions
...@@ -16,14 +16,13 @@ class SigmoidCrossEntropyOp final : public Operator<Context> { ...@@ -16,14 +16,13 @@ class SigmoidCrossEntropyOp final : public Operator<Context> {
public: public:
SigmoidCrossEntropyOp(const OperatorDef& op_def, Workspace* ws) SigmoidCrossEntropyOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
normalization(OperatorBase::GetSingleArg<string>("normalization", "FULL")) {} normalization(OperatorBase::GetSingleArg<string>("normalization", "VALID")) {}
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
Tensor losses; Tensor valid, losses;
Tensor* prob;
string normalization; string normalization;
}; };
...@@ -32,13 +31,13 @@ class SigmoidCrossEntropyGradientOp final : public Operator<Context> { ...@@ -32,13 +31,13 @@ class SigmoidCrossEntropyGradientOp final : public Operator<Context> {
public: public:
SigmoidCrossEntropyGradientOp(const OperatorDef& op_def, Workspace* ws) SigmoidCrossEntropyGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
normalization(OperatorBase::GetSingleArg<string>("normalization", "FULL")) {} normalization(OperatorBase::GetSingleArg<string>("normalization", "VALID")) {}
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
Tensor* prob; Tensor valid;
string normalization; string normalization;
}; };
......
...@@ -4,17 +4,17 @@ ...@@ -4,17 +4,17 @@
// Written by Ting Pan // Written by Ting Pan
// -------------------------------------------------------- // --------------------------------------------------------
#ifndef DRAGON_OPERATORS_NDARRAY_AT_OP_H_ #ifndef DRAGON_OPERATORS_NDARRAY_GATHER_OP_H_
#define DRAGON_OPERATORS_NDARRAY_AT_OP_H_ #define DRAGON_OPERATORS_NDARRAY_GATHER_OP_H_
#include "core/operator.h" #include "core/operator.h"
namespace dragon { namespace dragon {
template <class Context> template <class Context>
class AtOp final : public Operator<Context> { class GatherOp final : public Operator<Context> {
public: public:
AtOp(const OperatorDef& op_def, Workspace* ws) GatherOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 0)) {} axis(OperatorBase::GetSingleArg<int>("axis", 0)) {}
...@@ -27,9 +27,9 @@ class AtOp final : public Operator<Context> { ...@@ -27,9 +27,9 @@ class AtOp final : public Operator<Context> {
}; };
template <class Context> template <class Context>
class AtGradientOp final : public Operator<Context> { class GatherGradientOp final : public Operator<Context> {
public: public:
AtGradientOp(const OperatorDef& op_def, Workspace* ws) GatherGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 0)), axis(OperatorBase::GetSingleArg<int>("axis", 0)),
acc_grad(OperatorBase::GetSingleArg<bool>("acc_gradient", false)) {} acc_grad(OperatorBase::GetSingleArg<bool>("acc_gradient", false)) {}
...@@ -44,4 +44,4 @@ class AtGradientOp final : public Operator<Context> { ...@@ -44,4 +44,4 @@ class AtGradientOp final : public Operator<Context> {
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_NDARRAY_AT_OP_H_ #endif // DRAGON_OPERATORS_NDARRAY_GATHER_OP_H_
\ No newline at end of file \ No newline at end of file
...@@ -18,7 +18,8 @@ class ROIAlignOp : public Operator<Context> { ...@@ -18,7 +18,8 @@ class ROIAlignOp : public Operator<Context> {
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
pool_h(OperatorBase::GetSingleArg<int>("pool_h", 0)), pool_h(OperatorBase::GetSingleArg<int>("pool_h", 0)),
pool_w(OperatorBase::GetSingleArg<int>("pool_w", 0)), pool_w(OperatorBase::GetSingleArg<int>("pool_w", 0)),
spatial_scale(OperatorBase::GetSingleArg<float>("spatial_scale", 1.0)) { spatial_scale(OperatorBase::GetSingleArg<float>("spatial_scale", 1.0)),
sampling_ratio(OperatorBase::GetSingleArg<int>("sampling_ratio", 2)) {
CHECK_GT(pool_h, 0) << "\npool_h must > 0"; CHECK_GT(pool_h, 0) << "\npool_h must > 0";
CHECK_GT(pool_w, 0) << "\npool_w must > 0"; CHECK_GT(pool_w, 0) << "\npool_w must > 0";
} }
...@@ -27,9 +28,8 @@ class ROIAlignOp : public Operator<Context> { ...@@ -27,9 +28,8 @@ class ROIAlignOp : public Operator<Context> {
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
int pool_h, pool_w; int pool_h, pool_w, sampling_ratio;
float spatial_scale; float spatial_scale;
Tensor* mask_h, *mask_w;
}; };
template <class Context> template <class Context>
...@@ -39,7 +39,8 @@ class ROIAlignGradientOp : public Operator<Context> { ...@@ -39,7 +39,8 @@ class ROIAlignGradientOp : public Operator<Context> {
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
pool_h(OperatorBase::GetSingleArg<int>("pool_h", 0)), pool_h(OperatorBase::GetSingleArg<int>("pool_h", 0)),
pool_w(OperatorBase::GetSingleArg<int>("pool_w", 0)), pool_w(OperatorBase::GetSingleArg<int>("pool_w", 0)),
spatial_scale(OperatorBase::GetSingleArg<float>("spatial_scale", 1.0)) { spatial_scale(OperatorBase::GetSingleArg<float>("spatial_scale", 1.0)),
sampling_ratio(OperatorBase::GetSingleArg<int>("sampling_ratio", 2)) {
CHECK_GT(pool_h, 0) << "\npool_h must > 0"; CHECK_GT(pool_h, 0) << "\npool_h must > 0";
CHECK_GT(pool_w, 0) << "\npool_w must > 0"; CHECK_GT(pool_w, 0) << "\npool_w must > 0";
} }
...@@ -48,9 +49,8 @@ class ROIAlignGradientOp : public Operator<Context> { ...@@ -48,9 +49,8 @@ class ROIAlignGradientOp : public Operator<Context> {
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
int pool_h, pool_w; int pool_h, pool_w, sampling_ratio;
float spatial_scale; float spatial_scale;
Tensor* mask_h, *mask_w;
}; };
} // namespace dragon } // namespace dragon
......
...@@ -198,7 +198,18 @@ void AbsGrad(const int count, const T* dy, T* dx); ...@@ -198,7 +198,18 @@ void AbsGrad(const int count, const T* dy, T* dx);
/******************** loss.sigmoid_cross_entropy ********************/ /******************** loss.sigmoid_cross_entropy ********************/
template <typename T, class Context> template <typename T, class Context>
void SigmoidCrossEntropy(const int count, const T* x, const T* target, T* loss); void SigmoidCrossEntropy(const int count,
const T* x,
const T* target,
T* loss,
T* valid);
template <typename T, class Context>
void SigmoidCrossEntropyGrad(const int count,
const T* x,
const T* target,
T* dx,
T* valid);
/******************** loss.smooth_l1_loss ********************/ /******************** loss.smooth_l1_loss ********************/
...@@ -312,13 +323,13 @@ void Argmin(const int count, ...@@ -312,13 +323,13 @@ void Argmin(const int count,
const T* x, const T* x,
T* y); T* y);
/******************** ndarray.at ********************/ /******************** ndarray.gather ********************/
template <typename T, class Context> template <typename T, class Context>
void CanonicalAxis(const int count, const int dim, T* y); void CanonicalAxis(const int count, const int dim, T* y);
template <typename T, class Context> template <typename T, class Context>
void At(const int count, void Gather(const int count,
const int outer_dim, const int outer_dim,
const int inner_dim, const int inner_dim,
const int x_slice_dim, const int x_slice_dim,
...@@ -329,7 +340,7 @@ void At(const int count, ...@@ -329,7 +340,7 @@ void At(const int count,
Context* ctx); Context* ctx);
template <typename T, class Context> template <typename T, class Context>
void AtGrad(const int count, void GatherGrad(const int count,
const int outer_dim, const int outer_dim,
const int inner_dim, const int inner_dim,
const int x_slice_dim, const int x_slice_dim,
...@@ -791,7 +802,7 @@ void ROIPooling(const float spatial_scale, ...@@ -791,7 +802,7 @@ void ROIPooling(const float spatial_scale,
const int pool_h, const int pool_h,
const int pool_w, const int pool_w,
Tensor* x, Tensor* x,
Tensor* roi, Tensor* rois,
Tensor* mask, Tensor* mask,
Tensor* y); Tensor* y);
...@@ -800,7 +811,7 @@ void ROIPoolingGrad(const float spatial_scale, ...@@ -800,7 +811,7 @@ void ROIPoolingGrad(const float spatial_scale,
const int pool_h, const int pool_h,
const int pool_w, const int pool_w,
Tensor* dy, Tensor* dy,
Tensor* roi, Tensor* rois,
Tensor* mask, Tensor* mask,
Tensor* dx); Tensor* dx);
...@@ -810,20 +821,18 @@ template <typename T, class Context> ...@@ -810,20 +821,18 @@ template <typename T, class Context>
void ROIAlign(const float spatial_scale, void ROIAlign(const float spatial_scale,
const int pool_h, const int pool_h,
const int pool_w, const int pool_w,
const int sampling_ratio,
Tensor* x, Tensor* x,
Tensor* roi, Tensor* rois,
Tensor* mask_h,
Tensor* mask_w,
Tensor* y); Tensor* y);
template <typename T, class Context> template <typename T, class Context>
void ROIAlignGrad(const float spatial_scale, void ROIAlignGrad(const float spatial_scale,
const int pool_h, const int pool_h,
const int pool_w, const int pool_w,
const int sampling_ratio,
Tensor* dy, Tensor* dy,
Tensor* roi, Tensor* rois,
Tensor* mask_h,
Tensor* mask_w,
Tensor* dx); Tensor* dx);
} // namespace kernel } // namespace kernel
......
...@@ -4,16 +4,22 @@ ...@@ -4,16 +4,22 @@
# Written by Ting Pan # Written by Ting Pan
# -------------------------------------------------------- # --------------------------------------------------------
import logging # core
import sys from dragon.core.tensor import Tensor
import dragon.core.workspace as workspace
try: # ops
from dragon.libdragon import * from dragon.ops import *
except ImportError as e:
logging.critical(
'cannot load dragon. Error: {0}'.format(str(e)))
sys.exit(1)
# updaters
from dragon.updaters import *
# theano utilities
from dragon.vm.theano.compile.function import function as function
from dragon.vm.theano.tensor import grad as grad
# scope
from dragon.core.scope import TensorScope as name_scope from dragon.core.scope import TensorScope as name_scope
from dragon.core.scope import PhaseScope as phase_scope from dragon.core.scope import PhaseScope as phase_scope
from dragon.core.scope import DeviceScope as device_scope from dragon.core.scope import DeviceScope as device_scope
...@@ -4,13 +4,18 @@ ...@@ -4,13 +4,18 @@
# Written by Ting Pan # Written by Ting Pan
# -------------------------------------------------------- # --------------------------------------------------------
from dragon.__init__ import * from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import logging import logging
logger = logging.getLogger('dragon') logger = logging.getLogger('dragon')
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
logger.addHandler(logging.StreamHandler(sys.stdout)) logger.addHandler(logging.StreamHandler(sys.stdout))
from dragon.import_c_apis import *
option = {} option = {}
REGISTERED_OPERATORS = set(s for s in RegisteredOperatorsCC()) REGISTERED_OPERATORS = set(s for s in RegisteredOperatorsCC())
......
...@@ -4,14 +4,20 @@ ...@@ -4,14 +4,20 @@
# Written by Ting Pan # Written by Ting Pan
# -------------------------------------------------------- # --------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import defaultdict
from dragon.import_c_apis import *
import dragon.config as config import dragon.config as config
import dragon.protos.dragon_pb2 as pb import dragon.protos.dragon_pb2 as pb
from collections import defaultdict
from dragon.core.utils import MakeOperatorDef from dragon.core.utils import MakeOperatorDef
from dragon.__init__ import *
from .scope import GetOperatorName from .scope import GetOperatorName
class GraphGradientMaker(object): class GraphGradientMaker(object):
""" """
GraphGradientMaker is deigned to generate gradient operators automatically. GraphGradientMaker is deigned to generate gradient operators automatically.
......
...@@ -4,11 +4,14 @@ ...@@ -4,11 +4,14 @@
# Written by Ting Pan # Written by Ting Pan
# -------------------------------------------------------- # --------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np import numpy as np
from six.moves import range as xrange from six.moves import range as xrange
from dragon import MPIInitCC, MPIRankCC, MPISizeCC, \ from dragon.import_c_apis import *
MPICreateGroupCC, MPIFinalizeCC
_is_init = False _is_init = False
_snapshot_ranks = [] _snapshot_ranks = []
......
...@@ -4,7 +4,9 @@ ...@@ -4,7 +4,9 @@
# Written by Ting Pan # Written by Ting Pan
# -------------------------------------------------------- # --------------------------------------------------------
from collections import defaultdict from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
_TENSOR_SCOPE = '' _TENSOR_SCOPE = ''
_PHASE_SCOPE = '' _PHASE_SCOPE = ''
......
...@@ -4,13 +4,18 @@ ...@@ -4,13 +4,18 @@
# Written by Ting Pan # Written by Ting Pan
# -------------------------------------------------------- # --------------------------------------------------------
import dragon.core.workspace as ws from __future__ import absolute_import
import dragon.protos.dragon_pb2 as pb from __future__ import division
from __future__ import print_function
import numpy as np import numpy as np
from collections import OrderedDict from collections import OrderedDict
from six.moves import range as xrange
import dragon.core.workspace as ws
import dragon.protos.dragon_pb2 as pb
from dragon.core.utils import MakeOperatorDef from dragon.core.utils import MakeOperatorDef
from dragon.core.scope import GetOperatorName, GetTensorName from dragon.core.scope import GetOperatorName, GetTensorName
from six.moves import range as xrange
class Tensor(object): class Tensor(object):
...@@ -416,7 +421,7 @@ class Tensor(object): ...@@ -416,7 +421,7 @@ class Tensor(object):
if not isinstance(item, tuple): if not isinstance(item, tuple):
# 1D At # 1D At
if isinstance(item, int): if isinstance(item, int):
output = self.CreateOperator(inputs=[self, wrapper_indices([item])], nout=1, op_type='At') output = self.CreateOperator(inputs=[self, wrapper_indices([item])], nout=1, op_type='Gather')
if self.shape is not None: if self.shape is not None:
output.shape = self.shape[:] output.shape = self.shape[:]
output.shape[0] = 1 output.shape[0] = 1
......
...@@ -4,11 +4,16 @@ ...@@ -4,11 +4,16 @@
# Written by Ting Pan # Written by Ting Pan
# -------------------------------------------------------- # --------------------------------------------------------
import sys from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import numpy as np
from google.protobuf.message import Message from google.protobuf.message import Message
from dragon.protos import dragon_pb2 as pb from dragon.protos import dragon_pb2 as pb
import numpy as np
if sys.version_info >= (3,0): if sys.version_info >= (3,0):
def MakeArgument(key, value): def MakeArgument(key, value):
......
...@@ -4,19 +4,25 @@ ...@@ -4,19 +4,25 @@
# Written by Ting Pan # Written by Ting Pan
# -------------------------------------------------------- # --------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
try: try:
import cPickle import cPickle
except: except:
import pickle as cPickle import pickle as cPickle
import dragon.core.utils as utils
import dragon.core.mpi as mpi
import dragon.protos.dragon_pb2 as pb
import numpy as np
import os import os
from dragon import * import numpy as np
from google.protobuf.message import Message from google.protobuf.message import Message
from six.moves import range as xrange from six.moves import range as xrange
from dragon.import_c_apis import *
import dragon.core.utils as utils
import dragon.core.mpi as mpi
import dragon.protos.dragon_pb2 as pb
CURRENT_GRAPH_IDX = 0 CURRENT_GRAPH_IDX = 0
__all__ = [ __all__ = [
...@@ -44,6 +50,7 @@ _DATA_TYPES = { ...@@ -44,6 +50,7 @@ _DATA_TYPES = {
'float64': np.float64, 'float64': np.float64,
} }
def _stringify_proto(obj): def _stringify_proto(obj):
""" """
Stringify a protobuf structure. Stringify a protobuf structure.
......
...@@ -32,9 +32,9 @@ Vision ...@@ -32,9 +32,9 @@ Vision
=================== ====================================================================== =================== ======================================================================
List Brief List Brief
=================== ====================================================================== =================== ======================================================================
`Conv2D`_ 2D Convolution. `Conv2d`_ 2d Convolution.
`Deconv2D`_ 2D Deconvolution. `Conv2dTranspose`_ 2d Deconvolution.
`Pool2D`_ 2D Pooling, MAX or AVG. `Pool2d`_ 2d Pooling, MAX or AVG.
`ROIPooling`_ ROIPoolin(MAX), introduced by `[Girshick, 2015] <https://arxiv.org/abs/1504.08083>`_. `ROIPooling`_ ROIPoolin(MAX), introduced by `[Girshick, 2015] <https://arxiv.org/abs/1504.08083>`_.
`ROIAlign`_ ROIAlign(MAX), introduced by `[He et.al, 2017] <https://arxiv.org/abs/1703.06870>`_. `ROIAlign`_ ROIAlign(MAX), introduced by `[He et.al, 2017] <https://arxiv.org/abs/1703.06870>`_.
`LRN`_ Local Response Normalization, introduced by `[Krizhevsky et.al, 2012] <http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks>`_. `LRN`_ Local Response Normalization, introduced by `[Krizhevsky et.al, 2012] <http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks>`_.
...@@ -122,8 +122,8 @@ NDArray ...@@ -122,8 +122,8 @@ NDArray
=============== ====================================================================== =============== ======================================================================
List Brief List Brief
=============== ====================================================================== =============== ======================================================================
`At`_ 1D At interface of NDArray. `Gather`_ Gather the input according to the indices along the given axis.
`RandomPick`_ 1D RandomPick interface of NDArray. `RandomPick`_ Randomly pick the input along the given axis.
`Reduce`_ The general reduce operator. `Reduce`_ The general reduce operator.
`Sum`_ Compute the sum along the given axis. `Sum`_ Compute the sum along the given axis.
`Mean`_ Compute the mean along the given axis. `Mean`_ Compute the mean along the given axis.
...@@ -195,9 +195,9 @@ List Brief ...@@ -195,9 +195,9 @@ List Brief
.. _GlorotUniform: operators/initializer.html#dragon.operators.initializer.GlorotUniform .. _GlorotUniform: operators/initializer.html#dragon.operators.initializer.GlorotUniform
.. _GlorotNormal: operators/initializer.html#dragon.operators.initializer.GlorotNormal .. _GlorotNormal: operators/initializer.html#dragon.operators.initializer.GlorotNormal
.. _Conv2D: operators/vision.html#dragon.operators.vision.Conv2D .. _Conv2d: operators/vision.html#dragon.operators.vision.Conv2d
.. _Deconv2D: operators/vision.html#dragon.operators.vision.Deconv2D .. _Conv2dTranspose: operators/vision.html#dragon.operators.vision.Conv2dTranspose
.. _Pool2D: operators/vision.html#dragon.operators.vision.Pool2D .. _Pool2d: operators/vision.html#dragon.operators.vision.Pool2d
.. _ROIPooling: operators/vision.html#dragon.operators.vision.ROIPooling .. _ROIPooling: operators/vision.html#dragon.operators.vision.ROIPooling
.. _ROIAlign: operators/vision.html#dragon.operators.vision.ROIAlign .. _ROIAlign: operators/vision.html#dragon.operators.vision.ROIAlign
.. _LRN: operators/vision.html#dragon.operators.vision.LRN .. _LRN: operators/vision.html#dragon.operators.vision.LRN
...@@ -249,7 +249,7 @@ List Brief ...@@ -249,7 +249,7 @@ List Brief
.. _InstanceNorm: operators/norm.html#dragon.operators.norm.InstanceNorm .. _InstanceNorm: operators/norm.html#dragon.operators.norm.InstanceNorm
.. _L2Norm: operators/norm.html#dragon.operators.norm.L2Norm .. _L2Norm: operators/norm.html#dragon.operators.norm.L2Norm
.. _At: operators/ndarray.html#dragon.operators.ndarray.At .. _Gather: operators/ndarray.html#dragon.operators.ndarray.Gather
.. _RandomPick: operators/ndarray.html#dragon.operators.ndarray.RandomPick .. _RandomPick: operators/ndarray.html#dragon.operators.ndarray.RandomPick
.. _Crop: operators/ndarray.html#dragon.operators.ndarray.Crop .. _Crop: operators/ndarray.html#dragon.operators.ndarray.Crop
.. _Reduce: operators/ndarray.html#dragon.operators.ndarray.Reduce .. _Reduce: operators/ndarray.html#dragon.operators.ndarray.Reduce
......
...@@ -68,6 +68,7 @@ List Brief ...@@ -68,6 +68,7 @@ List Brief
`ReshapeLayer`_ The implementation of ``ReshapeLayer``. `ReshapeLayer`_ The implementation of ``ReshapeLayer``.
`PermuteLayer`_ The implementation of ``PermuteLayer``. `PermuteLayer`_ The implementation of ``PermuteLayer``.
`FlattenLayer`_ The implementation of ``FlattenLayer``. `FlattenLayer`_ The implementation of ``FlattenLayer``.
`GatherLayer`_ The extended implementation for ``GatherOp``.
`SoftmaxLayer`_ The implementation of ``SoftmaxLayer``. `SoftmaxLayer`_ The implementation of ``SoftmaxLayer``.
`ArgMaxLayer`_ The implementation of ``ArgMaxLayer``. `ArgMaxLayer`_ The implementation of ``ArgMaxLayer``.
`BatchNormLayer`_ The implementation of ``BatchNormLayer``. `BatchNormLayer`_ The implementation of ``BatchNormLayer``.
...@@ -174,6 +175,7 @@ API Reference ...@@ -174,6 +175,7 @@ API Reference
.. _ReshapeLayer: #dragon.vm.caffe.layers.common.ReshapeLayer .. _ReshapeLayer: #dragon.vm.caffe.layers.common.ReshapeLayer
.. _PermuteLayer: #dragon.vm.caffe.layers.common.PermuteLayer .. _PermuteLayer: #dragon.vm.caffe.layers.common.PermuteLayer
.. _FlattenLayer: #dragon.vm.caffe.layers.common.FlattenLayer .. _FlattenLayer: #dragon.vm.caffe.layers.common.FlattenLayer
.. _GatherLayer: #dragon.vm.caffe.layers.common.GatherLayer
.. _SoftmaxLayer: #dragon.vm.caffe.layers.common.SoftmaxLayer .. _SoftmaxLayer: #dragon.vm.caffe.layers.common.SoftmaxLayer
.. _ArgMaxLayer: #dragon.vm.caffe.layers.common.ArgMaxLayer .. _ArgMaxLayer: #dragon.vm.caffe.layers.common.ArgMaxLayer
.. _BatchNormLayer: #dragon.vm.caffe.layers.common.BatchNormLayer .. _BatchNormLayer: #dragon.vm.caffe.layers.common.BatchNormLayer
......
# --------------------------------------------------------
# Dragon
# Copyright(c) 2017 SeetaTech
# Written by Ting Pan
# --------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import logging
try:
from dragon.libdragon import *
except ImportError as e:
logging.critical(
'cannot load dragon. Error: {0}'.format(str(e)))
sys.exit(1)
\ No newline at end of file
...@@ -4,8 +4,13 @@ ...@@ -4,8 +4,13 @@
# Written by Ting Pan # Written by Ting Pan
# -------------------------------------------------------- # --------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from . import * from . import *
def Relu(inputs, **kwargs): def Relu(inputs, **kwargs):
"""Rectified Linear Unit function, introduces by `[Nair & Hinton, 2010] <http://www.csri.utoronto.ca/~hinton/absps/reluICML.pdf>`_. """Rectified Linear Unit function, introduces by `[Nair & Hinton, 2010] <http://www.csri.utoronto.ca/~hinton/absps/reluICML.pdf>`_.
......
...@@ -4,8 +4,13 @@ ...@@ -4,8 +4,13 @@
# Written by Ting Pan # Written by Ting Pan
# -------------------------------------------------------- # --------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from . import * from . import *
def Add(inputs, **kwargs): def Add(inputs, **kwargs):
"""Calculate A + B. """Calculate A + B.
......
...@@ -4,8 +4,13 @@ ...@@ -4,8 +4,13 @@
# Written by Ting Pan # Written by Ting Pan
# -------------------------------------------------------- # --------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from . import * from . import *
def FloatToHalf(inputs, **kwargs): def FloatToHalf(inputs, **kwargs):
"""Cast the type of tensor from ``float32`` to ``float16``. """Cast the type of tensor from ``float32`` to ``float16``.
......
...@@ -4,8 +4,13 @@ ...@@ -4,8 +4,13 @@
# Written by Ting Pan # Written by Ting Pan
# -------------------------------------------------------- # --------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from . import * from . import *
def Copy(inputs, **kwargs): def Copy(inputs, **kwargs):
"""Copy A to B. """Copy A to B.
......
...@@ -4,11 +4,15 @@ ...@@ -4,11 +4,15 @@
# Written by Ting Pan # Written by Ting Pan
# -------------------------------------------------------- # --------------------------------------------------------
import numpy as np from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from dragon.operators.misc import Run from dragon.operators.misc import Run
from . import * from . import *
def LMDBData(**kwargs): def LMDBData(**kwargs):
"""Prefetch Image data with `LMDB`_ database. """Prefetch Image data with `LMDB`_ database.
......
...@@ -4,6 +4,10 @@ ...@@ -4,6 +4,10 @@
# Written by Ting Pan # Written by Ting Pan
# -------------------------------------------------------- # --------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from . import * from . import *
......
...@@ -4,10 +4,15 @@ ...@@ -4,10 +4,15 @@
# Written by Ting Pan # Written by Ting Pan
# -------------------------------------------------------- # --------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np import numpy as np
from . import * from . import *
def SparseSoftmaxCrossEntropy(inputs, axis=1, normalization='VALID', ignore_labels=(), **kwargs): def SparseSoftmaxCrossEntropy(inputs, axis=1, normalization='VALID', ignore_labels=(), **kwargs):
"""SoftmaxCrossEntropy with sparse labels. """SoftmaxCrossEntropy with sparse labels.
...@@ -48,7 +53,7 @@ def SparseSoftmaxCrossEntropy(inputs, axis=1, normalization='VALID', ignore_labe ...@@ -48,7 +53,7 @@ def SparseSoftmaxCrossEntropy(inputs, axis=1, normalization='VALID', ignore_labe
return output return output
def SigmoidCrossEntropy(inputs, normalization='FULL', **kwargs): def SigmoidCrossEntropy(inputs, normalization='VALID', **kwargs):
"""SigmoidCrossEntropy with binary labels. """SigmoidCrossEntropy with binary labels.
Parameters Parameters
...@@ -56,7 +61,7 @@ def SigmoidCrossEntropy(inputs, normalization='FULL', **kwargs): ...@@ -56,7 +61,7 @@ def SigmoidCrossEntropy(inputs, normalization='FULL', **kwargs):
inputs : list of Tensor inputs : list of Tensor
The inputs, represent [input, labels]. The inputs, represent [input, labels].
normalization : str normalization : str
The normalization, ``UNIT``, ``FULL``, ``BATCH_SIZE`` or ``NONE``. The normalization, ``UNIT``, ``FULL``, ``VALID``, ``BATCH_SIZE`` or ``NONE``.
Returns Returns
------- -------
......
...@@ -4,8 +4,13 @@ ...@@ -4,8 +4,13 @@
# Written by Ting Pan # Written by Ting Pan
# -------------------------------------------------------- # --------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from . import * from . import *
def Run(inputs, module, op, param_str='', nout=1, **kwargs): def Run(inputs, module, op, param_str='', nout=1, **kwargs):
"""Run a custom operator. (Without GradientFlow) """Run a custom operator. (Without GradientFlow)
......
...@@ -4,11 +4,17 @@ ...@@ -4,11 +4,17 @@
# Written by Ting Pan # Written by Ting Pan
# -------------------------------------------------------- # --------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from six.moves import range as xrange from six.moves import range as xrange
import dragon.core.mpi as mpi import dragon.core.mpi as mpi
from . import * from . import *
def MPIBroadcast(inputs, root, mpi_ranks=None, **kwargs): def MPIBroadcast(inputs, root, mpi_ranks=None, **kwargs):
"""Broadcast a tensor to all nodes in the ``MPIGroup``. """Broadcast a tensor to all nodes in the ``MPIGroup``.
......
...@@ -4,21 +4,24 @@ ...@@ -4,21 +4,24 @@
# Written by Ting Pan # Written by Ting Pan
# -------------------------------------------------------- # --------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np import numpy as np
from six.moves import range as xrange from six.moves import range as xrange
from dragon.core.tensor import GetTensorName
import dragon.core.workspace as ws
from . import * from . import *
def At(inputs, indices, axis=0, acc_gradient=False, **kwargs):
"""1D At interface of NDArray. def Gather(inputs, indices, axis=0, acc_gradient=False, **kwargs):
"""Gather the input according to the indices along the given axis.
Parameters Parameters
---------- ----------
inputs : Tensor inputs : Tensor
The input tensor. The input tensor.
indices : list or Tensor indices : int, list or Tensor
The indices to form output tensor. The indices to form output tensor.
axis : int axis : int
The start axis. The start axis.
...@@ -31,29 +34,28 @@ def At(inputs, indices, axis=0, acc_gradient=False, **kwargs): ...@@ -31,29 +34,28 @@ def At(inputs, indices, axis=0, acc_gradient=False, **kwargs):
The output tensor. The output tensor.
""" """
CheckInputs(inputs, 1)
arguments = ParseArguments(locals()) arguments = ParseArguments(locals())
arguments['inputs'] = [arguments['inputs'],
Tensor.Convert(indices, dtype='int32')]
arguments['indices'] = None
if isinstance(inputs, list): CheckInputs(inputs, 2) output = Tensor.CreateOperator(op_type='Gather', nout=1, **arguments)
elif isinstance(inputs, Tensor):
if not isinstance(indices, list):
raise ValueError('The type of indices should be list.')
indices = np.array(indices, dtype=np.float32)
tensor = GetTensorName()
ws.FeedTensor(tensor, indices)
arguments['inputs'] = [arguments['inputs'], Tensor(tensor)]
output = Tensor.CreateOperator(op_type='At', nout=1, **arguments)
if isinstance(inputs, Tensor):
if inputs.shape is not None: if inputs.shape is not None:
output.shape = inputs.shape[:] output.shape = inputs.shape[:]
if not isinstance(indices, Tensor):
if not isinstance(indices, (list, tuple)):
indices = [indices]
output.shape[axis] = len(indices) output.shape[axis] = len(indices)
else:
output.shape[axis] = None
return output return output
def RandomPick(inputs, max_samples=1, axis=0, **kwargs): def RandomPick(inputs, max_samples=1, axis=0, **kwargs):
"""1D RandomPick interface of NDArray. """Randomly pick the input along the given axis.
Parameters Parameters
---------- ----------
...@@ -541,8 +543,6 @@ def Pad(inputs, paddings, mode='CONSTANT', value=0, **kwargs): ...@@ -541,8 +543,6 @@ def Pad(inputs, paddings, mode='CONSTANT', value=0, **kwargs):
output = Tensor.CreateOperator(nout=1, op_type='Pad', **arguments) output = Tensor.CreateOperator(nout=1, op_type='Pad', **arguments)
return output return output
......
...@@ -4,8 +4,13 @@ ...@@ -4,8 +4,13 @@
# Written by Ting Pan # Written by Ting Pan
# -------------------------------------------------------- # --------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from . import * from . import *
def BatchNorm(inputs, axis=-1, momentum=0.9, eps=1e-3, def BatchNorm(inputs, axis=-1, momentum=0.9, eps=1e-3,
use_stats=-1, mode='DEFAULT', **kwargs): use_stats=-1, mode='DEFAULT', **kwargs):
"""Batch Normalization, introduced by `[Ioffe & Szegedy, 2015] <https://arxiv.org/abs/1502.03167>`_. """Batch Normalization, introduced by `[Ioffe & Szegedy, 2015] <https://arxiv.org/abs/1502.03167>`_.
......
...@@ -4,8 +4,13 @@ ...@@ -4,8 +4,13 @@
# Written by Ting Pan # Written by Ting Pan
# -------------------------------------------------------- # --------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from . import * from . import *
def LSTMUnit(c_t_1, gate_input, cont_t=None, **kwargs): def LSTMUnit(c_t_1, gate_input, cont_t=None, **kwargs):
"""Simple LSTMCell module. """Simple LSTMCell module.
...@@ -31,13 +36,3 @@ def LSTMUnit(c_t_1, gate_input, cont_t=None, **kwargs): ...@@ -31,13 +36,3 @@ def LSTMUnit(c_t_1, gate_input, cont_t=None, **kwargs):
arguments['cont_t'] = cont_t.name arguments['cont_t'] = cont_t.name
return Tensor.CreateOperator(inputs=[c_t_1, gate_input], nout=2, return Tensor.CreateOperator(inputs=[c_t_1, gate_input], nout=2,
op_type='LSTMUnit', **arguments) op_type='LSTMUnit', **arguments)
\ No newline at end of file
...@@ -13,6 +13,7 @@ from six.moves import range as xrange ...@@ -13,6 +13,7 @@ from six.moves import range as xrange
from . import * from . import *
def Conv2d(inputs, num_output, kernel_size, def Conv2d(inputs, num_output, kernel_size,
stride=1, pad=0, dilation=1, group=1, stride=1, pad=0, dilation=1, group=1,
padding='VALID', data_format='NCHW', **kwargs): padding='VALID', data_format='NCHW', **kwargs):
...@@ -327,7 +328,7 @@ def ROIPooling(inputs, pool_h, pool_w, spatial_scale, **kwargs): ...@@ -327,7 +328,7 @@ def ROIPooling(inputs, pool_h, pool_w, spatial_scale, **kwargs):
return Tensor.CreateOperator(nout=1, op_type='ROIPooling', **arguments) return Tensor.CreateOperator(nout=1, op_type='ROIPooling', **arguments)
def ROIAlign(inputs, pool_h=0, pool_w=0, spatial_scale=1.0, **kwargs): def ROIAlign(inputs, pool_h=0, pool_w=0, spatial_scale=1.0, sampling_ratio=2, **kwargs):
"""Max ROIAlign, introduced by `[He et.al, 2017] <https://arxiv.org/abs/1703.06870>`_. """Max ROIAlign, introduced by `[He et.al, 2017] <https://arxiv.org/abs/1703.06870>`_.
The first dimension of input must be ``1``. The first dimension of input must be ``1``.
...@@ -342,6 +343,8 @@ def ROIAlign(inputs, pool_h=0, pool_w=0, spatial_scale=1.0, **kwargs): ...@@ -342,6 +343,8 @@ def ROIAlign(inputs, pool_h=0, pool_w=0, spatial_scale=1.0, **kwargs):
The width of pooled tensor. The width of pooled tensor.
spatial_scale : float spatial_scale : float
The ``inverse`` of total down-sampling multiples on input tensor. The ``inverse`` of total down-sampling multiples on input tensor.
sampling_ratio : int
The number of sampling grids for each RoI bin.
Returns Returns
------- -------
......
...@@ -4,6 +4,10 @@ ...@@ -4,6 +4,10 @@
# Written by Ting Pan # Written by Ting Pan
# -------------------------------------------------------- # --------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from .operators import initializer as init from .operators import initializer as init
from .operators import vision from .operators import vision
from .operators import loss from .operators import loss
...@@ -92,7 +96,7 @@ InstanceNorm = norm.InstanceNorm ...@@ -92,7 +96,7 @@ InstanceNorm = norm.InstanceNorm
L2Norm = norm.L2Norm L2Norm = norm.L2Norm
# ndarray # ndarray
At = ndarray.At Gather = ndarray.Gather
RandomPick = ndarray.RandomPick RandomPick = ndarray.RandomPick
Crop = ndarray.Crop Crop = ndarray.Crop
Reduce = ndarray.Reduce Reduce = ndarray.Reduce
......
...@@ -4,11 +4,17 @@ ...@@ -4,11 +4,17 @@
# Written by Ting Pan # Written by Ting Pan
# -------------------------------------------------------- # --------------------------------------------------------
import numpy as np from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import pprint import pprint
import numpy as np
import dragon.core.workspace as ws import dragon.core.workspace as ws
from dragon.core.tensor import Tensor from dragon.core.tensor import Tensor
class BaseUpdater(object): class BaseUpdater(object):
""" """
BaseUpdater is designed to preprocess the gradients. BaseUpdater is designed to preprocess the gradients.
......
...@@ -50,6 +50,7 @@ from .common import InnerProductLayer, \ ...@@ -50,6 +50,7 @@ from .common import InnerProductLayer, \
ArgMaxLayer, \ ArgMaxLayer, \
PermuteLayer, \ PermuteLayer, \
FlattenLayer, \ FlattenLayer, \
GatherLayer, \
ConcatLayer, \ ConcatLayer, \
NormalizeLayer, \ NormalizeLayer, \
InstanceNormLayer, \ InstanceNormLayer, \
......
...@@ -266,6 +266,25 @@ class FlattenLayer(Layer): ...@@ -266,6 +266,25 @@ class FlattenLayer(Layer):
return ops.Flatten(input, **self._param) return ops.Flatten(input, **self._param)
class GatherLayer(Layer):
"""The extended implementation of ``GatherOp``.
Parameters
----------
axis : int
The axis for gathering. Refer `GatherParameter.axis`_.
"""
def __init__(self, LayerParameter):
super(GatherLayer, self).__init__(LayerParameter)
param = LayerParameter.gather_param
self._param = {'axis': param.axis}
def Setup(self, bottom):
super(GatherLayer, self).Setup(bottom)
return ops.Gather(bottom[0], indices=bottom[1], **self._param)
class SoftmaxLayer(Layer): class SoftmaxLayer(Layer):
"""The implementation of ``SoftmaxLayer``. """The implementation of ``SoftmaxLayer``.
......
...@@ -57,12 +57,12 @@ class SigmoidCrossEntropyLossLayer(Layer): ...@@ -57,12 +57,12 @@ class SigmoidCrossEntropyLossLayer(Layer):
def __init__(self, LayerParameter): def __init__(self, LayerParameter):
super(SigmoidCrossEntropyLossLayer, self).__init__(LayerParameter) super(SigmoidCrossEntropyLossLayer, self).__init__(LayerParameter)
param = LayerParameter.loss_param param = LayerParameter.loss_param
norm_mode = {0: 'FULL', 1: 'BATCH_SIZE', 2: 'BATCH_SIZE', 3: 'NONE', 4: 'UNIT'} norm_mode = {0: 'FULL', 1: 'VALID', 2: 'BATCH_SIZE', 3: 'NONE', 4: 'UNIT'}
normalization = 'BATCH_SIZE' normalization = 'VALID'
if param.HasField('normalize'): if param.HasField('normalize'):
if param.normalize: normalization = 'FULL' if not param.normalize: normalization = 'BATCH_SIZE'
else: normalization = norm_mode[param.normalization] else: normalization = norm_mode[param.normalization]
self._param = { 'normalization': normalization } self._param = {'normalization': normalization}
def Setup(self, bottom): def Setup(self, bottom):
super(SigmoidCrossEntropyLossLayer, self).Setup(bottom) super(SigmoidCrossEntropyLossLayer, self).Setup(bottom)
......
...@@ -422,6 +422,7 @@ message LayerParameter { ...@@ -422,6 +422,7 @@ message LayerParameter {
optional BatchRenormParameter batch_renorm_param = 161; optional BatchRenormParameter batch_renorm_param = 161;
optional DenseConcatParameter dense_concat_param = 163; optional DenseConcatParameter dense_concat_param = 163;
optional FocalLossParameter focal_loss_param = 164; optional FocalLossParameter focal_loss_param = 164;
optional GatherParameter gather_param = 165;
} }
// Message that stores parameters used to apply transformation // Message that stores parameters used to apply transformation
...@@ -1504,3 +1505,7 @@ message FocalLossParameter { ...@@ -1504,3 +1505,7 @@ message FocalLossParameter {
optional int32 neg_id = 4 [default = -1]; optional int32 neg_id = 4 [default = -1];
} }
message GatherParameter {
optional int32 axis = 1 [default = 0];
}
...@@ -4,10 +4,11 @@ ...@@ -4,10 +4,11 @@
# Written by Ting Pan # Written by Ting Pan
# -------------------------------------------------------- # --------------------------------------------------------
import sys
import copy import copy
from collections import OrderedDict
import numpy as np import numpy as np
import sys from collections import OrderedDict
from six.moves import xrange
import dragon.core.mpi as mpi import dragon.core.mpi as mpi
import dragon.core.workspace as ws import dragon.core.workspace as ws
......
...@@ -36,7 +36,7 @@ find_packages('dragon') ...@@ -36,7 +36,7 @@ find_packages('dragon')
find_modules() find_modules()
setup(name = 'dragon', setup(name = 'dragon',
version='0.2.1.2', version='0.2.1.3',
description = 'Dragon: A Computation Graph Virtual Machine Based Deep Learning Framework', description = 'Dragon: A Computation Graph Virtual Machine Based Deep Learning Framework',
url='https://github.com/neopenx/Dragon', url='https://github.com/neopenx/Dragon',
author='Ting Pan', author='Ting Pan',
......
...@@ -8,12 +8,15 @@ namespace dragon { ...@@ -8,12 +8,15 @@ namespace dragon {
template <class Context> template <typename T> template <class Context> template <typename T>
void SigmoidCrossEntropyOp<Context>::RunWithType() { void SigmoidCrossEntropyOp<Context>::RunWithType() {
auto* Xdata = input(0).template data<T, Context>(); auto* Xdata = input(0).template data<T, Context>();
auto* Pdata = prob->template mutable_data<T, Context>();
kernel::Sigmoid<T, Context>(prob->count(), Xdata, Pdata);
auto* Tdata = input(1).template data<T, Context>(); auto* Tdata = input(1).template data<T, Context>();
auto* Ldata = losses.template mutable_data<T, Context>(); auto* Ldata = losses.template mutable_data<T, Context>();
kernel::SigmoidCrossEntropy<T, Context>(input(0).count(), Xdata, Tdata, Ldata); auto* Vdata = valid.template mutable_data<T, Context>();
kernel::SigmoidCrossEntropy<T, Context>(input(0).count(),
Xdata,
Tdata,
Ldata,
Vdata);
if (normalization == "UNIT") { if (normalization == "UNIT") {
output(0)->ReshapeLike(losses); output(0)->ReshapeLike(losses);
...@@ -22,7 +25,9 @@ void SigmoidCrossEntropyOp<Context>::RunWithType() { ...@@ -22,7 +25,9 @@ void SigmoidCrossEntropyOp<Context>::RunWithType() {
} }
T normalizer; T normalizer;
if (normalization == "BATCH_SIZE") normalizer = input(0).dim(0); if (normalization == "VALID")
normalizer = math::ASum<T, Context>(valid.count(), Vdata);
else if (normalization == "BATCH_SIZE") normalizer = input(0).dim(0);
else if (normalization == "FULL") normalizer = input(0).count(); else if (normalization == "FULL") normalizer = input(0).count();
else if (normalization == "NONE") normalizer = 1; else if (normalization == "NONE") normalizer = 1;
T loss = math::ASum<T, Context>(losses.count(), Ldata); T loss = math::ASum<T, Context>(losses.count(), Ldata);
...@@ -35,9 +40,8 @@ template <class Context> ...@@ -35,9 +40,8 @@ template <class Context>
void SigmoidCrossEntropyOp<Context>::RunOnDevice() { void SigmoidCrossEntropyOp<Context>::RunOnDevice() {
CHECK_EQ(input(0).count(), input(1).count()) CHECK_EQ(input(0).count(), input(1).count())
<< "\nNumber of predictions must match the number of labels."; << "\nNumber of predictions must match the number of labels.";
prob = ws()->CreateTensor("/mnt/" + anchor() + "/sigmoid_prob");
prob->ReshapeLike(input(0));
losses.ReshapeLike(input(0)); losses.ReshapeLike(input(0));
valid.ReshapeLike(input(0));
if (input(0).template IsType<float>()) RunWithType<float>(); if (input(0).template IsType<float>()) RunWithType<float>();
else LOG(FATAL) << "Unsupported input types."; else LOG(FATAL) << "Unsupported input types.";
...@@ -51,11 +55,16 @@ OPERATOR_SCHEMA(SigmoidCrossEntropy).NumInputs(2).NumOutputs(1); ...@@ -51,11 +55,16 @@ OPERATOR_SCHEMA(SigmoidCrossEntropy).NumInputs(2).NumOutputs(1);
template <class Context> template <typename T> template <class Context> template <typename T>
void SigmoidCrossEntropyGradientOp<Context>::RunWithType() { void SigmoidCrossEntropyGradientOp<Context>::RunWithType() {
auto* Pdata = prob->template data<T, Context>(); auto* Xdata = input(0).template data<T, Context>();
auto* Tdata = input(1).template data<T, Context>(); auto* Tdata = input(1).template data<T, Context>();
auto* Vdata = valid.template mutable_data<T, Context>();
auto* dXdata = output(0)->template mutable_data<T, Context>(); auto* dXdata = output(0)->template mutable_data<T, Context>();
ctx().template Copy<T, Context, Context>(prob->count(), dXdata, Pdata);
math::Axpy<T, Context>(output(0)->count(), -1.0, Tdata, dXdata); kernel::SigmoidCrossEntropyGrad<T, Context>(input(0).count(),
Xdata,
Tdata,
dXdata,
Vdata);
if (normalization == "UNIT") { if (normalization == "UNIT") {
auto* dYdata = input(-1).template data<T, Context>(); auto* dYdata = input(-1).template data<T, Context>();
...@@ -64,7 +73,8 @@ void SigmoidCrossEntropyGradientOp<Context>::RunWithType() { ...@@ -64,7 +73,8 @@ void SigmoidCrossEntropyGradientOp<Context>::RunWithType() {
} }
T normalizer; T normalizer;
if (normalization == "BATCH_SIZE") normalizer = input(0).dim(0); if (normalization == "VALID") normalizer = math::ASum<T, Context>(valid.count(), Vdata);
else if (normalization == "BATCH_SIZE") normalizer = input(0).dim(0);
else if (normalization == "FULL") normalizer = input(0).count(); else if (normalization == "FULL") normalizer = input(0).count();
else if (normalization == "NONE") normalizer = 1; else if (normalization == "NONE") normalizer = 1;
auto* dYdata = input(-1).template data<T, CPUContext>(); auto* dYdata = input(-1).template data<T, CPUContext>();
...@@ -73,8 +83,8 @@ void SigmoidCrossEntropyGradientOp<Context>::RunWithType() { ...@@ -73,8 +83,8 @@ void SigmoidCrossEntropyGradientOp<Context>::RunWithType() {
template <class Context> template <class Context>
void SigmoidCrossEntropyGradientOp<Context>::RunOnDevice() { void SigmoidCrossEntropyGradientOp<Context>::RunOnDevice() {
prob = ws()->GetTensor("/mnt/" + anchor() + "/sigmoid_prob");
output(0)->ReshapeLike(input(0)); output(0)->ReshapeLike(input(0));
valid.ReshapeLike(input(0));
if (input(0).template IsType<float>()) RunWithType<float>(); if (input(0).template IsType<float>()) RunWithType<float>();
else LOG(FATAL) << "Unsupported input types."; else LOG(FATAL) << "Unsupported input types.";
......
#include "operators/ndarray/at_op.h" #include "operators/ndarray/gather_op.h"
#include "core/workspace.h" #include "core/workspace.h"
#include "utils/math_functions.h" #include "utils/math_functions.h"
#include "utils/op_kernel.h" #include "utils/op_kernel.h"
...@@ -6,12 +6,12 @@ ...@@ -6,12 +6,12 @@
namespace dragon { namespace dragon {
template <class Context> template <typename T> template <class Context> template <typename T>
void AtOp<Context>::RunWithType() { void GatherOp<Context>::RunWithType() {
auto* Xdata = input(0).template data<T, Context>(); auto* Xdata = input(0).template data<T, Context>();
auto* indices = input(1).template mutable_data<int, Context>(); auto* indices = input(1).template mutable_data<int, Context>();
auto* Ydata = output(0)->template mutable_data<T, Context>(); auto* Ydata = output(0)->template mutable_data<T, Context>();
kernel::CanonicalAxis<int, Context>(input(1).count(), x_slice_dim, indices); kernel::CanonicalAxis<int, Context>(input(1).count(), x_slice_dim, indices);
kernel::At<T, Context>(output(0)->count(), outer_dim, inner_dim, kernel::Gather<T, Context>(output(0)->count(), outer_dim, inner_dim,
x_slice_dim, y_slice_dim, x_slice_dim, y_slice_dim,
indices, indices,
Xdata, Xdata,
...@@ -20,7 +20,7 @@ void AtOp<Context>::RunWithType() { ...@@ -20,7 +20,7 @@ void AtOp<Context>::RunWithType() {
} }
template <class Context> template <class Context>
void AtOp<Context>::RunOnDevice() { void GatherOp<Context>::RunOnDevice() {
output_dims = input(0).dims(); output_dims = input(0).dims();
x_slice_dim = input(0).dim(axis); x_slice_dim = input(0).dim(axis);
output_dims[axis] = y_slice_dim = input(1).count(); output_dims[axis] = y_slice_dim = input(1).count();
...@@ -35,19 +35,19 @@ void AtOp<Context>::RunOnDevice() { ...@@ -35,19 +35,19 @@ void AtOp<Context>::RunOnDevice() {
else LOG(FATAL) << "Unsupported input types."; else LOG(FATAL) << "Unsupported input types.";
} }
DEPLOY_CPU(At); DEPLOY_CPU(Gather);
#ifdef WITH_CUDA #ifdef WITH_CUDA
DEPLOY_CUDA(At); DEPLOY_CUDA(Gather);
#endif #endif
OPERATOR_SCHEMA(At).NumInputs(2).NumOutputs(1); OPERATOR_SCHEMA(Gather).NumInputs(2).NumOutputs(1);
template <class Context> template <typename T> template <class Context> template <typename T>
void AtGradientOp<Context>::RunWithType() { void GatherGradientOp<Context>::RunWithType() {
auto* indices = input(1).template data<int, Context>(); auto* indices = input(1).template data<int, Context>();
auto* dYdata = input(-1).template data<T, Context>(); auto* dYdata = input(-1).template data<T, Context>();
auto* dXdata = output(0)->template mutable_data<T, Context>(); auto* dXdata = output(0)->template mutable_data<T, Context>();
if (!acc_grad) math::Set<T, Context>(output(0)->count(), 0, dXdata); if (!acc_grad) math::Set<T, Context>(output(0)->count(), 0, dXdata);
kernel::AtGrad<T, Context>(input(-1).count(), outer_dim, inner_dim, kernel::GatherGrad<T, Context>(input(-1).count(), outer_dim, inner_dim,
x_slice_dim, y_slice_dim, x_slice_dim, y_slice_dim,
indices, indices,
dYdata, dYdata,
...@@ -55,7 +55,7 @@ void AtGradientOp<Context>::RunWithType() { ...@@ -55,7 +55,7 @@ void AtGradientOp<Context>::RunWithType() {
} }
template <class Context> template <class Context>
void AtGradientOp<Context>::RunOnDevice() { void GatherGradientOp<Context>::RunOnDevice() {
x_slice_dim = input(0).dim(axis); x_slice_dim = input(0).dim(axis);
y_slice_dim = input(1).count(); y_slice_dim = input(1).count();
outer_dim = input(0).count(0, axis); outer_dim = input(0).count(0, axis);
...@@ -68,21 +68,21 @@ void AtGradientOp<Context>::RunOnDevice() { ...@@ -68,21 +68,21 @@ void AtGradientOp<Context>::RunOnDevice() {
else LOG(FATAL) << "Unsupported input types."; else LOG(FATAL) << "Unsupported input types.";
} }
DEPLOY_CPU(AtGradient); DEPLOY_CPU(GatherGradient);
#ifdef WITH_CUDA #ifdef WITH_CUDA
DEPLOY_CUDA(AtGradient); DEPLOY_CUDA(GatherGradient);
#endif #endif
OPERATOR_SCHEMA(AtGradient).NumInputs(3).NumOutputs(1); OPERATOR_SCHEMA(GatherGradient).NumInputs(3).NumOutputs(1);
class GetAtGradient final : public GradientMakerBase { class GetGatherGradient final : public GradientMakerBase {
public: public:
GRADIENT_MAKER_CTOR(GetAtGradient); GRADIENT_MAKER_CTOR(GetGatherGradient);
vector<OperatorDef> MakeDefs() override { vector<OperatorDef> MakeDefs() override {
return SingleDef(def.type() + "Gradient", "", return SingleDef(def.type() + "Gradient", "",
vector<string> {I(0), I(1), GO(0)}, vector<string> {I(0), I(1), GO(0)},
vector<string> {GI(0)}); vector<string> {GI(0)});
} }
}; };
REGISTER_GRADIENT(At, GetAtGradient); REGISTER_GRADIENT(Gather, GetGatherGradient);
} // namespace dragon } // namespace dragon
\ No newline at end of file
...@@ -14,9 +14,8 @@ void RandomPickOp<Context>::RunWithType() { ...@@ -14,9 +14,8 @@ void RandomPickOp<Context>::RunWithType() {
auto* Xdata = input(0).template data<T, Context>(); auto* Xdata = input(0).template data<T, Context>();
indices = pick_indices->template mutable_data<int, Context>(); indices = pick_indices->template mutable_data<int, Context>();
auto* Ydata = output(0)->template mutable_data<T, Context>(); auto* Ydata = output(0)->template mutable_data<T, Context>();
kernel::At<T, Context>(output(0)->count(), outer_dim, inner_dim, kernel::Gather<T, Context>(output(0)->count(), outer_dim, inner_dim,
x_slice_dim, x_slice_dim, y_slice_dim,
y_slice_dim,
indices, indices,
Xdata, Xdata,
Ydata, Ydata,
...@@ -57,9 +56,8 @@ void RandomPickGradientOp<Context>::RunWithType() { ...@@ -57,9 +56,8 @@ void RandomPickGradientOp<Context>::RunWithType() {
auto* dYdata = input(-1).template data<T, Context>(); auto* dYdata = input(-1).template data<T, Context>();
auto* dXdata = output(0)->template mutable_data<T, Context>(); auto* dXdata = output(0)->template mutable_data<T, Context>();
math::Set<T, Context>(output(0)->count(), 0, dXdata); math::Set<T, Context>(output(0)->count(), 0, dXdata);
kernel::AtGrad<T, Context>(input(-1).count(), outer_dim, inner_dim, kernel::GatherGrad<T, Context>(input(-1).count(), outer_dim, inner_dim,
x_slice_dim, x_slice_dim, y_slice_dim,
y_slice_dim,
indices, indices,
dYdata, dYdata,
dXdata); dXdata);
......
...@@ -9,22 +9,17 @@ template <class Context> template <typename T> ...@@ -9,22 +9,17 @@ template <class Context> template <typename T>
void ROIAlignOp<Context>::RunWithType() { void ROIAlignOp<Context>::RunWithType() {
kernel::ROIAlign<T, Context>(spatial_scale, kernel::ROIAlign<T, Context>(spatial_scale,
pool_h, pool_w, pool_h, pool_w,
sampling_ratio,
&input(0), &input(0),
&input(1), &input(1),
mask_h,
mask_w,
output(0)); output(0));
} }
template <class Context> template <class Context>
void ROIAlignOp<Context>::RunOnDevice() { void ROIAlignOp<Context>::RunOnDevice() {
mask_h = ws()->CreateTensor("/mnt/" + anchor() + "/roi_align_mask_h"); output(0)->Reshape(vector<TIndex>({ input(1).dim(0),
mask_w = ws()->CreateTensor("/mnt/" + anchor() + "/roi_align_mask_w"); input(0).dim(1),
vector<TIndex> dims({input(1).dim(0), input(0).dim(1), pool_h, pool_w}); pool_h, pool_w }));
output(0)->Reshape(dims);
mask_h->Reshape(dims);
mask_w->Reshape(dims);
if (input(0).template IsType<float>()) return RunWithType<float>(); if (input(0).template IsType<float>()) return RunWithType<float>();
else LOG(FATAL) << "Unsupported input types."; else LOG(FATAL) << "Unsupported input types.";
...@@ -40,18 +35,14 @@ template <class Context> template <typename T> ...@@ -40,18 +35,14 @@ template <class Context> template <typename T>
void ROIAlignGradientOp<Context>::RunWithType() { void ROIAlignGradientOp<Context>::RunWithType() {
kernel::ROIAlignGrad<T, Context>(spatial_scale, kernel::ROIAlignGrad<T, Context>(spatial_scale,
pool_h, pool_w, pool_h, pool_w,
sampling_ratio,
&input(-1), &input(-1),
&input(1), &input(1),
mask_h,
mask_w,
output(0)); output(0));
} }
template <class Context> template <class Context>
void ROIAlignGradientOp<Context>::RunOnDevice() { void ROIAlignGradientOp<Context>::RunOnDevice() {
mask_h = ws()->GetTensor("/mnt/" + anchor() + "/roi_align_mask_h");
mask_w = ws()->GetTensor("/mnt/" + anchor() + "/roi_align_mask_w");
output(0)->ReshapeLike(input(0)); output(0)->ReshapeLike(input(0));
if (input(0).template IsType<float>()) return RunWithType<float>(); if (input(0).template IsType<float>()) return RunWithType<float>();
......
...@@ -503,13 +503,39 @@ template<> void AbsGrad<float, CPUContext>(const int count, const float* dy, flo ...@@ -503,13 +503,39 @@ template<> void AbsGrad<float, CPUContext>(const int count, const float* dy, flo
template <> void SigmoidCrossEntropy<float, CPUContext>(const int count, template <> void SigmoidCrossEntropy<float, CPUContext>(const int count,
const float* x, const float* x,
const float* target, const float* target,
float* loss) { float* loss,
float* valid) {
#ifdef WITH_OMP #ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count)) #pragma omp parallel for num_threads(GET_OMP_THREADS(count))
#endif #endif
for (int i = 0; i < count; ++i) { for (int i = 0; i < count; ++i) {
if (target[i] < 0) {
loss[i] = 0.;
valid[i] = 0.;
} else {
loss[i] = std::log(1 + std::exp(x[i] - 2 * x[i] * (x[i] >= 0))) loss[i] = std::log(1 + std::exp(x[i] - 2 * x[i] * (x[i] >= 0)))
+ x[i] * ((x[i] >= 0) - target[i]); + x[i] * ((x[i] >= 0) - target[i]);
valid[i] = 1.;
}
}
}
template <> void SigmoidCrossEntropyGrad<float, CPUContext>(const int count,
const float* x,
const float* target,
float* dx,
float* valid) {
#ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count))
#endif
for (int i = 0; i < count; ++i) {
if (target[i] < 0) {
dx[i] = 0.;
valid[i] = 0.;
} else {
dx[i] = 1. / (1. + expf(-x[i])) - target[i];
valid[i] = 1.;
}
} }
} }
...@@ -902,7 +928,7 @@ template<> void Argmin<float, CPUContext>(const int count, ...@@ -902,7 +928,7 @@ template<> void Argmin<float, CPUContext>(const int count,
} }
} }
/******************** ndarray.at ********************/ /******************** ndarray.gather ********************/
template <> void CanonicalAxis<int, CPUContext>(const int count, const int dim, int* y) { template <> void CanonicalAxis<int, CPUContext>(const int count, const int dim, int* y) {
#ifdef WITH_OMP #ifdef WITH_OMP
...@@ -912,7 +938,7 @@ template <> void CanonicalAxis<int, CPUContext>(const int count, const int dim, ...@@ -912,7 +938,7 @@ template <> void CanonicalAxis<int, CPUContext>(const int count, const int dim,
} }
template <typename T> template <typename T>
void _At(const int count, void _Gather(const int count,
const int outer_dim, const int outer_dim,
const int inner_dim, const int inner_dim,
const int x_slice_dim, const int x_slice_dim,
...@@ -935,7 +961,7 @@ void _At(const int count, ...@@ -935,7 +961,7 @@ void _At(const int count,
} }
} }
template <> void At<float, CPUContext>(const int count, template <> void Gather<float, CPUContext>(const int count,
const int outer_dim, const int outer_dim,
const int inner_dim, const int inner_dim,
const int x_slice_dim, const int x_slice_dim,
...@@ -944,13 +970,13 @@ template <> void At<float, CPUContext>(const int count, ...@@ -944,13 +970,13 @@ template <> void At<float, CPUContext>(const int count,
const float* x, const float* x,
float* y, float* y,
CPUContext* ctx) { CPUContext* ctx) {
_At<float>(count, outer_dim, inner_dim, _Gather<float>(count, outer_dim, inner_dim,
x_slice_dim, y_slice_dim, x_slice_dim, y_slice_dim,
indices, x, y, ctx); indices, x, y, ctx);
} }
template <> void At<int, CPUContext>(const int count, template <> void Gather<int, CPUContext>(const int count,
const int outer_dim, const int outer_dim,
const int inner_dim, const int inner_dim,
const int x_slice_dim, const int x_slice_dim,
...@@ -959,13 +985,13 @@ template <> void At<int, CPUContext>(const int count, ...@@ -959,13 +985,13 @@ template <> void At<int, CPUContext>(const int count,
const int* x, const int* x,
int* y, int* y,
CPUContext* ctx) { CPUContext* ctx) {
_At<int>(count, outer_dim, inner_dim, _Gather<int>(count, outer_dim, inner_dim,
x_slice_dim, y_slice_dim, x_slice_dim, y_slice_dim,
indices, x, y, ctx); indices, x, y, ctx);
} }
template <typename T> template <typename T>
void _AtGrad(const int count, void _GatherGrad(const int count,
const int outer_dim, const int outer_dim,
const int inner_dim, const int inner_dim,
const int x_slice_dim, const int x_slice_dim,
...@@ -988,7 +1014,7 @@ void _AtGrad(const int count, ...@@ -988,7 +1014,7 @@ void _AtGrad(const int count,
} }
} }
template <> void AtGrad<float, CPUContext>(const int count, template <> void GatherGrad<float, CPUContext>(const int count,
const int outer_dim, const int outer_dim,
const int inner_dim, const int inner_dim,
const int x_slice_dim, const int x_slice_dim,
...@@ -996,12 +1022,12 @@ template <> void AtGrad<float, CPUContext>(const int count, ...@@ -996,12 +1022,12 @@ template <> void AtGrad<float, CPUContext>(const int count,
const int* indices, const int* indices,
const float* dy, const float* dy,
float* dx) { float* dx) {
_AtGrad<float>(count, outer_dim, inner_dim, _GatherGrad<float>(count, outer_dim, inner_dim,
x_slice_dim, y_slice_dim, x_slice_dim, y_slice_dim,
indices, dy, dx); indices, dy, dx);
} }
template <> void AtGrad<int, CPUContext>(const int count, template <> void GatherGrad<int, CPUContext>(const int count,
const int outer_dim, const int outer_dim,
const int inner_dim, const int inner_dim,
const int x_slice_dim, const int x_slice_dim,
...@@ -1009,7 +1035,7 @@ template <> void AtGrad<int, CPUContext>(const int count, ...@@ -1009,7 +1035,7 @@ template <> void AtGrad<int, CPUContext>(const int count,
const int* indices, const int* indices,
const int* dy, const int* dy,
int* dx) { int* dx) {
_AtGrad<int>(count, outer_dim, inner_dim, _GatherGrad<int>(count, outer_dim, inner_dim,
x_slice_dim, y_slice_dim, x_slice_dim, y_slice_dim,
indices, dy, dx); indices, dy, dx);
} }
...@@ -2694,20 +2720,18 @@ template<> void ROIPoolingGrad<float, CPUContext>(const float spatial_scale, ...@@ -2694,20 +2720,18 @@ template<> void ROIPoolingGrad<float, CPUContext>(const float spatial_scale,
template<> void ROIAlign<float, CPUContext>(const float spatial_scale, template<> void ROIAlign<float, CPUContext>(const float spatial_scale,
const int pool_h, const int pool_w, const int pool_h, const int pool_w,
const int sampling_ratio,
Tensor* x, Tensor* x,
Tensor* roi, Tensor* rois,
Tensor* mask_h,
Tensor* mask_w,
Tensor* y) { Tensor* y) {
NOT_IMPLEMENTED; NOT_IMPLEMENTED;
} }
template<> void ROIAlignGrad<float, CPUContext>(const float spatial_scale, template<> void ROIAlignGrad<float, CPUContext>(const float spatial_scale,
const int pool_h, const int pool_w, const int pool_h, const int pool_w,
const int sampling_ratio,
Tensor* dy, Tensor* dy,
Tensor* roi, Tensor* rois,
Tensor* mask_h,
Tensor* mask_w,
Tensor* dx) { Tensor* dx) {
NOT_IMPLEMENTED; NOT_IMPLEMENTED;
} }
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!