Commit ae11a987 by Ting PAN

Add ChannelShuffleOp

1 parent 6b82cb26
## General
# Compiled Object files # Compiled Object files
*.slo *.slo
*.lo *.lo
...@@ -7,13 +5,15 @@ ...@@ -7,13 +5,15 @@
*.cuo *.cuo
# Compiled Dynamic libraries # Compiled Dynamic libraries
# *.so *.so
*.dll
*.dylib *.dylib
# Compiled Static libraries # Compiled Static libraries
*.lai *.lai
*.la *.la
#*.a *.a
*.lib
# Compiled python # Compiled python
*.pyc *.pyc
...@@ -40,6 +40,9 @@ __pycache__ ...@@ -40,6 +40,9 @@ __pycache__
# QtCreator files # QtCreator files
*.user *.user
# VSCode files
.vscode
# PyCharm files # PyCharm files
.idea .idea
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
SPHINXOPTS = SPHINXOPTS =
SPHINXBUILD = sphinx-build SPHINXBUILD = sphinx-build
PAPER = PAPER =
BUILDDIR = _build BUILDDIR = _build/api
# User-friendly check for sphinx-build # User-friendly check for sphinx-build
ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1)
...@@ -52,9 +52,9 @@ clean: ...@@ -52,9 +52,9 @@ clean:
rm -rf $(BUILDDIR)/* rm -rf $(BUILDDIR)/*
html: html:
$(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/python
@echo @echo
@echo "Build finished. The HTML pages are in $(BUILDDIR)/html." @echo "Build finished. The HTML pages are in $(BUILDDIR)/python."
dirhtml: dirhtml:
$(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml
......
...@@ -11,14 +11,12 @@ Quick Reference ...@@ -11,14 +11,12 @@ Quick Reference
============================== ============================================================================= ============================== =============================================================================
List Brief List Brief
============================== ============================================================================= ============================== =============================================================================
`FromShape`_ Create a Tensor from the shape. `FromShape`_ Create a Tensor from the given shape.
`FromArray`_ Create a Tensor from an existing Array.
`ToArray`_ Create a Array from an existing Tensor.
`SetShape`_ Set a Tensor with the shape. `SetShape`_ Set a Tensor with the shape.
`FromPyArray`_ Create a Tensor from a existing Array. `SetArray`_ Set a Tensor from an existing Array.
`SetPyArray`_ Set a Tensor from a existing Array. `GetStorage`_ Get the storage of an existing Tensor.
`ToPyArray`_ Create a Array from a existing Tensor.
`GetStorage`_ Get the storage of a existing Tensor.
`ToCPUTensor`_ Switch the storage of a existing Tensor on cpu memory.
`ToCUDATensor`_ Switch the storage of a existing Tensor on cuda memory.
============================== ============================================================================= ============================== =============================================================================
API Reference API Reference
...@@ -28,10 +26,8 @@ API Reference ...@@ -28,10 +26,8 @@ API Reference
:members: :members:
.. _FromShape: #dragon.core.tensor_utils.FromShape .. _FromShape: #dragon.core.tensor_utils.FromShape
.. _FromArray: #dragon.core.tensor_utils.FromArray
.. _ToArray: #dragon.core.tensor_utils.ToArray
.. _SetShape: #dragon.core.tensor_utils.SetShape .. _SetShape: #dragon.core.tensor_utils.SetShape
.. _FromPyArray: #dragon.core.tensor_utils.FromPyArray .. _SetArray: #dragon.core.tensor_utils.SetArray
.. _SetPyArray: #dragon.core.tensor_utils.SetPyArray .. _GetStorage: #dragon.core.tensor_utils.GetStorage
.. _ToPyArray: #dragon.core.tensor_utils.ToPyArray \ No newline at end of file
.. _GetStorage: #dragon.core.tensor_utils.GetStorage
.. _ToCPUTensor: #dragon.core.tensor_utils.ToCPUTensor
.. _ToCUDATensor: #dragon.core.tensor_utils.ToCUDATensor
\ No newline at end of file
...@@ -132,37 +132,38 @@ List Brief ...@@ -132,37 +132,38 @@ List Brief
Array Array
----- -----
=============== ====================================================================== ================== ======================================================================
List Brief List Brief
=============== ====================================================================== ================== ======================================================================
`Where`_ Select elements from either *x* or *y*. `Where`_ Select elements from either *x* or *y*.
`IndexSelect`_ Select the elements according to the indices along the given axis. `IndexSelect`_ Select the elements according to the indices along the given axis.
`MaskedSelect`_ Select the the elements where *mask* is *1*. `MaskedSelect`_ Select the the elements where *mask* is *1*.
`Reduce`_ Reduce the inputs along the axis in given axes. `Reduce`_ Reduce the inputs along the axis in given axes.
`Sum`_ Compute the sum along the given axis. `Sum`_ Compute the sum along the given axis.
`Mean`_ Compute the mean along the given axis. `Mean`_ Compute the mean along the given axis.
`Max`_ Compute the values of maximum elements along the given axis. `Max`_ Compute the values of maximum elements along the given axis.
`ArgMax`_ Compute the indices of maximum elements along the given axis. `ArgMax`_ Compute the indices of maximum elements along the given axis.
`Min`_ Compute the values of minimum elements along the given axis. `Min`_ Compute the values of minimum elements along the given axis.
`ArgMin`_ Compute the indices of minimum elements along the given axis. `ArgMin`_ Compute the indices of minimum elements along the given axis.
`Slice`_ Slice the inputs into several parts along the given axis. `Slice`_ Slice the inputs into several parts along the given axis.
`Stack`_ Stack the inputs along the given axis. `Stack`_ Stack the inputs along the given axis.
`Concat`_ Concatenate the inputs along the given axis. `Concat`_ Concatenate the inputs along the given axis.
`Repeat`_ Repeat the input along the given axis. `ChannelShuffle`_ Shuffle channels between groups along the given axis. `[Zhang et.al, 2017] <https://arxiv.org/abs/1707.01083>`_.
`Transpose`_ Transpose the input according to the given permutations. `Repeat`_ Repeat the input along the given axis.
`Tile`_ Tile the input according to the given multiples. `Transpose`_ Transpose the input according to the given permutations.
`Pad`_ Pad the input according to the given sizes. `Tile`_ Tile the input according to the given multiples.
`Crop`_ Crop the input according to the given starts and sizes. `Pad`_ Pad the input according to the given sizes.
`OneHot`_ Generate the one-hot representation of inputs. `Crop`_ Crop the input according to the given starts and sizes.
`Flatten`_ Flatten the input along the given axes. `OneHot`_ Generate the one-hot representation of inputs.
`Reshape`_ Reshape the dimensions of input. `Flatten`_ Flatten the input along the given axes.
`Squeeze`_ Remove the dimensions with size 1. `Reshape`_ Reshape the dimensions of input.
`ExpandDims`_ Expand the new dimension with size 1 to specific axis. `Squeeze`_ Remove the dimensions with size 1.
`Shape`_ Get the dynamic shape of a Tensor. `ExpandDims`_ Expand the new dimension with size 1 to specific axis.
`NonZero`_ Return the indices of non-zero elements. `Shape`_ Get the dynamic shape of a Tensor.
`Arange`_ Return evenly spaced values within a given interval. `NonZero`_ Return the indices of non-zero elements.
`Multinomial`_ Return indices sampled from the multinomial distribution. `Arange`_ Return evenly spaced values within a given interval.
=============== ====================================================================== `Multinomial`_ Return indices sampled from the multinomial distribution.
================== ======================================================================
Control Flow Control Flow
------------ ------------
...@@ -302,6 +303,7 @@ List Brief ...@@ -302,6 +303,7 @@ List Brief
.. _Slice: operators/array.html#dragon.operators.array.Slice .. _Slice: operators/array.html#dragon.operators.array.Slice
.. _Stack: operators/array.html#dragon.operators.array.Stack .. _Stack: operators/array.html#dragon.operators.array.Stack
.. _Concat: operators/array.html#dragon.operators.array.Concat .. _Concat: operators/array.html#dragon.operators.array.Concat
.. _ChannelShuffle: operators/array.html#dragon.operators.array.ChannelShuffle
.. _Transpose: operators/array.html#dragon.operators.array.Transpose .. _Transpose: operators/array.html#dragon.operators.array.Transpose
.. _Repeat: operators/array.html#dragon.operators.array.Repeat .. _Repeat: operators/array.html#dragon.operators.array.Repeat
.. _Tile: operators/array.html#dragon.operators.array.Tile .. _Tile: operators/array.html#dragon.operators.array.Tile
......
...@@ -14,8 +14,8 @@ Data ...@@ -14,8 +14,8 @@ Data
==================== ============================================================================= ==================== =============================================================================
List Brief List Brief
==================== ============================================================================= ==================== =============================================================================
`DataLayer`_ The implementation of ``DataLayer``. `DataLayer`_ The implementation of *DataLayer*.
`MemoryDataLayer`_ The implementation of ``MemoryDataLayer``. `MemoryDataLayer`_ The implementation of *MemoryDataLayer*.
==================== ============================================================================= ==================== =============================================================================
Vision Vision
...@@ -24,16 +24,16 @@ Vision ...@@ -24,16 +24,16 @@ Vision
============================== ============================================================================= ============================== =============================================================================
List Brief List Brief
============================== ============================================================================= ============================== =============================================================================
`ConvolutionLayer`_ The implementation of ``ConvolutionLayer``. `ConvolutionLayer`_ The implementation of *ConvolutionLayer*.
`DepthwiseConvolutionLayer`_ The implementation of ``DepthwiseConvolutionLayer``. `DepthwiseConvolutionLayer`_ The implementation of *DepthwiseConvolutionLayer*.
`DeconvolutionLayer`_ The implementation of ``DeconvolutionLayer``. `DeconvolutionLayer`_ The implementation of *DeconvolutionLayer*.
`PoolingLayer`_ The implementation of ``PoolingLayer``. `PoolingLayer`_ The implementation of *PoolingLayer*.
`ROIPoolingLayer`_ The implementation of ``ROIPoolingLayer``. `ROIPoolingLayer`_ The implementation of *ROIPoolingLayer*.
`ROIAlignLayer`_ The implementation of ``ROIAlignLayer``. `ROIAlignLayer`_ The implementation of *ROIAlignLayer*.
`LRNLayer`_ The implementation of ``LRNLayer``. `LRNLayer`_ The implementation of *LRNLayer*.
`NNResizeLayer`_ The implementation of ``NNResizeLayer``. `NNResizeLayer`_ The implementation of *NNResizeLayer*.
`BilinearResizeLayer`_ The implementation of ``BilinearResizeLayer``. `BilinearResizeLayer`_ The implementation of *BilinearResizeLayer*.
`DropBlockLayer`_ The implementation of ``DropBlockLayer``. `DropBlockLayer`_ The implementation of *DropBlockLayer*.
============================== ============================================================================= ============================== =============================================================================
...@@ -43,14 +43,14 @@ Neuron ...@@ -43,14 +43,14 @@ Neuron
==================== ============================================================================= ==================== =============================================================================
List Brief List Brief
==================== ============================================================================= ==================== =============================================================================
`ReLULayer`_ The implementation of ``ReLULayer``. `ReLULayer`_ The implementation of *ReLULayer*.
`PReLULayer`_ The implementation of ``PReLULayer``. `PReLULayer`_ The implementation of *PReLULayer*.
`ELULayer`_ The implementation of ``ELULayer``. `ELULayer`_ The implementation of *ELULayer*.
`SELULayer`_ The implementation of ``SELULayer``. `SELULayer`_ The implementation of *SELULayer*.
`SigmoidLayer`_ The implementation of ``SigmoidLayer``. `SigmoidLayer`_ The implementation of *SigmoidLayer*.
`TanHLayer`_ The implementation of ``TanHLayer``. `TanHLayer`_ The implementation of *TanHLayer*.
`DropoutLayer`_ The implementation of ``DropoutLayer``. `DropoutLayer`_ The implementation of *DropoutLayer*.
`PowerLayer`_ The implementation of ``PowerLayer``. `PowerLayer`_ The implementation of *PowerLayer*.
==================== ============================================================================= ==================== =============================================================================
Common Common
...@@ -59,31 +59,30 @@ Common ...@@ -59,31 +59,30 @@ Common
======================== ============================================================================= ======================== =============================================================================
List Brief List Brief
======================== ============================================================================= ======================== =============================================================================
`InnerProductLayer`_ The implementation of ``InnerProductLayer``. `InnerProductLayer`_ The implementation of *InnerProductLayer*.
`AccuracyLayer`_ The implementation of ``AccuracyLayer``. `AccuracyLayer`_ The implementation of *AccuracyLayer*.
`PythonLayer`_ The implementation of ``PythonLayer``. `PythonLayer`_ The implementation of *PythonLayer*.
`EltwiseLayer`_ The implementation of ``EltwiseLayer`` `EltwiseLayer`_ The implementation of *EltwiseLayer*
`AddLayer`_ The extended implementation of ``EltwiseLayer``. `AddLayer`_ The extended implementation of *EltwiseLayer*.
`ConcatLayer`_ The implementation of ``ConcatLayer``. `ConcatLayer`_ The implementation of *ConcatLayer*.
`SliceLayer`_ The implementation of ``SliceLayer``. `SliceLayer`_ The implementation of *SliceLayer*.
`CropLayer`_ The implementation of ``CropLayer``. `CropLayer`_ The implementation of *CropLayer*.
`ReshapeLayer`_ The implementation of ``ReshapeLayer``. `ReshapeLayer`_ The implementation of *ReshapeLayer*.
`PermuteLayer`_ The implementation of ``PermuteLayer``. `PermuteLayer`_ The implementation of *PermuteLayer*.
`FlattenLayer`_ The implementation of ``FlattenLayer``. `FlattenLayer`_ The implementation of *FlattenLayer*.
`GatherLayer`_ The extended implementation for ``GatherOp``. `GatherLayer`_ The extended implementation for *GatherOp*.
`SoftmaxLayer`_ The implementation of ``SoftmaxLayer``. `SoftmaxLayer`_ The implementation of *SoftmaxLayer*.
`ArgMaxLayer`_ The implementation of ``ArgMaxLayer``. `ArgMaxLayer`_ The implementation of *ArgMaxLayer*.
`BatchNormLayer`_ The implementation of ``BatchNormLayer``. `BatchNormLayer`_ The implementation of *BatchNormLayer*.
`GroupNormLayer`_ The implementation of ``GroupNormLayer``. `GroupNormLayer`_ The implementation of *GroupNormLayer*.
`InstanceNormLayer`_ The implementation of ``InstanceNormLayer``. `ScaleLayer`_ The implementation of *ScaleLayer*.
`ScaleLayer`_ The implementation of ``ScaleLayer``. `BNLayer`_ The implementation of *BNLayer*.
`BNLayer`_ The implementation of ``BNLayer``. `GNLayer`_ The implementation of *GNLayer*.
`GNLayer`_ The implementation of ``GNLayer``. `NormalizeLayer`_ The implementation of *NormalizeLayer*.
`NormalizeLayer`_ The implementation of ``NormalizeLayer``. `TileLayer`_ The extended implementation of *TileLayer*.
`TileLayer`_ The extended implementation of ``TileLayer``. `ExpandDimsLayer`_ The implementation of *ExpandDimsLayer*.
`ExpandDimsLayer`_ The implementation of ``ExpandDimsLayer``. `StopGradientLayer`_ The implementation of *StopGradientLayer*.
`StopGradientLayer`_ The implementation of ``StopGradientLayer``. `ProposalLayer`_ The implementation of *ProposalLayer*.
`ProposalLayer`_ The implementation of ``ProposalLayer``.
======================== ============================================================================= ======================== =============================================================================
Loss Loss
...@@ -92,12 +91,12 @@ Loss ...@@ -92,12 +91,12 @@ Loss
================================= ============================================================================= ================================= =============================================================================
List Brief List Brief
================================= ============================================================================= ================================= =============================================================================
`SoftmaxWithLossLayer`_ The implementation of ``SoftmaxWithLossLayer``. `SoftmaxWithLossLayer`_ The implementation of *SoftmaxWithLossLayer*.
`SigmoidCrossEntropyLossLayer`_ The implementation of ``SigmoidCrossEntropyLossLayer``. `SigmoidCrossEntropyLossLayer`_ The implementation of *SigmoidCrossEntropyLossLayer*.
`L2LossLayer`_ The implementation of ``L2LossLayer``. `L2LossLayer`_ The implementation of *L2LossLayer*.
`SmoothL1LossLayer`_ The implementation of ``SmoothL1LossLayer``. `SmoothL1LossLayer`_ The implementation of *SmoothL1LossLayer*.
`SigmoidWithFocalLossLayer`_ The implementation of ``SigmoidWithFocalLossLayer``. `SigmoidWithFocalLossLayer`_ The implementation of *SigmoidWithFocalLossLayer*.
`SoftmaxWithFocalLossLayer`_ The implementation of ``SoftmaxWithFocalLossLayer``. `SoftmaxWithFocalLossLayer`_ The implementation of *SoftmaxWithFocalLossLayer*.
================================= ============================================================================= ================================= =============================================================================
MPI MPI
...@@ -106,8 +105,8 @@ MPI ...@@ -106,8 +105,8 @@ MPI
================================= ============================================================================= ================================= =============================================================================
List Brief List Brief
================================= ============================================================================= ================================= =============================================================================
`MPIBroadcastLayer`_ The implementation of ``MPIBroadcastLayer`` `MPIBroadcastLayer`_ The implementation of *MPIBroadcastLayer*
`MPIGatherLayer`_ The implementation of ``MPIGatherLayer`` `MPIGatherLayer`_ The implementation of *MPIGatherLayer*
================================= ============================================================================= ================================= =============================================================================
...@@ -188,7 +187,6 @@ API Reference ...@@ -188,7 +187,6 @@ API Reference
.. _ArgMaxLayer: #dragon.vm.caffe.layers.common.ArgMaxLayer .. _ArgMaxLayer: #dragon.vm.caffe.layers.common.ArgMaxLayer
.. _BatchNormLayer: #dragon.vm.caffe.layers.common.BatchNormLayer .. _BatchNormLayer: #dragon.vm.caffe.layers.common.BatchNormLayer
.. _GroupNormLayer: #dragon.vm.caffe.layers.common.GroupNormLayer .. _GroupNormLayer: #dragon.vm.caffe.layers.common.GroupNormLayer
.. _InstanceNormLayer: #dragon.vm.caffe.layers.common.InstanceNormLayer
.. _ScaleLayer: #dragon.vm.caffe.layers.common.ScaleLayer .. _ScaleLayer: #dragon.vm.caffe.layers.common.ScaleLayer
.. _BNLayer: #dragon.vm.caffe.layers.common.BNLayer .. _BNLayer: #dragon.vm.caffe.layers.common.BNLayer
.. _GNLayer: #dragon.vm.caffe.layers.common.GNLayer .. _GNLayer: #dragon.vm.caffe.layers.common.GNLayer
......
...@@ -94,6 +94,9 @@ include_directories(${THIRD_PARTY_DIR}/eigen) ...@@ -94,6 +94,9 @@ include_directories(${THIRD_PARTY_DIR}/eigen)
include_directories(${THIRD_PARTY_DIR}/protobuf/include) include_directories(${THIRD_PARTY_DIR}/protobuf/include)
include_directories(${PROJECT_SOURCE_DIR}/include) include_directories(${PROJECT_SOURCE_DIR}/include)
include_directories(${PROJECT_SOURCE_DIR}/src) include_directories(${PROJECT_SOURCE_DIR}/src)
if(APPLE)
include_directories(/usr/local/include)
endif()
if (BUILD_PYTHON_API) if (BUILD_PYTHON_API)
include_directories(${PYTHON_INCLUDE_DIRS}) include_directories(${PYTHON_INCLUDE_DIRS})
include_directories(${NUMPY_INCLUDE_DIR}) include_directories(${NUMPY_INCLUDE_DIR})
...@@ -112,6 +115,9 @@ endif() ...@@ -112,6 +115,9 @@ endif()
# ---[ Lib Directories # ---[ Lib Directories
list(APPEND THIRD_PARTY_LIBRARY_DIRS ${THIRD_PARTY_DIR}/protobuf/lib) list(APPEND THIRD_PARTY_LIBRARY_DIRS ${THIRD_PARTY_DIR}/protobuf/lib)
if(APPLE)
list(APPEND THIRD_PARTY_LIBRARY_DIRS /usr/local/lib)
endif()
if (WITH_CUDA) if (WITH_CUDA)
list(APPEND THIRD_PARTY_LIBRARY_DIRS ${CUDA_TOOLKIT_ROOT_DIR}/lib) list(APPEND THIRD_PARTY_LIBRARY_DIRS ${CUDA_TOOLKIT_ROOT_DIR}/lib)
list(APPEND THIRD_PARTY_LIBRARY_DIRS ${CUDA_TOOLKIT_ROOT_DIR}/lib64) list(APPEND THIRD_PARTY_LIBRARY_DIRS ${CUDA_TOOLKIT_ROOT_DIR}/lib64)
...@@ -177,8 +183,13 @@ if(WIN32) ...@@ -177,8 +183,13 @@ if(WIN32)
endif() endif()
endif() endif()
if(UNIX) if(UNIX)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -s -fPIC") if(APPLE)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -s -fPIC -O3 -m64 -std=c++11") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -w -fPIC")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -w -fPIC -O3 -m64 -std=c++11")
else()
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -s -fPIC")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -s -fPIC -O3 -m64 -std=c++11")
endif()
if (WITH_OMP AND (NOT APPLE)) if (WITH_OMP AND (NOT APPLE))
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fopenmp") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fopenmp")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp")
......
...@@ -346,7 +346,7 @@ class CUDAContext { ...@@ -346,7 +346,7 @@ class CUDAContext {
void FinishDeviceCompution() { CUDA_NOT_COMPILED; } void FinishDeviceCompution() { CUDA_NOT_COMPILED; }
/*! \brief Malloc the memory */ /*! \brief Malloc the memory */
static void* New(size_t nbytes) { CUDA_NOT_COMPILED; } static void* New(size_t nbytes) { return nullptr; }
/*! \brief Zero-Reset the memory */ /*! \brief Zero-Reset the memory */
static void Memset( static void Memset(
......
...@@ -21,23 +21,16 @@ namespace dragon { ...@@ -21,23 +21,16 @@ namespace dragon {
class GraphBase { class GraphBase {
public: public:
/*! \brief Default constructor */ /*! \brief Default constructor */
GraphBase( GraphBase(const GraphDef&, Workspace*);
const GraphDef& def,
Workspace* ws);
/*! \brief Default deconstructor */ /*! \brief Default deconstructor */
virtual ~GraphBase() {} virtual ~GraphBase() {}
/*! \brief Create a graph from the optimized def */ /*! \brief Create a graph from the optimized def */
virtual bool Create( virtual bool Create(const GraphDef&, Workspace*) = 0;
const GraphDef& def,
Workspace* ws) = 0;
/*! \brief Run the graph once synchronously */ /*! \brief Run the graph once synchronously */
virtual bool Run( virtual bool Run(const string&, const string&, int = 0) = 0;
const string& include,
const string& exclude,
int stream_id = 0) = 0;
/*! \brief Return the graph name */ /*! \brief Return the graph name */
string name() const { return name_; } string name() const { return name_; }
...@@ -46,7 +39,7 @@ class GraphBase { ...@@ -46,7 +39,7 @@ class GraphBase {
const string& phase() const { return phase_; } const string& phase() const { return phase_; }
/*! \brief Return the argument map */ /*! \brief Return the argument map */
const Map<std::string, const Argument*>& args() { return args_; } const Map<string, const Argument*>& args() { return args_; }
/*! \brief Return the specified argument */ /*! \brief Return the specified argument */
const Argument& arg(const string& name) { return *(args_[name]); } const Argument& arg(const string& name) { return *(args_[name]); }
...@@ -83,15 +76,10 @@ class Graph : public GraphBase { ...@@ -83,15 +76,10 @@ class Graph : public GraphBase {
virtual ~Graph() { for (auto* op : ops_) delete op; } virtual ~Graph() { for (auto* op : ops_) delete op; }
/*! \brief Create a graph from the optimized def */ /*! \brief Create a graph from the optimized def */
bool Create( bool Create(const GraphDef&, Workspace*) override;
const GraphDef& def,
Workspace* ws) override;
/*! \brief Run the graph once synchronously */ /*! \brief Run the graph once synchronously */
bool Run( bool Run(const string&, const string&, int = 0) override;
const string& include,
const string& exclude,
int stream_id = 0) override;
protected: protected:
/*! \brief Store the internal operators */ /*! \brief Store the internal operators */
...@@ -99,9 +87,7 @@ class Graph : public GraphBase { ...@@ -99,9 +87,7 @@ class Graph : public GraphBase {
}; };
/*! \brief Create a graph from the raw def */ /*! \brief Create a graph from the raw def */
GraphBase* NewGraph( GraphBase* NewGraph(const GraphDef&, Workspace*);
const GraphDef& def,
Workspace* ws);
/* Macros */ /* Macros */
......
...@@ -31,7 +31,7 @@ class Workspace; ...@@ -31,7 +31,7 @@ class Workspace;
class OperatorBase { class OperatorBase {
public: public:
/*! \brief Default constructor */ /*! \brief Default constructor */
OperatorBase(const OperatorDef& def, Workspace* ws); OperatorBase(const OperatorDef&, Workspace*);
/*! \brief Default deconstructor */ /*! \brief Default deconstructor */
virtual ~OperatorBase() {} virtual ~OperatorBase() {}
...@@ -49,7 +49,7 @@ class OperatorBase { ...@@ -49,7 +49,7 @@ class OperatorBase {
int YSize() { return (int)outputs_.size(); } int YSize() { return (int)outputs_.size(); }
/*! \brief Modify operator according to the given def */ /*! \brief Modify operator according to the given def */
void UpdateFrom(const OperatorDef& def); void UpdateFrom(const OperatorDef&);
/*! \brief Switch the internal running phase */ /*! \brief Switch the internal running phase */
void SwitchToPhase(const string& phase) { phase_ = phase; } void SwitchToPhase(const string& phase) { phase_ = phase; }
...@@ -95,7 +95,7 @@ class OperatorBase { ...@@ -95,7 +95,7 @@ class OperatorBase {
vector<T> Args(const string& name); vector<T> Args(const string& name);
/*! \brief Return the argument map */ /*! \brief Return the argument map */
const Map<std::string, const Argument*>& args() { return args_; } const Map<string, const Argument*>& args() { return args_; }
/*! \brief Return the specified argument */ /*! \brief Return the specified argument */
const Argument& arg(const string& name) { return *(args_[name]); } const Argument& arg(const string& name) { return *(args_[name]); }
...@@ -212,12 +212,11 @@ class Operator : public OperatorBase { ...@@ -212,12 +212,11 @@ class Operator : public OperatorBase {
#ifndef WITH_MPI #ifndef WITH_MPI
return true; return true;
#else #else
vec32_t allow_ranks = auto ranks = OperatorBase::Args<int>("mpi_ranks");
OperatorBase::Args<int>("mpi_ranks"); if (ranks.empty()) return true;
if (allow_ranks.empty()) return true;
int cur_rank; int cur_rank;
MPI_Comm_rank(MPI_COMM_WORLD, &cur_rank); MPI_Comm_rank(MPI_COMM_WORLD, &cur_rank);
for (auto mpi_rank : allow_ranks) for (auto mpi_rank : ranks)
if (cur_rank == mpi_rank) return true; if (cur_rank == mpi_rank) return true;
return false; return false;
#endif #endif
...@@ -225,10 +224,7 @@ class Operator : public OperatorBase { ...@@ -225,10 +224,7 @@ class Operator : public OperatorBase {
}; };
/*! \brief New a operator from the raw def */ /*! \brief New a operator from the raw def */
OperatorBase* NewOperator(const OperatorDef&, Workspace*);
OperatorBase* NewOperator(
const OperatorDef& def,
Workspace* ws);
/* Macros */ /* Macros */
......
...@@ -40,7 +40,7 @@ class TypeMeta { ...@@ -40,7 +40,7 @@ class TypeMeta {
TypeMeta(const TypeMeta& src) TypeMeta(const TypeMeta& src)
: id_(src.id_), itemsize_(src.itemsize_), : id_(src.id_), itemsize_(src.itemsize_),
ctor_(src.ctor_), copy_(src.copy_), dtor_(src.dtor_) {} ctor_(src.ctor_), copy_(src.copy_), dtor_(src.dtor_) {}
TypeMeta& operator = (const TypeMeta& src) { TypeMeta& operator = (const TypeMeta& src) {
if (this == &src) return *this; if (this == &src) return *this;
......
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ARRAY_CHANNEL_SHUFFLE_OP_H_
#define DRAGON_OPERATORS_ARRAY_CHANNEL_SHUFFLE_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class ChannelShuffleOp final: public Operator<Context> {
public:
ChannelShuffleOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
axis_(OpArg<int64_t>("axis", 0)),
group_(OpArg<int64_t>("group", 1)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunImpl();
protected:
int64_t outer_dim_, inner_dim_;
int64_t axis_, axis_dim_, group_;
};
template <class Context>
class ChannelShuffleGradientOp final: public Operator<Context> {
public:
ChannelShuffleGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
axis_(OpArg<int64_t>("axis", 0)),
group_(OpArg<int64_t>("group", 1)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T> void RunImpl();
protected:
int64_t outer_dim_, inner_dim_;
int64_t axis_, axis_dim_, group_;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_ARRAY_CHANNEL_SHUFFLE_OP_H_
\ No newline at end of file
...@@ -30,7 +30,7 @@ class SliceOp final : public Operator<Context> { ...@@ -30,7 +30,7 @@ class SliceOp final : public Operator<Context> {
template <typename T> void RunImpl(); template <typename T> void RunImpl();
protected: protected:
vec64_t points_; vec64_t points_, sections_;
int64_t outer_dim_, inner_dim_; int64_t outer_dim_, inner_dim_;
int64_t axis_, axis_dim_, slice_dim_, N_; int64_t axis_, axis_dim_, slice_dim_, N_;
}; };
...@@ -48,7 +48,7 @@ class SliceGradientOp final : public Operator<Context> { ...@@ -48,7 +48,7 @@ class SliceGradientOp final : public Operator<Context> {
template <typename T> void RunImpl(); template <typename T> void RunImpl();
protected: protected:
vec64_t points_; vec64_t points_, sections_;
int64_t outer_dim_, inner_dim_; int64_t outer_dim_, inner_dim_;
int64_t axis_, axis_dim_, slice_dim_, N_; int64_t axis_, axis_dim_, slice_dim_, N_;
}; };
......
...@@ -27,13 +27,13 @@ class ImageDataOp final : public Operator<Context> { ...@@ -27,13 +27,13 @@ class ImageDataOp final : public Operator<Context> {
if (mean_vec_.size() > 0) { if (mean_vec_.size() > 0) {
CHECK_EQ((int)mean_vec_.size(), 3); CHECK_EQ((int)mean_vec_.size(), 3);
auto* mean = mean_.Reshape({ 3 }) auto* mean = mean_.Reshape({ 3 })
->mutable_data<float, CPUContext>(); ->template mutable_data<float, CPUContext>();
for (int i = 0; i < 3; ++i) mean[i] = mean_vec_[i]; for (int i = 0; i < 3; ++i) mean[i] = mean_vec_[i];
} }
if (std_vec_.size() > 0) { if (std_vec_.size() > 0) {
CHECK_EQ((int)std_vec_.size(), 3); CHECK_EQ((int)std_vec_.size(), 3);
auto* std = std_.Reshape({ 3 }) auto* std = std_.Reshape({ 3 })
->mutable_data<float, CPUContext>(); ->template mutable_data<float, CPUContext>();
for (int i = 0; i < 3; ++i) std[i] = std_vec_[i]; for (int i = 0; i < 3; ++i) std[i] = std_vec_[i];
} }
} }
......
...@@ -378,6 +378,18 @@ void ArgMin( ...@@ -378,6 +378,18 @@ void ArgMin(
T* values, T* values,
Context* ctx); Context* ctx);
/*! array.channel_shuffle */
template <typename T, class Context>
void ChannelShuffle(
const int outer_dim,
const int inner_dim,
const int axis_dim,
const int group,
const T* x,
T* y,
Context* ctx);
/*! array.concat */ /*! array.concat */
template <typename T, class Context> template <typename T, class Context>
......
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
* ------------------------------------------------------------ * ------------------------------------------------------------
*/ */
#ifndef DRAGON_PYTHON_PY_CONIFG_H_ #ifndef DRAGON_PYTHON_PY_CONFIG_H_
#define DRAGON_PYTHON_PY_CONFIG_H_ #define DRAGON_PYTHON_PY_CONFIG_H_
#include "py_dragon.h" #include "py_dragon.h"
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
*/ */
#ifndef DRAGON_PYTHON_PY_TENSOR_H_ #ifndef DRAGON_PYTHON_PY_TENSOR_H_
#define DRAGON_PYTHON_PY_TENOSR_H_ #define DRAGON_PYTHON_PY_TENSOR_H_
#include "py_dragon.h" #include "py_dragon.h"
......
...@@ -98,7 +98,7 @@ def EnableCNML(mlu_id=0): ...@@ -98,7 +98,7 @@ def EnableCNML(mlu_id=0):
Parameters Parameters
---------- ----------
device_id : int mlu_id : int
The index of MLU to use. The index of MLU to use.
Returns Returns
...@@ -193,14 +193,14 @@ def SetGraphOptimizationLevel(level=3): ...@@ -193,14 +193,14 @@ def SetGraphOptimizationLevel(level=3):
We have predefined four levels: We have predefined four levels:
-O0(level=0): Do nothing. *-O0*: Do nothing.
-O1(level=1): Prune the redundant nodes. *-O1*: Prune the redundant nodes.
-O2(level=2): Add the inplace to outputs. *-O2*: Add the inplace to outputs.
Note that the graph will no longer be a DAG. Note that the graph will no longer be a DAG.
-O3(level=3): Allocate the buffer for outputs. *-O3*: Allocate the buffer for outputs.
This level is memory-efficient while debugging will be non-trivial. This level is memory-efficient while debugging will be non-trivial.
Parameters Parameters
......
...@@ -33,7 +33,7 @@ class OperatorHelper(object): ...@@ -33,7 +33,7 @@ class OperatorHelper(object):
# Following operators is the simplest case: # Following operators is the simplest case:
# Input(0) => Output(0), shape and data type unchanged. # Input(0) => Output(0), shape and data type unchanged.
'Relu', 'PRelu', 'Elu', 'SElu', 'Sigmoid', 'Tanh', 'Softmax', 'Relu', 'PRelu', 'Elu', 'SElu', 'Sigmoid', 'Tanh', 'Softmax',
'Dropout', 'DropPath', 'DropBlock2d', 'Dropout', 'DropPath', 'DropBlock2d', 'ChannelShuffle',
'Add', 'Sub', 'Mul', 'Div', 'Clip', 'Log', 'Exp', 'Pow', 'Square', 'Sqrt', 'Add', 'Sub', 'Mul', 'Div', 'Clip', 'Log', 'Exp', 'Pow', 'Square', 'Sqrt',
'Accumulate', 'Affine', 'Copy', 'StopGradient', 'MPIBroadcast', 'Accumulate', 'Affine', 'Copy', 'StopGradient', 'MPIBroadcast',
'BatchNorm', 'GroupNorm', 'L2Norm', 'LRN', 'BiasAdd', 'BatchNorm', 'GroupNorm', 'L2Norm', 'LRN', 'BiasAdd',
......
...@@ -52,7 +52,7 @@ def IndexSelect(inputs, indices, axis=0, **kwargs): ...@@ -52,7 +52,7 @@ def IndexSelect(inputs, indices, axis=0, **kwargs):
The input tensor. The input tensor.
indices : sequence or Tensor indices : sequence or Tensor
The indices to select elements. The indices to select elements.
axis : int, optional axis : int, optional, default=0
The axis of indices. The axis of indices.
Returns Returns
...@@ -156,25 +156,23 @@ def Crop( ...@@ -156,25 +156,23 @@ def Crop(
@OpSchema.Inputs(1) @OpSchema.Inputs(1)
def Slice(inputs, axis=0, num_outputs=1, slice_points=None, **kwargs): def Slice(inputs, axis=0, num_slices=1, slice_points=None, **kwargs):
"""Slice the inputs into several parts along the given axis. """Slice the inputs into several parts along the given axis.
All dimensions except the specified ``axis`` should be same. All dimensions except the specified ``axis`` should be same.
The number of ``slice_points`` should be *len(X.shape) - 1*. The number of ``slice_points`` should be *len(X.shape) - 1*.
if ``slice_points`` is *None*, dimension of axis should be divided by ``num_outputs``.
**Type Constraints**: (*bool*, *int8*, *uint8*, *int32*, *int64*, *float16*, *float32*, *float64*) **Type Constraints**: (*bool*, *int8*, *uint8*, *int32*, *int64*, *float16*, *float32*, *float64*)
Parameters Parameters
---------- ----------
inputs : Tensor inputs : Tensor
The input tensor. The input tensor.
axis : int, optional axis : int, optional, default=0
The axis to slice, can be negative. The axis to slice, can be negative.
num_outputs : int, optional num_slices: int, optional, default=1
The optional number number of slices. The optional number of slices.
slice_points : sequence of int, optional slice_points : sequence of int, optional
The optional slice points. The optional slice points.
...@@ -184,9 +182,12 @@ def Slice(inputs, axis=0, num_outputs=1, slice_points=None, **kwargs): ...@@ -184,9 +182,12 @@ def Slice(inputs, axis=0, num_outputs=1, slice_points=None, **kwargs):
The outputs. The outputs.
""" """
if slice_points is not None and len(slice_points) > 0: arguments = ParseArgs(locals())
num_outputs = len(slice_points) + 1 if slice_points is not None:
return Tensor.CreateOperator('Slice', **ParseArgs(locals())) arguments['num_outputs'] = len(slice_points) + 1
else:
arguments['num_outputs'] = num_slices
return Tensor.CreateOperator('Slice', **arguments)
@OpSchema.Inputs(1, INT_MAX) @OpSchema.Inputs(1, INT_MAX)
...@@ -201,7 +202,7 @@ def Stack(inputs, axis=0, **kwargs): ...@@ -201,7 +202,7 @@ def Stack(inputs, axis=0, **kwargs):
---------- ----------
inputs : sequence of Tensor inputs : sequence of Tensor
The inputs. The inputs.
axis : int axis : int, optional, default=0
The axis to stack, can be negative. The axis to stack, can be negative.
Returns Returns
...@@ -225,7 +226,7 @@ def Concat(inputs, axis=0, **kwargs): ...@@ -225,7 +226,7 @@ def Concat(inputs, axis=0, **kwargs):
---------- ----------
inputs : sequence of Tensor inputs : sequence of Tensor
The inputs. The inputs.
axis : int axis : int, optional, default=0
The axis to concatenate, can be negative. The axis to concatenate, can be negative.
Returns Returns
...@@ -238,6 +239,30 @@ def Concat(inputs, axis=0, **kwargs): ...@@ -238,6 +239,30 @@ def Concat(inputs, axis=0, **kwargs):
@OpSchema.Inputs(1) @OpSchema.Inputs(1)
def ChannelShuffle(inputs, axis=0, group=1, **kwargs):
"""Shuffle channels between groups along the given axis. `[Zhang et.al, 2017] <https://arxiv.org/abs/1707.01083>`_.
**Type Constraints**: (*bool*, *int8*, *uint8*, *int32*, *int64*, *float16*, *float32*, *float64*)
Parameters
----------
inputs : Tensor
The inputs.
axis : int, optional, default=0
The axis of channels, can be negative.
group : int, optional, default=1
The number of groups.
Returns
-------
Tensor
The output tensor.
"""
return Tensor.CreateOperator('ChannelShuffle', **ParseArgs(locals()))
@OpSchema.Inputs(1)
def Reduce(inputs, axes=None, operation='SUM', keep_dims=False, **kwargs): def Reduce(inputs, axes=None, operation='SUM', keep_dims=False, **kwargs):
"""Reduce the inputs along the axis in given axes. """Reduce the inputs along the axis in given axes.
...@@ -411,9 +436,9 @@ def ArgMin(inputs, axis=None, top_k=1, keep_dims=False, **kwargs): ...@@ -411,9 +436,9 @@ def ArgMin(inputs, axis=None, top_k=1, keep_dims=False, **kwargs):
The input tensor. The input tensor.
axis : int, optional axis : int, optional
The axis to compute, can be negative. The axis to compute, can be negative.
top_k : int, optional top_k : int, optional, default=1
The top k results to keep. The top k results to keep.
keep_dims : bool, optional keep_dims : bool, optional, default=False
Whether to keep dims after computing. Whether to keep dims after computing.
Returns Returns
...@@ -439,9 +464,9 @@ def Min(inputs, axis=None, top_k=1, keep_dims=False, **kwargs): ...@@ -439,9 +464,9 @@ def Min(inputs, axis=None, top_k=1, keep_dims=False, **kwargs):
The input tensor. The input tensor.
axis : int, optional axis : int, optional
The axis to compute, can be negative. The axis to compute, can be negative.
top_k : int, optional top_k : int, optional, default=1
The top k results to keep. The top k results to keep.
keep_dims : bool, optional keep_dims : bool, optional, default=False
Whether to keep dims after computing. Whether to keep dims after computing.
Returns Returns
...@@ -494,7 +519,7 @@ def Repeat(inputs, axis=None, repeats=1, **kwargs): ...@@ -494,7 +519,7 @@ def Repeat(inputs, axis=None, repeats=1, **kwargs):
The input tensor. The input tensor.
axis : int, optional axis : int, optional
The axis to repeat. The axis to repeat.
repeats : int or Tensor, optional repeats : int or Tensor, optional, default=1
The magnitude of repeating. The magnitude of repeating.
Returns Returns
...@@ -586,11 +611,11 @@ def OneHot(inputs, depth, on_value=1, off_value=0, **kwargs): ...@@ -586,11 +611,11 @@ def OneHot(inputs, depth, on_value=1, off_value=0, **kwargs):
inputs : Tensor inputs : Tensor
The input tensor. The input tensor.
depth : int depth : int
The depth of one-hot representation. The depth of representation.
on_value : int, optional on_value : int, optional, default=1
The value when ``indices[j] = i``. The value when *indices[j]* = *i*.
off_value : int, optional off_value : int, optional, default=0
The value when ``indices[j] != i``. The value when *indices[j]* != *i*.
Returns Returns
------- -------
...@@ -613,12 +638,12 @@ def Flatten(inputs, axis=0, num_axes=-1, keep_axes=None, **kwargs): ...@@ -613,12 +638,12 @@ def Flatten(inputs, axis=0, num_axes=-1, keep_axes=None, **kwargs):
---------- ----------
inputs : Tensor inputs : Tensor
The input tensor. The input tensor.
axis : int, optional axis : int, optional, default=0
The start axis to flatten, can be negative. The start axis to flatten, can be negative.
num_axes : int, optional num_axes : int, optional, default=-1
The number of axes to flatten. Default is ``-1`` (Along all axes). The number of axes to flatten.
keep_axes : int, optional keep_axes : int, optional
The number of axes to keep. Default is ``None`` (Disabled). The number of axes to keep.
Returns Returns
------- -------
......
...@@ -27,20 +27,16 @@ def LMDBData(**kwargs): ...@@ -27,20 +27,16 @@ def LMDBData(**kwargs):
The path of database. The path of database.
shuffle : bool, optional, default=False shuffle : bool, optional, default=False
Whether to shuffle the data. Whether to shuffle the data.
node_step: bool
Whether to split data for multiple parallel nodes.
num_chunks : int, optional, default=2048 num_chunks : int, optional, default=2048
The number of chunks to split. The number of chunks to split.
chunk_size : int, optional, default=-1
The size(MB) of each chunk.
mean_values : list, optional
The mean value of each image channel.
scale : float, optional, default=1.
The scale performed after mean subtraction.
padding : int, optional, default=0 padding : int, optional, default=0
The zero-padding size. The zero-padding size.
fill_value : int or sequence, optional, default=127
The value(s) to fill for padding or cutout.
crop_size : int, optional, default=0 crop_size : int, optional, default=0
The cropping size. The cropping size.
cutout_size : int, optional, default=0
The square size to cutout.
mirror : bool, optional, default=False mirror : bool, optional, default=False
Whether to mirror(flip horizontally) images. Whether to mirror(flip horizontally) images.
color_augmentation : bool, optional, default=False color_augmentation : bool, optional, default=False
......
...@@ -126,6 +126,7 @@ ArgMin = _array_ops.ArgMin ...@@ -126,6 +126,7 @@ ArgMin = _array_ops.ArgMin
Slice = _array_ops.Slice Slice = _array_ops.Slice
Stack = _array_ops.Stack Stack = _array_ops.Stack
Concat = _array_ops.Concat Concat = _array_ops.Concat
ChannelShuffle = _array_ops.ChannelShuffle
Transpose = _array_ops.Transpose Transpose = _array_ops.Transpose
Repeat = _array_ops.Repeat Repeat = _array_ops.Repeat
Tile = _array_ops.Tile Tile = _array_ops.Tile
......
...@@ -45,8 +45,8 @@ class DataBatch(object): ...@@ -45,8 +45,8 @@ class DataBatch(object):
The number of chunks to split. The number of chunks to split.
padding : int, optional, default=0 padding : int, optional, default=0
The zero-padding size. The zero-padding size.
fill_value : int, optional, default=127 fill_value : int or sequence, optional, default=127
The value to fill when padding is valid. The value(s) to fill for padding or cutout.
crop_size : int, optional, default=0 crop_size : int, optional, default=0
The cropping size. The cropping size.
cutout_size : int, optional, default=0 cutout_size : int, optional, default=0
......
...@@ -120,7 +120,7 @@ class DataReader(multiprocessing.Process): ...@@ -120,7 +120,7 @@ class DataReader(multiprocessing.Process):
self._cursor += 1 self._cursor += 1
def next_chunk(self): def next_chunk(self):
"""Step the cursor of shuffling chunks. """Step the cursor of chunks.
Returns Returns
------- -------
...@@ -166,7 +166,7 @@ class DataReader(multiprocessing.Process): ...@@ -166,7 +166,7 @@ class DataReader(multiprocessing.Process):
# Search a optimal chunk size (Chunk-Wise) # Search a optimal chunk size (Chunk-Wise)
min_size, max_size = \ min_size, max_size = \
1, self._db._total_size * 1.0 \ 1, self._db._total_size * 1.0 \
/ ((self._num_chunks * (1 << 20))) / (self._num_chunks * (1 << 20))
while min_size * 2 < max_size: min_size *= 2 while min_size * 2 < max_size: min_size *= 2
self._perm_size = int(math.ceil( self._perm_size = int(math.ceil(
self._db._total_size * 1.1 / self._db._total_size * 1.1 /
......
...@@ -43,8 +43,8 @@ class DataTransformer(multiprocessing.Process): ...@@ -43,8 +43,8 @@ class DataTransformer(multiprocessing.Process):
---------- ----------
padding : int, optional, default=0 padding : int, optional, default=0
The zero-padding size. The zero-padding size.
fill_value : int, optional, default=127 fill_value : int or sequence, optional, default=127
The value to fill when padding is valid. The value(s) to fill for padding or cutout.
crop_size : int, optional, default=0 crop_size : int, optional, default=0
The cropping size. The cropping size.
cutout_size : int, optional, default=0 cutout_size : int, optional, default=0
...@@ -92,7 +92,7 @@ class DataTransformer(multiprocessing.Process): ...@@ -92,7 +92,7 @@ class DataTransformer(multiprocessing.Process):
The tuple image and labels. The tuple image and labels.
""" """
# decode # Decode
datum = _proto_def.Datum() datum = _proto_def.Datum()
datum.ParseFromString(serialized) datum.ParseFromString(serialized)
im = numpy.fromstring(datum.data, numpy.uint8) im = numpy.fromstring(datum.data, numpy.uint8)
...@@ -115,13 +115,14 @@ class DataTransformer(multiprocessing.Process): ...@@ -115,13 +115,14 @@ class DataTransformer(multiprocessing.Process):
# Padding # Padding
if self._padding > 0: if self._padding > 0:
pad_img = numpy.empty(( pad_im = numpy.empty((
im.shape[0] + 2 * self._padding, im.shape[0] + 2 * self._padding,
im.shape[1] + 2 * self._padding, im.shape[2]), dtype=im.dtype) im.shape[1] + 2 * self._padding, im.shape[2]
pad_img.fill(self._fill_value) ), dtype=im.dtype)
pad_img[self._padding : self._padding + im.shape[0], pad_im[:] = self._fill_value
self._padding : self._padding + im.shape[1], :] = im pad_im[self._padding : self._padding + im.shape[0],
im = pad_img self._padding : self._padding + im.shape[1], :] = im
im = pad_im
# Random crop # Random crop
if self._crop_size > 0: if self._crop_size > 0:
......
...@@ -84,7 +84,6 @@ from .common import ( ...@@ -84,7 +84,6 @@ from .common import (
FlattenLayer, FlattenLayer,
ConcatLayer, ConcatLayer,
NormalizeLayer, NormalizeLayer,
InstanceNormLayer,
TileLayer, TileLayer,
ReductionLayer, ReductionLayer,
ExpandDimsLayer, ExpandDimsLayer,
......
...@@ -370,28 +370,6 @@ class GroupNormLayer(_Layer): ...@@ -370,28 +370,6 @@ class GroupNormLayer(_Layer):
return _ops.GroupNorm(inputs, **self.arguments) return _ops.GroupNorm(inputs, **self.arguments)
class InstanceNormLayer(_Layer):
"""The implementation of *InstanceNormLayer*.
Introduced by `[Ulyanov et.al, 2016] <https://arxiv.org/abs/1607.08022>`_
Parameters
----------
eps : float
Refer ``InstanceNormParameter.eps``.
"""
def __init__(self, LayerParameter):
super(InstanceNormLayer, self).__init__(LayerParameter)
self.arguments = {
'axis': 1,
'eps': LayerParameter.instance_norm_param.eps,
}
def LayerSetup(self, bottom):
return _ops.InstanceNorm(bottom, **self.arguments)
class ScaleLayer(_Layer): class ScaleLayer(_Layer):
"""The implementation of *ScaleLayer*. """The implementation of *ScaleLayer*.
......
...@@ -30,11 +30,15 @@ class DataLayer(_Layer): ...@@ -30,11 +30,15 @@ class DataLayer(_Layer):
The path of database. Refer `DataParameter.source`_. The path of database. Refer `DataParameter.source`_.
prefetch: int prefetch: int
The prefetch count. Refer `DataParameter.prefetch`_. The prefetch count. Refer `DataParameter.prefetch`_.
shuffle : boolean
Whether to shuffle the data. Refer ``DataParameter.shuffle``.
nun_chunks : int
The number of chunks to shuffle. Refer ``DataParameter.num_chunks``.
batch_size : int batch_size : int
The size of a mini-batch. Refer `DataParameter.batch_size`_. The size of a mini-batch. Refer `DataParameter.batch_size`_.
phase : Phase phase : Phase
The phase of layer. Refer `LayerParameter.phase`_. The phase of layer. Refer `LayerParameter.phase`_.
mirrow : boolean mirror : boolean
Whether to randomly mirror. Refer `TransformationParameter.mirror`_. Whether to randomly mirror. Refer `TransformationParameter.mirror`_.
crop_size : int crop_size : int
The crop size. Refer `TransformationParameter.crop_size`_. The crop size. Refer `TransformationParameter.crop_size`_.
...@@ -62,11 +66,12 @@ class DataLayer(_Layer): ...@@ -62,11 +66,12 @@ class DataLayer(_Layer):
param = LayerParameter.data_param param = LayerParameter.data_param
memory_param = LayerParameter.memory_data_param memory_param = LayerParameter.memory_data_param
transform_param = LayerParameter.transform_param transform_param = LayerParameter.transform_param
parallel_param = LayerParameter.parallel_param
self.arguments = { self.arguments = {
'source': param.source, 'source': param.source,
'prefetch': param.prefetch, 'prefetch': param.prefetch,
'shuffle': param.shuffle,
'num_chunks': param.num_chunks,
'batch_size': param.batch_size, 'batch_size': param.batch_size,
'phase': {0: 'TRAIN', 1: 'TEST'}[int(LayerParameter.phase)], 'phase': {0: 'TRAIN', 1: 'TEST'}[int(LayerParameter.phase)],
'mirror': transform_param.mirror, 'mirror': transform_param.mirror,
...@@ -76,9 +81,6 @@ class DataLayer(_Layer): ...@@ -76,9 +81,6 @@ class DataLayer(_Layer):
'padding': transform_param.padding, 'padding': transform_param.padding,
'min_random_scale': transform_param.min_random_scale, 'min_random_scale': transform_param.min_random_scale,
'max_random_scale': transform_param.max_random_scale, 'max_random_scale': transform_param.max_random_scale,
'shuffle': parallel_param.shuffle,
'multiple_nodes': parallel_param.multiple_nodes,
'partition': parallel_param.partition,
'dtype': {0: 'float32', 1: 'float16'}[memory_param.dtype], 'dtype': {0: 'float32', 1: 'float16'}[memory_param.dtype],
'data_format': 'NCHW', 'data_format': 'NCHW',
} }
......
...@@ -416,18 +416,13 @@ message LayerParameter { ...@@ -416,18 +416,13 @@ message LayerParameter {
optional MPIParameter mpi_param = 153; optional MPIParameter mpi_param = 153;
optional PermuteParameter permute_param = 154; optional PermuteParameter permute_param = 154;
optional NormalizeParameter normalize_param = 155; optional NormalizeParameter normalize_param = 155;
optional ParallelParameter parallel_param = 157; optional ResizeParameter resize_param = 156;
optional ResizeParameter resize_param = 158; optional ExpandDimsParameter expand_dims_param = 157;
optional ExpandDimsParameter expand_dims_param = 159; optional ProposalParameter proposal_param = 158;
optional ProposalParameter proposal_param = 160; optional FocalLossParameter focal_loss_param = 159;
optional BatchRenormParameter batch_renorm_param = 161; optional GroupNormParameter group_norm_param = 160;
optional DenseConcatParameter dense_concat_param = 163; optional DropBlockParameter drop_block_param = 161;
optional FocalLossParameter focal_loss_param = 164; optional CastParameter cast_param = 163;
optional GatherParameter gather_param = 165;
optional InstanceNormParameter instance_norm_param = 166;
optional GroupNormParameter group_norm_param = 167;
optional DropBlockParameter drop_block_param = 168;
optional CastParameter cast_param = 169;
} }
// Message that stores parameters used to apply transformation // Message that stores parameters used to apply transformation
...@@ -690,6 +685,10 @@ message DataParameter { ...@@ -690,6 +685,10 @@ message DataParameter {
// Prefetch queue (Number of batches to prefetch to host memory, increase if // Prefetch queue (Number of batches to prefetch to host memory, increase if
// data access bandwidth varies). // data access bandwidth varies).
optional uint32 prefetch = 10 [default = 5]; optional uint32 prefetch = 10 [default = 5];
// Whether to shuffle the data.
optional bool shuffle = 11 [default = false];
// The number of chunks to shuffle.
optional int32 num_chunks = 12 [default = 2048];
} }
message DropoutParameter { message DropoutParameter {
...@@ -1462,12 +1461,6 @@ message NormalizeParameter { ...@@ -1462,12 +1461,6 @@ message NormalizeParameter {
optional float eps = 4 [default = 1e-5]; optional float eps = 4 [default = 1e-5];
} }
message ParallelParameter {
optional bool multiple_nodes = 1 [default = false];
optional bool shuffle = 2 [default = false];
optional bool partition = 3 [default = false];
}
message ResizeParameter { message ResizeParameter {
optional BlobShape shape = 1; optional BlobShape shape = 1;
optional float fx = 2 [default = -1.0]; optional float fx = 2 [default = -1.0];
...@@ -1492,37 +1485,15 @@ message ProposalParameter { ...@@ -1492,37 +1485,15 @@ message ProposalParameter {
optional int32 canonical_level = 11 [default = 4]; optional int32 canonical_level = 11 [default = 4];
} }
message BatchRenormParameter {
optional bool use_global_stats = 1;
optional float moving_average_fraction = 2 [default = 0.9];
optional float eps = 3 [default = 1e-5];
optional float r_max = 4 [default = 3.0];
optional float d_max = 5 [default = 5.0];
optional float t_delta = 6 [default = 0.001];
}
message DenseConcatParameter {
optional int32 axis = 1 [default = 1];
optional int32 growth_rate = 2 [default = 0];
}
message FocalLossParameter { message FocalLossParameter {
optional float alpha = 1 [default = 0.25]; optional float alpha = 1 [default = 0.25];
optional float gamma = 2 [default = 2.0]; optional float gamma = 2 [default = 2.0];
optional int32 neg_id = 3 [default = 0]; optional int32 neg_id = 3 [default = 0];
} }
message GatherParameter {
optional int32 axis = 1 [default = 0];
}
message InstanceNormParameter {
optional float eps = 1 [default = 1e-5];
}
message GroupNormParameter { message GroupNormParameter {
optional float eps = 1 [default = 1e-5]; optional float eps = 1 [default = 1e-5];
optional int32 group = 2 [default = 32]; // The group size optional int32 group = 2 [default = 32];
} }
message DropBlockParameter { message DropBlockParameter {
......
...@@ -34,9 +34,9 @@ from dragon.vm.torch.ops.modules.init import ( ...@@ -34,9 +34,9 @@ from dragon.vm.torch.ops.modules.init import (
) )
from dragon.vm.torch.ops.modules.array import ( from dragon.vm.torch.ops.modules.array import (
Reshape, Squeeze, UnSqueeze, Permute, Reshape, Squeeze, UnSqueeze, Permute,
Indexing, Repeat, Concat, Stack, ChannelShuffle, Repeat, Concat, Stack, Chunk,
IndexSelect, MaskedSelect, Indexing, IndexSelect, MaskedSelect,
Reduce, ArgReduce, Reduce, ArgReduce,
NonZero, Where, NonZero, Where,
OneHot, Multinomial, OneHot, Multinomial,
...@@ -60,7 +60,8 @@ __all__ = [ ...@@ -60,7 +60,8 @@ __all__ = [
'mean', 'sum', 'min', 'max', 'topk', 'mean', 'sum', 'min', 'max', 'topk',
'nonzero', 'where', 'argmin', 'argmax', 'nonzero', 'where', 'argmin', 'argmax',
'gt', 'lt', 'eq', 'ne', 'ge', 'le', 'gt', 'lt', 'eq', 'ne', 'ge', 'le',
'cat', 'stack', 'narrow', 'cat', 'stack', 'chunk',
'narrow', 'channel_shuffle',
'index_select', 'masked_select', 'index_select', 'masked_select',
'one_hot', 'multinomial', 'one_hot', 'multinomial',
'rand', 'randn', 'rand', 'randn',
...@@ -422,7 +423,7 @@ def _reshape(input, shape, shape_like=None): ...@@ -422,7 +423,7 @@ def _reshape(input, shape, shape_like=None):
def _permute(input, perm): def _permute(input, perm):
dev = MakeDevice(inputs=[input]); nperm = len(perm) dev, nperm = MakeDevice([input]), len(perm)
key = 'Permute/{}/nperm:{}'.format(dev, nperm) key = 'Permute/{}/nperm:{}'.format(dev, nperm)
module = get_module(Permute, key, dev, nperm=nperm) module = get_module(Permute, key, dev, nperm=nperm)
return module.forward(input, perm) return module.forward(input, perm)
...@@ -576,7 +577,9 @@ def squeeze(input, dim=None, out=None): ...@@ -576,7 +577,9 @@ def squeeze(input, dim=None, out=None):
Parameters Parameters
---------- ----------
dim : int input : dragon.vm.torch.Tensor
The input tensor.
dim : int, optional
The optional dim to remove. The optional dim to remove.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The output tensor. The output tensor.
...@@ -598,6 +601,8 @@ def unsqueeze(input, dim, out=None): ...@@ -598,6 +601,8 @@ def unsqueeze(input, dim, out=None):
Parameters Parameters
---------- ----------
input : dragon.vm.torch.Tensor
The input tensor.
dim : int dim : int
The dim to remove. The dim to remove.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
...@@ -856,7 +861,7 @@ def le(input, other, out=None): ...@@ -856,7 +861,7 @@ def le(input, other, out=None):
---------- ----------
input : dragon.vm.torch.Tensor input : dragon.vm.torch.Tensor
The input tensor. The input tensor.
other : dragon.vm.torch.Tensor, number other : dragon.vm.torch.Tensor or number
The other tensor. The other tensor.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The optional output tensor.
...@@ -877,7 +882,7 @@ def eq(input, other, out=None): ...@@ -877,7 +882,7 @@ def eq(input, other, out=None):
---------- ----------
input : dragon.vm.torch.Tensor input : dragon.vm.torch.Tensor
The input tensor. The input tensor.
other : dragon.vm.torch.Tensor, number other : dragon.vm.torch.Tensor or number
The other tensor. The other tensor.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The optional output tensor.
...@@ -898,7 +903,7 @@ def ne(input, other, out=None): ...@@ -898,7 +903,7 @@ def ne(input, other, out=None):
---------- ----------
input : dragon.vm.torch.Tensor input : dragon.vm.torch.Tensor
The input tensor. The input tensor.
other : dragon.vm.torch.Tensor, number other : dragon.vm.torch.Tensor or number
The other tensor. The other tensor.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The optional output tensor.
...@@ -930,7 +935,7 @@ def cat(seq, dim=0, out=None): ...@@ -930,7 +935,7 @@ def cat(seq, dim=0, out=None):
The output tensor. The output tensor.
""" """
dev = MakeDevice(inputs=seq, outputs=[out] if out else []) dev = MakeDevice(seq, [out] if out else [])
key = 'Concat/{}/dim:{}'.format(dev, dim) key = 'Concat/{}/dim:{}'.format(dev, dim)
module = get_module(Concat, key, dev, axis=dim) module = get_module(Concat, key, dev, axis=dim)
return module.forward(seq, out) return module.forward(seq, out)
...@@ -943,7 +948,7 @@ def stack(seq, dim=0, out=None): ...@@ -943,7 +948,7 @@ def stack(seq, dim=0, out=None):
---------- ----------
seq : sequence of dragon.vm.torch.Tensor seq : sequence of dragon.vm.torch.Tensor
The sequence. The sequence.
dim : int, optional dim : int, optional, default=0
The dim to stack. The dim to stack.
out : dragon.vm.torch.Tensor, optional out : dragon.vm.torch.Tensor, optional
The optional output tensor. The optional output tensor.
...@@ -960,6 +965,60 @@ def stack(seq, dim=0, out=None): ...@@ -960,6 +965,60 @@ def stack(seq, dim=0, out=None):
return module.forward(seq, out) return module.forward(seq, out)
def channel_shuffle(input, dim=0, group=1, out=None):
"""Shuffle channels between groups along the given axis.
Parameters
----------
input : dragon.vm.torch.Tensor
The input tensor.
dim : int, optional, default=0
The axis of channels.
group : int, optional, default=1
The number of groups.
out : dragon.vm.torch.Tensor, optional
The output tensor.
Returns
-------
dragon.vm.torch.Tensor
The new tensor.
"""
dev = MakeDevice([input])
key = 'ChannelShuffle/{}/dim:{}/group:{}'.format(dev, dim, group)
module = get_module(
ChannelShuffle, key, dev,
axis=dim,
group=group,
)
return module.forward(input, out)
def chunk(tensor, chunks, dim=0):
"""Split the input into several parts along the given axis.
Parameters
----------
tensor : dragon.vm.torch.Tensor
The input to split.
chunks : int
The number of chunks to split.
dim : int, optional, default=0
The dim to split.
Returns
-------
sequence of dragon.vm.torch.Tensor
The output chunks.
"""
dev = MakeDevice([tensor])
key = 'Chunk/{}/chunks:{}/dim:{}'.format(dev, chunks, dim)
module = get_module(Chunk, key, dev, axis=dim, chunks=chunks)
return module.forward(tensor)
def index_select(input, dim, index, out=None): def index_select(input, dim, index, out=None):
"""Select the input values along the given axis using index. """Select the input values along the given axis using index.
...@@ -1047,7 +1106,7 @@ def nonzero(input, out=None): ...@@ -1047,7 +1106,7 @@ def nonzero(input, out=None):
Returns Returns
------- -------
dragon.vm.torch.FloatTensor dragon.vm.torch.Tensor
The output tensor. The output tensor.
""" """
...@@ -1069,7 +1128,7 @@ def one_hot(input, depth): ...@@ -1069,7 +1128,7 @@ def one_hot(input, depth):
Returns Returns
------- -------
dragon.vm.torch.FloatTensor dragon.vm.torch.Tensor
The output tensor. The output tensor.
""" """
...@@ -1137,7 +1196,7 @@ def zeros(*sizes, **kwargs): ...@@ -1137,7 +1196,7 @@ def zeros(*sizes, **kwargs):
Returns Returns
------- -------
vm.torch.FloatTensor dragon.vm.torch.Tensor
The output tensor. The output tensor.
""" """
...@@ -1158,7 +1217,7 @@ def zeros_like(input, out=None, **kwargs): ...@@ -1158,7 +1217,7 @@ def zeros_like(input, out=None, **kwargs):
Returns Returns
------- -------
vm.torch.FloatTensor dragon.vm.torch.Tensor
The output tensor. The output tensor.
""" """
...@@ -1180,7 +1239,7 @@ def ones(*sizes, **kwargs): ...@@ -1180,7 +1239,7 @@ def ones(*sizes, **kwargs):
Returns Returns
------- -------
vm.torch.FloatTensor dragon.vm.torch.Tensor
The output tensor. The output tensor.
""" """
...@@ -1201,7 +1260,7 @@ def ones_like(input, out=None, **kwargs): ...@@ -1201,7 +1260,7 @@ def ones_like(input, out=None, **kwargs):
Returns Returns
------- -------
vm.torch.FloatTensor dragon.vm.torch.Tensor
The output tensor. The output tensor.
""" """
...@@ -1223,7 +1282,7 @@ def rand(*sizes, **kwargs): ...@@ -1223,7 +1282,7 @@ def rand(*sizes, **kwargs):
Returns Returns
------- -------
vm.torch.FloatTensor dragon.vm.torch.Tensor
The output tensor. The output tensor.
""" """
...@@ -1244,7 +1303,7 @@ def randn(*sizes, **kwargs): ...@@ -1244,7 +1303,7 @@ def randn(*sizes, **kwargs):
Returns Returns
------- -------
vm.torch.FloatTensor dragon.vm.torch.Tensor
The output tensor. The output tensor.
""" """
......
...@@ -106,6 +106,32 @@ class Stack(BaseModule): ...@@ -106,6 +106,32 @@ class Stack(BaseModule):
return self.run(inputs, outputs) return self.run(inputs, outputs)
class Chunk(BaseModule):
"""This module imports the *SliceOp* from backend.
Slice the inputs into several parts along the given axis.
"""
def __init__(self, key, dev, **kwargs):
super(Chunk, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', 0)
self.chunks = kwargs.get('chunks', 1)
self.register_op()
def register_op(self):
self.op_meta = {
'op_type': 'Slice',
'arguments': {
'axis': self.axis,
},
}
def forward(self, x):
inputs = [x]; self.unify_devices(inputs)
outputs = [self.register_output() for _ in range(self.chunks)]
return self.run(inputs, outputs)
class IndexSelect(BaseModule): class IndexSelect(BaseModule):
"""This module imports the *IndexSelectOp* from backend. """This module imports the *IndexSelectOp* from backend.
...@@ -315,6 +341,28 @@ class Permute(BaseModule): ...@@ -315,6 +341,28 @@ class Permute(BaseModule):
return self.run(inputs, outputs, callback=callback) return self.run(inputs, outputs, callback=callback)
class ChannelShuffle(BaseModule):
def __init__(self, key, dev, **kwargs):
super(ChannelShuffle, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', 0)
self.group = kwargs.get('group', 1)
self.register_op()
def register_op(self):
self.op_meta = {
'op_type': 'ChannelShuffle',
'arguments': {
'axis': self.axis,
'group': self.group,
},
}
def forward(self, x, y):
inputs = [x]; self.unify_devices(inputs)
outputs = [y] if y else [self.register_output()]
return self.run(inputs, outputs)
class Repeat(BaseModule): class Repeat(BaseModule):
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(Repeat, self).__init__(key, dev, **kwargs) super(Repeat, self).__init__(key, dev, **kwargs)
......
...@@ -23,7 +23,7 @@ from dragon.vm.torch.ops.builtin import ( ...@@ -23,7 +23,7 @@ from dragon.vm.torch.ops.builtin import (
_fundamental, _rfundamental, _fundamental, _rfundamental,
log, exp, sqrt, clamp, log, exp, sqrt, clamp,
_reshape, squeeze, unsqueeze, _reshape, squeeze, unsqueeze,
_permute, _repeat, narrow, _index, _permute, _repeat, chunk, narrow, _index,
_assign, _masked_assign, _assign, _masked_assign,
index_select, masked_select, index_select, masked_select,
mean, sum, max, min, mean, sum, max, min,
...@@ -76,6 +76,7 @@ Tensor.view = lambda self, *shape: _reshape(self, shape) ...@@ -76,6 +76,7 @@ Tensor.view = lambda self, *shape: _reshape(self, shape)
Tensor.view_as = lambda *args, **kwargs: _reshape(*args, **kwargs) Tensor.view_as = lambda *args, **kwargs: _reshape(*args, **kwargs)
Tensor.permute = lambda self, *dims: _permute(self, dims) Tensor.permute = lambda self, *dims: _permute(self, dims)
Tensor.repeat = lambda self, *args: _repeat(self, args) Tensor.repeat = lambda self, *args: _repeat(self, args)
Tensor.chunk = lambda *args, **kwargs: chunk(*args, **kwargs)
Tensor.mean = lambda *args, **kwargs: mean(*args, **kwargs) Tensor.mean = lambda *args, **kwargs: mean(*args, **kwargs)
Tensor.sum = lambda *args, **kwargs: sum(*args, **kwargs) Tensor.sum = lambda *args, **kwargs: sum(*args, **kwargs)
Tensor.max = lambda *args, **kwargs: max(*args, **kwargs) Tensor.max = lambda *args, **kwargs: max(*args, **kwargs)
......
...@@ -355,7 +355,7 @@ class Tensor(object): ...@@ -355,7 +355,7 @@ class Tensor(object):
Parameters Parameters
---------- ----------
other : dragon.vm.torch.Tensor, number other : dragon.vm.torch.Tensor or number
The other tensor. The other tensor.
Returns Returns
...@@ -371,7 +371,7 @@ class Tensor(object): ...@@ -371,7 +371,7 @@ class Tensor(object):
Parameters Parameters
---------- ----------
other : dragon.vm.torch.Tensor, number other : dragon.vm.torch.Tensor or number
The other tensor. The other tensor.
Returns Returns
...@@ -387,7 +387,7 @@ class Tensor(object): ...@@ -387,7 +387,7 @@ class Tensor(object):
Parameters Parameters
---------- ----------
other : dragon.vm.torch.Tensor, number other : dragon.vm.torch.Tensor or number
The other tensor. The other tensor.
Returns Returns
...@@ -403,7 +403,7 @@ class Tensor(object): ...@@ -403,7 +403,7 @@ class Tensor(object):
Parameters Parameters
---------- ----------
other : dragon.vm.torch.Tensor, number other : dragon.vm.torch.Tensor or number
The other tensor. The other tensor.
Returns Returns
...@@ -419,7 +419,7 @@ class Tensor(object): ...@@ -419,7 +419,7 @@ class Tensor(object):
Parameters Parameters
---------- ----------
other : dragon.vm.torch.Tensor, number other : dragon.vm.torch.Tensor or number
The other tensor. The other tensor.
Returns Returns
...@@ -847,6 +847,24 @@ class Tensor(object): ...@@ -847,6 +847,24 @@ class Tensor(object):
""" """
raise NotImplementedError('Refer torch.ops.tensor.repeat') raise NotImplementedError('Refer torch.ops.tensor.repeat')
def chunk(self, chunks, dim=0):
"""Split self into several parts along the given axis.
Parameters
----------
chunks : int
The number of chunks to split.
dim : int, optional
The dim to split.
Returns
-------
sequence of dragon.vm.torch.Tensor
The output chunks.
"""
raise NotImplementedError('Refer torch.ops.tensor.chunk')
def nonzero(self): def nonzero(self):
"""Return the indices of non-zero elements. """Return the indices of non-zero elements.
......
#include "utils/op_kernel.h"
#include "utils/math_functions.h"
namespace dragon {
namespace kernel {
/* <T = ?, Device = CPU> */
template <typename T>
void _ChannelShuffle(
const int outer_dim,
const int inner_dim,
const int G,
const int K,
const T* x,
T* y,
CPUContext* ctx) {
int64_t x_ofs, y_ofs;
for (int n = 0; n < outer_dim; ++n) {
for (int gi = 0; gi < G; ++gi) {
for (int ki = 0; ki < K; ++ki) {
x_ofs = ((n * G + gi) * K + ki) * inner_dim;
y_ofs = ((n * K + ki) * G + gi) * inner_dim;
math::Copy(
inner_dim,
x + x_ofs,
y + y_ofs, ctx
);
}
}
}
}
/* Kernel Launchers */
#define DEFINE_SHUFFLE_KERNEL_LAUNCHER(T) \
template <> void ChannelShuffle<T, CPUContext>( \
const int outer_dim, \
const int inner_dim, \
const int axis_dim, \
const int group, \
const T* x, \
T* y, \
CPUContext* ctx) { \
_ChannelShuffle( \
outer_dim, \
inner_dim, \
group, \
axis_dim / group, \
x, y, ctx \
); \
}
DEFINE_SHUFFLE_KERNEL_LAUNCHER(bool);
DEFINE_SHUFFLE_KERNEL_LAUNCHER(int8_t);
DEFINE_SHUFFLE_KERNEL_LAUNCHER(uint8_t);
DEFINE_SHUFFLE_KERNEL_LAUNCHER(int);
DEFINE_SHUFFLE_KERNEL_LAUNCHER(int64_t);
DEFINE_SHUFFLE_KERNEL_LAUNCHER(float16);
DEFINE_SHUFFLE_KERNEL_LAUNCHER(float);
DEFINE_SHUFFLE_KERNEL_LAUNCHER(double);
#undef DEFINE_SHUFFLE_KERNEL_LAUNCHER
} // namespace kernel
} // namepsace dragon
\ No newline at end of file
#ifdef WITH_CUDA
#include "core/context_cuda.h"
#include "utils/op_kernel.h"
namespace dragon {
namespace kernel {
/* <T = ?, Device = CUDA> */
template <typename T>
__global__ void _ChannelShuffle(
const int nthreads,
const int inner_dim,
const int G,
const int K,
const T* x,
T* y) {
CUDA_1D_KERNEL_LOOP(yi, nthreads) {
const int inner_idx = yi % inner_dim;
const int gi = (yi / inner_dim) % G;
const int ki = (yi / inner_dim / G) % K;
const int outer_idx = yi / inner_dim / G / K;
y[yi] = x[((outer_idx * G + gi) * K + ki
) * inner_dim + inner_idx];
}
}
/* Kernel Launchers */
#define DEFINE_SHUFFLE_KERNEL_LAUNCHER(T) \
template <> void ChannelShuffle<T, CUDAContext>( \
const int outer_dim, \
const int inner_dim, \
const int axis_dim, \
const int group, \
const T* x, \
T* y, \
CUDAContext* ctx) { \
auto nthreads = outer_dim * axis_dim * inner_dim; \
_ChannelShuffle \
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS, \
0, ctx->cuda_stream() >>>( \
nthreads, \
inner_dim, \
group, \
axis_dim / group, \
x, y \
); \
}
DEFINE_SHUFFLE_KERNEL_LAUNCHER(bool);
DEFINE_SHUFFLE_KERNEL_LAUNCHER(int8_t);
DEFINE_SHUFFLE_KERNEL_LAUNCHER(uint8_t);
DEFINE_SHUFFLE_KERNEL_LAUNCHER(int);
DEFINE_SHUFFLE_KERNEL_LAUNCHER(int64_t);
DEFINE_SHUFFLE_KERNEL_LAUNCHER(float16);
DEFINE_SHUFFLE_KERNEL_LAUNCHER(float);
DEFINE_SHUFFLE_KERNEL_LAUNCHER(double);
#undef DEFINE_SHUFFLE_KERNEL_LAUNCHER
} // namespace kernel
} // namepsace dragon
#endif // WITH_CUDA
\ No newline at end of file
#include "core/workspace.h"
#include "utils/op_kernel.h"
#include "operators/array/channel_shuffle_op.h"
namespace dragon {
#define DETERMINE_RUNTIME_ARGS(X) \
axis_ = OpArg<int64_t>("axis", 0); \
axis_ = axis_ < 0 ? axis_ + X.ndim() : axis_; \
CHECK(axis_ >= 0 && axis_ < X.ndim()) \
<< "\nExcepted the axis in [-" << X.ndim() \
<< ", " << X.ndim() << "), got " \
<< OpArg<int64_t>("axis", 0) << ".";
template <class Context> template <typename T>
void ChannelShuffleOp<Context>::RunImpl() {
auto* x = X(0).template data<T, Context>();
auto* y = Y(0)->template mutable_data<T, Context>();
kernel::ChannelShuffle(
outer_dim_,
inner_dim_,
axis_dim_,
group_,
x, y, ctx()
);
}
template <class Context>
void ChannelShuffleOp<Context>::RunOnDevice() {
DETERMINE_RUNTIME_ARGS(X(0));
axis_dim_ = X(0).dim(axis_);
outer_dim_ = X(0).count(0, axis_);
inner_dim_ = X(0).count(axis_ + 1);
CHECK_EQ(axis_dim_ % group_, 0)
<< "\nThe " << axis_dim_ << " channels "
<< "can not be split into " << group_ << " groups.";
Y(0)->ReshapeLike(X(0));
DispatchHelper<TensorTypes
<bool, int8_t, uint8_t, int, int64_t,
float16, float, double>
>::Call(this, X(0));
}
template <class Context> template <typename T>
void ChannelShuffleGradientOp<Context>::RunImpl() {
auto* dy = X(1).template data<T, Context>();
auto* dx = Y(0)->template mutable_data<T, Context>();
kernel::ChannelShuffle(
outer_dim_,
inner_dim_,
axis_dim_,
axis_dim_ / group_,
dy, dx, ctx()
);
}
template <class Context>
void ChannelShuffleGradientOp<Context>::RunOnDevice() {
DETERMINE_RUNTIME_ARGS(X(0));
axis_dim_ = X(0).dim(axis_);
outer_dim_ = X(0).count(0, axis_);
inner_dim_ = X(0).count(axis_ + 1);
CHECK_EQ(axis_dim_ % group_, 0)
<< "\nThe " << axis_dim_ << " channels "
<< "can not be split into " << group_ << " groups.";
Y(0)->ReshapeLike(X(0));
DispatchHelper<TensorTypes
<bool, int8_t, uint8_t, int, int64_t,
float16, float, double>
>::Call(this, X(0));
}
DEPLOY_CPU(ChannelShuffle);
#ifdef WITH_CUDA
DEPLOY_CUDA(ChannelShuffle);
#endif
DEPLOY_CPU(ChannelShuffleGradient);
#ifdef WITH_CUDA
DEPLOY_CUDA(ChannelShuffleGradient);
#endif
OPERATOR_SCHEMA(ChannelShuffle)
/* X */
.NumInputs(1)
/* Y */
.NumOutputs(1);
OPERATOR_SCHEMA(ChannelShuffleGradient)
/* X, dY */
.NumInputs(2)
/* dX */
.NumOutputs(1);
REGISTER_GRADIENT(ChannelShuffle, SimpleGradientMaker);
#undef DETERMINE_RUNTIME_ARGS
} // namespace dragon
\ No newline at end of file
...@@ -12,9 +12,16 @@ namespace dragon { ...@@ -12,9 +12,16 @@ namespace dragon {
<< ", " << Input.ndim() << "), got " \ << ", " << Input.ndim() << "), got " \
<< OpArg<int64_t>("axis", 0) << "."; \ << OpArg<int64_t>("axis", 0) << "."; \
if (points_.empty()) { \ if (points_.empty()) { \
CHECK_EQ(X(0).dim(axis_) % N_, 0) \ auto dim = (X(0).dim(axis_) + N_ - 1) / N_; \
<< "\nSelected dim is " << X(0).dim(axis_) \ sections_ = vec64_t(N_, dim); \
<< ", can't be sliced into " << N_ << " parts."; \ sections_[N_ - 1] = X(0).dim(axis_) - dim * (N_ - 1); \
} else { \
int64_t slice_ofs = 0; \
sections_ = vec64_t(N_); \
for (int i = 0; i < N_; i++) \
sections_[i] = i < N_ - 1 ? \
points_[i] - slice_ofs : \
axis_dim_ - slice_ofs; \
} }
template <class Context> template <typename T> template <class Context> template <typename T>
...@@ -24,16 +31,11 @@ void SliceOp<Context>::RunImpl() { ...@@ -24,16 +31,11 @@ void SliceOp<Context>::RunImpl() {
auto* x = X(0).template data<T, Context>(); auto* x = X(0).template data<T, Context>();
for (int i = 0; i < N_; i++) { for (int i = 0; i < N_; i++) {
if (!points_.empty()) { slice_dim_ = sections_[i];
slice_dim_ = i < N_ - 1 ?
points_[i] - slice_ofs :
axis_dim_ - slice_ofs;
}
CHECK(slice_dim_ > 0 && CHECK(slice_dim_ > 0)
slice_ofs + slice_dim_ <= axis_dim_) << "\nIllegal slicing sections: "
<< "\nIllegal slicing points: " << Tensor::DimString(sections_)
<< Tensor::DimString(points_)
<< " for dimension: " << axis_dim_; << " for dimension: " << axis_dim_;
out_shape[axis_] = slice_dim_; out_shape[axis_] = slice_dim_;
...@@ -77,16 +79,11 @@ void SliceGradientOp<Context>::RunImpl() { ...@@ -77,16 +79,11 @@ void SliceGradientOp<Context>::RunImpl() {
auto* dx = Y(0)->template mutable_data<T, Context>(); auto* dx = Y(0)->template mutable_data<T, Context>();
for (int i = 0; i < N_; i++) { for (int i = 0; i < N_; i++) {
if (!points_.empty()) { slice_dim_ = sections_[i];
slice_dim_ = i < N_ - 1 ?
points_[i] - slice_ofs :
axis_dim_ - slice_ofs;
}
CHECK(slice_dim_ > 0 && CHECK(slice_dim_ > 0)
slice_ofs + slice_dim_ <= axis_dim_)
<< "\nIllegal slicing points: " << "\nIllegal slicing points: "
<< Tensor::DimString(points_) << Tensor::DimString(sections_)
<< " for dimension: " << axis_dim_; << " for dimension: " << axis_dim_;
if (X(i + 1).name() != "NULL") { if (X(i + 1).name() != "NULL") {
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!