Skip to content
Toggle navigation
P
Projects
G
Groups
S
Snippets
Help
SeetaResearch
/
Dragon
This project
Loading...
Sign in
Toggle navigation
Go to a project
Project
Repository
Issues
0
Merge Requests
0
Pipelines
Wiki
Snippets
Settings
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Commit ae11a987
authored
May 29, 2019
by
Ting PAN
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add ChannelShuffleOp
1 parent
6b82cb26
Hide whitespace changes
Inline
Side-by-side
Showing
36 changed files
with
718 additions
and
315 deletions
.gitignore
Docs/api/python/Makefile
Docs/api/python/contents/core/tensor_utils.rst
Docs/api/python/contents/ops.rst
Docs/api/python/contents/vm/caffe/layer.rst
Dragon/CMakeLists.txt
Dragon/include/core/context_cuda.h
Dragon/include/core/graph.h
Dragon/include/core/operator.h
Dragon/include/core/typeid.h
Dragon/include/operators/array/channel_shuffle_op.h
Dragon/include/operators/array/slice_op.h
Dragon/include/operators/misc/image_data_op.h
Dragon/include/utils/op_kernel.h
Dragon/modules/python/py_config.h
Dragon/modules/python/py_tensor.h
Dragon/python/dragon/config.py
Dragon/python/dragon/core/helper.py
Dragon/python/dragon/operators/array.py
Dragon/python/dragon/operators/data.py
Dragon/python/dragon/ops.py
Dragon/python/dragon/utils/vision/data_batch.py
Dragon/python/dragon/utils/vision/data_reader.py
Dragon/python/dragon/utils/vision/data_transformer.py
Dragon/python/dragon/vm/caffe/layers/__init__.py
Dragon/python/dragon/vm/caffe/layers/common.py
Dragon/python/dragon/vm/caffe/layers/data.py
Dragon/python/dragon/vm/caffe/proto/caffe.proto
Dragon/python/dragon/vm/torch/ops/builtin.py
Dragon/python/dragon/vm/torch/ops/modules/array.py
Dragon/python/dragon/vm/torch/ops/tensor.py
Dragon/python/dragon/vm/torch/tensor.py
Dragon/src/kernels/array/channel_shuffle_op_kernel.cc
Dragon/src/kernels/array/channel_shuffle_op_kernel.cu
Dragon/src/operators/array/channel_shuffle_op.cc
Dragon/src/operators/array/slice_op.cc
.gitignore
View file @
ae11a98
## General
# Compiled Object files
# Compiled Object files
*.slo
*.slo
*.lo
*.lo
...
@@ -7,13 +5,15 @@
...
@@ -7,13 +5,15 @@
*.cuo
*.cuo
# Compiled Dynamic libraries
# Compiled Dynamic libraries
# *.so
*.so
*.dll
*.dylib
*.dylib
# Compiled Static libraries
# Compiled Static libraries
*.lai
*.lai
*.la
*.la
#*.a
*.a
*.lib
# Compiled python
# Compiled python
*.pyc
*.pyc
...
@@ -40,6 +40,9 @@ __pycache__
...
@@ -40,6 +40,9 @@ __pycache__
# QtCreator files
# QtCreator files
*.user
*.user
# VSCode files
.vscode
# PyCharm files
# PyCharm files
.idea
.idea
...
...
Docs/api/python/Makefile
View file @
ae11a98
...
@@ -5,7 +5,7 @@
...
@@ -5,7 +5,7 @@
SPHINXOPTS
=
SPHINXOPTS
=
SPHINXBUILD
=
sphinx-build
SPHINXBUILD
=
sphinx-build
PAPER
=
PAPER
=
BUILDDIR
=
_build
BUILDDIR
=
_build
/api
# User-friendly check for sphinx-build
# User-friendly check for sphinx-build
ifeq
($(shell
which
$(SPHINXBUILD)
>/dev/null
2>&1;
echo
$$?),
1)
ifeq
($(shell
which
$(SPHINXBUILD)
>/dev/null
2>&1;
echo
$$?),
1)
...
@@ -52,9 +52,9 @@ clean:
...
@@ -52,9 +52,9 @@ clean:
rm -rf
$(BUILDDIR)
/
*
rm -rf
$(BUILDDIR)
/
*
html
:
html
:
$(SPHINXBUILD)
-b html
$(ALLSPHINXOPTS)
$(BUILDDIR)
/
html
$(SPHINXBUILD)
-b html
$(ALLSPHINXOPTS)
$(BUILDDIR)
/
python
@
echo
@
echo
@
echo
"Build finished. The HTML pages are in
$(BUILDDIR)
/
html
."
@
echo
"Build finished. The HTML pages are in
$(BUILDDIR)
/
python
."
dirhtml
:
dirhtml
:
$(SPHINXBUILD)
-b dirhtml
$(ALLSPHINXOPTS)
$(BUILDDIR)
/dirhtml
$(SPHINXBUILD)
-b dirhtml
$(ALLSPHINXOPTS)
$(BUILDDIR)
/dirhtml
...
...
Docs/api/python/contents/core/tensor_utils.rst
View file @
ae11a98
...
@@ -11,14 +11,12 @@ Quick Reference
...
@@ -11,14 +11,12 @@ Quick Reference
============================== =============================================================================
============================== =============================================================================
List Brief
List Brief
============================== =============================================================================
============================== =============================================================================
`FromShape`_ Create a Tensor from the shape.
`FromShape`_ Create a Tensor from the given shape.
`FromArray`_ Create a Tensor from an existing Array.
`ToArray`_ Create a Array from an existing Tensor.
`SetShape`_ Set a Tensor with the shape.
`SetShape`_ Set a Tensor with the shape.
`FromPyArray`_ Create a Tensor from a existing Array.
`SetArray`_ Set a Tensor from an existing Array.
`SetPyArray`_ Set a Tensor from a existing Array.
`GetStorage`_ Get the storage of an existing Tensor.
`ToPyArray`_ Create a Array from a existing Tensor.
`GetStorage`_ Get the storage of a existing Tensor.
`ToCPUTensor`_ Switch the storage of a existing Tensor on cpu memory.
`ToCUDATensor`_ Switch the storage of a existing Tensor on cuda memory.
============================== =============================================================================
============================== =============================================================================
API Reference
API Reference
...
@@ -28,10 +26,8 @@ API Reference
...
@@ -28,10 +26,8 @@ API Reference
:members:
:members:
.. _FromShape: #dragon.core.tensor_utils.FromShape
.. _FromShape: #dragon.core.tensor_utils.FromShape
.. _FromArray: #dragon.core.tensor_utils.FromArray
.. _ToArray: #dragon.core.tensor_utils.ToArray
.. _SetShape: #dragon.core.tensor_utils.SetShape
.. _SetShape: #dragon.core.tensor_utils.SetShape
.. _FromPyArray: #dragon.core.tensor_utils.FromPyArray
.. _SetArray: #dragon.core.tensor_utils.SetArray
.. _SetPyArray: #dragon.core.tensor_utils.SetPyArray
.. _GetStorage: #dragon.core.tensor_utils.GetStorage
.. _ToPyArray: #dragon.core.tensor_utils.ToPyArray
\ No newline at end of file
.. _GetStorage: #dragon.core.tensor_utils.GetStorage
.. _ToCPUTensor: #dragon.core.tensor_utils.ToCPUTensor
.. _ToCUDATensor: #dragon.core.tensor_utils.ToCUDATensor
\ No newline at end of file
Docs/api/python/contents/ops.rst
View file @
ae11a98
...
@@ -132,37 +132,38 @@ List Brief
...
@@ -132,37 +132,38 @@ List Brief
Array
Array
-----
-----
=============== ======================================================================
================== ======================================================================
List Brief
List Brief
=============== ======================================================================
================== ======================================================================
`Where`_ Select elements from either *x* or *y*.
`Where`_ Select elements from either *x* or *y*.
`IndexSelect`_ Select the elements according to the indices along the given axis.
`IndexSelect`_ Select the elements according to the indices along the given axis.
`MaskedSelect`_ Select the the elements where *mask* is *1*.
`MaskedSelect`_ Select the the elements where *mask* is *1*.
`Reduce`_ Reduce the inputs along the axis in given axes.
`Reduce`_ Reduce the inputs along the axis in given axes.
`Sum`_ Compute the sum along the given axis.
`Sum`_ Compute the sum along the given axis.
`Mean`_ Compute the mean along the given axis.
`Mean`_ Compute the mean along the given axis.
`Max`_ Compute the values of maximum elements along the given axis.
`Max`_ Compute the values of maximum elements along the given axis.
`ArgMax`_ Compute the indices of maximum elements along the given axis.
`ArgMax`_ Compute the indices of maximum elements along the given axis.
`Min`_ Compute the values of minimum elements along the given axis.
`Min`_ Compute the values of minimum elements along the given axis.
`ArgMin`_ Compute the indices of minimum elements along the given axis.
`ArgMin`_ Compute the indices of minimum elements along the given axis.
`Slice`_ Slice the inputs into several parts along the given axis.
`Slice`_ Slice the inputs into several parts along the given axis.
`Stack`_ Stack the inputs along the given axis.
`Stack`_ Stack the inputs along the given axis.
`Concat`_ Concatenate the inputs along the given axis.
`Concat`_ Concatenate the inputs along the given axis.
`Repeat`_ Repeat the input along the given axis.
`ChannelShuffle`_ Shuffle channels between groups along the given axis. `[Zhang et.al, 2017] <https://arxiv.org/abs/1707.01083>`_.
`Transpose`_ Transpose the input according to the given permutations.
`Repeat`_ Repeat the input along the given axis.
`Tile`_ Tile the input according to the given multiples.
`Transpose`_ Transpose the input according to the given permutations.
`Pad`_ Pad the input according to the given sizes.
`Tile`_ Tile the input according to the given multiples.
`Crop`_ Crop the input according to the given starts and sizes.
`Pad`_ Pad the input according to the given sizes.
`OneHot`_ Generate the one-hot representation of inputs.
`Crop`_ Crop the input according to the given starts and sizes.
`Flatten`_ Flatten the input along the given axes.
`OneHot`_ Generate the one-hot representation of inputs.
`Reshape`_ Reshape the dimensions of input.
`Flatten`_ Flatten the input along the given axes.
`Squeeze`_ Remove the dimensions with size 1.
`Reshape`_ Reshape the dimensions of input.
`ExpandDims`_ Expand the new dimension with size 1 to specific axis.
`Squeeze`_ Remove the dimensions with size 1.
`Shape`_ Get the dynamic shape of a Tensor.
`ExpandDims`_ Expand the new dimension with size 1 to specific axis.
`NonZero`_ Return the indices of non-zero elements.
`Shape`_ Get the dynamic shape of a Tensor.
`Arange`_ Return evenly spaced values within a given interval.
`NonZero`_ Return the indices of non-zero elements.
`Multinomial`_ Return indices sampled from the multinomial distribution.
`Arange`_ Return evenly spaced values within a given interval.
=============== ======================================================================
`Multinomial`_ Return indices sampled from the multinomial distribution.
================== ======================================================================
Control Flow
Control Flow
------------
------------
...
@@ -302,6 +303,7 @@ List Brief
...
@@ -302,6 +303,7 @@ List Brief
.. _Slice: operators/array.html#dragon.operators.array.Slice
.. _Slice: operators/array.html#dragon.operators.array.Slice
.. _Stack: operators/array.html#dragon.operators.array.Stack
.. _Stack: operators/array.html#dragon.operators.array.Stack
.. _Concat: operators/array.html#dragon.operators.array.Concat
.. _Concat: operators/array.html#dragon.operators.array.Concat
.. _ChannelShuffle: operators/array.html#dragon.operators.array.ChannelShuffle
.. _Transpose: operators/array.html#dragon.operators.array.Transpose
.. _Transpose: operators/array.html#dragon.operators.array.Transpose
.. _Repeat: operators/array.html#dragon.operators.array.Repeat
.. _Repeat: operators/array.html#dragon.operators.array.Repeat
.. _Tile: operators/array.html#dragon.operators.array.Tile
.. _Tile: operators/array.html#dragon.operators.array.Tile
...
...
Docs/api/python/contents/vm/caffe/layer.rst
View file @
ae11a98
...
@@ -14,8 +14,8 @@ Data
...
@@ -14,8 +14,8 @@ Data
==================== =============================================================================
==================== =============================================================================
List Brief
List Brief
==================== =============================================================================
==================== =============================================================================
`DataLayer`_ The implementation of
``DataLayer``
.
`DataLayer`_ The implementation of
*DataLayer*
.
`MemoryDataLayer`_ The implementation of
``MemoryDataLayer``
.
`MemoryDataLayer`_ The implementation of
*MemoryDataLayer*
.
==================== =============================================================================
==================== =============================================================================
Vision
Vision
...
@@ -24,16 +24,16 @@ Vision
...
@@ -24,16 +24,16 @@ Vision
============================== =============================================================================
============================== =============================================================================
List Brief
List Brief
============================== =============================================================================
============================== =============================================================================
`ConvolutionLayer`_ The implementation of
``ConvolutionLayer``
.
`ConvolutionLayer`_ The implementation of
*ConvolutionLayer*
.
`DepthwiseConvolutionLayer`_ The implementation of
``DepthwiseConvolutionLayer``
.
`DepthwiseConvolutionLayer`_ The implementation of
*DepthwiseConvolutionLayer*
.
`DeconvolutionLayer`_ The implementation of
``DeconvolutionLayer``
.
`DeconvolutionLayer`_ The implementation of
*DeconvolutionLayer*
.
`PoolingLayer`_ The implementation of
``PoolingLayer``
.
`PoolingLayer`_ The implementation of
*PoolingLayer*
.
`ROIPoolingLayer`_ The implementation of
``ROIPoolingLayer``
.
`ROIPoolingLayer`_ The implementation of
*ROIPoolingLayer*
.
`ROIAlignLayer`_ The implementation of
``ROIAlignLayer``
.
`ROIAlignLayer`_ The implementation of
*ROIAlignLayer*
.
`LRNLayer`_ The implementation of
``LRNLayer``
.
`LRNLayer`_ The implementation of
*LRNLayer*
.
`NNResizeLayer`_ The implementation of
``NNResizeLayer``
.
`NNResizeLayer`_ The implementation of
*NNResizeLayer*
.
`BilinearResizeLayer`_ The implementation of
``BilinearResizeLayer``
.
`BilinearResizeLayer`_ The implementation of
*BilinearResizeLayer*
.
`DropBlockLayer`_ The implementation of
``DropBlockLayer``
.
`DropBlockLayer`_ The implementation of
*DropBlockLayer*
.
============================== =============================================================================
============================== =============================================================================
...
@@ -43,14 +43,14 @@ Neuron
...
@@ -43,14 +43,14 @@ Neuron
==================== =============================================================================
==================== =============================================================================
List Brief
List Brief
==================== =============================================================================
==================== =============================================================================
`ReLULayer`_ The implementation of
``ReLULayer``
.
`ReLULayer`_ The implementation of
*ReLULayer*
.
`PReLULayer`_ The implementation of
``PReLULayer``
.
`PReLULayer`_ The implementation of
*PReLULayer*
.
`ELULayer`_ The implementation of
``ELULayer``
.
`ELULayer`_ The implementation of
*ELULayer*
.
`SELULayer`_ The implementation of
``SELULayer``
.
`SELULayer`_ The implementation of
*SELULayer*
.
`SigmoidLayer`_ The implementation of
``SigmoidLayer``
.
`SigmoidLayer`_ The implementation of
*SigmoidLayer*
.
`TanHLayer`_ The implementation of
``TanHLayer``
.
`TanHLayer`_ The implementation of
*TanHLayer*
.
`DropoutLayer`_ The implementation of
``DropoutLayer``
.
`DropoutLayer`_ The implementation of
*DropoutLayer*
.
`PowerLayer`_ The implementation of
``PowerLayer``
.
`PowerLayer`_ The implementation of
*PowerLayer*
.
==================== =============================================================================
==================== =============================================================================
Common
Common
...
@@ -59,31 +59,30 @@ Common
...
@@ -59,31 +59,30 @@ Common
======================== =============================================================================
======================== =============================================================================
List Brief
List Brief
======================== =============================================================================
======================== =============================================================================
`InnerProductLayer`_ The implementation of ``InnerProductLayer``.
`InnerProductLayer`_ The implementation of *InnerProductLayer*.
`AccuracyLayer`_ The implementation of ``AccuracyLayer``.
`AccuracyLayer`_ The implementation of *AccuracyLayer*.
`PythonLayer`_ The implementation of ``PythonLayer``.
`PythonLayer`_ The implementation of *PythonLayer*.
`EltwiseLayer`_ The implementation of ``EltwiseLayer``
`EltwiseLayer`_ The implementation of *EltwiseLayer*
`AddLayer`_ The extended implementation of ``EltwiseLayer``.
`AddLayer`_ The extended implementation of *EltwiseLayer*.
`ConcatLayer`_ The implementation of ``ConcatLayer``.
`ConcatLayer`_ The implementation of *ConcatLayer*.
`SliceLayer`_ The implementation of ``SliceLayer``.
`SliceLayer`_ The implementation of *SliceLayer*.
`CropLayer`_ The implementation of ``CropLayer``.
`CropLayer`_ The implementation of *CropLayer*.
`ReshapeLayer`_ The implementation of ``ReshapeLayer``.
`ReshapeLayer`_ The implementation of *ReshapeLayer*.
`PermuteLayer`_ The implementation of ``PermuteLayer``.
`PermuteLayer`_ The implementation of *PermuteLayer*.
`FlattenLayer`_ The implementation of ``FlattenLayer``.
`FlattenLayer`_ The implementation of *FlattenLayer*.
`GatherLayer`_ The extended implementation for ``GatherOp``.
`GatherLayer`_ The extended implementation for *GatherOp*.
`SoftmaxLayer`_ The implementation of ``SoftmaxLayer``.
`SoftmaxLayer`_ The implementation of *SoftmaxLayer*.
`ArgMaxLayer`_ The implementation of ``ArgMaxLayer``.
`ArgMaxLayer`_ The implementation of *ArgMaxLayer*.
`BatchNormLayer`_ The implementation of ``BatchNormLayer``.
`BatchNormLayer`_ The implementation of *BatchNormLayer*.
`GroupNormLayer`_ The implementation of ``GroupNormLayer``.
`GroupNormLayer`_ The implementation of *GroupNormLayer*.
`InstanceNormLayer`_ The implementation of ``InstanceNormLayer``.
`ScaleLayer`_ The implementation of *ScaleLayer*.
`ScaleLayer`_ The implementation of ``ScaleLayer``.
`BNLayer`_ The implementation of *BNLayer*.
`BNLayer`_ The implementation of ``BNLayer``.
`GNLayer`_ The implementation of *GNLayer*.
`GNLayer`_ The implementation of ``GNLayer``.
`NormalizeLayer`_ The implementation of *NormalizeLayer*.
`NormalizeLayer`_ The implementation of ``NormalizeLayer``.
`TileLayer`_ The extended implementation of *TileLayer*.
`TileLayer`_ The extended implementation of ``TileLayer``.
`ExpandDimsLayer`_ The implementation of *ExpandDimsLayer*.
`ExpandDimsLayer`_ The implementation of ``ExpandDimsLayer``.
`StopGradientLayer`_ The implementation of *StopGradientLayer*.
`StopGradientLayer`_ The implementation of ``StopGradientLayer``.
`ProposalLayer`_ The implementation of *ProposalLayer*.
`ProposalLayer`_ The implementation of ``ProposalLayer``.
======================== =============================================================================
======================== =============================================================================
Loss
Loss
...
@@ -92,12 +91,12 @@ Loss
...
@@ -92,12 +91,12 @@ Loss
================================= =============================================================================
================================= =============================================================================
List Brief
List Brief
================================= =============================================================================
================================= =============================================================================
`SoftmaxWithLossLayer`_ The implementation of
``SoftmaxWithLossLayer``
.
`SoftmaxWithLossLayer`_ The implementation of
*SoftmaxWithLossLayer*
.
`SigmoidCrossEntropyLossLayer`_ The implementation of
``SigmoidCrossEntropyLossLayer``
.
`SigmoidCrossEntropyLossLayer`_ The implementation of
*SigmoidCrossEntropyLossLayer*
.
`L2LossLayer`_ The implementation of
``L2LossLayer``
.
`L2LossLayer`_ The implementation of
*L2LossLayer*
.
`SmoothL1LossLayer`_ The implementation of
``SmoothL1LossLayer``
.
`SmoothL1LossLayer`_ The implementation of
*SmoothL1LossLayer*
.
`SigmoidWithFocalLossLayer`_ The implementation of
``SigmoidWithFocalLossLayer``
.
`SigmoidWithFocalLossLayer`_ The implementation of
*SigmoidWithFocalLossLayer*
.
`SoftmaxWithFocalLossLayer`_ The implementation of
``SoftmaxWithFocalLossLayer``
.
`SoftmaxWithFocalLossLayer`_ The implementation of
*SoftmaxWithFocalLossLayer*
.
================================= =============================================================================
================================= =============================================================================
MPI
MPI
...
@@ -106,8 +105,8 @@ MPI
...
@@ -106,8 +105,8 @@ MPI
================================= =============================================================================
================================= =============================================================================
List Brief
List Brief
================================= =============================================================================
================================= =============================================================================
`MPIBroadcastLayer`_ The implementation of
``MPIBroadcastLayer``
`MPIBroadcastLayer`_ The implementation of
*MPIBroadcastLayer*
`MPIGatherLayer`_ The implementation of
``MPIGatherLayer``
`MPIGatherLayer`_ The implementation of
*MPIGatherLayer*
================================= =============================================================================
================================= =============================================================================
...
@@ -188,7 +187,6 @@ API Reference
...
@@ -188,7 +187,6 @@ API Reference
.. _ArgMaxLayer: #dragon.vm.caffe.layers.common.ArgMaxLayer
.. _ArgMaxLayer: #dragon.vm.caffe.layers.common.ArgMaxLayer
.. _BatchNormLayer: #dragon.vm.caffe.layers.common.BatchNormLayer
.. _BatchNormLayer: #dragon.vm.caffe.layers.common.BatchNormLayer
.. _GroupNormLayer: #dragon.vm.caffe.layers.common.GroupNormLayer
.. _GroupNormLayer: #dragon.vm.caffe.layers.common.GroupNormLayer
.. _InstanceNormLayer: #dragon.vm.caffe.layers.common.InstanceNormLayer
.. _ScaleLayer: #dragon.vm.caffe.layers.common.ScaleLayer
.. _ScaleLayer: #dragon.vm.caffe.layers.common.ScaleLayer
.. _BNLayer: #dragon.vm.caffe.layers.common.BNLayer
.. _BNLayer: #dragon.vm.caffe.layers.common.BNLayer
.. _GNLayer: #dragon.vm.caffe.layers.common.GNLayer
.. _GNLayer: #dragon.vm.caffe.layers.common.GNLayer
...
...
Dragon/CMakeLists.txt
View file @
ae11a98
...
@@ -94,6 +94,9 @@ include_directories(${THIRD_PARTY_DIR}/eigen)
...
@@ -94,6 +94,9 @@ include_directories(${THIRD_PARTY_DIR}/eigen)
include_directories
(
${
THIRD_PARTY_DIR
}
/protobuf/include
)
include_directories
(
${
THIRD_PARTY_DIR
}
/protobuf/include
)
include_directories
(
${
PROJECT_SOURCE_DIR
}
/include
)
include_directories
(
${
PROJECT_SOURCE_DIR
}
/include
)
include_directories
(
${
PROJECT_SOURCE_DIR
}
/src
)
include_directories
(
${
PROJECT_SOURCE_DIR
}
/src
)
if
(
APPLE
)
include_directories
(
/usr/local/include
)
endif
()
if
(
BUILD_PYTHON_API
)
if
(
BUILD_PYTHON_API
)
include_directories
(
${
PYTHON_INCLUDE_DIRS
}
)
include_directories
(
${
PYTHON_INCLUDE_DIRS
}
)
include_directories
(
${
NUMPY_INCLUDE_DIR
}
)
include_directories
(
${
NUMPY_INCLUDE_DIR
}
)
...
@@ -112,6 +115,9 @@ endif()
...
@@ -112,6 +115,9 @@ endif()
# ---[ Lib Directories
# ---[ Lib Directories
list
(
APPEND THIRD_PARTY_LIBRARY_DIRS
${
THIRD_PARTY_DIR
}
/protobuf/lib
)
list
(
APPEND THIRD_PARTY_LIBRARY_DIRS
${
THIRD_PARTY_DIR
}
/protobuf/lib
)
if
(
APPLE
)
list
(
APPEND THIRD_PARTY_LIBRARY_DIRS /usr/local/lib
)
endif
()
if
(
WITH_CUDA
)
if
(
WITH_CUDA
)
list
(
APPEND THIRD_PARTY_LIBRARY_DIRS
${
CUDA_TOOLKIT_ROOT_DIR
}
/lib
)
list
(
APPEND THIRD_PARTY_LIBRARY_DIRS
${
CUDA_TOOLKIT_ROOT_DIR
}
/lib
)
list
(
APPEND THIRD_PARTY_LIBRARY_DIRS
${
CUDA_TOOLKIT_ROOT_DIR
}
/lib64
)
list
(
APPEND THIRD_PARTY_LIBRARY_DIRS
${
CUDA_TOOLKIT_ROOT_DIR
}
/lib64
)
...
@@ -177,8 +183,13 @@ if(WIN32)
...
@@ -177,8 +183,13 @@ if(WIN32)
endif
()
endif
()
endif
()
endif
()
if
(
UNIX
)
if
(
UNIX
)
set
(
CMAKE_C_FLAGS
"
${
CMAKE_C_FLAGS
}
-s -fPIC"
)
if
(
APPLE
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-s -fPIC -O3 -m64 -std=c++11"
)
set
(
CMAKE_C_FLAGS
"
${
CMAKE_C_FLAGS
}
-w -fPIC"
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-w -fPIC -O3 -m64 -std=c++11"
)
else
()
set
(
CMAKE_C_FLAGS
"
${
CMAKE_C_FLAGS
}
-s -fPIC"
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-s -fPIC -O3 -m64 -std=c++11"
)
endif
()
if
(
WITH_OMP
AND
(
NOT APPLE
))
if
(
WITH_OMP
AND
(
NOT APPLE
))
set
(
CMAKE_C_FLAGS
"
${
CMAKE_C_FLAGS
}
-fopenmp"
)
set
(
CMAKE_C_FLAGS
"
${
CMAKE_C_FLAGS
}
-fopenmp"
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-fopenmp"
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-fopenmp"
)
...
...
Dragon/include/core/context_cuda.h
View file @
ae11a98
...
@@ -346,7 +346,7 @@ class CUDAContext {
...
@@ -346,7 +346,7 @@ class CUDAContext {
void
FinishDeviceCompution
()
{
CUDA_NOT_COMPILED
;
}
void
FinishDeviceCompution
()
{
CUDA_NOT_COMPILED
;
}
/*! \brief Malloc the memory */
/*! \brief Malloc the memory */
static
void
*
New
(
size_t
nbytes
)
{
CUDA_NOT_COMPILED
;
}
static
void
*
New
(
size_t
nbytes
)
{
return
nullptr
;
}
/*! \brief Zero-Reset the memory */
/*! \brief Zero-Reset the memory */
static
void
Memset
(
static
void
Memset
(
...
...
Dragon/include/core/graph.h
View file @
ae11a98
...
@@ -21,23 +21,16 @@ namespace dragon {
...
@@ -21,23 +21,16 @@ namespace dragon {
class
GraphBase
{
class
GraphBase
{
public
:
public
:
/*! \brief Default constructor */
/*! \brief Default constructor */
GraphBase
(
GraphBase
(
const
GraphDef
&
,
Workspace
*
);
const
GraphDef
&
def
,
Workspace
*
ws
);
/*! \brief Default deconstructor */
/*! \brief Default deconstructor */
virtual
~
GraphBase
()
{}
virtual
~
GraphBase
()
{}
/*! \brief Create a graph from the optimized def */
/*! \brief Create a graph from the optimized def */
virtual
bool
Create
(
virtual
bool
Create
(
const
GraphDef
&
,
Workspace
*
)
=
0
;
const
GraphDef
&
def
,
Workspace
*
ws
)
=
0
;
/*! \brief Run the graph once synchronously */
/*! \brief Run the graph once synchronously */
virtual
bool
Run
(
virtual
bool
Run
(
const
string
&
,
const
string
&
,
int
=
0
)
=
0
;
const
string
&
include
,
const
string
&
exclude
,
int
stream_id
=
0
)
=
0
;
/*! \brief Return the graph name */
/*! \brief Return the graph name */
string
name
()
const
{
return
name_
;
}
string
name
()
const
{
return
name_
;
}
...
@@ -46,7 +39,7 @@ class GraphBase {
...
@@ -46,7 +39,7 @@ class GraphBase {
const
string
&
phase
()
const
{
return
phase_
;
}
const
string
&
phase
()
const
{
return
phase_
;
}
/*! \brief Return the argument map */
/*! \brief Return the argument map */
const
Map
<
st
d
::
st
ring
,
const
Argument
*>&
args
()
{
return
args_
;
}
const
Map
<
string
,
const
Argument
*>&
args
()
{
return
args_
;
}
/*! \brief Return the specified argument */
/*! \brief Return the specified argument */
const
Argument
&
arg
(
const
string
&
name
)
{
return
*
(
args_
[
name
]);
}
const
Argument
&
arg
(
const
string
&
name
)
{
return
*
(
args_
[
name
]);
}
...
@@ -83,15 +76,10 @@ class Graph : public GraphBase {
...
@@ -83,15 +76,10 @@ class Graph : public GraphBase {
virtual
~
Graph
()
{
for
(
auto
*
op
:
ops_
)
delete
op
;
}
virtual
~
Graph
()
{
for
(
auto
*
op
:
ops_
)
delete
op
;
}
/*! \brief Create a graph from the optimized def */
/*! \brief Create a graph from the optimized def */
bool
Create
(
bool
Create
(
const
GraphDef
&
,
Workspace
*
)
override
;
const
GraphDef
&
def
,
Workspace
*
ws
)
override
;
/*! \brief Run the graph once synchronously */
/*! \brief Run the graph once synchronously */
bool
Run
(
bool
Run
(
const
string
&
,
const
string
&
,
int
=
0
)
override
;
const
string
&
include
,
const
string
&
exclude
,
int
stream_id
=
0
)
override
;
protected
:
protected
:
/*! \brief Store the internal operators */
/*! \brief Store the internal operators */
...
@@ -99,9 +87,7 @@ class Graph : public GraphBase {
...
@@ -99,9 +87,7 @@ class Graph : public GraphBase {
};
};
/*! \brief Create a graph from the raw def */
/*! \brief Create a graph from the raw def */
GraphBase
*
NewGraph
(
GraphBase
*
NewGraph
(
const
GraphDef
&
,
Workspace
*
);
const
GraphDef
&
def
,
Workspace
*
ws
);
/* Macros */
/* Macros */
...
...
Dragon/include/core/operator.h
View file @
ae11a98
...
@@ -31,7 +31,7 @@ class Workspace;
...
@@ -31,7 +31,7 @@ class Workspace;
class
OperatorBase
{
class
OperatorBase
{
public
:
public
:
/*! \brief Default constructor */
/*! \brief Default constructor */
OperatorBase
(
const
OperatorDef
&
def
,
Workspace
*
ws
);
OperatorBase
(
const
OperatorDef
&
,
Workspace
*
);
/*! \brief Default deconstructor */
/*! \brief Default deconstructor */
virtual
~
OperatorBase
()
{}
virtual
~
OperatorBase
()
{}
...
@@ -49,7 +49,7 @@ class OperatorBase {
...
@@ -49,7 +49,7 @@ class OperatorBase {
int
YSize
()
{
return
(
int
)
outputs_
.
size
();
}
int
YSize
()
{
return
(
int
)
outputs_
.
size
();
}
/*! \brief Modify operator according to the given def */
/*! \brief Modify operator according to the given def */
void
UpdateFrom
(
const
OperatorDef
&
def
);
void
UpdateFrom
(
const
OperatorDef
&
);
/*! \brief Switch the internal running phase */
/*! \brief Switch the internal running phase */
void
SwitchToPhase
(
const
string
&
phase
)
{
phase_
=
phase
;
}
void
SwitchToPhase
(
const
string
&
phase
)
{
phase_
=
phase
;
}
...
@@ -95,7 +95,7 @@ class OperatorBase {
...
@@ -95,7 +95,7 @@ class OperatorBase {
vector
<
T
>
Args
(
const
string
&
name
);
vector
<
T
>
Args
(
const
string
&
name
);
/*! \brief Return the argument map */
/*! \brief Return the argument map */
const
Map
<
st
d
::
st
ring
,
const
Argument
*>&
args
()
{
return
args_
;
}
const
Map
<
string
,
const
Argument
*>&
args
()
{
return
args_
;
}
/*! \brief Return the specified argument */
/*! \brief Return the specified argument */
const
Argument
&
arg
(
const
string
&
name
)
{
return
*
(
args_
[
name
]);
}
const
Argument
&
arg
(
const
string
&
name
)
{
return
*
(
args_
[
name
]);
}
...
@@ -212,12 +212,11 @@ class Operator : public OperatorBase {
...
@@ -212,12 +212,11 @@ class Operator : public OperatorBase {
#ifndef WITH_MPI
#ifndef WITH_MPI
return
true
;
return
true
;
#else
#else
vec32_t
allow_ranks
=
auto
ranks
=
OperatorBase
::
Args
<
int
>
(
"mpi_ranks"
);
OperatorBase
::
Args
<
int
>
(
"mpi_ranks"
);
if
(
ranks
.
empty
())
return
true
;
if
(
allow_ranks
.
empty
())
return
true
;
int
cur_rank
;
int
cur_rank
;
MPI_Comm_rank
(
MPI_COMM_WORLD
,
&
cur_rank
);
MPI_Comm_rank
(
MPI_COMM_WORLD
,
&
cur_rank
);
for
(
auto
mpi_rank
:
allow_
ranks
)
for
(
auto
mpi_rank
:
ranks
)
if
(
cur_rank
==
mpi_rank
)
return
true
;
if
(
cur_rank
==
mpi_rank
)
return
true
;
return
false
;
return
false
;
#endif
#endif
...
@@ -225,10 +224,7 @@ class Operator : public OperatorBase {
...
@@ -225,10 +224,7 @@ class Operator : public OperatorBase {
};
};
/*! \brief New a operator from the raw def */
/*! \brief New a operator from the raw def */
OperatorBase
*
NewOperator
(
const
OperatorDef
&
,
Workspace
*
);
OperatorBase
*
NewOperator
(
const
OperatorDef
&
def
,
Workspace
*
ws
);
/* Macros */
/* Macros */
...
...
Dragon/include/core/typeid.h
View file @
ae11a98
...
@@ -40,7 +40,7 @@ class TypeMeta {
...
@@ -40,7 +40,7 @@ class TypeMeta {
TypeMeta
(
const
TypeMeta
&
src
)
TypeMeta
(
const
TypeMeta
&
src
)
:
id_
(
src
.
id_
),
itemsize_
(
src
.
itemsize_
),
:
id_
(
src
.
id_
),
itemsize_
(
src
.
itemsize_
),
ctor_
(
src
.
ctor_
),
copy_
(
src
.
copy_
),
dtor_
(
src
.
dtor_
)
{}
ctor_
(
src
.
ctor_
),
copy_
(
src
.
copy_
),
dtor_
(
src
.
dtor_
)
{}
TypeMeta
&
operator
=
(
const
TypeMeta
&
src
)
{
TypeMeta
&
operator
=
(
const
TypeMeta
&
src
)
{
if
(
this
==
&
src
)
return
*
this
;
if
(
this
==
&
src
)
return
*
this
;
...
...
Dragon/include/operators/array/channel_shuffle_op.h
0 → 100644
View file @
ae11a98
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ARRAY_CHANNEL_SHUFFLE_OP_H_
#define DRAGON_OPERATORS_ARRAY_CHANNEL_SHUFFLE_OP_H_
#include "core/operator.h"
namespace
dragon
{
template
<
class
Context
>
class
ChannelShuffleOp
final
:
public
Operator
<
Context
>
{
public
:
ChannelShuffleOp
(
const
OperatorDef
&
def
,
Workspace
*
ws
)
:
Operator
<
Context
>
(
def
,
ws
),
axis_
(
OpArg
<
int64_t
>
(
"axis"
,
0
)),
group_
(
OpArg
<
int64_t
>
(
"group"
,
1
))
{}
USE_OPERATOR_FUNCTIONS
;
void
RunOnDevice
()
override
;
template
<
typename
T
>
void
RunImpl
();
protected
:
int64_t
outer_dim_
,
inner_dim_
;
int64_t
axis_
,
axis_dim_
,
group_
;
};
template
<
class
Context
>
class
ChannelShuffleGradientOp
final
:
public
Operator
<
Context
>
{
public
:
ChannelShuffleGradientOp
(
const
OperatorDef
&
def
,
Workspace
*
ws
)
:
Operator
<
Context
>
(
def
,
ws
),
axis_
(
OpArg
<
int64_t
>
(
"axis"
,
0
)),
group_
(
OpArg
<
int64_t
>
(
"group"
,
1
))
{}
USE_OPERATOR_FUNCTIONS
;
void
RunOnDevice
()
override
;
template
<
typename
T
>
void
RunImpl
();
protected
:
int64_t
outer_dim_
,
inner_dim_
;
int64_t
axis_
,
axis_dim_
,
group_
;
};
}
// namespace dragon
#endif // DRAGON_OPERATORS_ARRAY_CHANNEL_SHUFFLE_OP_H_
\ No newline at end of file
Dragon/include/operators/array/slice_op.h
View file @
ae11a98
...
@@ -30,7 +30,7 @@ class SliceOp final : public Operator<Context> {
...
@@ -30,7 +30,7 @@ class SliceOp final : public Operator<Context> {
template
<
typename
T
>
void
RunImpl
();
template
<
typename
T
>
void
RunImpl
();
protected
:
protected
:
vec64_t
points_
;
vec64_t
points_
,
sections_
;
int64_t
outer_dim_
,
inner_dim_
;
int64_t
outer_dim_
,
inner_dim_
;
int64_t
axis_
,
axis_dim_
,
slice_dim_
,
N_
;
int64_t
axis_
,
axis_dim_
,
slice_dim_
,
N_
;
};
};
...
@@ -48,7 +48,7 @@ class SliceGradientOp final : public Operator<Context> {
...
@@ -48,7 +48,7 @@ class SliceGradientOp final : public Operator<Context> {
template
<
typename
T
>
void
RunImpl
();
template
<
typename
T
>
void
RunImpl
();
protected
:
protected
:
vec64_t
points_
;
vec64_t
points_
,
sections_
;
int64_t
outer_dim_
,
inner_dim_
;
int64_t
outer_dim_
,
inner_dim_
;
int64_t
axis_
,
axis_dim_
,
slice_dim_
,
N_
;
int64_t
axis_
,
axis_dim_
,
slice_dim_
,
N_
;
};
};
...
...
Dragon/include/operators/misc/image_data_op.h
View file @
ae11a98
...
@@ -27,13 +27,13 @@ class ImageDataOp final : public Operator<Context> {
...
@@ -27,13 +27,13 @@ class ImageDataOp final : public Operator<Context> {
if
(
mean_vec_
.
size
()
>
0
)
{
if
(
mean_vec_
.
size
()
>
0
)
{
CHECK_EQ
((
int
)
mean_vec_
.
size
(),
3
);
CHECK_EQ
((
int
)
mean_vec_
.
size
(),
3
);
auto
*
mean
=
mean_
.
Reshape
({
3
})
auto
*
mean
=
mean_
.
Reshape
({
3
})
->
mutable_data
<
float
,
CPUContext
>
();
->
template
mutable_data
<
float
,
CPUContext
>
();
for
(
int
i
=
0
;
i
<
3
;
++
i
)
mean
[
i
]
=
mean_vec_
[
i
];
for
(
int
i
=
0
;
i
<
3
;
++
i
)
mean
[
i
]
=
mean_vec_
[
i
];
}
}
if
(
std_vec_
.
size
()
>
0
)
{
if
(
std_vec_
.
size
()
>
0
)
{
CHECK_EQ
((
int
)
std_vec_
.
size
(),
3
);
CHECK_EQ
((
int
)
std_vec_
.
size
(),
3
);
auto
*
std
=
std_
.
Reshape
({
3
})
auto
*
std
=
std_
.
Reshape
({
3
})
->
mutable_data
<
float
,
CPUContext
>
();
->
template
mutable_data
<
float
,
CPUContext
>
();
for
(
int
i
=
0
;
i
<
3
;
++
i
)
std
[
i
]
=
std_vec_
[
i
];
for
(
int
i
=
0
;
i
<
3
;
++
i
)
std
[
i
]
=
std_vec_
[
i
];
}
}
}
}
...
...
Dragon/include/utils/op_kernel.h
View file @
ae11a98
...
@@ -378,6 +378,18 @@ void ArgMin(
...
@@ -378,6 +378,18 @@ void ArgMin(
T
*
values
,
T
*
values
,
Context
*
ctx
);
Context
*
ctx
);
/*! array.channel_shuffle */
template
<
typename
T
,
class
Context
>
void
ChannelShuffle
(
const
int
outer_dim
,
const
int
inner_dim
,
const
int
axis_dim
,
const
int
group
,
const
T
*
x
,
T
*
y
,
Context
*
ctx
);
/*! array.concat */
/*! array.concat */
template
<
typename
T
,
class
Context
>
template
<
typename
T
,
class
Context
>
...
...
Dragon/modules/python/py_config.h
View file @
ae11a98
...
@@ -10,7 +10,7 @@
...
@@ -10,7 +10,7 @@
* ------------------------------------------------------------
* ------------------------------------------------------------
*/
*/
#ifndef DRAGON_PYTHON_PY_CON
IF
G_H_
#ifndef DRAGON_PYTHON_PY_CON
FI
G_H_
#define DRAGON_PYTHON_PY_CONFIG_H_
#define DRAGON_PYTHON_PY_CONFIG_H_
#include "py_dragon.h"
#include "py_dragon.h"
...
...
Dragon/modules/python/py_tensor.h
View file @
ae11a98
...
@@ -11,7 +11,7 @@
...
@@ -11,7 +11,7 @@
*/
*/
#ifndef DRAGON_PYTHON_PY_TENSOR_H_
#ifndef DRAGON_PYTHON_PY_TENSOR_H_
#define DRAGON_PYTHON_PY_TEN
OS
R_H_
#define DRAGON_PYTHON_PY_TEN
SO
R_H_
#include "py_dragon.h"
#include "py_dragon.h"
...
...
Dragon/python/dragon/config.py
View file @
ae11a98
...
@@ -98,7 +98,7 @@ def EnableCNML(mlu_id=0):
...
@@ -98,7 +98,7 @@ def EnableCNML(mlu_id=0):
Parameters
Parameters
----------
----------
device
_id : int
mlu
_id : int
The index of MLU to use.
The index of MLU to use.
Returns
Returns
...
@@ -193,14 +193,14 @@ def SetGraphOptimizationLevel(level=3):
...
@@ -193,14 +193,14 @@ def SetGraphOptimizationLevel(level=3):
We have predefined four levels:
We have predefined four levels:
-O0(level=0)
: Do nothing.
*-O0*
: Do nothing.
-O1(level=1)
: Prune the redundant nodes.
*-O1*
: Prune the redundant nodes.
-O2(level=2)
: Add the inplace to outputs.
*-O2*
: Add the inplace to outputs.
Note that the graph will no longer be a DAG.
Note that the graph will no longer be a DAG.
-O3(level=3)
: Allocate the buffer for outputs.
*-O3*
: Allocate the buffer for outputs.
This level is memory-efficient while debugging will be non-trivial.
This level is memory-efficient while debugging will be non-trivial.
Parameters
Parameters
...
...
Dragon/python/dragon/core/helper.py
View file @
ae11a98
...
@@ -33,7 +33,7 @@ class OperatorHelper(object):
...
@@ -33,7 +33,7 @@ class OperatorHelper(object):
# Following operators is the simplest case:
# Following operators is the simplest case:
# Input(0) => Output(0), shape and data type unchanged.
# Input(0) => Output(0), shape and data type unchanged.
'Relu'
,
'PRelu'
,
'Elu'
,
'SElu'
,
'Sigmoid'
,
'Tanh'
,
'Softmax'
,
'Relu'
,
'PRelu'
,
'Elu'
,
'SElu'
,
'Sigmoid'
,
'Tanh'
,
'Softmax'
,
'Dropout'
,
'DropPath'
,
'DropBlock2d'
,
'Dropout'
,
'DropPath'
,
'DropBlock2d'
,
'ChannelShuffle'
,
'Add'
,
'Sub'
,
'Mul'
,
'Div'
,
'Clip'
,
'Log'
,
'Exp'
,
'Pow'
,
'Square'
,
'Sqrt'
,
'Add'
,
'Sub'
,
'Mul'
,
'Div'
,
'Clip'
,
'Log'
,
'Exp'
,
'Pow'
,
'Square'
,
'Sqrt'
,
'Accumulate'
,
'Affine'
,
'Copy'
,
'StopGradient'
,
'MPIBroadcast'
,
'Accumulate'
,
'Affine'
,
'Copy'
,
'StopGradient'
,
'MPIBroadcast'
,
'BatchNorm'
,
'GroupNorm'
,
'L2Norm'
,
'LRN'
,
'BiasAdd'
,
'BatchNorm'
,
'GroupNorm'
,
'L2Norm'
,
'LRN'
,
'BiasAdd'
,
...
...
Dragon/python/dragon/operators/array.py
View file @
ae11a98
...
@@ -52,7 +52,7 @@ def IndexSelect(inputs, indices, axis=0, **kwargs):
...
@@ -52,7 +52,7 @@ def IndexSelect(inputs, indices, axis=0, **kwargs):
The input tensor.
The input tensor.
indices : sequence or Tensor
indices : sequence or Tensor
The indices to select elements.
The indices to select elements.
axis : int, optional
axis : int, optional
, default=0
The axis of indices.
The axis of indices.
Returns
Returns
...
@@ -156,25 +156,23 @@ def Crop(
...
@@ -156,25 +156,23 @@ def Crop(
@OpSchema.Inputs
(
1
)
@OpSchema.Inputs
(
1
)
def
Slice
(
inputs
,
axis
=
0
,
num_
output
s
=
1
,
slice_points
=
None
,
**
kwargs
):
def
Slice
(
inputs
,
axis
=
0
,
num_
slice
s
=
1
,
slice_points
=
None
,
**
kwargs
):
"""Slice the inputs into several parts along the given axis.
"""Slice the inputs into several parts along the given axis.
All dimensions except the specified ``axis`` should be same.
All dimensions except the specified ``axis`` should be same.
The number of ``slice_points`` should be *len(X.shape) - 1*.
The number of ``slice_points`` should be *len(X.shape) - 1*.
if ``slice_points`` is *None*, dimension of axis should be divided by ``num_outputs``.
**Type Constraints**: (*bool*, *int8*, *uint8*, *int32*, *int64*, *float16*, *float32*, *float64*)
**Type Constraints**: (*bool*, *int8*, *uint8*, *int32*, *int64*, *float16*, *float32*, *float64*)
Parameters
Parameters
----------
----------
inputs : Tensor
inputs : Tensor
The input tensor.
The input tensor.
axis : int, optional
axis : int, optional
, default=0
The axis to slice, can be negative.
The axis to slice, can be negative.
num_
outputs : int, optional
num_
slices: int, optional, default=1
The optional number
number
of slices.
The optional number of slices.
slice_points : sequence of int, optional
slice_points : sequence of int, optional
The optional slice points.
The optional slice points.
...
@@ -184,9 +182,12 @@ def Slice(inputs, axis=0, num_outputs=1, slice_points=None, **kwargs):
...
@@ -184,9 +182,12 @@ def Slice(inputs, axis=0, num_outputs=1, slice_points=None, **kwargs):
The outputs.
The outputs.
"""
"""
if
slice_points
is
not
None
and
len
(
slice_points
)
>
0
:
arguments
=
ParseArgs
(
locals
())
num_outputs
=
len
(
slice_points
)
+
1
if
slice_points
is
not
None
:
return
Tensor
.
CreateOperator
(
'Slice'
,
**
ParseArgs
(
locals
()))
arguments
[
'num_outputs'
]
=
len
(
slice_points
)
+
1
else
:
arguments
[
'num_outputs'
]
=
num_slices
return
Tensor
.
CreateOperator
(
'Slice'
,
**
arguments
)
@OpSchema.Inputs
(
1
,
INT_MAX
)
@OpSchema.Inputs
(
1
,
INT_MAX
)
...
@@ -201,7 +202,7 @@ def Stack(inputs, axis=0, **kwargs):
...
@@ -201,7 +202,7 @@ def Stack(inputs, axis=0, **kwargs):
----------
----------
inputs : sequence of Tensor
inputs : sequence of Tensor
The inputs.
The inputs.
axis : int
axis : int
, optional, default=0
The axis to stack, can be negative.
The axis to stack, can be negative.
Returns
Returns
...
@@ -225,7 +226,7 @@ def Concat(inputs, axis=0, **kwargs):
...
@@ -225,7 +226,7 @@ def Concat(inputs, axis=0, **kwargs):
----------
----------
inputs : sequence of Tensor
inputs : sequence of Tensor
The inputs.
The inputs.
axis : int
axis : int
, optional, default=0
The axis to concatenate, can be negative.
The axis to concatenate, can be negative.
Returns
Returns
...
@@ -238,6 +239,30 @@ def Concat(inputs, axis=0, **kwargs):
...
@@ -238,6 +239,30 @@ def Concat(inputs, axis=0, **kwargs):
@OpSchema.Inputs
(
1
)
@OpSchema.Inputs
(
1
)
def
ChannelShuffle
(
inputs
,
axis
=
0
,
group
=
1
,
**
kwargs
):
"""Shuffle channels between groups along the given axis. `[Zhang et.al, 2017] <https://arxiv.org/abs/1707.01083>`_.
**Type Constraints**: (*bool*, *int8*, *uint8*, *int32*, *int64*, *float16*, *float32*, *float64*)
Parameters
----------
inputs : Tensor
The inputs.
axis : int, optional, default=0
The axis of channels, can be negative.
group : int, optional, default=1
The number of groups.
Returns
-------
Tensor
The output tensor.
"""
return
Tensor
.
CreateOperator
(
'ChannelShuffle'
,
**
ParseArgs
(
locals
()))
@OpSchema.Inputs
(
1
)
def
Reduce
(
inputs
,
axes
=
None
,
operation
=
'SUM'
,
keep_dims
=
False
,
**
kwargs
):
def
Reduce
(
inputs
,
axes
=
None
,
operation
=
'SUM'
,
keep_dims
=
False
,
**
kwargs
):
"""Reduce the inputs along the axis in given axes.
"""Reduce the inputs along the axis in given axes.
...
@@ -411,9 +436,9 @@ def ArgMin(inputs, axis=None, top_k=1, keep_dims=False, **kwargs):
...
@@ -411,9 +436,9 @@ def ArgMin(inputs, axis=None, top_k=1, keep_dims=False, **kwargs):
The input tensor.
The input tensor.
axis : int, optional
axis : int, optional
The axis to compute, can be negative.
The axis to compute, can be negative.
top_k : int, optional
top_k : int, optional
, default=1
The top k results to keep.
The top k results to keep.
keep_dims : bool, optional
keep_dims : bool, optional
, default=False
Whether to keep dims after computing.
Whether to keep dims after computing.
Returns
Returns
...
@@ -439,9 +464,9 @@ def Min(inputs, axis=None, top_k=1, keep_dims=False, **kwargs):
...
@@ -439,9 +464,9 @@ def Min(inputs, axis=None, top_k=1, keep_dims=False, **kwargs):
The input tensor.
The input tensor.
axis : int, optional
axis : int, optional
The axis to compute, can be negative.
The axis to compute, can be negative.
top_k : int, optional
top_k : int, optional
, default=1
The top k results to keep.
The top k results to keep.
keep_dims : bool, optional
keep_dims : bool, optional
, default=False
Whether to keep dims after computing.
Whether to keep dims after computing.
Returns
Returns
...
@@ -494,7 +519,7 @@ def Repeat(inputs, axis=None, repeats=1, **kwargs):
...
@@ -494,7 +519,7 @@ def Repeat(inputs, axis=None, repeats=1, **kwargs):
The input tensor.
The input tensor.
axis : int, optional
axis : int, optional
The axis to repeat.
The axis to repeat.
repeats : int or Tensor, optional
repeats : int or Tensor, optional
, default=1
The magnitude of repeating.
The magnitude of repeating.
Returns
Returns
...
@@ -586,11 +611,11 @@ def OneHot(inputs, depth, on_value=1, off_value=0, **kwargs):
...
@@ -586,11 +611,11 @@ def OneHot(inputs, depth, on_value=1, off_value=0, **kwargs):
inputs : Tensor
inputs : Tensor
The input tensor.
The input tensor.
depth : int
depth : int
The depth of
one-hot
representation.
The depth of representation.
on_value : int, optional
on_value : int, optional
, default=1
The value when
``indices[j] = i``
.
The value when
*indices[j]* = *i*
.
off_value : int, optional
off_value : int, optional
, default=0
The value when
``indices[j] != i``
.
The value when
*indices[j]* != *i*
.
Returns
Returns
-------
-------
...
@@ -613,12 +638,12 @@ def Flatten(inputs, axis=0, num_axes=-1, keep_axes=None, **kwargs):
...
@@ -613,12 +638,12 @@ def Flatten(inputs, axis=0, num_axes=-1, keep_axes=None, **kwargs):
----------
----------
inputs : Tensor
inputs : Tensor
The input tensor.
The input tensor.
axis : int, optional
axis : int, optional
, default=0
The start axis to flatten, can be negative.
The start axis to flatten, can be negative.
num_axes : int, optional
num_axes : int, optional
, default=-1
The number of axes to flatten.
Default is ``-1`` (Along all axes).
The number of axes to flatten.
keep_axes : int, optional
keep_axes : int, optional
The number of axes to keep.
Default is ``None`` (Disabled).
The number of axes to keep.
Returns
Returns
-------
-------
...
...
Dragon/python/dragon/operators/data.py
View file @
ae11a98
...
@@ -27,20 +27,16 @@ def LMDBData(**kwargs):
...
@@ -27,20 +27,16 @@ def LMDBData(**kwargs):
The path of database.
The path of database.
shuffle : bool, optional, default=False
shuffle : bool, optional, default=False
Whether to shuffle the data.
Whether to shuffle the data.
node_step: bool
Whether to split data for multiple parallel nodes.
num_chunks : int, optional, default=2048
num_chunks : int, optional, default=2048
The number of chunks to split.
The number of chunks to split.
chunk_size : int, optional, default=-1
The size(MB) of each chunk.
mean_values : list, optional
The mean value of each image channel.
scale : float, optional, default=1.
The scale performed after mean subtraction.
padding : int, optional, default=0
padding : int, optional, default=0
The zero-padding size.
The zero-padding size.
fill_value : int or sequence, optional, default=127
The value(s) to fill for padding or cutout.
crop_size : int, optional, default=0
crop_size : int, optional, default=0
The cropping size.
The cropping size.
cutout_size : int, optional, default=0
The square size to cutout.
mirror : bool, optional, default=False
mirror : bool, optional, default=False
Whether to mirror(flip horizontally) images.
Whether to mirror(flip horizontally) images.
color_augmentation : bool, optional, default=False
color_augmentation : bool, optional, default=False
...
...
Dragon/python/dragon/ops.py
View file @
ae11a98
...
@@ -126,6 +126,7 @@ ArgMin = _array_ops.ArgMin
...
@@ -126,6 +126,7 @@ ArgMin = _array_ops.ArgMin
Slice
=
_array_ops
.
Slice
Slice
=
_array_ops
.
Slice
Stack
=
_array_ops
.
Stack
Stack
=
_array_ops
.
Stack
Concat
=
_array_ops
.
Concat
Concat
=
_array_ops
.
Concat
ChannelShuffle
=
_array_ops
.
ChannelShuffle
Transpose
=
_array_ops
.
Transpose
Transpose
=
_array_ops
.
Transpose
Repeat
=
_array_ops
.
Repeat
Repeat
=
_array_ops
.
Repeat
Tile
=
_array_ops
.
Tile
Tile
=
_array_ops
.
Tile
...
...
Dragon/python/dragon/utils/vision/data_batch.py
View file @
ae11a98
...
@@ -45,8 +45,8 @@ class DataBatch(object):
...
@@ -45,8 +45,8 @@ class DataBatch(object):
The number of chunks to split.
The number of chunks to split.
padding : int, optional, default=0
padding : int, optional, default=0
The zero-padding size.
The zero-padding size.
fill_value : int, optional, default=127
fill_value : int
or sequence
, optional, default=127
The value
to fill when padding is valid
.
The value
(s) to fill for padding or cutout
.
crop_size : int, optional, default=0
crop_size : int, optional, default=0
The cropping size.
The cropping size.
cutout_size : int, optional, default=0
cutout_size : int, optional, default=0
...
...
Dragon/python/dragon/utils/vision/data_reader.py
View file @
ae11a98
...
@@ -120,7 +120,7 @@ class DataReader(multiprocessing.Process):
...
@@ -120,7 +120,7 @@ class DataReader(multiprocessing.Process):
self
.
_cursor
+=
1
self
.
_cursor
+=
1
def
next_chunk
(
self
):
def
next_chunk
(
self
):
"""Step the cursor of
shuffling
chunks.
"""Step the cursor of chunks.
Returns
Returns
-------
-------
...
@@ -166,7 +166,7 @@ class DataReader(multiprocessing.Process):
...
@@ -166,7 +166,7 @@ class DataReader(multiprocessing.Process):
# Search a optimal chunk size (Chunk-Wise)
# Search a optimal chunk size (Chunk-Wise)
min_size
,
max_size
=
\
min_size
,
max_size
=
\
1
,
self
.
_db
.
_total_size
*
1.0
\
1
,
self
.
_db
.
_total_size
*
1.0
\
/
(
(
self
.
_num_chunks
*
(
1
<<
20
)
))
/
(
self
.
_num_chunks
*
(
1
<<
20
))
while
min_size
*
2
<
max_size
:
min_size
*=
2
while
min_size
*
2
<
max_size
:
min_size
*=
2
self
.
_perm_size
=
int
(
math
.
ceil
(
self
.
_perm_size
=
int
(
math
.
ceil
(
self
.
_db
.
_total_size
*
1.1
/
self
.
_db
.
_total_size
*
1.1
/
...
...
Dragon/python/dragon/utils/vision/data_transformer.py
View file @
ae11a98
...
@@ -43,8 +43,8 @@ class DataTransformer(multiprocessing.Process):
...
@@ -43,8 +43,8 @@ class DataTransformer(multiprocessing.Process):
----------
----------
padding : int, optional, default=0
padding : int, optional, default=0
The zero-padding size.
The zero-padding size.
fill_value : int, optional, default=127
fill_value : int
or sequence
, optional, default=127
The value
to fill when padding is valid
.
The value
(s) to fill for padding or cutout
.
crop_size : int, optional, default=0
crop_size : int, optional, default=0
The cropping size.
The cropping size.
cutout_size : int, optional, default=0
cutout_size : int, optional, default=0
...
@@ -92,7 +92,7 @@ class DataTransformer(multiprocessing.Process):
...
@@ -92,7 +92,7 @@ class DataTransformer(multiprocessing.Process):
The tuple image and labels.
The tuple image and labels.
"""
"""
#
d
ecode
#
D
ecode
datum
=
_proto_def
.
Datum
()
datum
=
_proto_def
.
Datum
()
datum
.
ParseFromString
(
serialized
)
datum
.
ParseFromString
(
serialized
)
im
=
numpy
.
fromstring
(
datum
.
data
,
numpy
.
uint8
)
im
=
numpy
.
fromstring
(
datum
.
data
,
numpy
.
uint8
)
...
@@ -115,13 +115,14 @@ class DataTransformer(multiprocessing.Process):
...
@@ -115,13 +115,14 @@ class DataTransformer(multiprocessing.Process):
# Padding
# Padding
if
self
.
_padding
>
0
:
if
self
.
_padding
>
0
:
pad_im
g
=
numpy
.
empty
((
pad_im
=
numpy
.
empty
((
im
.
shape
[
0
]
+
2
*
self
.
_padding
,
im
.
shape
[
0
]
+
2
*
self
.
_padding
,
im
.
shape
[
1
]
+
2
*
self
.
_padding
,
im
.
shape
[
2
]),
dtype
=
im
.
dtype
)
im
.
shape
[
1
]
+
2
*
self
.
_padding
,
im
.
shape
[
2
]
pad_img
.
fill
(
self
.
_fill_value
)
),
dtype
=
im
.
dtype
)
pad_img
[
self
.
_padding
:
self
.
_padding
+
im
.
shape
[
0
],
pad_im
[:]
=
self
.
_fill_value
self
.
_padding
:
self
.
_padding
+
im
.
shape
[
1
],
:]
=
im
pad_im
[
self
.
_padding
:
self
.
_padding
+
im
.
shape
[
0
],
im
=
pad_img
self
.
_padding
:
self
.
_padding
+
im
.
shape
[
1
],
:]
=
im
im
=
pad_im
# Random crop
# Random crop
if
self
.
_crop_size
>
0
:
if
self
.
_crop_size
>
0
:
...
...
Dragon/python/dragon/vm/caffe/layers/__init__.py
View file @
ae11a98
...
@@ -84,7 +84,6 @@ from .common import (
...
@@ -84,7 +84,6 @@ from .common import (
FlattenLayer
,
FlattenLayer
,
ConcatLayer
,
ConcatLayer
,
NormalizeLayer
,
NormalizeLayer
,
InstanceNormLayer
,
TileLayer
,
TileLayer
,
ReductionLayer
,
ReductionLayer
,
ExpandDimsLayer
,
ExpandDimsLayer
,
...
...
Dragon/python/dragon/vm/caffe/layers/common.py
View file @
ae11a98
...
@@ -370,28 +370,6 @@ class GroupNormLayer(_Layer):
...
@@ -370,28 +370,6 @@ class GroupNormLayer(_Layer):
return
_ops
.
GroupNorm
(
inputs
,
**
self
.
arguments
)
return
_ops
.
GroupNorm
(
inputs
,
**
self
.
arguments
)
class
InstanceNormLayer
(
_Layer
):
"""The implementation of *InstanceNormLayer*.
Introduced by `[Ulyanov et.al, 2016] <https://arxiv.org/abs/1607.08022>`_
Parameters
----------
eps : float
Refer ``InstanceNormParameter.eps``.
"""
def
__init__
(
self
,
LayerParameter
):
super
(
InstanceNormLayer
,
self
)
.
__init__
(
LayerParameter
)
self
.
arguments
=
{
'axis'
:
1
,
'eps'
:
LayerParameter
.
instance_norm_param
.
eps
,
}
def
LayerSetup
(
self
,
bottom
):
return
_ops
.
InstanceNorm
(
bottom
,
**
self
.
arguments
)
class
ScaleLayer
(
_Layer
):
class
ScaleLayer
(
_Layer
):
"""The implementation of *ScaleLayer*.
"""The implementation of *ScaleLayer*.
...
...
Dragon/python/dragon/vm/caffe/layers/data.py
View file @
ae11a98
...
@@ -30,11 +30,15 @@ class DataLayer(_Layer):
...
@@ -30,11 +30,15 @@ class DataLayer(_Layer):
The path of database. Refer `DataParameter.source`_.
The path of database. Refer `DataParameter.source`_.
prefetch: int
prefetch: int
The prefetch count. Refer `DataParameter.prefetch`_.
The prefetch count. Refer `DataParameter.prefetch`_.
shuffle : boolean
Whether to shuffle the data. Refer ``DataParameter.shuffle``.
nun_chunks : int
The number of chunks to shuffle. Refer ``DataParameter.num_chunks``.
batch_size : int
batch_size : int
The size of a mini-batch. Refer `DataParameter.batch_size`_.
The size of a mini-batch. Refer `DataParameter.batch_size`_.
phase : Phase
phase : Phase
The phase of layer. Refer `LayerParameter.phase`_.
The phase of layer. Refer `LayerParameter.phase`_.
mirro
w
: boolean
mirro
r
: boolean
Whether to randomly mirror. Refer `TransformationParameter.mirror`_.
Whether to randomly mirror. Refer `TransformationParameter.mirror`_.
crop_size : int
crop_size : int
The crop size. Refer `TransformationParameter.crop_size`_.
The crop size. Refer `TransformationParameter.crop_size`_.
...
@@ -62,11 +66,12 @@ class DataLayer(_Layer):
...
@@ -62,11 +66,12 @@ class DataLayer(_Layer):
param
=
LayerParameter
.
data_param
param
=
LayerParameter
.
data_param
memory_param
=
LayerParameter
.
memory_data_param
memory_param
=
LayerParameter
.
memory_data_param
transform_param
=
LayerParameter
.
transform_param
transform_param
=
LayerParameter
.
transform_param
parallel_param
=
LayerParameter
.
parallel_param
self
.
arguments
=
{
self
.
arguments
=
{
'source'
:
param
.
source
,
'source'
:
param
.
source
,
'prefetch'
:
param
.
prefetch
,
'prefetch'
:
param
.
prefetch
,
'shuffle'
:
param
.
shuffle
,
'num_chunks'
:
param
.
num_chunks
,
'batch_size'
:
param
.
batch_size
,
'batch_size'
:
param
.
batch_size
,
'phase'
:
{
0
:
'TRAIN'
,
1
:
'TEST'
}[
int
(
LayerParameter
.
phase
)],
'phase'
:
{
0
:
'TRAIN'
,
1
:
'TEST'
}[
int
(
LayerParameter
.
phase
)],
'mirror'
:
transform_param
.
mirror
,
'mirror'
:
transform_param
.
mirror
,
...
@@ -76,9 +81,6 @@ class DataLayer(_Layer):
...
@@ -76,9 +81,6 @@ class DataLayer(_Layer):
'padding'
:
transform_param
.
padding
,
'padding'
:
transform_param
.
padding
,
'min_random_scale'
:
transform_param
.
min_random_scale
,
'min_random_scale'
:
transform_param
.
min_random_scale
,
'max_random_scale'
:
transform_param
.
max_random_scale
,
'max_random_scale'
:
transform_param
.
max_random_scale
,
'shuffle'
:
parallel_param
.
shuffle
,
'multiple_nodes'
:
parallel_param
.
multiple_nodes
,
'partition'
:
parallel_param
.
partition
,
'dtype'
:
{
0
:
'float32'
,
1
:
'float16'
}[
memory_param
.
dtype
],
'dtype'
:
{
0
:
'float32'
,
1
:
'float16'
}[
memory_param
.
dtype
],
'data_format'
:
'NCHW'
,
'data_format'
:
'NCHW'
,
}
}
...
...
Dragon/python/dragon/vm/caffe/proto/caffe.proto
View file @
ae11a98
...
@@ -416,18 +416,13 @@ message LayerParameter {
...
@@ -416,18 +416,13 @@ message LayerParameter {
optional
MPIParameter
mpi_param
=
153
;
optional
MPIParameter
mpi_param
=
153
;
optional
PermuteParameter
permute_param
=
154
;
optional
PermuteParameter
permute_param
=
154
;
optional
NormalizeParameter
normalize_param
=
155
;
optional
NormalizeParameter
normalize_param
=
155
;
optional
ParallelParameter
parallel_param
=
157
;
optional
ResizeParameter
resize_param
=
156
;
optional
ResizeParameter
resize_param
=
158
;
optional
ExpandDimsParameter
expand_dims_param
=
157
;
optional
ExpandDimsParameter
expand_dims_param
=
159
;
optional
ProposalParameter
proposal_param
=
158
;
optional
ProposalParameter
proposal_param
=
160
;
optional
FocalLossParameter
focal_loss_param
=
159
;
optional
BatchRenormParameter
batch_renorm_param
=
161
;
optional
GroupNormParameter
group_norm_param
=
160
;
optional
DenseConcatParameter
dense_concat_param
=
163
;
optional
DropBlockParameter
drop_block_param
=
161
;
optional
FocalLossParameter
focal_loss_param
=
164
;
optional
CastParameter
cast_param
=
163
;
optional
GatherParameter
gather_param
=
165
;
optional
InstanceNormParameter
instance_norm_param
=
166
;
optional
GroupNormParameter
group_norm_param
=
167
;
optional
DropBlockParameter
drop_block_param
=
168
;
optional
CastParameter
cast_param
=
169
;
}
}
// Message that stores parameters used to apply transformation
// Message that stores parameters used to apply transformation
...
@@ -690,6 +685,10 @@ message DataParameter {
...
@@ -690,6 +685,10 @@ message DataParameter {
// Prefetch queue (Number of batches to prefetch to host memory, increase if
// Prefetch queue (Number of batches to prefetch to host memory, increase if
// data access bandwidth varies).
// data access bandwidth varies).
optional
uint32
prefetch
=
10
[
default
=
5
];
optional
uint32
prefetch
=
10
[
default
=
5
];
// Whether to shuffle the data.
optional
bool
shuffle
=
11
[
default
=
false
];
// The number of chunks to shuffle.
optional
int32
num_chunks
=
12
[
default
=
2048
];
}
}
message
DropoutParameter
{
message
DropoutParameter
{
...
@@ -1462,12 +1461,6 @@ message NormalizeParameter {
...
@@ -1462,12 +1461,6 @@ message NormalizeParameter {
optional
float
eps
=
4
[
default
=
1e-5
];
optional
float
eps
=
4
[
default
=
1e-5
];
}
}
message
ParallelParameter
{
optional
bool
multiple_nodes
=
1
[
default
=
false
];
optional
bool
shuffle
=
2
[
default
=
false
];
optional
bool
partition
=
3
[
default
=
false
];
}
message
ResizeParameter
{
message
ResizeParameter
{
optional
BlobShape
shape
=
1
;
optional
BlobShape
shape
=
1
;
optional
float
fx
=
2
[
default
=
-
1.0
];
optional
float
fx
=
2
[
default
=
-
1.0
];
...
@@ -1492,37 +1485,15 @@ message ProposalParameter {
...
@@ -1492,37 +1485,15 @@ message ProposalParameter {
optional
int32
canonical_level
=
11
[
default
=
4
];
optional
int32
canonical_level
=
11
[
default
=
4
];
}
}
message
BatchRenormParameter
{
optional
bool
use_global_stats
=
1
;
optional
float
moving_average_fraction
=
2
[
default
=
0.9
];
optional
float
eps
=
3
[
default
=
1e-5
];
optional
float
r_max
=
4
[
default
=
3.0
];
optional
float
d_max
=
5
[
default
=
5.0
];
optional
float
t_delta
=
6
[
default
=
0.001
];
}
message
DenseConcatParameter
{
optional
int32
axis
=
1
[
default
=
1
];
optional
int32
growth_rate
=
2
[
default
=
0
];
}
message
FocalLossParameter
{
message
FocalLossParameter
{
optional
float
alpha
=
1
[
default
=
0.25
];
optional
float
alpha
=
1
[
default
=
0.25
];
optional
float
gamma
=
2
[
default
=
2.0
];
optional
float
gamma
=
2
[
default
=
2.0
];
optional
int32
neg_id
=
3
[
default
=
0
];
optional
int32
neg_id
=
3
[
default
=
0
];
}
}
message
GatherParameter
{
optional
int32
axis
=
1
[
default
=
0
];
}
message
InstanceNormParameter
{
optional
float
eps
=
1
[
default
=
1e-5
];
}
message
GroupNormParameter
{
message
GroupNormParameter
{
optional
float
eps
=
1
[
default
=
1e-5
];
optional
float
eps
=
1
[
default
=
1e-5
];
optional
int32
group
=
2
[
default
=
32
];
// The group size
optional
int32
group
=
2
[
default
=
32
];
}
}
message
DropBlockParameter
{
message
DropBlockParameter
{
...
...
Dragon/python/dragon/vm/torch/ops/builtin.py
View file @
ae11a98
...
@@ -34,9 +34,9 @@ from dragon.vm.torch.ops.modules.init import (
...
@@ -34,9 +34,9 @@ from dragon.vm.torch.ops.modules.init import (
)
)
from
dragon.vm.torch.ops.modules.array
import
(
from
dragon.vm.torch.ops.modules.array
import
(
Reshape
,
Squeeze
,
UnSqueeze
,
Permute
,
Reshape
,
Squeeze
,
UnSqueeze
,
Permute
,
Indexing
,
Repeat
,
Concat
,
Stac
k
,
ChannelShuffle
,
Repeat
,
Concat
,
Stack
,
Chun
k
,
IndexSelect
,
MaskedSelect
,
Index
ing
,
Index
Select
,
MaskedSelect
,
Reduce
,
ArgReduce
,
Reduce
,
ArgReduce
,
NonZero
,
Where
,
NonZero
,
Where
,
OneHot
,
Multinomial
,
OneHot
,
Multinomial
,
...
@@ -60,7 +60,8 @@ __all__ = [
...
@@ -60,7 +60,8 @@ __all__ = [
'mean'
,
'sum'
,
'min'
,
'max'
,
'topk'
,
'mean'
,
'sum'
,
'min'
,
'max'
,
'topk'
,
'nonzero'
,
'where'
,
'argmin'
,
'argmax'
,
'nonzero'
,
'where'
,
'argmin'
,
'argmax'
,
'gt'
,
'lt'
,
'eq'
,
'ne'
,
'ge'
,
'le'
,
'gt'
,
'lt'
,
'eq'
,
'ne'
,
'ge'
,
'le'
,
'cat'
,
'stack'
,
'narrow'
,
'cat'
,
'stack'
,
'chunk'
,
'narrow'
,
'channel_shuffle'
,
'index_select'
,
'masked_select'
,
'index_select'
,
'masked_select'
,
'one_hot'
,
'multinomial'
,
'one_hot'
,
'multinomial'
,
'rand'
,
'randn'
,
'rand'
,
'randn'
,
...
@@ -422,7 +423,7 @@ def _reshape(input, shape, shape_like=None):
...
@@ -422,7 +423,7 @@ def _reshape(input, shape, shape_like=None):
def
_permute
(
input
,
perm
):
def
_permute
(
input
,
perm
):
dev
=
MakeDevice
(
inputs
=
[
input
]);
nperm
=
len
(
perm
)
dev
,
nperm
=
MakeDevice
([
input
]),
len
(
perm
)
key
=
'Permute/{}/nperm:{}'
.
format
(
dev
,
nperm
)
key
=
'Permute/{}/nperm:{}'
.
format
(
dev
,
nperm
)
module
=
get_module
(
Permute
,
key
,
dev
,
nperm
=
nperm
)
module
=
get_module
(
Permute
,
key
,
dev
,
nperm
=
nperm
)
return
module
.
forward
(
input
,
perm
)
return
module
.
forward
(
input
,
perm
)
...
@@ -576,7 +577,9 @@ def squeeze(input, dim=None, out=None):
...
@@ -576,7 +577,9 @@ def squeeze(input, dim=None, out=None):
Parameters
Parameters
----------
----------
dim : int
input : dragon.vm.torch.Tensor
The input tensor.
dim : int, optional
The optional dim to remove.
The optional dim to remove.
out : dragon.vm.torch.Tensor, optional
out : dragon.vm.torch.Tensor, optional
The output tensor.
The output tensor.
...
@@ -598,6 +601,8 @@ def unsqueeze(input, dim, out=None):
...
@@ -598,6 +601,8 @@ def unsqueeze(input, dim, out=None):
Parameters
Parameters
----------
----------
input : dragon.vm.torch.Tensor
The input tensor.
dim : int
dim : int
The dim to remove.
The dim to remove.
out : dragon.vm.torch.Tensor, optional
out : dragon.vm.torch.Tensor, optional
...
@@ -856,7 +861,7 @@ def le(input, other, out=None):
...
@@ -856,7 +861,7 @@ def le(input, other, out=None):
----------
----------
input : dragon.vm.torch.Tensor
input : dragon.vm.torch.Tensor
The input tensor.
The input tensor.
other : dragon.vm.torch.Tensor
,
number
other : dragon.vm.torch.Tensor
or
number
The other tensor.
The other tensor.
out : dragon.vm.torch.Tensor, optional
out : dragon.vm.torch.Tensor, optional
The optional output tensor.
The optional output tensor.
...
@@ -877,7 +882,7 @@ def eq(input, other, out=None):
...
@@ -877,7 +882,7 @@ def eq(input, other, out=None):
----------
----------
input : dragon.vm.torch.Tensor
input : dragon.vm.torch.Tensor
The input tensor.
The input tensor.
other : dragon.vm.torch.Tensor
,
number
other : dragon.vm.torch.Tensor
or
number
The other tensor.
The other tensor.
out : dragon.vm.torch.Tensor, optional
out : dragon.vm.torch.Tensor, optional
The optional output tensor.
The optional output tensor.
...
@@ -898,7 +903,7 @@ def ne(input, other, out=None):
...
@@ -898,7 +903,7 @@ def ne(input, other, out=None):
----------
----------
input : dragon.vm.torch.Tensor
input : dragon.vm.torch.Tensor
The input tensor.
The input tensor.
other : dragon.vm.torch.Tensor
,
number
other : dragon.vm.torch.Tensor
or
number
The other tensor.
The other tensor.
out : dragon.vm.torch.Tensor, optional
out : dragon.vm.torch.Tensor, optional
The optional output tensor.
The optional output tensor.
...
@@ -930,7 +935,7 @@ def cat(seq, dim=0, out=None):
...
@@ -930,7 +935,7 @@ def cat(seq, dim=0, out=None):
The output tensor.
The output tensor.
"""
"""
dev
=
MakeDevice
(
inputs
=
seq
,
outputs
=
[
out
]
if
out
else
[])
dev
=
MakeDevice
(
seq
,
[
out
]
if
out
else
[])
key
=
'Concat/{}/dim:{}'
.
format
(
dev
,
dim
)
key
=
'Concat/{}/dim:{}'
.
format
(
dev
,
dim
)
module
=
get_module
(
Concat
,
key
,
dev
,
axis
=
dim
)
module
=
get_module
(
Concat
,
key
,
dev
,
axis
=
dim
)
return
module
.
forward
(
seq
,
out
)
return
module
.
forward
(
seq
,
out
)
...
@@ -943,7 +948,7 @@ def stack(seq, dim=0, out=None):
...
@@ -943,7 +948,7 @@ def stack(seq, dim=0, out=None):
----------
----------
seq : sequence of dragon.vm.torch.Tensor
seq : sequence of dragon.vm.torch.Tensor
The sequence.
The sequence.
dim : int, optional
dim : int, optional
, default=0
The dim to stack.
The dim to stack.
out : dragon.vm.torch.Tensor, optional
out : dragon.vm.torch.Tensor, optional
The optional output tensor.
The optional output tensor.
...
@@ -960,6 +965,60 @@ def stack(seq, dim=0, out=None):
...
@@ -960,6 +965,60 @@ def stack(seq, dim=0, out=None):
return
module
.
forward
(
seq
,
out
)
return
module
.
forward
(
seq
,
out
)
def
channel_shuffle
(
input
,
dim
=
0
,
group
=
1
,
out
=
None
):
"""Shuffle channels between groups along the given axis.
Parameters
----------
input : dragon.vm.torch.Tensor
The input tensor.
dim : int, optional, default=0
The axis of channels.
group : int, optional, default=1
The number of groups.
out : dragon.vm.torch.Tensor, optional
The output tensor.
Returns
-------
dragon.vm.torch.Tensor
The new tensor.
"""
dev
=
MakeDevice
([
input
])
key
=
'ChannelShuffle/{}/dim:{}/group:{}'
.
format
(
dev
,
dim
,
group
)
module
=
get_module
(
ChannelShuffle
,
key
,
dev
,
axis
=
dim
,
group
=
group
,
)
return
module
.
forward
(
input
,
out
)
def
chunk
(
tensor
,
chunks
,
dim
=
0
):
"""Split the input into several parts along the given axis.
Parameters
----------
tensor : dragon.vm.torch.Tensor
The input to split.
chunks : int
The number of chunks to split.
dim : int, optional, default=0
The dim to split.
Returns
-------
sequence of dragon.vm.torch.Tensor
The output chunks.
"""
dev
=
MakeDevice
([
tensor
])
key
=
'Chunk/{}/chunks:{}/dim:{}'
.
format
(
dev
,
chunks
,
dim
)
module
=
get_module
(
Chunk
,
key
,
dev
,
axis
=
dim
,
chunks
=
chunks
)
return
module
.
forward
(
tensor
)
def
index_select
(
input
,
dim
,
index
,
out
=
None
):
def
index_select
(
input
,
dim
,
index
,
out
=
None
):
"""Select the input values along the given axis using index.
"""Select the input values along the given axis using index.
...
@@ -1047,7 +1106,7 @@ def nonzero(input, out=None):
...
@@ -1047,7 +1106,7 @@ def nonzero(input, out=None):
Returns
Returns
-------
-------
dragon.vm.torch.
Float
Tensor
dragon.vm.torch.Tensor
The output tensor.
The output tensor.
"""
"""
...
@@ -1069,7 +1128,7 @@ def one_hot(input, depth):
...
@@ -1069,7 +1128,7 @@ def one_hot(input, depth):
Returns
Returns
-------
-------
dragon.vm.torch.
Float
Tensor
dragon.vm.torch.Tensor
The output tensor.
The output tensor.
"""
"""
...
@@ -1137,7 +1196,7 @@ def zeros(*sizes, **kwargs):
...
@@ -1137,7 +1196,7 @@ def zeros(*sizes, **kwargs):
Returns
Returns
-------
-------
vm.torch.Float
Tensor
dragon.vm.torch.
Tensor
The output tensor.
The output tensor.
"""
"""
...
@@ -1158,7 +1217,7 @@ def zeros_like(input, out=None, **kwargs):
...
@@ -1158,7 +1217,7 @@ def zeros_like(input, out=None, **kwargs):
Returns
Returns
-------
-------
vm.torch.Float
Tensor
dragon.vm.torch.
Tensor
The output tensor.
The output tensor.
"""
"""
...
@@ -1180,7 +1239,7 @@ def ones(*sizes, **kwargs):
...
@@ -1180,7 +1239,7 @@ def ones(*sizes, **kwargs):
Returns
Returns
-------
-------
vm.torch.Float
Tensor
dragon.vm.torch.
Tensor
The output tensor.
The output tensor.
"""
"""
...
@@ -1201,7 +1260,7 @@ def ones_like(input, out=None, **kwargs):
...
@@ -1201,7 +1260,7 @@ def ones_like(input, out=None, **kwargs):
Returns
Returns
-------
-------
vm.torch.Float
Tensor
dragon.vm.torch.
Tensor
The output tensor.
The output tensor.
"""
"""
...
@@ -1223,7 +1282,7 @@ def rand(*sizes, **kwargs):
...
@@ -1223,7 +1282,7 @@ def rand(*sizes, **kwargs):
Returns
Returns
-------
-------
vm.torch.Float
Tensor
dragon.vm.torch.
Tensor
The output tensor.
The output tensor.
"""
"""
...
@@ -1244,7 +1303,7 @@ def randn(*sizes, **kwargs):
...
@@ -1244,7 +1303,7 @@ def randn(*sizes, **kwargs):
Returns
Returns
-------
-------
vm.torch.Float
Tensor
dragon.vm.torch.
Tensor
The output tensor.
The output tensor.
"""
"""
...
...
Dragon/python/dragon/vm/torch/ops/modules/array.py
View file @
ae11a98
...
@@ -106,6 +106,32 @@ class Stack(BaseModule):
...
@@ -106,6 +106,32 @@ class Stack(BaseModule):
return
self
.
run
(
inputs
,
outputs
)
return
self
.
run
(
inputs
,
outputs
)
class
Chunk
(
BaseModule
):
"""This module imports the *SliceOp* from backend.
Slice the inputs into several parts along the given axis.
"""
def
__init__
(
self
,
key
,
dev
,
**
kwargs
):
super
(
Chunk
,
self
)
.
__init__
(
key
,
dev
,
**
kwargs
)
self
.
axis
=
kwargs
.
get
(
'axis'
,
0
)
self
.
chunks
=
kwargs
.
get
(
'chunks'
,
1
)
self
.
register_op
()
def
register_op
(
self
):
self
.
op_meta
=
{
'op_type'
:
'Slice'
,
'arguments'
:
{
'axis'
:
self
.
axis
,
},
}
def
forward
(
self
,
x
):
inputs
=
[
x
];
self
.
unify_devices
(
inputs
)
outputs
=
[
self
.
register_output
()
for
_
in
range
(
self
.
chunks
)]
return
self
.
run
(
inputs
,
outputs
)
class
IndexSelect
(
BaseModule
):
class
IndexSelect
(
BaseModule
):
"""This module imports the *IndexSelectOp* from backend.
"""This module imports the *IndexSelectOp* from backend.
...
@@ -315,6 +341,28 @@ class Permute(BaseModule):
...
@@ -315,6 +341,28 @@ class Permute(BaseModule):
return
self
.
run
(
inputs
,
outputs
,
callback
=
callback
)
return
self
.
run
(
inputs
,
outputs
,
callback
=
callback
)
class
ChannelShuffle
(
BaseModule
):
def
__init__
(
self
,
key
,
dev
,
**
kwargs
):
super
(
ChannelShuffle
,
self
)
.
__init__
(
key
,
dev
,
**
kwargs
)
self
.
axis
=
kwargs
.
get
(
'axis'
,
0
)
self
.
group
=
kwargs
.
get
(
'group'
,
1
)
self
.
register_op
()
def
register_op
(
self
):
self
.
op_meta
=
{
'op_type'
:
'ChannelShuffle'
,
'arguments'
:
{
'axis'
:
self
.
axis
,
'group'
:
self
.
group
,
},
}
def
forward
(
self
,
x
,
y
):
inputs
=
[
x
];
self
.
unify_devices
(
inputs
)
outputs
=
[
y
]
if
y
else
[
self
.
register_output
()]
return
self
.
run
(
inputs
,
outputs
)
class
Repeat
(
BaseModule
):
class
Repeat
(
BaseModule
):
def
__init__
(
self
,
key
,
dev
,
**
kwargs
):
def
__init__
(
self
,
key
,
dev
,
**
kwargs
):
super
(
Repeat
,
self
)
.
__init__
(
key
,
dev
,
**
kwargs
)
super
(
Repeat
,
self
)
.
__init__
(
key
,
dev
,
**
kwargs
)
...
...
Dragon/python/dragon/vm/torch/ops/tensor.py
View file @
ae11a98
...
@@ -23,7 +23,7 @@ from dragon.vm.torch.ops.builtin import (
...
@@ -23,7 +23,7 @@ from dragon.vm.torch.ops.builtin import (
_fundamental
,
_rfundamental
,
_fundamental
,
_rfundamental
,
log
,
exp
,
sqrt
,
clamp
,
log
,
exp
,
sqrt
,
clamp
,
_reshape
,
squeeze
,
unsqueeze
,
_reshape
,
squeeze
,
unsqueeze
,
_permute
,
_repeat
,
narrow
,
_index
,
_permute
,
_repeat
,
chunk
,
narrow
,
_index
,
_assign
,
_masked_assign
,
_assign
,
_masked_assign
,
index_select
,
masked_select
,
index_select
,
masked_select
,
mean
,
sum
,
max
,
min
,
mean
,
sum
,
max
,
min
,
...
@@ -76,6 +76,7 @@ Tensor.view = lambda self, *shape: _reshape(self, shape)
...
@@ -76,6 +76,7 @@ Tensor.view = lambda self, *shape: _reshape(self, shape)
Tensor
.
view_as
=
lambda
*
args
,
**
kwargs
:
_reshape
(
*
args
,
**
kwargs
)
Tensor
.
view_as
=
lambda
*
args
,
**
kwargs
:
_reshape
(
*
args
,
**
kwargs
)
Tensor
.
permute
=
lambda
self
,
*
dims
:
_permute
(
self
,
dims
)
Tensor
.
permute
=
lambda
self
,
*
dims
:
_permute
(
self
,
dims
)
Tensor
.
repeat
=
lambda
self
,
*
args
:
_repeat
(
self
,
args
)
Tensor
.
repeat
=
lambda
self
,
*
args
:
_repeat
(
self
,
args
)
Tensor
.
chunk
=
lambda
*
args
,
**
kwargs
:
chunk
(
*
args
,
**
kwargs
)
Tensor
.
mean
=
lambda
*
args
,
**
kwargs
:
mean
(
*
args
,
**
kwargs
)
Tensor
.
mean
=
lambda
*
args
,
**
kwargs
:
mean
(
*
args
,
**
kwargs
)
Tensor
.
sum
=
lambda
*
args
,
**
kwargs
:
sum
(
*
args
,
**
kwargs
)
Tensor
.
sum
=
lambda
*
args
,
**
kwargs
:
sum
(
*
args
,
**
kwargs
)
Tensor
.
max
=
lambda
*
args
,
**
kwargs
:
max
(
*
args
,
**
kwargs
)
Tensor
.
max
=
lambda
*
args
,
**
kwargs
:
max
(
*
args
,
**
kwargs
)
...
...
Dragon/python/dragon/vm/torch/tensor.py
View file @
ae11a98
...
@@ -355,7 +355,7 @@ class Tensor(object):
...
@@ -355,7 +355,7 @@ class Tensor(object):
Parameters
Parameters
----------
----------
other : dragon.vm.torch.Tensor
,
number
other : dragon.vm.torch.Tensor
or
number
The other tensor.
The other tensor.
Returns
Returns
...
@@ -371,7 +371,7 @@ class Tensor(object):
...
@@ -371,7 +371,7 @@ class Tensor(object):
Parameters
Parameters
----------
----------
other : dragon.vm.torch.Tensor
,
number
other : dragon.vm.torch.Tensor
or
number
The other tensor.
The other tensor.
Returns
Returns
...
@@ -387,7 +387,7 @@ class Tensor(object):
...
@@ -387,7 +387,7 @@ class Tensor(object):
Parameters
Parameters
----------
----------
other : dragon.vm.torch.Tensor
,
number
other : dragon.vm.torch.Tensor
or
number
The other tensor.
The other tensor.
Returns
Returns
...
@@ -403,7 +403,7 @@ class Tensor(object):
...
@@ -403,7 +403,7 @@ class Tensor(object):
Parameters
Parameters
----------
----------
other : dragon.vm.torch.Tensor
,
number
other : dragon.vm.torch.Tensor
or
number
The other tensor.
The other tensor.
Returns
Returns
...
@@ -419,7 +419,7 @@ class Tensor(object):
...
@@ -419,7 +419,7 @@ class Tensor(object):
Parameters
Parameters
----------
----------
other : dragon.vm.torch.Tensor
,
number
other : dragon.vm.torch.Tensor
or
number
The other tensor.
The other tensor.
Returns
Returns
...
@@ -847,6 +847,24 @@ class Tensor(object):
...
@@ -847,6 +847,24 @@ class Tensor(object):
"""
"""
raise
NotImplementedError
(
'Refer torch.ops.tensor.repeat'
)
raise
NotImplementedError
(
'Refer torch.ops.tensor.repeat'
)
def
chunk
(
self
,
chunks
,
dim
=
0
):
"""Split self into several parts along the given axis.
Parameters
----------
chunks : int
The number of chunks to split.
dim : int, optional
The dim to split.
Returns
-------
sequence of dragon.vm.torch.Tensor
The output chunks.
"""
raise
NotImplementedError
(
'Refer torch.ops.tensor.chunk'
)
def
nonzero
(
self
):
def
nonzero
(
self
):
"""Return the indices of non-zero elements.
"""Return the indices of non-zero elements.
...
...
Dragon/src/kernels/array/channel_shuffle_op_kernel.cc
0 → 100644
View file @
ae11a98
#include "utils/op_kernel.h"
#include "utils/math_functions.h"
namespace
dragon
{
namespace
kernel
{
/* <T = ?, Device = CPU> */
template
<
typename
T
>
void
_ChannelShuffle
(
const
int
outer_dim
,
const
int
inner_dim
,
const
int
G
,
const
int
K
,
const
T
*
x
,
T
*
y
,
CPUContext
*
ctx
)
{
int64_t
x_ofs
,
y_ofs
;
for
(
int
n
=
0
;
n
<
outer_dim
;
++
n
)
{
for
(
int
gi
=
0
;
gi
<
G
;
++
gi
)
{
for
(
int
ki
=
0
;
ki
<
K
;
++
ki
)
{
x_ofs
=
((
n
*
G
+
gi
)
*
K
+
ki
)
*
inner_dim
;
y_ofs
=
((
n
*
K
+
ki
)
*
G
+
gi
)
*
inner_dim
;
math
::
Copy
(
inner_dim
,
x
+
x_ofs
,
y
+
y_ofs
,
ctx
);
}
}
}
}
/* Kernel Launchers */
#define DEFINE_SHUFFLE_KERNEL_LAUNCHER(T) \
template <> void ChannelShuffle<T, CPUContext>( \
const int outer_dim, \
const int inner_dim, \
const int axis_dim, \
const int group, \
const T* x, \
T* y, \
CPUContext* ctx) { \
_ChannelShuffle( \
outer_dim, \
inner_dim, \
group, \
axis_dim / group, \
x, y, ctx \
); \
}
DEFINE_SHUFFLE_KERNEL_LAUNCHER
(
bool
);
DEFINE_SHUFFLE_KERNEL_LAUNCHER
(
int8_t
);
DEFINE_SHUFFLE_KERNEL_LAUNCHER
(
uint8_t
);
DEFINE_SHUFFLE_KERNEL_LAUNCHER
(
int
);
DEFINE_SHUFFLE_KERNEL_LAUNCHER
(
int64_t
);
DEFINE_SHUFFLE_KERNEL_LAUNCHER
(
float16
);
DEFINE_SHUFFLE_KERNEL_LAUNCHER
(
float
);
DEFINE_SHUFFLE_KERNEL_LAUNCHER
(
double
);
#undef DEFINE_SHUFFLE_KERNEL_LAUNCHER
}
// namespace kernel
}
//
namepsace
dragon
\ No newline at end of file
Dragon/src/kernels/array/channel_shuffle_op_kernel.cu
0 → 100644
View file @
ae11a98
#ifdef WITH_CUDA
#include "core/context_cuda.h"
#include "utils/op_kernel.h"
namespace dragon {
namespace kernel {
/* <T = ?, Device = CUDA> */
template <typename T>
__global__ void _ChannelShuffle(
const int nthreads,
const int inner_dim,
const int G,
const int K,
const T* x,
T* y) {
CUDA_1D_KERNEL_LOOP(yi, nthreads) {
const int inner_idx = yi % inner_dim;
const int gi = (yi / inner_dim) % G;
const int ki = (yi / inner_dim / G) % K;
const int outer_idx = yi / inner_dim / G / K;
y[yi] = x[((outer_idx * G + gi) * K + ki
) * inner_dim + inner_idx];
}
}
/* Kernel Launchers */
#define DEFINE_SHUFFLE_KERNEL_LAUNCHER(T) \
template <> void ChannelShuffle<T, CUDAContext>( \
const int outer_dim, \
const int inner_dim, \
const int axis_dim, \
const int group, \
const T* x, \
T* y, \
CUDAContext* ctx) { \
auto nthreads = outer_dim * axis_dim * inner_dim; \
_ChannelShuffle \
<<< CUDA_BLOCKS(nthreads), CUDA_THREADS, \
0, ctx->cuda_stream() >>>( \
nthreads, \
inner_dim, \
group, \
axis_dim / group, \
x, y \
); \
}
DEFINE_SHUFFLE_KERNEL_LAUNCHER(bool);
DEFINE_SHUFFLE_KERNEL_LAUNCHER(int8_t);
DEFINE_SHUFFLE_KERNEL_LAUNCHER(uint8_t);
DEFINE_SHUFFLE_KERNEL_LAUNCHER(int);
DEFINE_SHUFFLE_KERNEL_LAUNCHER(int64_t);
DEFINE_SHUFFLE_KERNEL_LAUNCHER(float16);
DEFINE_SHUFFLE_KERNEL_LAUNCHER(float);
DEFINE_SHUFFLE_KERNEL_LAUNCHER(double);
#undef DEFINE_SHUFFLE_KERNEL_LAUNCHER
} // namespace kernel
} // namepsace dragon
#endif // WITH_CUDA
\ No newline at end of file
Dragon/src/operators/array/channel_shuffle_op.cc
0 → 100644
View file @
ae11a98
#include "core/workspace.h"
#include "utils/op_kernel.h"
#include "operators/array/channel_shuffle_op.h"
namespace
dragon
{
#define DETERMINE_RUNTIME_ARGS(X) \
axis_ = OpArg<int64_t>("axis", 0); \
axis_ = axis_ < 0 ? axis_ + X.ndim() : axis_; \
CHECK(axis_ >= 0 && axis_ < X.ndim()) \
<< "\nExcepted the axis in [-" << X.ndim() \
<< ", " << X.ndim() << "), got " \
<< OpArg<int64_t>("axis", 0) << ".";
template
<
class
Context
>
template
<
typename
T
>
void
ChannelShuffleOp
<
Context
>::
RunImpl
()
{
auto
*
x
=
X
(
0
).
template
data
<
T
,
Context
>
();
auto
*
y
=
Y
(
0
)
->
template
mutable_data
<
T
,
Context
>
();
kernel
::
ChannelShuffle
(
outer_dim_
,
inner_dim_
,
axis_dim_
,
group_
,
x
,
y
,
ctx
()
);
}
template
<
class
Context
>
void
ChannelShuffleOp
<
Context
>::
RunOnDevice
()
{
DETERMINE_RUNTIME_ARGS
(
X
(
0
));
axis_dim_
=
X
(
0
).
dim
(
axis_
);
outer_dim_
=
X
(
0
).
count
(
0
,
axis_
);
inner_dim_
=
X
(
0
).
count
(
axis_
+
1
);
CHECK_EQ
(
axis_dim_
%
group_
,
0
)
<<
"
\n
The "
<<
axis_dim_
<<
" channels "
<<
"can not be split into "
<<
group_
<<
" groups."
;
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
DispatchHelper
<
TensorTypes
<
bool
,
int8_t
,
uint8_t
,
int
,
int64_t
,
float16
,
float
,
double
>
>::
Call
(
this
,
X
(
0
));
}
template
<
class
Context
>
template
<
typename
T
>
void
ChannelShuffleGradientOp
<
Context
>::
RunImpl
()
{
auto
*
dy
=
X
(
1
).
template
data
<
T
,
Context
>
();
auto
*
dx
=
Y
(
0
)
->
template
mutable_data
<
T
,
Context
>
();
kernel
::
ChannelShuffle
(
outer_dim_
,
inner_dim_
,
axis_dim_
,
axis_dim_
/
group_
,
dy
,
dx
,
ctx
()
);
}
template
<
class
Context
>
void
ChannelShuffleGradientOp
<
Context
>::
RunOnDevice
()
{
DETERMINE_RUNTIME_ARGS
(
X
(
0
));
axis_dim_
=
X
(
0
).
dim
(
axis_
);
outer_dim_
=
X
(
0
).
count
(
0
,
axis_
);
inner_dim_
=
X
(
0
).
count
(
axis_
+
1
);
CHECK_EQ
(
axis_dim_
%
group_
,
0
)
<<
"
\n
The "
<<
axis_dim_
<<
" channels "
<<
"can not be split into "
<<
group_
<<
" groups."
;
Y
(
0
)
->
ReshapeLike
(
X
(
0
));
DispatchHelper
<
TensorTypes
<
bool
,
int8_t
,
uint8_t
,
int
,
int64_t
,
float16
,
float
,
double
>
>::
Call
(
this
,
X
(
0
));
}
DEPLOY_CPU
(
ChannelShuffle
);
#ifdef WITH_CUDA
DEPLOY_CUDA
(
ChannelShuffle
);
#endif
DEPLOY_CPU
(
ChannelShuffleGradient
);
#ifdef WITH_CUDA
DEPLOY_CUDA
(
ChannelShuffleGradient
);
#endif
OPERATOR_SCHEMA
(
ChannelShuffle
)
/* X */
.
NumInputs
(
1
)
/* Y */
.
NumOutputs
(
1
);
OPERATOR_SCHEMA
(
ChannelShuffleGradient
)
/* X, dY */
.
NumInputs
(
2
)
/* dX */
.
NumOutputs
(
1
);
REGISTER_GRADIENT
(
ChannelShuffle
,
SimpleGradientMaker
);
#undef DETERMINE_RUNTIME_ARGS
}
//
namespace
dragon
\ No newline at end of file
Dragon/src/operators/array/slice_op.cc
View file @
ae11a98
...
@@ -12,9 +12,16 @@ namespace dragon {
...
@@ -12,9 +12,16 @@ namespace dragon {
<< ", " << Input.ndim() << "), got " \
<< ", " << Input.ndim() << "), got " \
<< OpArg<int64_t>("axis", 0) << "."; \
<< OpArg<int64_t>("axis", 0) << "."; \
if (points_.empty()) { \
if (points_.empty()) { \
CHECK_EQ(X(0).dim(axis_) % N_, 0) \
auto dim = (X(0).dim(axis_) + N_ - 1) / N_; \
<< "\nSelected dim is " << X(0).dim(axis_) \
sections_ = vec64_t(N_, dim); \
<< ", can't be sliced into " << N_ << " parts."; \
sections_[N_ - 1] = X(0).dim(axis_) - dim * (N_ - 1); \
} else { \
int64_t slice_ofs = 0; \
sections_ = vec64_t(N_); \
for (int i = 0; i < N_; i++) \
sections_[i] = i < N_ - 1 ? \
points_[i] - slice_ofs : \
axis_dim_ - slice_ofs; \
}
}
template
<
class
Context
>
template
<
typename
T
>
template
<
class
Context
>
template
<
typename
T
>
...
@@ -24,16 +31,11 @@ void SliceOp<Context>::RunImpl() {
...
@@ -24,16 +31,11 @@ void SliceOp<Context>::RunImpl() {
auto
*
x
=
X
(
0
).
template
data
<
T
,
Context
>
();
auto
*
x
=
X
(
0
).
template
data
<
T
,
Context
>
();
for
(
int
i
=
0
;
i
<
N_
;
i
++
)
{
for
(
int
i
=
0
;
i
<
N_
;
i
++
)
{
if
(
!
points_
.
empty
())
{
slice_dim_
=
sections_
[
i
];
slice_dim_
=
i
<
N_
-
1
?
points_
[
i
]
-
slice_ofs
:
axis_dim_
-
slice_ofs
;
}
CHECK
(
slice_dim_
>
0
&&
CHECK
(
slice_dim_
>
0
)
slice_ofs
+
slice_dim_
<=
axis_dim_
)
<<
"
\n
Illegal slicing sections: "
<<
"
\n
Illegal slicing points: "
<<
Tensor
::
DimString
(
sections_
)
<<
Tensor
::
DimString
(
points_
)
<<
" for dimension: "
<<
axis_dim_
;
<<
" for dimension: "
<<
axis_dim_
;
out_shape
[
axis_
]
=
slice_dim_
;
out_shape
[
axis_
]
=
slice_dim_
;
...
@@ -77,16 +79,11 @@ void SliceGradientOp<Context>::RunImpl() {
...
@@ -77,16 +79,11 @@ void SliceGradientOp<Context>::RunImpl() {
auto
*
dx
=
Y
(
0
)
->
template
mutable_data
<
T
,
Context
>
();
auto
*
dx
=
Y
(
0
)
->
template
mutable_data
<
T
,
Context
>
();
for
(
int
i
=
0
;
i
<
N_
;
i
++
)
{
for
(
int
i
=
0
;
i
<
N_
;
i
++
)
{
if
(
!
points_
.
empty
())
{
slice_dim_
=
sections_
[
i
];
slice_dim_
=
i
<
N_
-
1
?
points_
[
i
]
-
slice_ofs
:
axis_dim_
-
slice_ofs
;
}
CHECK
(
slice_dim_
>
0
&&
CHECK
(
slice_dim_
>
0
)
slice_ofs
+
slice_dim_
<=
axis_dim_
)
<<
"
\n
Illegal slicing points: "
<<
"
\n
Illegal slicing points: "
<<
Tensor
::
DimString
(
point
s_
)
<<
Tensor
::
DimString
(
section
s_
)
<<
" for dimension: "
<<
axis_dim_
;
<<
" for dimension: "
<<
axis_dim_
;
if
(
X
(
i
+
1
).
name
()
!=
"NULL"
)
{
if
(
X
(
i
+
1
).
name
()
!=
"NULL"
)
{
...
...
Write
Preview
Markdown
is supported
Attach a file
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to post a comment