Commit b4019faa by Ting PAN

Add Sort Operator

Summary:
This commit adds the sort op for dragon, torch and tensorflow.
Besides, cuda implementation of topk op is now available.
1 parent fdf26ef2
Showing with 869 additions and 277 deletions
...@@ -27,11 +27,11 @@ vm.dali ...@@ -27,11 +27,11 @@ vm.dali
.. toctree:: .. toctree::
:hidden: :hidden:
dali/Iterator
dali/Pipeline
dali/device dali/device
dali/get_device_type dali/get_device_type
dali/get_distributed_info dali/get_distributed_info
dali/Iterator
dali/Pipeline
.. raw:: html .. raw:: html
......
...@@ -138,6 +138,9 @@ dragon ...@@ -138,6 +138,9 @@ dragon
`slice(...) <dragon/slice.html>`_ `slice(...) <dragon/slice.html>`_
: Select the elements according to the given sections. : Select the elements according to the given sections.
`sort(...) <dragon/sort.html>`_
: Return the sorted elements along the given axis.
`split(...) <dragon/split.html>`_ `split(...) <dragon/split.html>`_
: Split the input into chunks along the given axis. : Split the input into chunks along the given axis.
...@@ -171,6 +174,10 @@ dragon ...@@ -171,6 +174,10 @@ dragon
.. toctree:: .. toctree::
:hidden: :hidden:
dragon/EagerTensor
dragon/GradientTape
dragon/Tensor
dragon/Workspace
dragon/assign dragon/assign
dragon/broadcast_to dragon/broadcast_to
dragon/cast dragon/cast
...@@ -182,7 +189,6 @@ dragon ...@@ -182,7 +189,6 @@ dragon
dragon/copy dragon/copy
dragon/create_function dragon/create_function
dragon/device dragon/device
dragon/EagerTensor
dragon/eager_mode dragon/eager_mode
dragon/eager_scope dragon/eager_scope
dragon/expand_dims dragon/expand_dims
...@@ -193,7 +199,6 @@ dragon ...@@ -193,7 +199,6 @@ dragon
dragon/function dragon/function
dragon/get_workspace dragon/get_workspace
dragon/gradients dragon/gradients
dragon/GradientTape
dragon/graph_mode dragon/graph_mode
dragon/index_select dragon/index_select
dragon/load_library dragon/load_library
...@@ -212,16 +217,15 @@ dragon ...@@ -212,16 +217,15 @@ dragon
dragon/reshape dragon/reshape
dragon/shape dragon/shape
dragon/slice dragon/slice
dragon/sort
dragon/split dragon/split
dragon/squeeze dragon/squeeze
dragon/stack dragon/stack
dragon/stop_gradient dragon/stop_gradient
dragon/Tensor
dragon/tile dragon/tile
dragon/transpose dragon/transpose
dragon/unique dragon/unique
dragon/where dragon/where
dragon/Workspace
dragon/zeros dragon/zeros
dragon/zeros_like dragon/zeros_like
......
...@@ -113,6 +113,9 @@ dragon.nn ...@@ -113,6 +113,9 @@ dragon.nn
.. toctree:: .. toctree::
:hidden: :hidden:
nn/GRU
nn/LSTM
nn/RNN
nn/batch_norm nn/batch_norm
nn/bias_add nn/bias_add
nn/conv2d nn/conv2d
...@@ -125,18 +128,15 @@ dragon.nn ...@@ -125,18 +128,15 @@ dragon.nn
nn/elu nn/elu
nn/fully_connected nn/fully_connected
nn/group_norm nn/group_norm
nn/GRU
nn/instance_norm nn/instance_norm
nn/layer_norm nn/layer_norm
nn/leaky_relu nn/leaky_relu
nn/local_response_norm nn/local_response_norm
nn/log_softmax nn/log_softmax
nn/LSTM
nn/pool2d nn/pool2d
nn/prelu nn/prelu
nn/relu nn/relu
nn/relu6 nn/relu6
nn/RNN
nn/selu nn/selu
nn/softmax nn/softmax
nn/space_to_depth nn/space_to_depth
......
sort
====
.. autofunction:: dragon.sort
.. raw:: html
<style>
h1:before {
content: "dragon.";
color: #103d3e;
}
</style>
...@@ -18,6 +18,9 @@ vm.tensorflow ...@@ -18,6 +18,9 @@ vm.tensorflow
Functions Functions
######### #########
`argsort(...) <tensorflow/argsort.html>`_
: Return the index of sorted elements along the given axis.
`broadcast_to(...) <dragon/broadcast_to.html>`_ `broadcast_to(...) <dragon/broadcast_to.html>`_
: Broadcast input according to a given shape. : Broadcast input according to a given shape.
...@@ -84,6 +87,9 @@ vm.tensorflow ...@@ -84,6 +87,9 @@ vm.tensorflow
`slice(...) <tensorflow/slice.html>`_ `slice(...) <tensorflow/slice.html>`_
: Select the elements according to the given sections. : Select the elements according to the given sections.
`sort(...) <tensorflow/sort.html>`_
: Return the sorted elements along the given axis.
`split(...) <tensorflow/split.html>`_ `split(...) <tensorflow/split.html>`_
: Split input into chunks along the given axis. : Split input into chunks along the given axis.
...@@ -108,6 +114,10 @@ vm.tensorflow ...@@ -108,6 +114,10 @@ vm.tensorflow
.. toctree:: .. toctree::
:hidden: :hidden:
tensorflow/GradientTape
tensorflow/TensorShape
tensorflow/TensorSpec
tensorflow/argsort
tensorflow/broadcast_to tensorflow/broadcast_to
tensorflow/cast tensorflow/cast
tensorflow/clip_by_value tensorflow/clip_by_value
...@@ -120,7 +130,6 @@ vm.tensorflow ...@@ -120,7 +130,6 @@ vm.tensorflow
tensorflow/function tensorflow/function
tensorflow/gather tensorflow/gather
tensorflow/gradients tensorflow/gradients
tensorflow/GradientTape
tensorflow/identity tensorflow/identity
tensorflow/name_scope tensorflow/name_scope
tensorflow/ones tensorflow/ones
...@@ -131,10 +140,9 @@ vm.tensorflow ...@@ -131,10 +140,9 @@ vm.tensorflow
tensorflow/reshape tensorflow/reshape
tensorflow/shape tensorflow/shape
tensorflow/slice tensorflow/slice
tensorflow/sort
tensorflow/split tensorflow/split
tensorflow/squeeze tensorflow/squeeze
tensorflow/TensorShape
tensorflow/TensorSpec
tensorflow/transpose tensorflow/transpose
tensorflow/unique tensorflow/unique
tensorflow/unique_with_counts tensorflow/unique_with_counts
......
argsort
=======
.. autofunction:: dragon.vm.tensorflow.argsort
.. raw:: html
<style>
h1:before {
content: "tf.";
color: #103d3e;
}
</style>
...@@ -18,8 +18,8 @@ vm.tensorflow.dtypes ...@@ -18,8 +18,8 @@ vm.tensorflow.dtypes
.. toctree:: .. toctree::
:hidden: :hidden:
dtypes/as_dtype
dtypes/DType dtypes/DType
dtypes/as_dtype
.. raw:: html .. raw:: html
......
...@@ -46,7 +46,6 @@ initializers ...@@ -46,7 +46,6 @@ initializers
:hidden: :hidden:
initializers/Constant initializers/Constant
initializers/get
initializers/GlorotNormal initializers/GlorotNormal
initializers/GlorotUniform initializers/GlorotUniform
initializers/Initializer initializers/Initializer
...@@ -56,6 +55,7 @@ initializers ...@@ -56,6 +55,7 @@ initializers
initializers/TruncatedNormal initializers/TruncatedNormal
initializers/VarianceScaling initializers/VarianceScaling
initializers/Zeros initializers/Zeros
initializers/get
.. raw:: html .. raw:: html
......
...@@ -49,16 +49,16 @@ losses ...@@ -49,16 +49,16 @@ losses
:hidden: :hidden:
losses/BinaryCrossentropy losses/BinaryCrossentropy
losses/binary_crossentropy
losses/CategoricalCrossentropy losses/CategoricalCrossentropy
losses/categorical_crossentropy
losses/get
losses/Loss losses/Loss
losses/MeanAbsoluteError losses/MeanAbsoluteError
losses/MeanSquaredError losses/MeanSquaredError
losses/SparseCategoricalCrossentropy
losses/binary_crossentropy
losses/categorical_crossentropy
losses/get
losses/mean_absolute_error losses/mean_absolute_error
losses/mean_squared_error losses/mean_squared_error
losses/SparseCategoricalCrossentropy
losses/sparse_categorical_crossentropy losses/sparse_categorical_crossentropy
.. raw:: html .. raw:: html
......
...@@ -30,12 +30,12 @@ regularizers ...@@ -30,12 +30,12 @@ regularizers
.. toctree:: .. toctree::
:hidden: :hidden:
regularizers/get
regularizers/L1 regularizers/L1
regularizers/L1L2 regularizers/L1L2
regularizers/l1_l2
regularizers/L2 regularizers/L2
regularizers/Regularizer regularizers/Regularizer
regularizers/get
regularizers/l1_l2
.. raw:: html .. raw:: html
......
sort
====
.. autofunction:: dragon.vm.tensorflow.sort
.. raw:: html
<style>
h1:before {
content: "tf.";
color: #103d3e;
}
</style>
...@@ -214,6 +214,9 @@ vm.torch ...@@ -214,6 +214,9 @@ vm.torch
`sin(...) <torch/sin.html>`_ `sin(...) <torch/sin.html>`_
: Compute the sin of input. : Compute the sin of input.
`sort(...) <torch/sort.html>`_
: Return the sorted elements along the given dimension.
`split(...) <torch/split.html>`_ `split(...) <torch/split.html>`_
: Split input into chunks along the given dimension. : Split input into chunks along the given dimension.
...@@ -256,6 +259,8 @@ vm.torch ...@@ -256,6 +259,8 @@ vm.torch
.. toctree:: .. toctree::
:hidden: :hidden:
torch/Size
torch/Tensor_
torch/abs torch/abs
torch/add torch/add
torch/arange torch/arange
...@@ -322,14 +327,13 @@ vm.torch ...@@ -322,14 +327,13 @@ vm.torch
torch/set_grad_enabled torch/set_grad_enabled
torch/sign torch/sign
torch/sin torch/sin
torch/Size torch/sort
torch/split torch/split
torch/sqrt torch/sqrt
torch/squeeze torch/squeeze
torch/stack torch/stack
torch/sub torch/sub
torch/sum torch/sum
torch/Tensor_
torch/tensor torch/tensor
torch/topk torch/topk
torch/unique torch/unique
......
...@@ -397,6 +397,10 @@ size ...@@ -397,6 +397,10 @@ size
#### ####
.. automethod:: dragon.vm.torch.Tensor.size .. automethod:: dragon.vm.torch.Tensor.size
sort
####
.. automethod:: dragon.vm.torch.Tensor.sort
sqrt sqrt
#### ####
.. automethod:: dragon.vm.torch.Tensor.sqrt .. automethod:: dragon.vm.torch.Tensor.sqrt
...@@ -503,6 +507,7 @@ zero\_ ...@@ -503,6 +507,7 @@ zero\_
.. _torch.rsqrt(...): rsqrt.html .. _torch.rsqrt(...): rsqrt.html
.. _torch.sign(...): sign.html .. _torch.sign(...): sign.html
.. _torch.sin(...): sin.html .. _torch.sin(...): sin.html
.. _torch.sort(...): sort.html
.. _torch.sqrt(...): sqrt.html .. _torch.sqrt(...): sqrt.html
.. _torch.squeeze(...): squeeze.html .. _torch.squeeze(...): squeeze.html
.. _torch.sub(...): sub.html .. _torch.sub(...): sub.html
......
...@@ -18,8 +18,8 @@ vm.torch.autograd ...@@ -18,8 +18,8 @@ vm.torch.autograd
.. toctree:: .. toctree::
:hidden: :hidden:
autograd/backward
autograd/Function autograd/Function
autograd/backward
.. raw:: html .. raw:: html
......
sort
====
.. autofunction:: dragon.vm.torch.sort
.. raw:: html
<style>
h1:before {
content: "torch.";
color: #103d3e;
}
</style>
...@@ -12,14 +12,14 @@ void _IndexSelect( ...@@ -12,14 +12,14 @@ void _IndexSelect(
const int outer_dim, const int outer_dim,
const int inner_dim, const int inner_dim,
const int axis_dim, const int axis_dim,
const int num_indices, const int select_dim,
const int64_t* indices, const int64_t* indices,
const T* x, const T* x,
T* y, T* y,
CPUContext* ctx) { CPUContext* ctx) {
int index; int index;
for (int i = 0; i < outer_dim; ++i) { for (int i = 0; i < outer_dim; ++i) {
for (int j = 0; j < num_indices; ++j) { for (int j = 0; j < select_dim; ++j) {
index = indices[j]; index = indices[j];
index = index >= 0 ? index : index + axis_dim; index = index >= 0 ? index : index + axis_dim;
const T* offset_x = x + (i * axis_dim + index) * inner_dim; const T* offset_x = x + (i * axis_dim + index) * inner_dim;
...@@ -34,14 +34,14 @@ void _IndexSelectGrad( ...@@ -34,14 +34,14 @@ void _IndexSelectGrad(
const int outer_dim, const int outer_dim,
const int inner_dim, const int inner_dim,
const int axis_dim, const int axis_dim,
const int num_indices, const int select_dim,
const int64_t* indices, const int64_t* indices,
const T* dy, const T* dy,
T* dx, T* dx,
CPUContext* ctx) { CPUContext* ctx) {
int index; int index;
for (int i = 0; i < outer_dim; ++i) { for (int i = 0; i < outer_dim; ++i) {
for (int j = 0; j < num_indices; ++j) { for (int j = 0; j < select_dim; ++j) {
index = indices[j]; index = indices[j];
index = index >= 0 ? index : index + axis_dim; index = index >= 0 ? index : index + axis_dim;
T* offset_dx = dx + (i * axis_dim + index) * inner_dim; T* offset_dx = dx + (i * axis_dim + index) * inner_dim;
...@@ -55,18 +55,18 @@ void _IndexSelectGrad( ...@@ -55,18 +55,18 @@ void _IndexSelectGrad(
/* ------------------- Launcher Separator ------------------- */ /* ------------------- Launcher Separator ------------------- */
#define DEFINE_KERNEL_LAUNCHER(name, T) \ #define DEFINE_KERNEL_LAUNCHER(name, T) \
template <> \ template <> \
void name<T, CPUContext>( \ void name<T, CPUContext>( \
const int outer_dim, \ const int outer_dim, \
const int inner_dim, \ const int inner_dim, \
const int axis_dim, \ const int axis_dim, \
const int num_indices, \ const int select_dim, \
const int64_t* indices, \ const int64_t* indices, \
const T* x, \ const T* x, \
T* y, \ T* y, \
CPUContext* ctx) { \ CPUContext* ctx) { \
_##name(outer_dim, inner_dim, axis_dim, num_indices, indices, x, y, ctx); \ _##name(outer_dim, inner_dim, axis_dim, select_dim, indices, x, y, ctx); \
} }
DEFINE_KERNEL_LAUNCHER(IndexSelect, bool); DEFINE_KERNEL_LAUNCHER(IndexSelect, bool);
......
...@@ -14,17 +14,17 @@ __global__ void _IndexSelect( ...@@ -14,17 +14,17 @@ __global__ void _IndexSelect(
const int nthreads, const int nthreads,
const int inner_dim, const int inner_dim,
const int axis_dim, const int axis_dim,
const int num_indices, const int select_dim,
const int64_t* indices, const int64_t* indices,
const T* x, const T* x,
T* y) { T* y) {
CUDA_1D_KERNEL_LOOP(yi, nthreads) { CUDA_1D_KERNEL_LOOP(yi, nthreads) {
const int j = yi % inner_dim; const int j = yi % inner_dim;
const int i = yi / inner_dim / num_indices; const int i = yi / inner_dim / select_dim;
#if __CUDA_ARCH__ >= 350 #if __CUDA_ARCH__ >= 350
int index = __ldg(indices + ((yi / inner_dim) % num_indices)); int index = __ldg(indices + ((yi / inner_dim) % select_dim));
#else #else
int index = indices[(yi / inner_dim) % num_indices]; int index = indices[(yi / inner_dim) % select_dim];
#endif #endif
index = index >= 0 ? index : index + axis_dim; index = index >= 0 ? index : index + axis_dim;
y[yi] = x[(i * axis_dim + index) * inner_dim + j]; y[yi] = x[(i * axis_dim + index) * inner_dim + j];
...@@ -36,7 +36,7 @@ __global__ void _IndexSelectGrad( ...@@ -36,7 +36,7 @@ __global__ void _IndexSelectGrad(
const int nthreads, const int nthreads,
const int inner_dim, const int inner_dim,
const int axis_dim, const int axis_dim,
const int num_indices, const int select_dim,
const int64_t* indices, const int64_t* indices,
const T* dy, const T* dy,
T* dx) { T* dx) {
...@@ -44,8 +44,8 @@ __global__ void _IndexSelectGrad( ...@@ -44,8 +44,8 @@ __global__ void _IndexSelectGrad(
const int i = ti / inner_dim; const int i = ti / inner_dim;
const int j = ti % inner_dim; const int j = ti % inner_dim;
const int c = i * axis_dim * inner_dim + j; const int c = i * axis_dim * inner_dim + j;
const T* offset_dy = dy + i * num_indices * inner_dim + j; const T* offset_dy = dy + i * select_dim * inner_dim + j;
for (int k = 0; k < num_indices; ++k) { for (int k = 0; k < select_dim; ++k) {
#if __CUDA_ARCH__ >= 350 #if __CUDA_ARCH__ >= 350
int index = __ldg(indices + k); int index = __ldg(indices + k);
#else #else
...@@ -63,7 +63,7 @@ __global__ void _IndexSelectGrad<half>( ...@@ -63,7 +63,7 @@ __global__ void _IndexSelectGrad<half>(
const int nthreads, const int nthreads,
const int inner_dim, const int inner_dim,
const int axis_dim, const int axis_dim,
const int num_indices, const int select_dim,
const int64_t* indices, const int64_t* indices,
const half* dy, const half* dy,
half* dx) { half* dx) {
...@@ -72,8 +72,8 @@ __global__ void _IndexSelectGrad<half>( ...@@ -72,8 +72,8 @@ __global__ void _IndexSelectGrad<half>(
const int i = ti / inner_dim; const int i = ti / inner_dim;
const int j = ti % inner_dim; const int j = ti % inner_dim;
const int c = i * axis_dim * inner_dim + j; const int c = i * axis_dim * inner_dim + j;
const half* offset_dy = dy + i * num_indices * inner_dim + j; const half* offset_dy = dy + i * select_dim * inner_dim + j;
for (int k = 0; k < num_indices; ++k) { for (int k = 0; k < select_dim; ++k) {
int index = __ldg(indices + j); int index = __ldg(indices + j);
index = index >= 0 ? index : index + axis_dim; index = index >= 0 ? index : index + axis_dim;
index = c + index * inner_dim; index = c + index * inner_dim;
...@@ -93,7 +93,7 @@ void IndexSelectGrad<float16, CUDAContext>( ...@@ -93,7 +93,7 @@ void IndexSelectGrad<float16, CUDAContext>(
const int outer_dim, const int outer_dim,
const int inner_dim, const int inner_dim,
const int axis_dim, const int axis_dim,
const int num_indices, const int select_dim,
const int64_t* indices, const int64_t* indices,
const float16* dy, const float16* dy,
float16* dx, float16* dx,
...@@ -107,50 +107,50 @@ void IndexSelectGrad<float16, CUDAContext>( ...@@ -107,50 +107,50 @@ void IndexSelectGrad<float16, CUDAContext>(
nthreads, nthreads,
inner_dim, inner_dim,
axis_dim, axis_dim,
num_indices, select_dim,
indices, indices,
reinterpret_cast<const half*>(dy), reinterpret_cast<const half*>(dy),
reinterpret_cast<half*>(dx)); reinterpret_cast<half*>(dx));
} // IndexSelectGrad } // IndexSelectGrad
#define DEFINE_KERNEL_LAUNCHER(T) \ #define DEFINE_KERNEL_LAUNCHER(T) \
template <> \ template <> \
void IndexSelect<T, CUDAContext>( \ void IndexSelect<T, CUDAContext>( \
const int outer_dim, \ const int outer_dim, \
const int inner_dim, \ const int inner_dim, \
const int axis_dim, \ const int axis_dim, \
const int num_indices, \ const int select_dim, \
const int64_t* indices, \ const int64_t* indices, \
const T* x, \ const T* x, \
T* y, \ T* y, \
CUDAContext* ctx) { \ CUDAContext* ctx) { \
const int nthreads = outer_dim * num_indices * inner_dim; \ const int nthreads = outer_dim * select_dim * inner_dim; \
_IndexSelect<<< \ _IndexSelect<<< \
CUDA_BLOCKS(nthreads), \ CUDA_BLOCKS(nthreads), \
CUDA_THREADS, \ CUDA_THREADS, \
0, \ 0, \
ctx->cuda_stream()>>>( \ ctx->cuda_stream()>>>( \
nthreads, inner_dim, axis_dim, num_indices, indices, x, y); \ nthreads, inner_dim, axis_dim, select_dim, indices, x, y); \
} }
#define DEFINE_GRAD_KERNEL_LAUNCHER(T) \ #define DEFINE_GRAD_KERNEL_LAUNCHER(T) \
template <> \ template <> \
void IndexSelectGrad<T, CUDAContext>( \ void IndexSelectGrad<T, CUDAContext>( \
const int outer_dim, \ const int outer_dim, \
const int inner_dim, \ const int inner_dim, \
const int axis_dim, \ const int axis_dim, \
const int num_indices, \ const int select_dim, \
const int64_t* indices, \ const int64_t* indices, \
const T* dy, \ const T* dy, \
T* dx, \ T* dx, \
CUDAContext* ctx) { \ CUDAContext* ctx) { \
const int nthreads = outer_dim * inner_dim; \ const int nthreads = outer_dim * inner_dim; \
_IndexSelectGrad<<< \ _IndexSelectGrad<<< \
CUDA_BLOCKS(nthreads), \ CUDA_BLOCKS(nthreads), \
CUDA_THREADS, \ CUDA_THREADS, \
0, \ 0, \
ctx->cuda_stream()>>>( \ ctx->cuda_stream()>>>( \
nthreads, inner_dim, axis_dim, num_indices, indices, dy, dx); \ nthreads, inner_dim, axis_dim, select_dim, indices, dy, dx); \
} }
DEFINE_KERNEL_LAUNCHER(bool); DEFINE_KERNEL_LAUNCHER(bool);
......
...@@ -25,11 +25,11 @@ struct SmallestComp { ...@@ -25,11 +25,11 @@ struct SmallestComp {
}; };
template <typename T, class Comp> template <typename T, class Comp>
void _TopK( void _TopSelect(
const int outer_dim, const int outer_dim,
const int inner_dim, const int inner_dim,
const int axis_dim, const int axis_dim,
const int top_k, const int select_dim,
const int largest, const int largest,
const T* x, const T* x,
T* value, T* value,
...@@ -38,8 +38,8 @@ void _TopK( ...@@ -38,8 +38,8 @@ void _TopK(
for (int j = 0; j < inner_dim; ++j) { for (int j = 0; j < inner_dim; ++j) {
auto* offset_x = x + (i * axis_dim * inner_dim + j); auto* offset_x = x + (i * axis_dim * inner_dim + j);
vector<std::pair<T, int64_t>> head_data; vector<std::pair<T, int64_t>> head_data;
head_data.reserve(top_k); head_data.reserve(select_dim);
for (int k = 0; k < top_k && k < axis_dim; ++k) { for (int k = 0; k < select_dim && k < axis_dim; ++k) {
head_data.emplace_back(*offset_x, k); head_data.emplace_back(*offset_x, k);
offset_x += inner_dim; offset_x += inner_dim;
} }
...@@ -49,7 +49,7 @@ void _TopK( ...@@ -49,7 +49,7 @@ void _TopK(
Comp> Comp>
pq(Comp(), std::move(head_data)); pq(Comp(), std::move(head_data));
if (largest > 0) { if (largest > 0) {
for (int k = top_k; k < axis_dim; ++k) { for (int k = select_dim; k < axis_dim; ++k) {
if (pq.top().first < *offset_x) { if (pq.top().first < *offset_x) {
pq.pop(); pq.pop();
pq.emplace(*offset_x, k); pq.emplace(*offset_x, k);
...@@ -57,7 +57,7 @@ void _TopK( ...@@ -57,7 +57,7 @@ void _TopK(
offset_x += inner_dim; offset_x += inner_dim;
} }
} else { } else {
for (int k = top_k; k < axis_dim; ++k) { for (int k = select_dim; k < axis_dim; ++k) {
if (pq.top().first > *offset_x) { if (pq.top().first > *offset_x) {
pq.pop(); pq.pop();
pq.emplace(*offset_x, k); pq.emplace(*offset_x, k);
...@@ -65,7 +65,8 @@ void _TopK( ...@@ -65,7 +65,8 @@ void _TopK(
offset_x += inner_dim; offset_x += inner_dim;
} }
} }
auto y_offset = i * top_k * inner_dim + j + (top_k - 1) * inner_dim; auto y_offset =
i * select_dim * inner_dim + j + (select_dim - 1) * inner_dim;
while (!pq.empty()) { while (!pq.empty()) {
const auto& p = pq.top(); const auto& p = pq.top();
value[y_offset] = p.first; value[y_offset] = p.first;
...@@ -82,11 +83,11 @@ void _TopK( ...@@ -82,11 +83,11 @@ void _TopK(
/* ------------------- Launcher Separator ------------------- */ /* ------------------- Launcher Separator ------------------- */
template <> template <>
void TopK<float16, CPUContext>( void TopSelect<float16, CPUContext>(
const int outer_dim, const int outer_dim,
const int inner_dim, const int inner_dim,
const int axis_dim, const int axis_dim,
const int top_k, const int select_dim,
const int largest, const int largest,
const float16* x, const float16* x,
float16* value, float16* value,
...@@ -95,25 +96,39 @@ void TopK<float16, CPUContext>( ...@@ -95,25 +96,39 @@ void TopK<float16, CPUContext>(
CPU_FP16_NOT_SUPPORTED; CPU_FP16_NOT_SUPPORTED;
} }
#define DEFINE_KERNEL_LAUNCHER(T) \ #define DEFINE_KERNEL_LAUNCHER(T) \
template <> \ template <> \
void TopK<T, CPUContext>( \ void TopSelect<T, CPUContext>( \
const int outer_dim, \ const int outer_dim, \
const int inner_dim, \ const int inner_dim, \
const int axis_dim, \ const int axis_dim, \
const int top_k, \ const int select_dim, \
const int largest, \ const int largest, \
const T* x, \ const T* x, \
T* value, \ T* value, \
int64_t* index, \ int64_t* index, \
CPUContext* ctx) { \ CPUContext* ctx) { \
if (largest > 0) { \ if (largest > 0) { \
_TopK<T, LargestComp<T>>( \ _TopSelect<T, LargestComp<T>>( \
outer_dim, inner_dim, axis_dim, top_k, largest, x, value, index); \ outer_dim, \
} else { \ inner_dim, \
_TopK<T, SmallestComp<T>>( \ axis_dim, \
outer_dim, inner_dim, axis_dim, top_k, largest, x, value, index); \ select_dim, \
} \ largest, \
x, \
value, \
index); \
} else { \
_TopSelect<T, SmallestComp<T>>( \
outer_dim, \
inner_dim, \
axis_dim, \
select_dim, \
largest, \
x, \
value, \
index); \
} \
} }
DEFINE_KERNEL_LAUNCHER(int8_t); DEFINE_KERNEL_LAUNCHER(int8_t);
......
...@@ -40,18 +40,6 @@ __global__ void _ComputeCounts( ...@@ -40,18 +40,6 @@ __global__ void _ComputeCounts(
/* ------------------- Launcher Separator ------------------- */ /* ------------------- Launcher Separator ------------------- */
template <>
void Unique<float16, CUDAContext>(
const int dim,
const float16* x,
float16* y,
int64_t* inverse_index,
int64_t* counts,
int* num,
CUDAContext* ctx) {
LOG(FATAL) << "FP16 is unsupported for CUDAContext.";
}
#define DEFINE_KERNEL_LAUNCHER(T) \ #define DEFINE_KERNEL_LAUNCHER(T) \
template <> \ template <> \
void Unique<T, CUDAContext>( \ void Unique<T, CUDAContext>( \
...@@ -67,8 +55,10 @@ void Unique<float16, CUDAContext>( ...@@ -67,8 +55,10 @@ void Unique<float16, CUDAContext>(
thrust::device_vector<int> order1(dim), order2(dim); \ thrust::device_vector<int> order1(dim), order2(dim); \
thrust::sequence(policy, order1.begin(), order1.end()); \ thrust::sequence(policy, order1.begin(), order1.end()); \
thrust::sequence(policy, order2.begin(), order2.end()); \ thrust::sequence(policy, order2.begin(), order2.end()); \
thrust::sort_by_key(policy, y, y + dim, order1.begin()); \ thrust::sort_by_key( \
auto last = thrust::unique_by_key(policy, y, y + dim, order2.begin()); \ policy, y, y + dim, order1.begin(), math::LessFunctor<T>()); \
auto last = thrust::unique_by_key( \
policy, y, y + dim, order2.begin(), math::EqualFunctor<T>()); \
int n = num[0] = last.first - y; \ int n = num[0] = last.first - y; \
if (inverse_index) { \ if (inverse_index) { \
_RemapInverse<<<CUDA_BLOCKS(n), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \ _RemapInverse<<<CUDA_BLOCKS(n), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
...@@ -84,6 +74,7 @@ DEFINE_KERNEL_LAUNCHER(int8_t); ...@@ -84,6 +74,7 @@ DEFINE_KERNEL_LAUNCHER(int8_t);
DEFINE_KERNEL_LAUNCHER(uint8_t); DEFINE_KERNEL_LAUNCHER(uint8_t);
DEFINE_KERNEL_LAUNCHER(int); DEFINE_KERNEL_LAUNCHER(int);
DEFINE_KERNEL_LAUNCHER(int64_t); DEFINE_KERNEL_LAUNCHER(int64_t);
DEFINE_KERNEL_LAUNCHER(float16);
DEFINE_KERNEL_LAUNCHER(float); DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double); DEFINE_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
......
#include "dragon/operators/array/sort_op.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
template <class Context>
template <typename T>
void SortOp<Context>::DoRunWithType() {
auto &X = Input(0), *Y_value = Output(0), *Y_index = Output(1);
CANONICALIZE_AXIS_WITH_TENSOR(X);
axis = (axis == INT_MAX ? X.ndim() - 1 : axis);
kernel::TopSelect(
X.count(0, axis),
X.count(axis + 1),
X.dim(axis),
X.dim(axis),
descending_ > 0 ? 1 : 0,
X.template data<T, Context>(),
Y_value->ReshapeLike(X)->template mutable_data<T, Context>(),
Y_index->ReshapeLike(X)->template mutable_data<int64_t, Context>(),
ctx());
}
template <class Context>
void SortOp<Context>::RunOnDevice() {
DispatchHelper<NumericalTensorTypes>::Call(this, Input(0));
}
DEPLOY_CPU_OPERATOR(Sort);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(Sort);
#endif
OPERATOR_SCHEMA(Sort)
/* X */
.NumInputs(1)
/* Value, Index */
.NumOutputs(2);
NO_GRADIENT(Sort);
} // namespace dragon
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ARRAY_SORT_OP_H_
#define DRAGON_OPERATORS_ARRAY_SORT_OP_H_
#include "dragon/core/operator.h"
namespace dragon {
template <class Context>
class SortOp final : public Operator<Context> {
public:
SortOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
descending_(OP_SINGLE_ARG(int64_t, "descending", 0)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T>
void DoRunWithType();
protected:
int64_t descending_;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_ARRAY_SORT_OP_H_
...@@ -16,17 +16,16 @@ void TopKOp<Context>::DoRunWithType() { ...@@ -16,17 +16,16 @@ void TopKOp<Context>::DoRunWithType() {
auto Y_dims = X.dims(); auto Y_dims = X.dims();
Y_dims[axis] = k_; Y_dims[axis] = k_;
CPUContext cpu_ctx; kernel::TopSelect(
kernel::TopK(
X.count(0, axis), X.count(0, axis),
X.count(axis + 1), X.count(axis + 1),
X.dim(axis), X.dim(axis),
k_, k_,
largest_, largest_,
X.template data<T, CPUContext>(), X.template data<T, Context>(),
Y_value->Reshape(Y_dims)->template mutable_data<T, CPUContext>(), Y_value->Reshape(Y_dims)->template mutable_data<T, Context>(),
Y_index->Reshape(Y_dims)->template mutable_data<int64_t, CPUContext>(), Y_index->Reshape(Y_dims)->template mutable_data<int64_t, Context>(),
&cpu_ctx); ctx());
} }
template <class Context> template <class Context>
......
...@@ -105,12 +105,10 @@ void L1LossGradientOp<Context>::DoRunWithType() { ...@@ -105,12 +105,10 @@ void L1LossGradientOp<Context>::DoRunWithType() {
// Gradient w.r.t. the second input // Gradient w.r.t. the second input
if (OutputSize() > 1 && Output(1)->has_name()) { if (OutputSize() > 1 && Output(1)->has_name()) {
Output(1)->ReshapeLike(Input(1)); math::Neg(
math::Scale(
dX->count(), dX->count(),
-1.f,
dx, dx,
Output(1)->template mutable_data<T, Context>(), Output(1)->ReshapeLike(Input(1))->template mutable_data<T, Context>(),
ctx()); ctx());
} }
} }
......
...@@ -103,12 +103,10 @@ void L2LossGradientOp<Context>::DoRunWithType() { ...@@ -103,12 +103,10 @@ void L2LossGradientOp<Context>::DoRunWithType() {
// Gradient w.r.t. the second input // Gradient w.r.t. the second input
if (OutputSize() > 1 && Output(1)->has_name()) { if (OutputSize() > 1 && Output(1)->has_name()) {
Output(1)->ReshapeLike(Input(1)); math::Neg(
math::Scale(
dX->count(), dX->count(),
-1.f,
dx, dx,
Output(1)->template mutable_data<T, Context>(), Output(1)->ReshapeLike(Input(1))->template mutable_data<T, Context>(),
ctx()); ctx());
} }
} }
......
...@@ -105,12 +105,10 @@ void SmoothL1LossGradientOp<Context>::DoRunWithType() { ...@@ -105,12 +105,10 @@ void SmoothL1LossGradientOp<Context>::DoRunWithType() {
// Gradient w.r.t. the second input // Gradient w.r.t. the second input
if (OutputSize() > 1 && Output(1)->has_name()) { if (OutputSize() > 1 && Output(1)->has_name()) {
Output(1)->ReshapeLike(Input(1)); math::Neg(
math::Scale(
dX->count(), dX->count(),
-1.f,
dx, dx,
Output(1)->template mutable_data<T, Context>(), Output(1)->ReshapeLike(Input(1))->template mutable_data<T, Context>(),
ctx()); ctx());
} }
} }
......
...@@ -179,9 +179,8 @@ void DivGradientOp<Context>::DoRunWithType() { ...@@ -179,9 +179,8 @@ void DivGradientOp<Context>::DoRunWithType() {
B.template data<T, Context>(), B.template data<T, Context>(),
dB->template mutable_data<T, Context>(), dB->template mutable_data<T, Context>(),
ctx()); ctx());
math::Scale( math::Neg(
B_ref.count(), B_ref.count(),
-1.f,
dB->template data<T, Context>(), dB->template data<T, Context>(),
dB->template mutable_data<T, Context>(), dB->template mutable_data<T, Context>(),
ctx()); ctx());
......
...@@ -7,9 +7,8 @@ template <class Context> ...@@ -7,9 +7,8 @@ template <class Context>
template <typename T> template <typename T>
void NegOp<Context>::DoRunWithType() { void NegOp<Context>::DoRunWithType() {
auto &X = Input(0), *Y = Output(0); auto &X = Input(0), *Y = Output(0);
math::Scale( math::Neg(
X.count(), X.count(),
-1.f,
X.template data<T, Context>(), X.template data<T, Context>(),
Y->ReshapeLike(X)->template mutable_data<T, Context>(), Y->ReshapeLike(X)->template mutable_data<T, Context>(),
ctx()); ctx());
...@@ -25,9 +24,8 @@ template <class Context> ...@@ -25,9 +24,8 @@ template <class Context>
template <typename T> template <typename T>
void NegGradientOp<Context>::DoRunWithType() { void NegGradientOp<Context>::DoRunWithType() {
auto &dY = Input(0), *dX = Output(0); auto &dY = Input(0), *dX = Output(0);
math::Scale( math::Neg(
dY.count(), dY.count(),
-1.f,
dY.template data<T, Context>(), dY.template data<T, Context>(),
dX->ReshapeLike(dY)->template mutable_data<T, Context>(), dX->ReshapeLike(dY)->template mutable_data<T, Context>(),
ctx()); ctx());
......
...@@ -72,9 +72,8 @@ void SubGradientOp<Context>::DoRunWithType() { ...@@ -72,9 +72,8 @@ void SubGradientOp<Context>::DoRunWithType() {
if (dB->has_name()) { if (dB->has_name()) {
if (B_broadcast_axes.empty()) { if (B_broadcast_axes.empty()) {
math::Scale( math::Neg(
B.count(), B.count(),
-1.f,
dY.template data<T, Context>(), dY.template data<T, Context>(),
dB->ReshapeLike(B)->template mutable_data<T, Context>(), dB->ReshapeLike(B)->template mutable_data<T, Context>(),
ctx()); ctx());
......
...@@ -72,6 +72,7 @@ from dragon.core.ops.array_ops import repeat ...@@ -72,6 +72,7 @@ from dragon.core.ops.array_ops import repeat
from dragon.core.ops.array_ops import reshape from dragon.core.ops.array_ops import reshape
from dragon.core.ops.array_ops import shape from dragon.core.ops.array_ops import shape
from dragon.core.ops.array_ops import slice from dragon.core.ops.array_ops import slice
from dragon.core.ops.array_ops import sort
from dragon.core.ops.array_ops import split from dragon.core.ops.array_ops import split
from dragon.core.ops.array_ops import squeeze from dragon.core.ops.array_ops import squeeze
from dragon.core.ops.array_ops import stack from dragon.core.ops.array_ops import stack
......
...@@ -857,6 +857,20 @@ def softmax_loss_spec(args, inputs, outputs): ...@@ -857,6 +857,20 @@ def softmax_loss_spec(args, inputs, outputs):
return outputs return outputs
@register('Sort')
def sort_spec(args, inputs, outputs):
_ = locals()
outputs[0].dtype = inputs[0].dtype
outputs[1].dtype = 'int64'
try:
out_shape = list(inputs[0].shape[:])
outputs[0].shape = out_shape[:]
outputs[1].shape = out_shape[:]
except (TypeError, IndexError):
pass
return outputs
@register('SpaceToDepth') @register('SpaceToDepth')
def space_to_depth_spec(args, inputs, outputs): def space_to_depth_spec(args, inputs, outputs):
outputs[0].dtype = inputs[0].dtype outputs[0].dtype = inputs[0].dtype
...@@ -1029,8 +1043,8 @@ def top_k_spec(args, inputs, outputs): ...@@ -1029,8 +1043,8 @@ def top_k_spec(args, inputs, outputs):
try: try:
out_shape = list(inputs[0].shape[:]) out_shape = list(inputs[0].shape[:])
out_shape[axis] = k out_shape[axis] = k
outputs[0].shape = out_shape outputs[0].shape = out_shape[:]
outputs[1].shape = out_shape outputs[1].shape = out_shape[:]
except (TypeError, IndexError): except (TypeError, IndexError):
pass pass
return outputs return outputs
......
...@@ -1273,6 +1273,54 @@ def slice(inputs, starts, sizes, **kwargs): ...@@ -1273,6 +1273,54 @@ def slice(inputs, starts, sizes, **kwargs):
@OpSchema.num_inputs(1) @OpSchema.num_inputs(1)
def sort(inputs, axis=-1, descending=False, **kwargs):
"""Return the sorted elements along the given axis.
By default, the last axis is chosen:
```python
x = dragon.constant([[1, 2, 3], [3, 2, 1]])
value1, index1 = dragon.sort(x)
value2, index2 = dragon.sort(x, axis=1) # Equivalent
```
Sort in the inverse order if ``descending`` is **True**:
```python
x = dragon.constant([1, 2, 3])
_, index1 = dragon.sort(-x)
_, index2 = dragon.sort(x, descending=True) # Equivalent
```
Parameters
----------
inputs : dragon.Tensor
The input tensor.
axis : int, optional, default=-1
The axis to sort elements.
descending : bool, optional, default=False
Sort in the descending order or not.
Returns
-------
Sequence[dragon.vm.torch.Tensor]
The value and index tensor.
"""
args = parse_args(locals())
op_lib = array_ops_lib.Sort
if context.executing_eagerly():
return op_lib \
.instantiate(
axis=axis,
descending=descending,
).apply([inputs])
else:
args['num_outputs'] = 2
return op_lib.blend(**args)
@OpSchema.num_inputs(1)
def split( def split(
inputs, inputs,
num_or_size_splits, num_or_size_splits,
...@@ -1548,10 +1596,10 @@ def transpose(inputs, perm=None, **kwargs): ...@@ -1548,10 +1596,10 @@ def transpose(inputs, perm=None, **kwargs):
@OpSchema.num_inputs(1) @OpSchema.num_inputs(1)
def top_k(inputs, k=1, axis=None, largest=True, sorted=True, **kwargs): def top_k(inputs, k=1, axis=-1, largest=True, sorted=True, **kwargs):
"""Return the top-K largest or smallest elements along the given axis. """Return the top-K largest or smallest elements along the given axis.
If ``axis`` is not given, the last axis is chosen: By default, the last axis is chosen:
```python ```python
x = dragon.constant([[1, 2, 3], [3, 2, 1]]) x = dragon.constant([[1, 2, 3], [3, 2, 1]])
...@@ -1562,9 +1610,9 @@ def top_k(inputs, k=1, axis=None, largest=True, sorted=True, **kwargs): ...@@ -1562,9 +1610,9 @@ def top_k(inputs, k=1, axis=None, largest=True, sorted=True, **kwargs):
If ``largest`` is **False**, the k smallest elements are returned: If ``largest`` is **False**, the k smallest elements are returned:
```python ```python
x = dragon.constant([[1, 2, 3], [3, 2, 1]]) x = dragon.constant([1, 2, 3])
_, index1 = dragon.math.top_k(x, largest=False) _, index1 = dragon.math.top_k(-x)
_, index2 = dragon.math.top_k(-x, largest=True) # Equivalent _, index2 = dragon.math.top_k(x, largest=False) # Equivalent
``` ```
Parameters Parameters
...@@ -1573,11 +1621,11 @@ def top_k(inputs, k=1, axis=None, largest=True, sorted=True, **kwargs): ...@@ -1573,11 +1621,11 @@ def top_k(inputs, k=1, axis=None, largest=True, sorted=True, **kwargs):
The input tensor. The input tensor.
k : int, optional, default=1 k : int, optional, default=1
The number of top elements to select. The number of top elements to select.
axis : int, optional axis : int, optional, default=-1
The axis to reduce. The axis to select elements.
largest : bool, optional, default=True largest : bool, optional, default=True
Return largest or smallest elements. Return largest or smallest elements.
sorted : bool, optional sorted : bool, optional, default=True
Whether to return in the sorted order. Whether to return in the sorted order.
Returns Returns
......
...@@ -551,6 +551,25 @@ class Shape(Operator): ...@@ -551,6 +551,25 @@ class Shape(Operator):
return self.dispatch(inputs, [self.alloc()], no_grad=True) return self.dispatch(inputs, [self.alloc()], no_grad=True)
class Sort(Operator):
def __init__(self, key, dev, **kwargs):
super(Sort, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', -1)
self.descending = kwargs.get('descending', False)
def attributes(self):
return {
'op_type': 'Sort',
'arguments': {
'axis': self.axis,
'descending': self.descending,
}
}
def forward(self, inputs):
return self.dispatch(inputs, [self.alloc(), self.alloc()], no_grad=True)
class Split(Operator): class Split(Operator):
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(Split, self).__init__(key, dev, **kwargs) super(Split, self).__init__(key, dev, **kwargs)
...@@ -666,7 +685,7 @@ class TopK(Operator): ...@@ -666,7 +685,7 @@ class TopK(Operator):
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(TopK, self).__init__(key, dev, **kwargs) super(TopK, self).__init__(key, dev, **kwargs)
self.k = kwargs.get('k', 1) self.k = kwargs.get('k', 1)
self.axis = kwargs.get('axis', None) self.axis = kwargs.get('axis', -1)
self.largest = kwargs.get('largest', True) self.largest = kwargs.get('largest', True)
self.sorted = kwargs.get('sorted', True) self.sorted = kwargs.get('sorted', True)
......
...@@ -110,54 +110,28 @@ inline float to<float, float16>(float16 val) { ...@@ -110,54 +110,28 @@ inline float to<float, float16>(float16 val) {
#ifdef USE_CUDA #ifdef USE_CUDA
template <> template <>
inline float16 to<float16, half>(half val) {
return float16{__half_raw(val).x};
}
template <>
inline half to<half, float>(float val) { inline half to<half, float>(float val) {
#if CUDA_VERSION_MIN(9, 0, 0) return __float2half(val);
__half_raw fp16_raw;
fp16_raw.x = cast::to<float16>(val).x;
return half(fp16_raw);
#else
half fp16;
fp16.x = dragon_cast<float16, float>(val).x;
return fp16;
#endif
} }
template <> template <>
inline half2 to<half2, float>(float val) { inline half to<half, float16>(float16 val) {
#if CUDA_VERSION_MIN(9, 0, 0) return __half_raw{val.x};
half fp16 = cast::to<half>(val);
return half2(fp16, fp16);
#else
half2 fp32;
fp32.x = cast::to<float32>(val).x;
return fp32;
#endif
} }
template <> template <>
inline half2 to<half2, float16>(float16 val) { inline half2 to<half2, float>(float val) {
#if CUDA_VERSION_MIN(9, 0, 0) return __float2half2_rn(val);
__half_raw fp16_raw;
fp16_raw.x = val.x;
return half2(half(fp16_raw), half(fp16_raw));
#else
half2 fp32;
fp32.x = dragon_cast<float32, float16>(val).x;
return fp32;
#endif
} }
template <> template <>
inline half to<half, float16>(float16 val) { inline half2 to<half2, float16>(float16 val) {
#if CUDA_VERSION_MIN(9, 0, 0) return half2(__half2_raw{val.x, val.x});
__half_raw fp16_raw;
fp16_raw.x = val.x;
return fp16_raw;
#else
half fp16;
fp16.x = val.x;
return fp16;
#endif
} }
#endif // USE_CUDA #endif // USE_CUDA
......
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_UTILS_DEVICE_COMMON_CUB_H_ #ifndef DRAGON_UTILS_DEVICE_COMMON_CUB_H_
#define DRAGON_UTILS_DEVICE_COMMON_CUB_H_ #define DRAGON_UTILS_DEVICE_COMMON_CUB_H_
#ifdef USE_CUDA #ifdef USE_CUDA
#include <cub/block/block_radix_sort.cuh>
#include <cub/block/block_reduce.cuh> #include <cub/block/block_reduce.cuh>
#include <cub/device/device_reduce.cuh> #include <cub/device/device_reduce.cuh>
#include <cub/device/device_select.cuh> #include <cub/device/device_select.cuh>
#include <cub/iterator/counting_input_iterator.cuh> #include <cub/iterator/counting_input_iterator.cuh>
#include "dragon/utils/device/common_cuda.h"
namespace cub {
struct SumHalf {
inline __device__ half operator()(const half& a, const half& b) const {
#if __CUDA_ARCH__ >= 530
return __hadd(a, b);
#else
return __float2half(__half2float(a) + __half2float(b));
#endif
}
};
struct MinHalf {
inline __device__ half operator()(const half& a, const half& b) const {
#if __CUDA_ARCH__ >= 530
return __hlt(a, b) ? a : b;
#else
return __half2float(a) < __half2float(b) ? a : b;
#endif
}
};
struct MaxHalf {
inline __device__ half operator()(const half& a, const half& b) const {
#if __CUDA_ARCH__ >= 530
return __hgt(a, b) ? a : b;
#else
return __half2float(a) > __half2float(b) ? a : b;
#endif
}
};
} // namespace cub
namespace dragon { namespace dragon {
template <typename T> template <typename T>
......
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_UTILS_DEVICE_COMMON_NCCL_H_ #ifndef DRAGON_UTILS_DEVICE_COMMON_NCCL_H_
#define DRAGON_UTILS_DEVICE_COMMON_NCCL_H_ #define DRAGON_UTILS_DEVICE_COMMON_NCCL_H_
......
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_UTILS_DEVICE_COMMON_THRUST_H_ #ifndef DRAGON_UTILS_DEVICE_COMMON_THRUST_H_
#define DRAGON_UTILS_DEVICE_COMMON_THRUST_H_ #define DRAGON_UTILS_DEVICE_COMMON_THRUST_H_
......
...@@ -50,6 +50,31 @@ DEFINE_COPY_FUNC(float); ...@@ -50,6 +50,31 @@ DEFINE_COPY_FUNC(float);
DEFINE_COPY_FUNC(double); DEFINE_COPY_FUNC(double);
#undef DEFINE_COPY_FUNC #undef DEFINE_COPY_FUNC
#define DEFINE_COPY_FUNC(T) \
template <> \
DRAGON_API void Copy<T, CPUContext>( \
const int n, \
const int incx, \
const int incy, \
const T* x, \
T* y, \
CPUContext* ctx) { \
if (x != y && n > 0) { \
EigenStridedVectorMap<T>(y, 1, n, EigenInnerStride(incy)) = \
ConstEigenStridedVectorMap<T>(x, 1, n, EigenInnerStride(incx)); \
} \
}
DEFINE_COPY_FUNC(bool);
DEFINE_COPY_FUNC(int8_t);
DEFINE_COPY_FUNC(uint8_t);
DEFINE_COPY_FUNC(int);
DEFINE_COPY_FUNC(int64_t);
DEFINE_COPY_FUNC(float16);
DEFINE_COPY_FUNC(float);
DEFINE_COPY_FUNC(double);
#undef DEFINE_COPY_FUNC
template <> template <>
DRAGON_API void Axpy<float16, CPUContext>( DRAGON_API void Axpy<float16, CPUContext>(
const int n, const int n,
......
...@@ -18,6 +18,14 @@ __global__ void _Scale(const int n, const T alpha, const T* x, T* y) { ...@@ -18,6 +18,14 @@ __global__ void _Scale(const int n, const T alpha, const T* x, T* y) {
} }
template <typename T> template <typename T>
__global__ void
_Copy(const int n, const int incx, const int incy, const T* x, T* y) {
CUDA_1D_KERNEL_LOOP(i, n) {
y[i * incy] = x[i * incx];
}
}
template <typename T>
__global__ void _Axpy(const int n, const T alpha, const T* x, T* y) { __global__ void _Axpy(const int n, const T alpha, const T* x, T* y) {
CUDA_1D_KERNEL_LOOP(i, n) { CUDA_1D_KERNEL_LOOP(i, n) {
y[i] += (alpha * x[i]); y[i] += (alpha * x[i]);
...@@ -200,6 +208,47 @@ DEFINE_COPY_FUNC(float); ...@@ -200,6 +208,47 @@ DEFINE_COPY_FUNC(float);
DEFINE_COPY_FUNC(double); DEFINE_COPY_FUNC(double);
#undef DEFINE_COPY_FUNC #undef DEFINE_COPY_FUNC
#define DEFINE_COPY_FUNC(T) \
template <> \
DRAGON_API void Copy<T, CUDAContext>( \
const int n, \
const int incx, \
const int incy, \
const T* x, \
T* y, \
CUDAContext* ctx) { \
if (x != y && n > 0) { \
_Copy<<<CUDA_BLOCKS(n), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
n, incx, incy, x, y); \
} \
}
DEFINE_COPY_FUNC(bool);
DEFINE_COPY_FUNC(int8_t);
DEFINE_COPY_FUNC(uint8_t);
DEFINE_COPY_FUNC(int);
DEFINE_COPY_FUNC(int64_t);
DEFINE_COPY_FUNC(float16);
#undef DEFINE_COPY_FUNC
#define DEFINE_COPY_FUNC(T, cublas_func) \
template <> \
DRAGON_API void Copy<T, CUDAContext>( \
const int n, \
const int incx, \
const int incy, \
const T* x, \
T* y, \
CUDAContext* ctx) { \
if (x != y && n > 0) { \
cublas_func(ctx->cublas_handle(), n, x, incx, y, incy); \
} \
}
DEFINE_COPY_FUNC(float, cublasScopy);
DEFINE_COPY_FUNC(double, cublasDcopy);
#undef DEFINE_COPY_FUNC
#define DEFINE_AXPY_FUNC(T) \ #define DEFINE_AXPY_FUNC(T) \
template <> \ template <> \
DRAGON_API void Axpy<T, CUDAContext>( \ DRAGON_API void Axpy<T, CUDAContext>( \
......
...@@ -32,6 +32,15 @@ template <typename T, class Context> ...@@ -32,6 +32,15 @@ template <typename T, class Context>
DRAGON_API void Copy(const int n, const T* x, T* y, Context* ctx); DRAGON_API void Copy(const int n, const T* x, T* y, Context* ctx);
template <typename T, class Context> template <typename T, class Context>
DRAGON_API void Copy(
const int n,
const int incx,
const int incy,
const T* x,
T* y,
Context* ctx);
template <typename T, class Context>
DRAGON_API void DRAGON_API void
Axpy(const int n, const float alpha, const T* x, T* y, Context* ctx); Axpy(const int n, const float alpha, const T* x, T* y, Context* ctx);
......
...@@ -102,6 +102,29 @@ DEFINE_UNARY_FUNC(Sign, double, [](double x) { ...@@ -102,6 +102,29 @@ DEFINE_UNARY_FUNC(Sign, double, [](double x) {
}); });
#undef DEFINE_UNARY_FUNC #undef DEFINE_UNARY_FUNC
template <>
#define DEFINE_NEG_FUNC(T) \
template <> \
DRAGON_API void Neg<T, CPUContext>( \
const int n, const T* x, T* y, CPUContext* ctx) { \
EigenVectorArrayMap<T>(y, n) = -ConstEigenVectorArrayMap<T>(x, n); \
}
DRAGON_API void Neg<float16, CPUContext>(
const int n,
const float16* x,
float16* y,
CPUContext* ctx) {
CPU_FP16_NOT_SUPPORTED;
}
DEFINE_NEG_FUNC(int8_t);
DEFINE_NEG_FUNC(int);
DEFINE_NEG_FUNC(int64_t);
DEFINE_NEG_FUNC(float);
DEFINE_NEG_FUNC(double);
#undef DEFINE_NEG_FUNC
/* y = value */ /* y = value */
#define DEFINE_SET_FUNC(T) \ #define DEFINE_SET_FUNC(T) \
......
...@@ -23,6 +23,9 @@ template <typename T, class Context> ...@@ -23,6 +23,9 @@ template <typename T, class Context>
DRAGON_API void Abs(const int n, const T* x, T* y, Context* ctx); DRAGON_API void Abs(const int n, const T* x, T* y, Context* ctx);
template <typename T, class Context> template <typename T, class Context>
DRAGON_API void Neg(const int n, const T* x, T* y, Context* ctx);
template <typename T, class Context>
DRAGON_API void Ceil(const int n, const T* x, T* y, Context* ctx); DRAGON_API void Ceil(const int n, const T* x, T* y, Context* ctx);
template <typename T, class Context> template <typename T, class Context>
......
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_UTILS_MATH_SORT_H_
#define DRAGON_UTILS_MATH_SORT_H_
#endif // DRAGON_UTILS_MATH_SORT_H_
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "dragon/utils/math/blas.h" #include "dragon/utils/math/blas.h"
#include "dragon/utils/math/broadcast.h" #include "dragon/utils/math/broadcast.h"
#include "dragon/utils/math/elementwise.h" #include "dragon/utils/math/elementwise.h"
#include "dragon/utils/math/functional.h"
#include "dragon/utils/math/random.h" #include "dragon/utils/math/random.h"
#include "dragon/utils/math/reduce.h" #include "dragon/utils/math/reduce.h"
#include "dragon/utils/math/utils.h" #include "dragon/utils/math/utils.h"
......
...@@ -308,7 +308,7 @@ void IndexSelect( ...@@ -308,7 +308,7 @@ void IndexSelect(
const int outer_dim, const int outer_dim,
const int inner_dim, const int inner_dim,
const int axis_dim, const int axis_dim,
const int num_indices, const int select_dim,
const int64_t* indices, const int64_t* indices,
const T* x, const T* x,
T* y, T* y,
...@@ -319,7 +319,7 @@ void IndexSelectGrad( ...@@ -319,7 +319,7 @@ void IndexSelectGrad(
const int outer_dim, const int outer_dim,
const int inner_dim, const int inner_dim,
const int axis_dim, const int axis_dim,
const int num_indices, const int select_dim,
const int64_t* index, const int64_t* index,
const T* dy, const T* dy,
T* dx, T* dx,
...@@ -539,11 +539,11 @@ void TransposeGrad( ...@@ -539,11 +539,11 @@ void TransposeGrad(
/* array.top_k */ /* array.top_k */
template <typename T, class Context> template <typename T, class Context>
void TopK( void TopSelect(
const int outer_dim, const int outer_dim,
const int inner_dim, const int inner_dim,
const int axis_dim, const int axis_dim,
const int top_k, const int topk,
const int largest, const int largest,
const T* x, const T* x,
T* value, T* value,
...@@ -551,6 +551,7 @@ void TopK( ...@@ -551,6 +551,7 @@ void TopK(
Context* ctx); Context* ctx);
/* array.unique */ /* array.unique */
template <typename T, class Context> template <typename T, class Context>
void Unique( void Unique(
const int dim, const int dim,
...@@ -562,6 +563,7 @@ void Unique( ...@@ -562,6 +563,7 @@ void Unique(
Context* ctx); Context* ctx);
/* control_flow.assgin */ /* control_flow.assgin */
template <typename T, class Context> template <typename T, class Context>
void Assign( void Assign(
const int num_dims, const int num_dims,
......
...@@ -112,6 +112,8 @@ from dragon.vm.tensorflow.core.ops.math_ops import square ...@@ -112,6 +112,8 @@ from dragon.vm.tensorflow.core.ops.math_ops import square
from dragon.vm.tensorflow.core.ops.math_ops import subtract from dragon.vm.tensorflow.core.ops.math_ops import subtract
from dragon.vm.tensorflow.core.ops.math_ops import tanh from dragon.vm.tensorflow.core.ops.math_ops import tanh
from dragon.vm.tensorflow.core.ops.gradients_impl import gradients from dragon.vm.tensorflow.core.ops.gradients_impl import gradients
from dragon.vm.tensorflow.core.ops.sort_ops import argsort
from dragon.vm.tensorflow.core.ops.sort_ops import sort
from dragon.vm.tensorflow.core.ops.variables import Variable from dragon.vm.tensorflow.core.ops.variables import Variable
# Attributes # Attributes
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# <https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/array_ops.py> # <https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/array_ops.py>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""The array ops.""" """Array ops."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# <https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/bitwise_ops.py> # <https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/bitwise_ops.py>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""The bitwise ops.""" """Bitwise ops."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# <https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/clip_ops.py> # <https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/clip_ops.py>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""The clip ops.""" """Clip ops."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""Grad implementation.""" """Gradient implementation."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# <https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/init_ops.py> # <https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/init_ops.py>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""The init ops.""" """Init ops."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# <https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/linalg_ops.py> # <https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/linalg_ops.py>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""The linalg ops.""" """Linalg ops."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# <https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/math_ops.py> # <https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/math_ops.py>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""The math ops.""" """Math ops."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
...@@ -8,8 +8,7 @@ ...@@ -8,8 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""NN components."""
"""The nn components."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""The nn ops implementation.""" """NN implementation."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""The nn ops.""" """NN ops."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""The random ops.""" """Random ops."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Sort ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from dragon.core.ops import array_ops
def argsort(values, axis=-1, direction='ASCENDING', name=None):
"""Return the index of sorted elements along the given axis.
By default, the last axis is chosen:
```python
x = tf.constant([[1, 2, 3], [3, 2, 1]])
index1 = tf.argsort(x)
index2 = tf.argsort(x, axis=1) # Equivalent
```
Sort in the inverse order if ``direction`` is ``DESCENDING``:
```python
x = tf.constant([1, 2, 3])
index1 = tf.argsort(-x)
index2 = tf.argsort(x, direction='DESCENDING') # Equivalent
```
Parameters
----------
values : dragon.Tensor
The input tensor.
axis : int, optional, default=-1
The axis to sort elements.
direction : {'ASCENDING', 'DESCENDING'}, optional
The sorting direction.
name : str, optional
The operation name.
Returns
-------
dragon.Tensor
The index tensor.
"""
if direction not in ('ASCENDING', 'DESCENDING'):
raise ValueError('Unknown direction: ' + direction)
value_and_index = array_ops.sort(
values,
axis=axis,
descending=direction == 'DESCENDING',
name=name,
)
return value_and_index[1]
def sort(values, axis=-1, direction='ASCENDING', name=None):
"""Return the sorted elements along the given axis.
By default, the last axis is chosen:
```python
x = tf.constant([[1, 2, 3], [3, 2, 1]])
value1 = tf.sort(x)
value2 = tf.sort(x, axis=1) # Equivalent
```
Sort in the inverse order if ``direction`` is ``DESCENDING``:
```python
x = tf.constant([1, 2, 3])
value1 = -tf.sort(-x)
value2 = tf.sort(x, direction='DESCENDING') # Equivalent
```
Parameters
----------
values : dragon.Tensor
The input tensor.
axis : int, optional, default=-1
The axis to sort elements.
direction : {'ASCENDING', 'DESCENDING'}, optional
The sorting direction.
name : str, optional
The operation name.
Returns
-------
dragon.Tensor
The value tensor.
"""
if direction not in ('ASCENDING', 'DESCENDING'):
raise ValueError('Unknown direction: ' + direction)
value_and_index = array_ops.sort(
values,
axis=axis,
descending=direction == 'DESCENDING',
name=name,
)
return value_and_index[0]
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""The standard ops.""" """Standard ops."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""The Variable class.""" """Variable class."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
...@@ -810,6 +810,28 @@ class TestArrayOps(OpTestCase): ...@@ -810,6 +810,28 @@ class TestArrayOps(OpTestCase):
with dragon.device('cuda'): with dragon.device('cuda'):
self.test_slice() self.test_slice()
def test_sort(self):
entries = [(None, True),
(0, True),
(-1, True),
(0, False),
(-1, False)]
for execution in ('EAGER_MODE', 'GRAPH_MODE'):
with execution_context().mode(execution):
for axis, descending in entries:
data = uniform((5, 10))
x = new_tensor(data)
y = dragon.sort(x, axis=axis, descending=descending)
axis = axis if axis is not None else -1
result = np.argsort(-data if descending else data, axis=axis)
result = np.take(result, np.arange(data.shape[axis]), axis=axis)
self.assertEqual(y[1], result)
@unittest.skipIf(not TEST_CUDA, 'CUDA unavailable')
def test_sort_cuda(self):
with dragon.device('cuda'):
self.test_sort()
def test_split(self): def test_split(self):
entries = [(2, 1, None), ((2, 1), 1, None), (2, 1, (2,))] entries = [(2, 1, None), ((2, 1), 1, None), (2, 1, (2,))]
for execution in ('EAGER_MODE', 'GRAPH_MODE'): for execution in ('EAGER_MODE', 'GRAPH_MODE'):
......
...@@ -483,6 +483,21 @@ class TestTensorOps(OpTestCase): ...@@ -483,6 +483,21 @@ class TestTensorOps(OpTestCase):
x = new_tensor(data) x = new_tensor(data)
self.assertEqual(x.sin(), np.sin(data)) self.assertEqual(x.sin(), np.sin(data))
def test_sort(self):
entries = [(None, True),
(0, True),
(-1, True),
(0, False),
(-1, False)]
for axis, descending in entries:
data = uniform((5, 10))
x = new_tensor(data)
y = x.sort(axis, descending)[1]
axis = axis if axis is not None else -1
result = np.argsort(-data if descending else data, axis=axis)
result = np.take(result, np.arange(data.shape[axis]), axis=axis)
self.assertEqual(y, result)
def test_sqrt(self): def test_sqrt(self):
data = np.array([4., 9., 16], 'float32') data = np.array([4., 9., 16], 'float32')
x = new_tensor(data) x = new_tensor(data)
......
Subproject commit c3cceac115c072fb63df1836ff46d8c60d9eb304 Subproject commit a3ee304a1f8e22f278df10600df2e4b333012592
...@@ -69,6 +69,7 @@ from dragon.vm.torch.core.ops.array.functional import one_hot ...@@ -69,6 +69,7 @@ from dragon.vm.torch.core.ops.array.functional import one_hot
from dragon.vm.torch.core.ops.array.functional import permute from dragon.vm.torch.core.ops.array.functional import permute
from dragon.vm.torch.core.ops.array.functional import repeat from dragon.vm.torch.core.ops.array.functional import repeat
from dragon.vm.torch.core.ops.array.functional import reshape from dragon.vm.torch.core.ops.array.functional import reshape
from dragon.vm.torch.core.ops.array.functional import sort
from dragon.vm.torch.core.ops.array.functional import split from dragon.vm.torch.core.ops.array.functional import split
from dragon.vm.torch.core.ops.array.functional import squeeze from dragon.vm.torch.core.ops.array.functional import squeeze
from dragon.vm.torch.core.ops.array.functional import stack from dragon.vm.torch.core.ops.array.functional import stack
......
...@@ -428,6 +428,26 @@ class Slice(function.Function): ...@@ -428,6 +428,26 @@ class Slice(function.Function):
) )
class Sort(function.Function):
def __init__(self, key, dev, **kwargs):
super(Sort, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', -1)
self.descending = kwargs.get('descending', False)
def attributes(self):
return {
'op_type': 'Sort',
'arguments': {
'axis': self.axis,
'descending': self.descending,
}
}
def forward(self, input, outputs=(None, None)):
outputs = [self.alloc(outputs[0]), self.alloc(outputs[1])]
return self.dispatch([input], outputs, no_grad=True)
class Split(function.Function): class Split(function.Function):
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(Split, self).__init__(key, dev, **kwargs) super(Split, self).__init__(key, dev, **kwargs)
...@@ -546,7 +566,7 @@ class TopK(function.Function): ...@@ -546,7 +566,7 @@ class TopK(function.Function):
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(TopK, self).__init__(key, dev, **kwargs) super(TopK, self).__init__(key, dev, **kwargs)
self.k = kwargs.get('k', 1) self.k = kwargs.get('k', 1)
self.axis = kwargs.get('axis', None) self.axis = kwargs.get('axis', -1)
self.largest = kwargs.get('largest', True) self.largest = kwargs.get('largest', True)
self.sorted = kwargs.get('sorted', True) self.sorted = kwargs.get('sorted', True)
......
...@@ -795,6 +795,50 @@ def slice(input, starts, sizes): ...@@ -795,6 +795,50 @@ def slice(input, starts, sizes):
.apply(input, starts, sizes) .apply(input, starts, sizes)
def sort(input, dim=-1, descending=False, out=None):
"""Return the sorted elements along the given dimension.
By default, the last dimension is chosen:
```python
x = torch.tensor([[1, 2, 3], [3, 2, 1]])
value1, index1 = torch.sort(x)
value2, index2 = torch.sort(x, dim=1) # Equivalent
```
Sort in the descending order if ``descending`` is ``True``:
```python
x = torch.tensor([1, 2, 3])
_, index1 = torch.sort(-x)
_, index2 = torch.sort(x, descending=True) # Equivalent
```
Parameters
----------
input : dragon.vm.torch.Tensor
The input tensor.
dim : int, optional, default=-1
The dimension to sort elements.
descending : bool, optional, default=False
Sort in the descending order or not.
out : Sequence[dragon.vm.torch.Tensor], optional
The optional output value and index.
Returns
-------
Sequence[dragon.vm.torch.Tensor]
The value and index tensor.
"""
return _functions.Sort \
.instantiate(
input.device,
axis=dim,
descending=descending,
).apply(input, out if out else (None, None))
def split(tensor, split_size_or_sections, dim=0): def split(tensor, split_size_or_sections, dim=0):
"""Split input into chunks along the given dimension. """Split input into chunks along the given dimension.
...@@ -960,10 +1004,10 @@ def sum(input, dim=None, keepdim=False, out=None): ...@@ -960,10 +1004,10 @@ def sum(input, dim=None, keepdim=False, out=None):
return _reduce(input, 'Sum', dim, keepdim, out) return _reduce(input, 'Sum', dim, keepdim, out)
def topk(input, k, dim=None, largest=True, sorted=True, out=None): def topk(input, k, dim=-1, largest=True, sorted=True, out=None):
"""Return the top-K largest or smallest elements along the given dimension. """Return the top-K largest or smallest elements along the given dimension.
If ``dim`` is not given, the last dimension is chosen: By default, the last dimension is chosen:
```python ```python
x = torch.tensor([[1, 2, 3], [3, 2, 1]]) x = torch.tensor([[1, 2, 3], [3, 2, 1]])
...@@ -974,9 +1018,9 @@ def topk(input, k, dim=None, largest=True, sorted=True, out=None): ...@@ -974,9 +1018,9 @@ def topk(input, k, dim=None, largest=True, sorted=True, out=None):
If ``largest`` is ``False``, the k smallest elements are returned: If ``largest`` is ``False``, the k smallest elements are returned:
```python ```python
x = torch.tensor([[1, 2, 3], [3, 2, 1]]) x = torch.tensor([1, 2, 3])
_, index1 = torch.topk(x, 1, largest=False) _, index1 = torch.topk(-x, 1)
_, index2 = torch.topk(-x, 1, largest=True) # Equivalent _, index2 = torch.topk(x, 1, largest=False) # Equivalent
``` ```
Parameters Parameters
...@@ -985,8 +1029,8 @@ def topk(input, k, dim=None, largest=True, sorted=True, out=None): ...@@ -985,8 +1029,8 @@ def topk(input, k, dim=None, largest=True, sorted=True, out=None):
The input tensor. The input tensor.
k : int k : int
The number of top elements to select. The number of top elements to select.
dim : int, optional dim : int, optional, default=-1
The dimension to reduce. The dimension to select elements.
largest : bool, optional largest : bool, optional
Return largest or smallest elements. Return largest or smallest elements.
sorted : bool, optional sorted : bool, optional
...@@ -1000,8 +1044,6 @@ def topk(input, k, dim=None, largest=True, sorted=True, out=None): ...@@ -1000,8 +1044,6 @@ def topk(input, k, dim=None, largest=True, sorted=True, out=None):
The value and index tensor. The value and index tensor.
""" """
if dim is None:
dim = input.ndimension() - 1
return _functions.TopK \ return _functions.TopK \
.instantiate( .instantiate(
input.device, input.device,
......
...@@ -1513,6 +1513,29 @@ def sin(self): ...@@ -1513,6 +1513,29 @@ def sin(self):
return math_funcs.sin(self) return math_funcs.sin(self)
def sort(self, dim=-1, descending=False):
"""Return the sorted elements.
Parameters
----------
dim : int, optional, default=-1
The dimension to sort elements.
descending : bool, optional, default=False
Sort in the descending order or not.
Returns
-------
Sequence[dragon.vm.torch.Tensor]
The value and index tensor.
See Also
--------
`torch.sort(...)`_
"""
return array_funcs.sort(self, dim, descending)
def sqrt(self): def sqrt(self):
r"""Compute the square root. r"""Compute the square root.
...@@ -1660,18 +1683,18 @@ def sub_(self, other): ...@@ -1660,18 +1683,18 @@ def sub_(self, other):
return math_funcs.sub(self, other, self) return math_funcs.sub(self, other, self)
def topk(self, k, dim=None, largest=True, sorted=True): def topk(self, k, dim=-1, largest=True, sorted=True):
"""Return the top-K largest or smallest elements. """Return the top-K largest or smallest elements.
Parameters Parameters
---------- ----------
k : int k : int
The number of top elements to select. The number of top elements to select.
dim : int, optional dim : int, optional, default=-1
The dimension to reduce. The dimension to select elements.
largest : bool, optional largest : bool, optional, default=True
Return largest or smallest elements. Return largest or smallest elements.
sorted : bool, optional sorted : bool, optional, default=True
Whether to return in the sorted order. Whether to return in the sorted order.
Returns Returns
...@@ -1939,6 +1962,7 @@ Tensor.rsqrt_ = rsqrt_ ...@@ -1939,6 +1962,7 @@ Tensor.rsqrt_ = rsqrt_
Tensor.sign = sign Tensor.sign = sign
Tensor.sign_ = sign_ Tensor.sign_ = sign_
Tensor.sin = sin Tensor.sin = sin
Tensor.sort = sort
Tensor.sqrt = sqrt Tensor.sqrt = sqrt
Tensor.sqrt_ = sqrt_ Tensor.sqrt_ = sqrt_
Tensor.squeeze = squeeze Tensor.squeeze = squeeze
......
...@@ -1668,6 +1668,27 @@ class Tensor(object): ...@@ -1668,6 +1668,27 @@ class Tensor(object):
s = cpp.Size(self._impl.dims) s = cpp.Size(self._impl.dims)
return s[axis] if axis is not None else s return s[axis] if axis is not None else s
def sort(self, dim=-1, descending=False):
"""Return the sorted elements.
Parameters
----------
dim : int, optional, default=-1
The dimension to sort elements.
descending : bool, optional, default=False
Sort in the descending order or not.
Returns
-------
Sequence[dragon.vm.torch.Tensor]
The value and index tensor.
See Also
--------
`torch.sort(...)`_
"""
def sqrt(self): def sqrt(self):
r"""Compute the square root. r"""Compute the square root.
...@@ -1849,15 +1870,15 @@ class Tensor(object): ...@@ -1849,15 +1870,15 @@ class Tensor(object):
return self.type(dtype) return self.type(dtype)
return self return self
def topk(self, k, dim=None, largest=True, sorted=True): def topk(self, k, dim=-1, largest=True, sorted=True):
"""Return the top-K largest or smallest elements. """Return the top-K largest or smallest elements.
Parameters Parameters
---------- ----------
k : int k : int
The number of top elements to select. The number of top elements to select.
dim : int, optional dim : int, optional, default=-1
The dimension to reduce. The dimension to select elements.
largest : bool, optional largest : bool, optional
Return largest or smallest elements. Return largest or smallest elements.
sorted : bool, optional sorted : bool, optional
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!