Commit b4019faa by Ting PAN

Add Sort Operator

Summary:
This commit adds the sort op for dragon, torch and tensorflow.
Besides, cuda implementation of topk op is now available.
1 parent fdf26ef2
Showing with 2347 additions and 1033 deletions
......@@ -27,11 +27,11 @@ vm.dali
.. toctree::
:hidden:
dali/Iterator
dali/Pipeline
dali/device
dali/get_device_type
dali/get_distributed_info
dali/Iterator
dali/Pipeline
.. raw:: html
......
......@@ -138,6 +138,9 @@ dragon
`slice(...) <dragon/slice.html>`_
: Select the elements according to the given sections.
`sort(...) <dragon/sort.html>`_
: Return the sorted elements along the given axis.
`split(...) <dragon/split.html>`_
: Split the input into chunks along the given axis.
......@@ -171,6 +174,10 @@ dragon
.. toctree::
:hidden:
dragon/EagerTensor
dragon/GradientTape
dragon/Tensor
dragon/Workspace
dragon/assign
dragon/broadcast_to
dragon/cast
......@@ -182,7 +189,6 @@ dragon
dragon/copy
dragon/create_function
dragon/device
dragon/EagerTensor
dragon/eager_mode
dragon/eager_scope
dragon/expand_dims
......@@ -193,7 +199,6 @@ dragon
dragon/function
dragon/get_workspace
dragon/gradients
dragon/GradientTape
dragon/graph_mode
dragon/index_select
dragon/load_library
......@@ -212,16 +217,15 @@ dragon
dragon/reshape
dragon/shape
dragon/slice
dragon/sort
dragon/split
dragon/squeeze
dragon/stack
dragon/stop_gradient
dragon/Tensor
dragon/tile
dragon/transpose
dragon/unique
dragon/where
dragon/Workspace
dragon/zeros
dragon/zeros_like
......
......@@ -113,6 +113,9 @@ dragon.nn
.. toctree::
:hidden:
nn/GRU
nn/LSTM
nn/RNN
nn/batch_norm
nn/bias_add
nn/conv2d
......@@ -125,18 +128,15 @@ dragon.nn
nn/elu
nn/fully_connected
nn/group_norm
nn/GRU
nn/instance_norm
nn/layer_norm
nn/leaky_relu
nn/local_response_norm
nn/log_softmax
nn/LSTM
nn/pool2d
nn/prelu
nn/relu
nn/relu6
nn/RNN
nn/selu
nn/softmax
nn/space_to_depth
......
sort
====
.. autofunction:: dragon.sort
.. raw:: html
<style>
h1:before {
content: "dragon.";
color: #103d3e;
}
</style>
......@@ -18,6 +18,9 @@ vm.tensorflow
Functions
#########
`argsort(...) <tensorflow/argsort.html>`_
: Return the index of sorted elements along the given axis.
`broadcast_to(...) <dragon/broadcast_to.html>`_
: Broadcast input according to a given shape.
......@@ -84,6 +87,9 @@ vm.tensorflow
`slice(...) <tensorflow/slice.html>`_
: Select the elements according to the given sections.
`sort(...) <tensorflow/sort.html>`_
: Return the sorted elements along the given axis.
`split(...) <tensorflow/split.html>`_
: Split input into chunks along the given axis.
......@@ -108,6 +114,10 @@ vm.tensorflow
.. toctree::
:hidden:
tensorflow/GradientTape
tensorflow/TensorShape
tensorflow/TensorSpec
tensorflow/argsort
tensorflow/broadcast_to
tensorflow/cast
tensorflow/clip_by_value
......@@ -120,7 +130,6 @@ vm.tensorflow
tensorflow/function
tensorflow/gather
tensorflow/gradients
tensorflow/GradientTape
tensorflow/identity
tensorflow/name_scope
tensorflow/ones
......@@ -131,10 +140,9 @@ vm.tensorflow
tensorflow/reshape
tensorflow/shape
tensorflow/slice
tensorflow/sort
tensorflow/split
tensorflow/squeeze
tensorflow/TensorShape
tensorflow/TensorSpec
tensorflow/transpose
tensorflow/unique
tensorflow/unique_with_counts
......
argsort
=======
.. autofunction:: dragon.vm.tensorflow.argsort
.. raw:: html
<style>
h1:before {
content: "tf.";
color: #103d3e;
}
</style>
......@@ -18,8 +18,8 @@ vm.tensorflow.dtypes
.. toctree::
:hidden:
dtypes/as_dtype
dtypes/DType
dtypes/as_dtype
.. raw:: html
......
......@@ -46,7 +46,6 @@ initializers
:hidden:
initializers/Constant
initializers/get
initializers/GlorotNormal
initializers/GlorotUniform
initializers/Initializer
......@@ -56,6 +55,7 @@ initializers
initializers/TruncatedNormal
initializers/VarianceScaling
initializers/Zeros
initializers/get
.. raw:: html
......
......@@ -49,16 +49,16 @@ losses
:hidden:
losses/BinaryCrossentropy
losses/binary_crossentropy
losses/CategoricalCrossentropy
losses/categorical_crossentropy
losses/get
losses/Loss
losses/MeanAbsoluteError
losses/MeanSquaredError
losses/SparseCategoricalCrossentropy
losses/binary_crossentropy
losses/categorical_crossentropy
losses/get
losses/mean_absolute_error
losses/mean_squared_error
losses/SparseCategoricalCrossentropy
losses/sparse_categorical_crossentropy
.. raw:: html
......
......@@ -30,12 +30,12 @@ regularizers
.. toctree::
:hidden:
regularizers/get
regularizers/L1
regularizers/L1L2
regularizers/l1_l2
regularizers/L2
regularizers/Regularizer
regularizers/get
regularizers/l1_l2
.. raw:: html
......
sort
====
.. autofunction:: dragon.vm.tensorflow.sort
.. raw:: html
<style>
h1:before {
content: "tf.";
color: #103d3e;
}
</style>
......@@ -214,6 +214,9 @@ vm.torch
`sin(...) <torch/sin.html>`_
: Compute the sin of input.
`sort(...) <torch/sort.html>`_
: Return the sorted elements along the given dimension.
`split(...) <torch/split.html>`_
: Split input into chunks along the given dimension.
......@@ -256,6 +259,8 @@ vm.torch
.. toctree::
:hidden:
torch/Size
torch/Tensor_
torch/abs
torch/add
torch/arange
......@@ -322,14 +327,13 @@ vm.torch
torch/set_grad_enabled
torch/sign
torch/sin
torch/Size
torch/sort
torch/split
torch/sqrt
torch/squeeze
torch/stack
torch/sub
torch/sum
torch/Tensor_
torch/tensor
torch/topk
torch/unique
......
......@@ -397,6 +397,10 @@ size
####
.. automethod:: dragon.vm.torch.Tensor.size
sort
####
.. automethod:: dragon.vm.torch.Tensor.sort
sqrt
####
.. automethod:: dragon.vm.torch.Tensor.sqrt
......@@ -503,6 +507,7 @@ zero\_
.. _torch.rsqrt(...): rsqrt.html
.. _torch.sign(...): sign.html
.. _torch.sin(...): sin.html
.. _torch.sort(...): sort.html
.. _torch.sqrt(...): sqrt.html
.. _torch.squeeze(...): squeeze.html
.. _torch.sub(...): sub.html
......
......@@ -18,8 +18,8 @@ vm.torch.autograd
.. toctree::
:hidden:
autograd/backward
autograd/Function
autograd/backward
.. raw:: html
......
sort
====
.. autofunction:: dragon.vm.torch.sort
.. raw:: html
<style>
h1:before {
content: "torch.";
color: #103d3e;
}
</style>
......@@ -12,14 +12,14 @@ void _IndexSelect(
const int outer_dim,
const int inner_dim,
const int axis_dim,
const int num_indices,
const int select_dim,
const int64_t* indices,
const T* x,
T* y,
CPUContext* ctx) {
int index;
for (int i = 0; i < outer_dim; ++i) {
for (int j = 0; j < num_indices; ++j) {
for (int j = 0; j < select_dim; ++j) {
index = indices[j];
index = index >= 0 ? index : index + axis_dim;
const T* offset_x = x + (i * axis_dim + index) * inner_dim;
......@@ -34,14 +34,14 @@ void _IndexSelectGrad(
const int outer_dim,
const int inner_dim,
const int axis_dim,
const int num_indices,
const int select_dim,
const int64_t* indices,
const T* dy,
T* dx,
CPUContext* ctx) {
int index;
for (int i = 0; i < outer_dim; ++i) {
for (int j = 0; j < num_indices; ++j) {
for (int j = 0; j < select_dim; ++j) {
index = indices[j];
index = index >= 0 ? index : index + axis_dim;
T* offset_dx = dx + (i * axis_dim + index) * inner_dim;
......@@ -55,18 +55,18 @@ void _IndexSelectGrad(
/* ------------------- Launcher Separator ------------------- */
#define DEFINE_KERNEL_LAUNCHER(name, T) \
template <> \
void name<T, CPUContext>( \
const int outer_dim, \
const int inner_dim, \
const int axis_dim, \
const int num_indices, \
const int64_t* indices, \
const T* x, \
T* y, \
CPUContext* ctx) { \
_##name(outer_dim, inner_dim, axis_dim, num_indices, indices, x, y, ctx); \
#define DEFINE_KERNEL_LAUNCHER(name, T) \
template <> \
void name<T, CPUContext>( \
const int outer_dim, \
const int inner_dim, \
const int axis_dim, \
const int select_dim, \
const int64_t* indices, \
const T* x, \
T* y, \
CPUContext* ctx) { \
_##name(outer_dim, inner_dim, axis_dim, select_dim, indices, x, y, ctx); \
}
DEFINE_KERNEL_LAUNCHER(IndexSelect, bool);
......
......@@ -14,17 +14,17 @@ __global__ void _IndexSelect(
const int nthreads,
const int inner_dim,
const int axis_dim,
const int num_indices,
const int select_dim,
const int64_t* indices,
const T* x,
T* y) {
CUDA_1D_KERNEL_LOOP(yi, nthreads) {
const int j = yi % inner_dim;
const int i = yi / inner_dim / num_indices;
const int i = yi / inner_dim / select_dim;
#if __CUDA_ARCH__ >= 350
int index = __ldg(indices + ((yi / inner_dim) % num_indices));
int index = __ldg(indices + ((yi / inner_dim) % select_dim));
#else
int index = indices[(yi / inner_dim) % num_indices];
int index = indices[(yi / inner_dim) % select_dim];
#endif
index = index >= 0 ? index : index + axis_dim;
y[yi] = x[(i * axis_dim + index) * inner_dim + j];
......@@ -36,7 +36,7 @@ __global__ void _IndexSelectGrad(
const int nthreads,
const int inner_dim,
const int axis_dim,
const int num_indices,
const int select_dim,
const int64_t* indices,
const T* dy,
T* dx) {
......@@ -44,8 +44,8 @@ __global__ void _IndexSelectGrad(
const int i = ti / inner_dim;
const int j = ti % inner_dim;
const int c = i * axis_dim * inner_dim + j;
const T* offset_dy = dy + i * num_indices * inner_dim + j;
for (int k = 0; k < num_indices; ++k) {
const T* offset_dy = dy + i * select_dim * inner_dim + j;
for (int k = 0; k < select_dim; ++k) {
#if __CUDA_ARCH__ >= 350
int index = __ldg(indices + k);
#else
......@@ -63,7 +63,7 @@ __global__ void _IndexSelectGrad<half>(
const int nthreads,
const int inner_dim,
const int axis_dim,
const int num_indices,
const int select_dim,
const int64_t* indices,
const half* dy,
half* dx) {
......@@ -72,8 +72,8 @@ __global__ void _IndexSelectGrad<half>(
const int i = ti / inner_dim;
const int j = ti % inner_dim;
const int c = i * axis_dim * inner_dim + j;
const half* offset_dy = dy + i * num_indices * inner_dim + j;
for (int k = 0; k < num_indices; ++k) {
const half* offset_dy = dy + i * select_dim * inner_dim + j;
for (int k = 0; k < select_dim; ++k) {
int index = __ldg(indices + j);
index = index >= 0 ? index : index + axis_dim;
index = c + index * inner_dim;
......@@ -93,7 +93,7 @@ void IndexSelectGrad<float16, CUDAContext>(
const int outer_dim,
const int inner_dim,
const int axis_dim,
const int num_indices,
const int select_dim,
const int64_t* indices,
const float16* dy,
float16* dx,
......@@ -107,50 +107,50 @@ void IndexSelectGrad<float16, CUDAContext>(
nthreads,
inner_dim,
axis_dim,
num_indices,
select_dim,
indices,
reinterpret_cast<const half*>(dy),
reinterpret_cast<half*>(dx));
} // IndexSelectGrad
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \
void IndexSelect<T, CUDAContext>( \
const int outer_dim, \
const int inner_dim, \
const int axis_dim, \
const int num_indices, \
const int64_t* indices, \
const T* x, \
T* y, \
CUDAContext* ctx) { \
const int nthreads = outer_dim * num_indices * inner_dim; \
_IndexSelect<<< \
CUDA_BLOCKS(nthreads), \
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>( \
nthreads, inner_dim, axis_dim, num_indices, indices, x, y); \
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \
void IndexSelect<T, CUDAContext>( \
const int outer_dim, \
const int inner_dim, \
const int axis_dim, \
const int select_dim, \
const int64_t* indices, \
const T* x, \
T* y, \
CUDAContext* ctx) { \
const int nthreads = outer_dim * select_dim * inner_dim; \
_IndexSelect<<< \
CUDA_BLOCKS(nthreads), \
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>( \
nthreads, inner_dim, axis_dim, select_dim, indices, x, y); \
}
#define DEFINE_GRAD_KERNEL_LAUNCHER(T) \
template <> \
void IndexSelectGrad<T, CUDAContext>( \
const int outer_dim, \
const int inner_dim, \
const int axis_dim, \
const int num_indices, \
const int64_t* indices, \
const T* dy, \
T* dx, \
CUDAContext* ctx) { \
const int nthreads = outer_dim * inner_dim; \
_IndexSelectGrad<<< \
CUDA_BLOCKS(nthreads), \
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>( \
nthreads, inner_dim, axis_dim, num_indices, indices, dy, dx); \
#define DEFINE_GRAD_KERNEL_LAUNCHER(T) \
template <> \
void IndexSelectGrad<T, CUDAContext>( \
const int outer_dim, \
const int inner_dim, \
const int axis_dim, \
const int select_dim, \
const int64_t* indices, \
const T* dy, \
T* dx, \
CUDAContext* ctx) { \
const int nthreads = outer_dim * inner_dim; \
_IndexSelectGrad<<< \
CUDA_BLOCKS(nthreads), \
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>( \
nthreads, inner_dim, axis_dim, select_dim, indices, dy, dx); \
}
DEFINE_KERNEL_LAUNCHER(bool);
......
......@@ -25,11 +25,11 @@ struct SmallestComp {
};
template <typename T, class Comp>
void _TopK(
void _TopSelect(
const int outer_dim,
const int inner_dim,
const int axis_dim,
const int top_k,
const int select_dim,
const int largest,
const T* x,
T* value,
......@@ -38,8 +38,8 @@ void _TopK(
for (int j = 0; j < inner_dim; ++j) {
auto* offset_x = x + (i * axis_dim * inner_dim + j);
vector<std::pair<T, int64_t>> head_data;
head_data.reserve(top_k);
for (int k = 0; k < top_k && k < axis_dim; ++k) {
head_data.reserve(select_dim);
for (int k = 0; k < select_dim && k < axis_dim; ++k) {
head_data.emplace_back(*offset_x, k);
offset_x += inner_dim;
}
......@@ -49,7 +49,7 @@ void _TopK(
Comp>
pq(Comp(), std::move(head_data));
if (largest > 0) {
for (int k = top_k; k < axis_dim; ++k) {
for (int k = select_dim; k < axis_dim; ++k) {
if (pq.top().first < *offset_x) {
pq.pop();
pq.emplace(*offset_x, k);
......@@ -57,7 +57,7 @@ void _TopK(
offset_x += inner_dim;
}
} else {
for (int k = top_k; k < axis_dim; ++k) {
for (int k = select_dim; k < axis_dim; ++k) {
if (pq.top().first > *offset_x) {
pq.pop();
pq.emplace(*offset_x, k);
......@@ -65,7 +65,8 @@ void _TopK(
offset_x += inner_dim;
}
}
auto y_offset = i * top_k * inner_dim + j + (top_k - 1) * inner_dim;
auto y_offset =
i * select_dim * inner_dim + j + (select_dim - 1) * inner_dim;
while (!pq.empty()) {
const auto& p = pq.top();
value[y_offset] = p.first;
......@@ -82,11 +83,11 @@ void _TopK(
/* ------------------- Launcher Separator ------------------- */
template <>
void TopK<float16, CPUContext>(
void TopSelect<float16, CPUContext>(
const int outer_dim,
const int inner_dim,
const int axis_dim,
const int top_k,
const int select_dim,
const int largest,
const float16* x,
float16* value,
......@@ -95,25 +96,39 @@ void TopK<float16, CPUContext>(
CPU_FP16_NOT_SUPPORTED;
}
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \
void TopK<T, CPUContext>( \
const int outer_dim, \
const int inner_dim, \
const int axis_dim, \
const int top_k, \
const int largest, \
const T* x, \
T* value, \
int64_t* index, \
CPUContext* ctx) { \
if (largest > 0) { \
_TopK<T, LargestComp<T>>( \
outer_dim, inner_dim, axis_dim, top_k, largest, x, value, index); \
} else { \
_TopK<T, SmallestComp<T>>( \
outer_dim, inner_dim, axis_dim, top_k, largest, x, value, index); \
} \
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \
void TopSelect<T, CPUContext>( \
const int outer_dim, \
const int inner_dim, \
const int axis_dim, \
const int select_dim, \
const int largest, \
const T* x, \
T* value, \
int64_t* index, \
CPUContext* ctx) { \
if (largest > 0) { \
_TopSelect<T, LargestComp<T>>( \
outer_dim, \
inner_dim, \
axis_dim, \
select_dim, \
largest, \
x, \
value, \
index); \
} else { \
_TopSelect<T, SmallestComp<T>>( \
outer_dim, \
inner_dim, \
axis_dim, \
select_dim, \
largest, \
x, \
value, \
index); \
} \
}
DEFINE_KERNEL_LAUNCHER(int8_t);
......
#ifdef USE_CUDA
#include "dragon/core/context_cuda.h"
#include "dragon/core/workspace.h"
#include "dragon/utils/device/common_cub.h"
#include "dragon/utils/device/common_thrust.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
namespace kernel {
namespace {
template <typename T>
struct LessFunctorWrapper {
LessFunctorWrapper(int64_t stride0, int64_t stride1)
: stride0_(stride0), stride1_(stride1) {}
inline __device__ bool operator()(
const thrust::tuple<int64_t, T>& lhs,
const thrust::tuple<int64_t, T>& rhs) const {
int64_t li = thrust::get<0>(lhs), ri = thrust::get<0>(rhs);
li = (li / stride0_) * stride1_ + li % stride1_;
ri = (ri / stride0_) * stride1_ + ri % stride1_;
if (li != ri) {
return li < ri;
} else {
return functor_(thrust::get<1>(lhs), thrust::get<1>(rhs));
}
}
int64_t stride0_, stride1_;
math::LessFunctor<T> functor_;
};
template <typename T>
struct GreaterFunctorWrapper {
GreaterFunctorWrapper(int64_t stride0, int64_t stride1)
: stride0_(stride0), stride1_(stride1) {}
inline __device__ bool operator()(
const thrust::tuple<int64_t, T>& lhs,
const thrust::tuple<int64_t, T>& rhs) const {
int64_t li = thrust::get<0>(lhs), ri = thrust::get<0>(rhs);
li = (li / stride0_) * stride1_ + li % stride1_;
ri = (ri / stride0_) * stride1_ + ri % stride1_;
if (li != ri) {
return li < ri;
} else {
return functor_(thrust::get<1>(lhs), thrust::get<1>(rhs));
}
}
int64_t stride0_, stride1_;
math::GreaterFunctor<T> functor_;
};
template <typename T, int ItemsPerThread>
__global__ void _SelectViaBlockSort(
const int rows,
const int cols,
const int inner_dim,
const int select_dim,
const bool largest,
const T init,
const T* x,
T* y,
int64_t* index) {
typedef cub::BlockRadixSort<T, CUDA_THREADS, ItemsPerThread, int64_t>
BlockSort;
__shared__ typename BlockSort::TempStorage storage;
CUDA_2D_KERNEL_LOOP1(i, rows) {
T keys[ItemsPerThread];
int64_t values[ItemsPerThread];
const int thread_offset = threadIdx.x * ItemsPerThread;
const int x_offset = (i / inner_dim) * cols * inner_dim + (i % inner_dim);
const int y_offset =
(i / inner_dim) * select_dim * inner_dim + (i % inner_dim);
#pragma unroll
for (int j = 0; j < ItemsPerThread; ++j) {
const int item_idx = thread_offset + j;
values[j] = item_idx < cols ? item_idx : cols - 1;
keys[j] = item_idx < cols ? x[x_offset + item_idx * inner_dim] : init;
}
__syncthreads();
if (largest) {
BlockSort(storage).SortDescending(keys, values);
} else {
BlockSort(storage).Sort(keys, values);
}
#pragma unroll
for (int j = 0; j < ItemsPerThread; ++j) {
if (thread_offset + j < select_dim) {
y[y_offset + (thread_offset + j) * inner_dim] = keys[j];
index[y_offset + (thread_offset + j) * inner_dim] = values[j];
}
}
}
}
template <typename T>
void _DeviceSort(
const int outer_dim,
const int inner_dim,
const int axis_dim,
const int largest,
T* key,
int64_t* value,
CUDAContext* ctx) {
const int rows = outer_dim * inner_dim, cols = axis_dim;
const int count = rows * cols;
// Fill value with global index
auto policy = thrust::cuda::par.on(ctx->cuda_stream());
thrust::sequence(policy, value, value + count);
if (rows == 1) {
// Sort a flatten array
if (largest > 0) {
thrust::sort_by_key(
policy, key, key + count, value, math::GreaterFunctor<T>());
} else {
thrust::sort_by_key(
policy, key, key + count, value, math::LessFunctor<T>());
}
} else {
// Sort a transposed array to handle multiple rows
auto iter = thrust::make_zip_iterator(thrust::make_tuple(value, key));
if (largest > 0) {
thrust::sort(
policy,
iter,
iter + count,
GreaterFunctorWrapper<T>(axis_dim * inner_dim, inner_dim));
} else {
thrust::sort(
policy,
iter,
iter + count,
LessFunctorWrapper<T>(axis_dim * inner_dim, inner_dim));
}
}
}
template <typename T>
__global__ void _SelectViaDeviceSort(
const int nthreads,
const int axis_dim,
const int inner_dim,
const int select_dim,
const T* sorted_keys,
const int64_t* sorted_values,
T* y,
int64_t* index) {
CUDA_1D_KERNEL_LOOP(yi, nthreads) {
const int xi =
((yi / inner_dim / select_dim) * inner_dim + yi % inner_dim) *
axis_dim +
(yi / inner_dim) % select_dim;
y[yi] = sorted_keys[xi];
index[yi] = (sorted_values[xi] / inner_dim) % axis_dim;
}
}
} // namespace
/* ------------------- Launcher Separator ------------------- */
#define PLACE_BLOCK_SORT_CASE(T, items_per_thread) \
_SelectViaBlockSort<T, items_per_thread> \
<<<CUDA_2D_BLOCKS(rows), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
rows, \
cols, \
inner_dim, \
select_dim, \
largest > 0, \
init, \
reinterpret_cast<const T*>(x), \
reinterpret_cast<T*>(value), \
index)
#define PLACE_BLOCK_SORT_CASES(T) \
if (cols <= CUDA_THREADS) { \
PLACE_BLOCK_SORT_CASE(T, 1); \
} else if (cols <= CUDA_THREADS * 2) { \
PLACE_BLOCK_SORT_CASE(T, 2); \
} else if (cols <= CUDA_THREADS * 4) { \
PLACE_BLOCK_SORT_CASE(T, 4); \
} else if (cols <= CUDA_THREADS * 8) { \
PLACE_BLOCK_SORT_CASE(T, 8); \
} else { \
LOG(FATAL) << "Too larger dimension (> " << CUDA_THREADS * 8 \
<< ") to launch the cuda kernel"; \
}
#define DEFINE_KERNEL_LAUNCHER(T1, T2, kLowest, kMax) \
template <> \
void TopSelect<T1, CUDAContext>( \
const int outer_dim, \
const int inner_dim, \
const int axis_dim, \
const int select_dim, \
const int largest, \
const T1* x, \
T1* value, \
int64_t* index, \
CUDAContext* ctx) { \
const int rows = outer_dim * inner_dim; \
const int cols = axis_dim; \
if (rows == 1 || cols > CUDA_THREADS * 8) { \
const int input_count = outer_dim * inner_dim * axis_dim; \
const int output_count = outer_dim * inner_dim * select_dim; \
auto data = ctx->workspace()->template data<CUDAContext>( \
{input_count * sizeof(T1), input_count * sizeof(int64_t)}); \
math::Copy(input_count, x, (T1*)data[0], ctx); \
_DeviceSort( \
outer_dim, \
inner_dim, \
axis_dim, \
largest, \
(T1*)data[0], \
(int64_t*)data[1], \
ctx); \
if (rows == 1) { \
math::Copy(output_count, (T1*)data[0], value, ctx); \
math::Copy(output_count, (int64_t*)data[1], index, ctx); \
} else { \
_SelectViaDeviceSort<<< \
CUDA_BLOCKS(output_count), \
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>( \
output_count, \
axis_dim, \
inner_dim, \
select_dim, \
(T1*)data[0], \
(int64_t*)data[1], \
value, \
index); \
} \
return; \
} \
T2 init = largest > 0 ? kLowest : kMax; \
PLACE_BLOCK_SORT_CASES(T2); \
}
DEFINE_KERNEL_LAUNCHER(
int8_t,
int8_t,
std::numeric_limits<int8_t>::lowest(),
std::numeric_limits<int8_t>::max());
DEFINE_KERNEL_LAUNCHER(
uint8_t,
uint8_t,
std::numeric_limits<uint8_t>::lowest(),
std::numeric_limits<uint8_t>::max());
DEFINE_KERNEL_LAUNCHER(
int,
int,
std::numeric_limits<int>::lowest(),
std::numeric_limits<int>::max());
DEFINE_KERNEL_LAUNCHER(
int64_t,
int64_t,
std::numeric_limits<int64_t>::lowest(),
std::numeric_limits<int64_t>::max());
DEFINE_KERNEL_LAUNCHER(
float16,
half,
cub::Traits<half>::Lowest(),
cub::Traits<half>::Max());
DEFINE_KERNEL_LAUNCHER(
float,
float,
std::numeric_limits<float>::lowest(),
std::numeric_limits<float>::max());
DEFINE_KERNEL_LAUNCHER(
double,
double,
std::numeric_limits<double>::lowest(),
std::numeric_limits<double>::max());
#undef PLACE_BLOCK_SORT_CASE
#undef PLACE_BLOCK_SORT_CASES
#undef DEFINE_KERNEL_LAUNCHER
} // namespace kernel
} // namespace dragon
#endif // USE_CUDA
......@@ -40,18 +40,6 @@ __global__ void _ComputeCounts(
/* ------------------- Launcher Separator ------------------- */
template <>
void Unique<float16, CUDAContext>(
const int dim,
const float16* x,
float16* y,
int64_t* inverse_index,
int64_t* counts,
int* num,
CUDAContext* ctx) {
LOG(FATAL) << "FP16 is unsupported for CUDAContext.";
}
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \
void Unique<T, CUDAContext>( \
......@@ -67,8 +55,10 @@ void Unique<float16, CUDAContext>(
thrust::device_vector<int> order1(dim), order2(dim); \
thrust::sequence(policy, order1.begin(), order1.end()); \
thrust::sequence(policy, order2.begin(), order2.end()); \
thrust::sort_by_key(policy, y, y + dim, order1.begin()); \
auto last = thrust::unique_by_key(policy, y, y + dim, order2.begin()); \
thrust::sort_by_key( \
policy, y, y + dim, order1.begin(), math::LessFunctor<T>()); \
auto last = thrust::unique_by_key( \
policy, y, y + dim, order2.begin(), math::EqualFunctor<T>()); \
int n = num[0] = last.first - y; \
if (inverse_index) { \
_RemapInverse<<<CUDA_BLOCKS(n), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
......@@ -84,6 +74,7 @@ DEFINE_KERNEL_LAUNCHER(int8_t);
DEFINE_KERNEL_LAUNCHER(uint8_t);
DEFINE_KERNEL_LAUNCHER(int);
DEFINE_KERNEL_LAUNCHER(int64_t);
DEFINE_KERNEL_LAUNCHER(float16);
DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER
......
#include "dragon/operators/array/sort_op.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
template <class Context>
template <typename T>
void SortOp<Context>::DoRunWithType() {
auto &X = Input(0), *Y_value = Output(0), *Y_index = Output(1);
CANONICALIZE_AXIS_WITH_TENSOR(X);
axis = (axis == INT_MAX ? X.ndim() - 1 : axis);
kernel::TopSelect(
X.count(0, axis),
X.count(axis + 1),
X.dim(axis),
X.dim(axis),
descending_ > 0 ? 1 : 0,
X.template data<T, Context>(),
Y_value->ReshapeLike(X)->template mutable_data<T, Context>(),
Y_index->ReshapeLike(X)->template mutable_data<int64_t, Context>(),
ctx());
}
template <class Context>
void SortOp<Context>::RunOnDevice() {
DispatchHelper<NumericalTensorTypes>::Call(this, Input(0));
}
DEPLOY_CPU_OPERATOR(Sort);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(Sort);
#endif
OPERATOR_SCHEMA(Sort)
/* X */
.NumInputs(1)
/* Value, Index */
.NumOutputs(2);
NO_GRADIENT(Sort);
} // namespace dragon
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ARRAY_SORT_OP_H_
#define DRAGON_OPERATORS_ARRAY_SORT_OP_H_
#include "dragon/core/operator.h"
namespace dragon {
template <class Context>
class SortOp final : public Operator<Context> {
public:
SortOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
descending_(OP_SINGLE_ARG(int64_t, "descending", 0)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T>
void DoRunWithType();
protected:
int64_t descending_;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_ARRAY_SORT_OP_H_
......@@ -16,17 +16,16 @@ void TopKOp<Context>::DoRunWithType() {
auto Y_dims = X.dims();
Y_dims[axis] = k_;
CPUContext cpu_ctx;
kernel::TopK(
kernel::TopSelect(
X.count(0, axis),
X.count(axis + 1),
X.dim(axis),
k_,
largest_,
X.template data<T, CPUContext>(),
Y_value->Reshape(Y_dims)->template mutable_data<T, CPUContext>(),
Y_index->Reshape(Y_dims)->template mutable_data<int64_t, CPUContext>(),
&cpu_ctx);
X.template data<T, Context>(),
Y_value->Reshape(Y_dims)->template mutable_data<T, Context>(),
Y_index->Reshape(Y_dims)->template mutable_data<int64_t, Context>(),
ctx());
}
template <class Context>
......
......@@ -105,12 +105,10 @@ void L1LossGradientOp<Context>::DoRunWithType() {
// Gradient w.r.t. the second input
if (OutputSize() > 1 && Output(1)->has_name()) {
Output(1)->ReshapeLike(Input(1));
math::Scale(
math::Neg(
dX->count(),
-1.f,
dx,
Output(1)->template mutable_data<T, Context>(),
Output(1)->ReshapeLike(Input(1))->template mutable_data<T, Context>(),
ctx());
}
}
......
......@@ -103,12 +103,10 @@ void L2LossGradientOp<Context>::DoRunWithType() {
// Gradient w.r.t. the second input
if (OutputSize() > 1 && Output(1)->has_name()) {
Output(1)->ReshapeLike(Input(1));
math::Scale(
math::Neg(
dX->count(),
-1.f,
dx,
Output(1)->template mutable_data<T, Context>(),
Output(1)->ReshapeLike(Input(1))->template mutable_data<T, Context>(),
ctx());
}
}
......
......@@ -105,12 +105,10 @@ void SmoothL1LossGradientOp<Context>::DoRunWithType() {
// Gradient w.r.t. the second input
if (OutputSize() > 1 && Output(1)->has_name()) {
Output(1)->ReshapeLike(Input(1));
math::Scale(
math::Neg(
dX->count(),
-1.f,
dx,
Output(1)->template mutable_data<T, Context>(),
Output(1)->ReshapeLike(Input(1))->template mutable_data<T, Context>(),
ctx());
}
}
......
......@@ -179,9 +179,8 @@ void DivGradientOp<Context>::DoRunWithType() {
B.template data<T, Context>(),
dB->template mutable_data<T, Context>(),
ctx());
math::Scale(
math::Neg(
B_ref.count(),
-1.f,
dB->template data<T, Context>(),
dB->template mutable_data<T, Context>(),
ctx());
......
......@@ -7,9 +7,8 @@ template <class Context>
template <typename T>
void NegOp<Context>::DoRunWithType() {
auto &X = Input(0), *Y = Output(0);
math::Scale(
math::Neg(
X.count(),
-1.f,
X.template data<T, Context>(),
Y->ReshapeLike(X)->template mutable_data<T, Context>(),
ctx());
......@@ -25,9 +24,8 @@ template <class Context>
template <typename T>
void NegGradientOp<Context>::DoRunWithType() {
auto &dY = Input(0), *dX = Output(0);
math::Scale(
math::Neg(
dY.count(),
-1.f,
dY.template data<T, Context>(),
dX->ReshapeLike(dY)->template mutable_data<T, Context>(),
ctx());
......
......@@ -72,9 +72,8 @@ void SubGradientOp<Context>::DoRunWithType() {
if (dB->has_name()) {
if (B_broadcast_axes.empty()) {
math::Scale(
math::Neg(
B.count(),
-1.f,
dY.template data<T, Context>(),
dB->ReshapeLike(B)->template mutable_data<T, Context>(),
ctx());
......
......@@ -72,6 +72,7 @@ from dragon.core.ops.array_ops import repeat
from dragon.core.ops.array_ops import reshape
from dragon.core.ops.array_ops import shape
from dragon.core.ops.array_ops import slice
from dragon.core.ops.array_ops import sort
from dragon.core.ops.array_ops import split
from dragon.core.ops.array_ops import squeeze
from dragon.core.ops.array_ops import stack
......
......@@ -857,6 +857,20 @@ def softmax_loss_spec(args, inputs, outputs):
return outputs
@register('Sort')
def sort_spec(args, inputs, outputs):
_ = locals()
outputs[0].dtype = inputs[0].dtype
outputs[1].dtype = 'int64'
try:
out_shape = list(inputs[0].shape[:])
outputs[0].shape = out_shape[:]
outputs[1].shape = out_shape[:]
except (TypeError, IndexError):
pass
return outputs
@register('SpaceToDepth')
def space_to_depth_spec(args, inputs, outputs):
outputs[0].dtype = inputs[0].dtype
......@@ -1029,8 +1043,8 @@ def top_k_spec(args, inputs, outputs):
try:
out_shape = list(inputs[0].shape[:])
out_shape[axis] = k
outputs[0].shape = out_shape
outputs[1].shape = out_shape
outputs[0].shape = out_shape[:]
outputs[1].shape = out_shape[:]
except (TypeError, IndexError):
pass
return outputs
......
......@@ -1273,6 +1273,54 @@ def slice(inputs, starts, sizes, **kwargs):
@OpSchema.num_inputs(1)
def sort(inputs, axis=-1, descending=False, **kwargs):
"""Return the sorted elements along the given axis.
By default, the last axis is chosen:
```python
x = dragon.constant([[1, 2, 3], [3, 2, 1]])
value1, index1 = dragon.sort(x)
value2, index2 = dragon.sort(x, axis=1) # Equivalent
```
Sort in the inverse order if ``descending`` is **True**:
```python
x = dragon.constant([1, 2, 3])
_, index1 = dragon.sort(-x)
_, index2 = dragon.sort(x, descending=True) # Equivalent
```
Parameters
----------
inputs : dragon.Tensor
The input tensor.
axis : int, optional, default=-1
The axis to sort elements.
descending : bool, optional, default=False
Sort in the descending order or not.
Returns
-------
Sequence[dragon.vm.torch.Tensor]
The value and index tensor.
"""
args = parse_args(locals())
op_lib = array_ops_lib.Sort
if context.executing_eagerly():
return op_lib \
.instantiate(
axis=axis,
descending=descending,
).apply([inputs])
else:
args['num_outputs'] = 2
return op_lib.blend(**args)
@OpSchema.num_inputs(1)
def split(
inputs,
num_or_size_splits,
......@@ -1548,10 +1596,10 @@ def transpose(inputs, perm=None, **kwargs):
@OpSchema.num_inputs(1)
def top_k(inputs, k=1, axis=None, largest=True, sorted=True, **kwargs):
def top_k(inputs, k=1, axis=-1, largest=True, sorted=True, **kwargs):
"""Return the top-K largest or smallest elements along the given axis.
If ``axis`` is not given, the last axis is chosen:
By default, the last axis is chosen:
```python
x = dragon.constant([[1, 2, 3], [3, 2, 1]])
......@@ -1562,9 +1610,9 @@ def top_k(inputs, k=1, axis=None, largest=True, sorted=True, **kwargs):
If ``largest`` is **False**, the k smallest elements are returned:
```python
x = dragon.constant([[1, 2, 3], [3, 2, 1]])
_, index1 = dragon.math.top_k(x, largest=False)
_, index2 = dragon.math.top_k(-x, largest=True) # Equivalent
x = dragon.constant([1, 2, 3])
_, index1 = dragon.math.top_k(-x)
_, index2 = dragon.math.top_k(x, largest=False) # Equivalent
```
Parameters
......@@ -1573,11 +1621,11 @@ def top_k(inputs, k=1, axis=None, largest=True, sorted=True, **kwargs):
The input tensor.
k : int, optional, default=1
The number of top elements to select.
axis : int, optional
The axis to reduce.
axis : int, optional, default=-1
The axis to select elements.
largest : bool, optional, default=True
Return largest or smallest elements.
sorted : bool, optional
sorted : bool, optional, default=True
Whether to return in the sorted order.
Returns
......
......@@ -551,6 +551,25 @@ class Shape(Operator):
return self.dispatch(inputs, [self.alloc()], no_grad=True)
class Sort(Operator):
def __init__(self, key, dev, **kwargs):
super(Sort, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', -1)
self.descending = kwargs.get('descending', False)
def attributes(self):
return {
'op_type': 'Sort',
'arguments': {
'axis': self.axis,
'descending': self.descending,
}
}
def forward(self, inputs):
return self.dispatch(inputs, [self.alloc(), self.alloc()], no_grad=True)
class Split(Operator):
def __init__(self, key, dev, **kwargs):
super(Split, self).__init__(key, dev, **kwargs)
......@@ -666,7 +685,7 @@ class TopK(Operator):
def __init__(self, key, dev, **kwargs):
super(TopK, self).__init__(key, dev, **kwargs)
self.k = kwargs.get('k', 1)
self.axis = kwargs.get('axis', None)
self.axis = kwargs.get('axis', -1)
self.largest = kwargs.get('largest', True)
self.sorted = kwargs.get('sorted', True)
......
......@@ -110,54 +110,28 @@ inline float to<float, float16>(float16 val) {
#ifdef USE_CUDA
template <>
inline float16 to<float16, half>(half val) {
return float16{__half_raw(val).x};
}
template <>
inline half to<half, float>(float val) {
#if CUDA_VERSION_MIN(9, 0, 0)
__half_raw fp16_raw;
fp16_raw.x = cast::to<float16>(val).x;
return half(fp16_raw);
#else
half fp16;
fp16.x = dragon_cast<float16, float>(val).x;
return fp16;
#endif
return __float2half(val);
}
template <>
inline half2 to<half2, float>(float val) {
#if CUDA_VERSION_MIN(9, 0, 0)
half fp16 = cast::to<half>(val);
return half2(fp16, fp16);
#else
half2 fp32;
fp32.x = cast::to<float32>(val).x;
return fp32;
#endif
inline half to<half, float16>(float16 val) {
return __half_raw{val.x};
}
template <>
inline half2 to<half2, float16>(float16 val) {
#if CUDA_VERSION_MIN(9, 0, 0)
__half_raw fp16_raw;
fp16_raw.x = val.x;
return half2(half(fp16_raw), half(fp16_raw));
#else
half2 fp32;
fp32.x = dragon_cast<float32, float16>(val).x;
return fp32;
#endif
inline half2 to<half2, float>(float val) {
return __float2half2_rn(val);
}
template <>
inline half to<half, float16>(float16 val) {
#if CUDA_VERSION_MIN(9, 0, 0)
__half_raw fp16_raw;
fp16_raw.x = val.x;
return fp16_raw;
#else
half fp16;
fp16.x = val.x;
return fp16;
#endif
inline half2 to<half2, float16>(float16 val) {
return half2(__half2_raw{val.x, val.x});
}
#endif // USE_CUDA
......
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_UTILS_DEVICE_COMMON_CUB_H_
#define DRAGON_UTILS_DEVICE_COMMON_CUB_H_
#ifdef USE_CUDA
#include <cub/block/block_radix_sort.cuh>
#include <cub/block/block_reduce.cuh>
#include <cub/device/device_reduce.cuh>
#include <cub/device/device_select.cuh>
#include <cub/iterator/counting_input_iterator.cuh>
#include "dragon/utils/device/common_cuda.h"
namespace cub {
struct SumHalf {
inline __device__ half operator()(const half& a, const half& b) const {
#if __CUDA_ARCH__ >= 530
return __hadd(a, b);
#else
return __float2half(__half2float(a) + __half2float(b));
#endif
}
};
struct MinHalf {
inline __device__ half operator()(const half& a, const half& b) const {
#if __CUDA_ARCH__ >= 530
return __hlt(a, b) ? a : b;
#else
return __half2float(a) < __half2float(b) ? a : b;
#endif
}
};
struct MaxHalf {
inline __device__ half operator()(const half& a, const half& b) const {
#if __CUDA_ARCH__ >= 530
return __hgt(a, b) ? a : b;
#else
return __half2float(a) > __half2float(b) ? a : b;
#endif
}
};
} // namespace cub
namespace dragon {
template <typename T>
......
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_UTILS_DEVICE_COMMON_NCCL_H_
#define DRAGON_UTILS_DEVICE_COMMON_NCCL_H_
......
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_UTILS_DEVICE_COMMON_THRUST_H_
#define DRAGON_UTILS_DEVICE_COMMON_THRUST_H_
......
......@@ -50,6 +50,31 @@ DEFINE_COPY_FUNC(float);
DEFINE_COPY_FUNC(double);
#undef DEFINE_COPY_FUNC
#define DEFINE_COPY_FUNC(T) \
template <> \
DRAGON_API void Copy<T, CPUContext>( \
const int n, \
const int incx, \
const int incy, \
const T* x, \
T* y, \
CPUContext* ctx) { \
if (x != y && n > 0) { \
EigenStridedVectorMap<T>(y, 1, n, EigenInnerStride(incy)) = \
ConstEigenStridedVectorMap<T>(x, 1, n, EigenInnerStride(incx)); \
} \
}
DEFINE_COPY_FUNC(bool);
DEFINE_COPY_FUNC(int8_t);
DEFINE_COPY_FUNC(uint8_t);
DEFINE_COPY_FUNC(int);
DEFINE_COPY_FUNC(int64_t);
DEFINE_COPY_FUNC(float16);
DEFINE_COPY_FUNC(float);
DEFINE_COPY_FUNC(double);
#undef DEFINE_COPY_FUNC
template <>
DRAGON_API void Axpy<float16, CPUContext>(
const int n,
......
......@@ -18,6 +18,14 @@ __global__ void _Scale(const int n, const T alpha, const T* x, T* y) {
}
template <typename T>
__global__ void
_Copy(const int n, const int incx, const int incy, const T* x, T* y) {
CUDA_1D_KERNEL_LOOP(i, n) {
y[i * incy] = x[i * incx];
}
}
template <typename T>
__global__ void _Axpy(const int n, const T alpha, const T* x, T* y) {
CUDA_1D_KERNEL_LOOP(i, n) {
y[i] += (alpha * x[i]);
......@@ -200,6 +208,47 @@ DEFINE_COPY_FUNC(float);
DEFINE_COPY_FUNC(double);
#undef DEFINE_COPY_FUNC
#define DEFINE_COPY_FUNC(T) \
template <> \
DRAGON_API void Copy<T, CUDAContext>( \
const int n, \
const int incx, \
const int incy, \
const T* x, \
T* y, \
CUDAContext* ctx) { \
if (x != y && n > 0) { \
_Copy<<<CUDA_BLOCKS(n), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
n, incx, incy, x, y); \
} \
}
DEFINE_COPY_FUNC(bool);
DEFINE_COPY_FUNC(int8_t);
DEFINE_COPY_FUNC(uint8_t);
DEFINE_COPY_FUNC(int);
DEFINE_COPY_FUNC(int64_t);
DEFINE_COPY_FUNC(float16);
#undef DEFINE_COPY_FUNC
#define DEFINE_COPY_FUNC(T, cublas_func) \
template <> \
DRAGON_API void Copy<T, CUDAContext>( \
const int n, \
const int incx, \
const int incy, \
const T* x, \
T* y, \
CUDAContext* ctx) { \
if (x != y && n > 0) { \
cublas_func(ctx->cublas_handle(), n, x, incx, y, incy); \
} \
}
DEFINE_COPY_FUNC(float, cublasScopy);
DEFINE_COPY_FUNC(double, cublasDcopy);
#undef DEFINE_COPY_FUNC
#define DEFINE_AXPY_FUNC(T) \
template <> \
DRAGON_API void Axpy<T, CUDAContext>( \
......
......@@ -32,6 +32,15 @@ template <typename T, class Context>
DRAGON_API void Copy(const int n, const T* x, T* y, Context* ctx);
template <typename T, class Context>
DRAGON_API void Copy(
const int n,
const int incx,
const int incy,
const T* x,
T* y,
Context* ctx);
template <typename T, class Context>
DRAGON_API void
Axpy(const int n, const float alpha, const T* x, T* y, Context* ctx);
......
......@@ -4,6 +4,7 @@
#include "dragon/utils/math/blas.h"
#include "dragon/utils/math/broadcast.h"
#include "dragon/utils/math/elementwise.h"
#include "dragon/utils/math/functional.h"
#include "dragon/utils/math/utils.h"
namespace dragon {
......@@ -12,116 +13,6 @@ namespace math {
namespace {
/*!
* Op Wrappers
*/
#define DEFINE_BINARY_OPERATOR(name, TOut, expr) \
template <typename T> \
struct name##Op { \
inline __device__ TOut operator()(const T& a, const T& b) const { \
return a expr b; \
} \
}
DEFINE_BINARY_OPERATOR(Add, T, +);
DEFINE_BINARY_OPERATOR(Sub, T, -);
DEFINE_BINARY_OPERATOR(Mul, T, *);
DEFINE_BINARY_OPERATOR(Div, T, /);
DEFINE_BINARY_OPERATOR(Equal, bool, ==);
DEFINE_BINARY_OPERATOR(NotEqual, bool, !=);
DEFINE_BINARY_OPERATOR(Less, bool, <);
DEFINE_BINARY_OPERATOR(LessEqual, bool, <=);
DEFINE_BINARY_OPERATOR(Greater, bool, >);
DEFINE_BINARY_OPERATOR(GreaterEqual, bool, >=);
#undef DEFINE_BINARY_OPERATOR
#define DEFINE_BINARY_OPERATOR(name, func) \
template <typename T> \
struct name##Op { \
inline __device__ T operator()(const T& a, const T& b) const { \
return func(a, b); \
} \
}
DEFINE_BINARY_OPERATOR(Pow, pow);
DEFINE_BINARY_OPERATOR(Min, min);
DEFINE_BINARY_OPERATOR(Max, max);
#undef DEFINE_BINARY_OPERATOR
#define DEFINE_BINARY_OPERATOR(name, TOut, func) \
template <typename T> \
struct name##Op { \
inline __device__ TOut operator()(const T& a, const T& b) const { \
return func(a, b); \
} \
}
#if __CUDA_ARCH__ >= 530
DEFINE_BINARY_OPERATOR(AddHalf, T, __hadd);
DEFINE_BINARY_OPERATOR(SubHalf, T, __hsub);
DEFINE_BINARY_OPERATOR(MulHalf, T, __hmul);
DEFINE_BINARY_OPERATOR(DivHalf, T, __hdiv);
DEFINE_BINARY_OPERATOR(EqualHalf, T, __heq);
DEFINE_BINARY_OPERATOR(NotEqualHalf, T, __hne);
DEFINE_BINARY_OPERATOR(LessHalf, T, __hlt);
DEFINE_BINARY_OPERATOR(LessEqualHalf, T, __hle);
DEFINE_BINARY_OPERATOR(GreaterHalf, T, __hgt);
DEFINE_BINARY_OPERATOR(GreaterEqualHalf, T, __hge);
#endif
#undef DEFINE_BINARY_OPERATOR
#define DEFINE_BINARY_OPERATOR(name, expr) \
template <typename T> \
struct name##Op { \
inline __device__ T operator()(const T& a, const T& b) const { \
return __float2half(__half2float(a) expr __half2float(b)); \
} \
}
#if __CUDA_ARCH__ < 530
DEFINE_BINARY_OPERATOR(AddHalf, +);
DEFINE_BINARY_OPERATOR(SubHalf, -);
DEFINE_BINARY_OPERATOR(MulHalf, *);
DEFINE_BINARY_OPERATOR(DivHalf, /);
#endif
#undef DEFINE_BINARY_OPERATOR
#define DEFINE_BINARY_OPERATOR(name, func) \
template <typename T> \
struct name##Op { \
inline __device__ T operator()(const T& a, const T& b) const { \
return __float2half(func(__half2float(a), __half2float(b))); \
} \
}
DEFINE_BINARY_OPERATOR(PowHalf, pow);
DEFINE_BINARY_OPERATOR(MinHalf, min);
DEFINE_BINARY_OPERATOR(MaxHalf, max);
#undef DEFINE_BINARY_OPERATOR
#define DEFINE_BINARY_OPERATOR(name, expr) \
template <typename T> \
struct name##Op { \
inline __device__ bool operator()(const T& a, const T& b) const { \
return __half2float(a) expr __half2float(b); \
} \
}
#if __CUDA_ARCH__ < 530
DEFINE_BINARY_OPERATOR(EqualHalf, ==);
DEFINE_BINARY_OPERATOR(NotEqualHalf, !=);
DEFINE_BINARY_OPERATOR(LessHalf, <);
DEFINE_BINARY_OPERATOR(LessEqualHalf, <=);
DEFINE_BINARY_OPERATOR(GreaterHalf, >);
DEFINE_BINARY_OPERATOR(GreaterEqualHalf, >=);
#endif
#undef DEFINE_BINARY_OPERATOR
/*!
* Op Kernels
*/
template <typename T>
__global__ void _RowwiseSet(const int n, const int cols, const T* x, T* y) {
CUDA_1D_KERNEL_LOOP(i, n) {
......@@ -349,7 +240,7 @@ DEFINE_SET_FUNC(float16, half);
DEFINE_SET_FUNC(double, double);
#undef DEFINE_SET_FUNC
#define DEFINE_BINARY_FUNC(name, TIn, TOut, Op) \
#define DEFINE_BINARY_FUNC(name, TIn, TOut, Functor) \
template <> \
DRAGON_API void name<TIn, CUDAContext>( \
const int a_ndim, \
......@@ -376,13 +267,13 @@ DEFINE_SET_FUNC(double, double);
A_dims, B_dims, &rows, &cols, &broadcast_1st)) { \
const auto nthreads = rows * cols; \
if (broadcast_1st > 0) { \
_RowwiseBinaryFunc<TIn, TOut, Op<TIn>, true> \
_RowwiseBinaryFunc<TIn, TOut, Functor<TIn>, true> \
<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
nthreads, cols, Op<TIn>(), a, b, y); \
nthreads, cols, Functor<TIn>(), a, b, y); \
} else { \
_RowwiseBinaryFunc<TIn, TOut, Op<TIn>, false> \
_RowwiseBinaryFunc<TIn, TOut, Functor<TIn>, false> \
<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
nthreads, cols, Op<TIn>(), a, b, y); \
nthreads, cols, Functor<TIn>(), a, b, y); \
} \
return; \
} \
......@@ -390,13 +281,13 @@ DEFINE_SET_FUNC(double, double);
A_dims, B_dims, &rows, &cols, &broadcast_1st)) { \
const auto nthreads = rows * cols; \
if (broadcast_1st > 0) { \
_ColwiseBinaryFunc<TIn, TOut, Op<TIn>, true> \
_ColwiseBinaryFunc<TIn, TOut, Functor<TIn>, true> \
<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
nthreads, cols, Op<TIn>(), a, b, y); \
nthreads, cols, Functor<TIn>(), a, b, y); \
} else { \
_ColwiseBinaryFunc<TIn, TOut, Op<TIn>, false> \
_ColwiseBinaryFunc<TIn, TOut, Functor<TIn>, false> \
<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
nthreads, cols, Op<TIn>(), a, b, y); \
nthreads, cols, Functor<TIn>(), a, b, y); \
} \
return; \
} \
......@@ -412,93 +303,93 @@ DEFINE_SET_FUNC(double, double);
b_strides.data[i] = B_broadcast_strides[i]; \
y_dims.data[i] = Y_dims[i]; \
} \
_BroadcastBinaryFunc<TIn, TOut, Op<TIn>, CUDA_TENSOR_MAX_DIMS> \
_BroadcastBinaryFunc<TIn, TOut, Functor<TIn>, CUDA_TENSOR_MAX_DIMS> \
<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
nthreads, \
Y_dims.size(), \
a_strides, \
b_strides, \
y_dims, \
Op<TIn>(), \
Functor<TIn>(), \
a, \
b, \
y); \
}
DEFINE_BINARY_FUNC(Add, int8_t, int8_t, AddOp);
DEFINE_BINARY_FUNC(Add, uint8_t, uint8_t, AddOp);
DEFINE_BINARY_FUNC(Add, int, int, AddOp);
DEFINE_BINARY_FUNC(Add, int64_t, int64_t, AddOp);
DEFINE_BINARY_FUNC(Add, float, float, AddOp);
DEFINE_BINARY_FUNC(Add, double, double, AddOp);
DEFINE_BINARY_FUNC(Sub, int8_t, int8_t, SubOp);
DEFINE_BINARY_FUNC(Sub, uint8_t, uint8_t, SubOp);
DEFINE_BINARY_FUNC(Sub, int, int, SubOp);
DEFINE_BINARY_FUNC(Sub, int64_t, int64_t, SubOp);
DEFINE_BINARY_FUNC(Sub, float, float, SubOp);
DEFINE_BINARY_FUNC(Sub, double, double, SubOp);
DEFINE_BINARY_FUNC(Mul, int8_t, int8_t, MulOp);
DEFINE_BINARY_FUNC(Mul, uint8_t, uint8_t, MulOp);
DEFINE_BINARY_FUNC(Mul, int, int, MulOp);
DEFINE_BINARY_FUNC(Mul, int64_t, int64_t, MulOp);
DEFINE_BINARY_FUNC(Mul, float, float, MulOp);
DEFINE_BINARY_FUNC(Mul, double, double, MulOp);
DEFINE_BINARY_FUNC(Div, int8_t, int8_t, DivOp);
DEFINE_BINARY_FUNC(Div, uint8_t, uint8_t, DivOp);
DEFINE_BINARY_FUNC(Div, int, int, DivOp);
DEFINE_BINARY_FUNC(Div, int64_t, int64_t, DivOp);
DEFINE_BINARY_FUNC(Div, float, float, DivOp);
DEFINE_BINARY_FUNC(Div, double, double, DivOp);
DEFINE_BINARY_FUNC(Pow, float, float, PowOp);
DEFINE_BINARY_FUNC(Pow, double, double, PowOp);
DEFINE_BINARY_FUNC(Minimum, int8_t, int8_t, MinOp);
DEFINE_BINARY_FUNC(Minimum, uint8_t, uint8_t, MinOp);
DEFINE_BINARY_FUNC(Minimum, int, int, MinOp);
DEFINE_BINARY_FUNC(Minimum, int64_t, int64_t, MinOp);
DEFINE_BINARY_FUNC(Minimum, float, float, MinOp);
DEFINE_BINARY_FUNC(Minimum, double, double, MinOp);
DEFINE_BINARY_FUNC(Maximum, int8_t, int8_t, MaxOp);
DEFINE_BINARY_FUNC(Maximum, uint8_t, uint8_t, MaxOp);
DEFINE_BINARY_FUNC(Maximum, int, int, MaxOp);
DEFINE_BINARY_FUNC(Maximum, int64_t, int64_t, MaxOp);
DEFINE_BINARY_FUNC(Maximum, float, float, MaxOp);
DEFINE_BINARY_FUNC(Maximum, double, double, MaxOp);
DEFINE_BINARY_FUNC(Equal, int8_t, bool, EqualOp);
DEFINE_BINARY_FUNC(Equal, uint8_t, bool, EqualOp);
DEFINE_BINARY_FUNC(Equal, int, bool, EqualOp);
DEFINE_BINARY_FUNC(Equal, int64_t, bool, EqualOp);
DEFINE_BINARY_FUNC(Equal, float, bool, EqualOp);
DEFINE_BINARY_FUNC(Equal, double, bool, EqualOp);
DEFINE_BINARY_FUNC(NotEqual, int8_t, bool, NotEqualOp);
DEFINE_BINARY_FUNC(NotEqual, uint8_t, bool, NotEqualOp);
DEFINE_BINARY_FUNC(NotEqual, int, bool, NotEqualOp);
DEFINE_BINARY_FUNC(NotEqual, int64_t, bool, NotEqualOp);
DEFINE_BINARY_FUNC(NotEqual, float, bool, NotEqualOp);
DEFINE_BINARY_FUNC(NotEqual, double, bool, NotEqualOp);
DEFINE_BINARY_FUNC(Less, int8_t, bool, LessOp);
DEFINE_BINARY_FUNC(Less, uint8_t, bool, LessOp);
DEFINE_BINARY_FUNC(Less, int, bool, LessOp);
DEFINE_BINARY_FUNC(Less, int64_t, bool, LessOp);
DEFINE_BINARY_FUNC(Less, float, bool, LessOp);
DEFINE_BINARY_FUNC(Less, double, bool, LessOp);
DEFINE_BINARY_FUNC(LessEqual, int8_t, bool, LessEqualOp);
DEFINE_BINARY_FUNC(LessEqual, uint8_t, bool, LessEqualOp);
DEFINE_BINARY_FUNC(LessEqual, int, bool, LessEqualOp);
DEFINE_BINARY_FUNC(LessEqual, int64_t, bool, LessEqualOp);
DEFINE_BINARY_FUNC(LessEqual, float, bool, LessEqualOp);
DEFINE_BINARY_FUNC(LessEqual, double, bool, LessEqualOp);
DEFINE_BINARY_FUNC(Greater, int8_t, bool, GreaterOp);
DEFINE_BINARY_FUNC(Greater, uint8_t, bool, GreaterOp);
DEFINE_BINARY_FUNC(Greater, int, bool, GreaterOp);
DEFINE_BINARY_FUNC(Greater, int64_t, bool, GreaterOp);
DEFINE_BINARY_FUNC(Greater, float, bool, GreaterOp);
DEFINE_BINARY_FUNC(Greater, double, bool, GreaterOp);
DEFINE_BINARY_FUNC(GreaterEqual, int8_t, bool, GreaterEqualOp);
DEFINE_BINARY_FUNC(GreaterEqual, uint8_t, bool, GreaterEqualOp);
DEFINE_BINARY_FUNC(GreaterEqual, int, bool, GreaterEqualOp);
DEFINE_BINARY_FUNC(GreaterEqual, int64_t, bool, GreaterEqualOp);
DEFINE_BINARY_FUNC(GreaterEqual, float, bool, GreaterEqualOp);
DEFINE_BINARY_FUNC(GreaterEqual, double, bool, GreaterEqualOp);
DEFINE_BINARY_FUNC(Add, int8_t, int8_t, math::PlusFunctor);
DEFINE_BINARY_FUNC(Add, uint8_t, uint8_t, math::PlusFunctor);
DEFINE_BINARY_FUNC(Add, int, int, math::PlusFunctor);
DEFINE_BINARY_FUNC(Add, int64_t, int64_t, math::PlusFunctor);
DEFINE_BINARY_FUNC(Add, float, float, math::PlusFunctor);
DEFINE_BINARY_FUNC(Add, double, double, math::PlusFunctor);
DEFINE_BINARY_FUNC(Sub, int8_t, int8_t, math::MinusFunctor);
DEFINE_BINARY_FUNC(Sub, uint8_t, uint8_t, math::MinusFunctor);
DEFINE_BINARY_FUNC(Sub, int, int, math::MinusFunctor);
DEFINE_BINARY_FUNC(Sub, int64_t, int64_t, math::MinusFunctor);
DEFINE_BINARY_FUNC(Sub, float, float, math::MinusFunctor);
DEFINE_BINARY_FUNC(Sub, double, double, math::MinusFunctor);
DEFINE_BINARY_FUNC(Mul, int8_t, int8_t, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Mul, uint8_t, uint8_t, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Mul, int, int, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Mul, int64_t, int64_t, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Mul, float, float, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Mul, double, double, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Div, int8_t, int8_t, math::DividesFunctor);
DEFINE_BINARY_FUNC(Div, uint8_t, uint8_t, math::DividesFunctor);
DEFINE_BINARY_FUNC(Div, int, int, math::DividesFunctor);
DEFINE_BINARY_FUNC(Div, int64_t, int64_t, math::DividesFunctor);
DEFINE_BINARY_FUNC(Div, float, float, math::DividesFunctor);
DEFINE_BINARY_FUNC(Div, double, double, math::DividesFunctor);
DEFINE_BINARY_FUNC(Pow, float, float, math::PowFunctor);
DEFINE_BINARY_FUNC(Pow, double, double, math::PowFunctor);
DEFINE_BINARY_FUNC(Minimum, int8_t, int8_t, math::MinFunctor);
DEFINE_BINARY_FUNC(Minimum, uint8_t, uint8_t, math::MinFunctor);
DEFINE_BINARY_FUNC(Minimum, int, int, math::MinFunctor);
DEFINE_BINARY_FUNC(Minimum, int64_t, int64_t, math::MinFunctor);
DEFINE_BINARY_FUNC(Minimum, float, float, math::MinFunctor);
DEFINE_BINARY_FUNC(Minimum, double, double, math::MinFunctor);
DEFINE_BINARY_FUNC(Maximum, int8_t, int8_t, math::MaxFunctor);
DEFINE_BINARY_FUNC(Maximum, uint8_t, uint8_t, math::MaxFunctor);
DEFINE_BINARY_FUNC(Maximum, int, int, math::MaxFunctor);
DEFINE_BINARY_FUNC(Maximum, int64_t, int64_t, math::MaxFunctor);
DEFINE_BINARY_FUNC(Maximum, float, float, math::MaxFunctor);
DEFINE_BINARY_FUNC(Maximum, double, double, math::MaxFunctor);
DEFINE_BINARY_FUNC(Equal, int8_t, bool, math::EqualFunctor);
DEFINE_BINARY_FUNC(Equal, uint8_t, bool, math::EqualFunctor);
DEFINE_BINARY_FUNC(Equal, int, bool, math::EqualFunctor);
DEFINE_BINARY_FUNC(Equal, int64_t, bool, math::EqualFunctor);
DEFINE_BINARY_FUNC(Equal, float, bool, math::EqualFunctor);
DEFINE_BINARY_FUNC(Equal, double, bool, math::EqualFunctor);
DEFINE_BINARY_FUNC(NotEqual, int8_t, bool, math::NotEqualFunctor);
DEFINE_BINARY_FUNC(NotEqual, uint8_t, bool, math::NotEqualFunctor);
DEFINE_BINARY_FUNC(NotEqual, int, bool, math::NotEqualFunctor);
DEFINE_BINARY_FUNC(NotEqual, int64_t, bool, math::NotEqualFunctor);
DEFINE_BINARY_FUNC(NotEqual, float, bool, math::NotEqualFunctor);
DEFINE_BINARY_FUNC(NotEqual, double, bool, math::NotEqualFunctor);
DEFINE_BINARY_FUNC(Less, int8_t, bool, math::LessFunctor);
DEFINE_BINARY_FUNC(Less, uint8_t, bool, math::LessFunctor);
DEFINE_BINARY_FUNC(Less, int, bool, math::LessFunctor);
DEFINE_BINARY_FUNC(Less, int64_t, bool, math::LessFunctor);
DEFINE_BINARY_FUNC(Less, float, bool, math::LessFunctor);
DEFINE_BINARY_FUNC(Less, double, bool, math::LessFunctor);
DEFINE_BINARY_FUNC(LessEqual, int8_t, bool, math::LessEqualFunctor);
DEFINE_BINARY_FUNC(LessEqual, uint8_t, bool, math::LessEqualFunctor);
DEFINE_BINARY_FUNC(LessEqual, int, bool, math::LessEqualFunctor);
DEFINE_BINARY_FUNC(LessEqual, int64_t, bool, math::LessEqualFunctor);
DEFINE_BINARY_FUNC(LessEqual, float, bool, math::LessEqualFunctor);
DEFINE_BINARY_FUNC(LessEqual, double, bool, math::LessEqualFunctor);
DEFINE_BINARY_FUNC(Greater, int8_t, bool, math::GreaterFunctor);
DEFINE_BINARY_FUNC(Greater, uint8_t, bool, math::GreaterFunctor);
DEFINE_BINARY_FUNC(Greater, int, bool, math::GreaterFunctor);
DEFINE_BINARY_FUNC(Greater, int64_t, bool, math::GreaterFunctor);
DEFINE_BINARY_FUNC(Greater, float, bool, math::GreaterFunctor);
DEFINE_BINARY_FUNC(Greater, double, bool, math::GreaterFunctor);
DEFINE_BINARY_FUNC(GreaterEqual, int8_t, bool, math::GreaterEqualFunctor);
DEFINE_BINARY_FUNC(GreaterEqual, uint8_t, bool, math::GreaterEqualFunctor);
DEFINE_BINARY_FUNC(GreaterEqual, int, bool, math::GreaterEqualFunctor);
DEFINE_BINARY_FUNC(GreaterEqual, int64_t, bool, math::GreaterEqualFunctor);
DEFINE_BINARY_FUNC(GreaterEqual, float, bool, math::GreaterEqualFunctor);
DEFINE_BINARY_FUNC(GreaterEqual, double, bool, math::GreaterEqualFunctor);
#undef DEFINE_BINARY_FUNC
#define DEFINE_BINARY_FUNC(name, T, dtype) \
......@@ -528,7 +419,7 @@ DEFINE_BINARY_FUNC(Sub, bool, uint8_t); // Xor
DEFINE_BINARY_FUNC(Mul, bool, uint8_t); // And
#undef DEFINE_BINARY_FUNC
#define DEFINE_BINARY_FUNC(name, TOut1, TOut2, Op) \
#define DEFINE_BINARY_FUNC(name, TOut1, TOut2, Functor) \
template <> \
DRAGON_API void name<float16, CUDAContext>( \
const int a_ndim, \
......@@ -555,20 +446,20 @@ DEFINE_BINARY_FUNC(Mul, bool, uint8_t); // And
A_dims, B_dims, &rows, &cols, &broadcast_1st)) { \
auto nthreads = rows * cols; \
if (broadcast_1st > 0) { \
_RowwiseBinaryFunc<half, TOut2, Op<half>, true> \
_RowwiseBinaryFunc<half, TOut2, Functor<half>, true> \
<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
nthreads, \
cols, \
Op<half>(), \
Functor<half>(), \
reinterpret_cast<const half*>(a), \
reinterpret_cast<const half*>(b), \
reinterpret_cast<TOut2*>(y)); \
} else { \
_RowwiseBinaryFunc<half, TOut2, Op<half>, false> \
_RowwiseBinaryFunc<half, TOut2, Functor<half>, false> \
<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
nthreads, \
cols, \
Op<half>(), \
Functor<half>(), \
reinterpret_cast<const half*>(a), \
reinterpret_cast<const half*>(b), \
reinterpret_cast<TOut2*>(y)); \
......@@ -579,20 +470,20 @@ DEFINE_BINARY_FUNC(Mul, bool, uint8_t); // And
A_dims, B_dims, &rows, &cols, &broadcast_1st)) { \
auto nthreads = rows * cols; \
if (broadcast_1st > 0) { \
_ColwiseBinaryFunc<half, TOut2, Op<half>, true> \
_ColwiseBinaryFunc<half, TOut2, Functor<half>, true> \
<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
nthreads, \
cols, \
Op<half>(), \
Functor<half>(), \
reinterpret_cast<const half*>(a), \
reinterpret_cast<const half*>(b), \
reinterpret_cast<TOut2*>(y)); \
} else { \
_ColwiseBinaryFunc<half, TOut2, Op<half>, false> \
_ColwiseBinaryFunc<half, TOut2, Functor<half>, false> \
<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
nthreads, \
cols, \
Op<half>(), \
Functor<half>(), \
reinterpret_cast<const half*>(a), \
reinterpret_cast<const half*>(b), \
reinterpret_cast<TOut2*>(y)); \
......@@ -611,32 +502,32 @@ DEFINE_BINARY_FUNC(Mul, bool, uint8_t); // And
b_strides.data[i] = B_broadcast_strides[i]; \
y_dims.data[i] = Y_dims[i]; \
} \
_BroadcastBinaryFunc<half, TOut2, Op<half>, CUDA_TENSOR_MAX_DIMS> \
_BroadcastBinaryFunc<half, TOut2, Functor<half>, CUDA_TENSOR_MAX_DIMS> \
<<<CUDA_BLOCKS(nthreads), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
nthreads, \
Y_dims.size(), \
a_strides, \
b_strides, \
y_dims, \
Op<half>(), \
Functor<half>(), \
reinterpret_cast<const half*>(a), \
reinterpret_cast<const half*>(b), \
reinterpret_cast<TOut2*>(y)); \
}
DEFINE_BINARY_FUNC(Add, float16, half, AddHalfOp);
DEFINE_BINARY_FUNC(Sub, float16, half, SubHalfOp);
DEFINE_BINARY_FUNC(Mul, float16, half, MulHalfOp);
DEFINE_BINARY_FUNC(Div, float16, half, DivHalfOp);
DEFINE_BINARY_FUNC(Pow, float16, half, PowHalfOp);
DEFINE_BINARY_FUNC(Minimum, float16, half, MinHalfOp);
DEFINE_BINARY_FUNC(Maximum, float16, half, MaxHalfOp);
DEFINE_BINARY_FUNC(Equal, bool, bool, EqualHalfOp);
DEFINE_BINARY_FUNC(NotEqual, bool, bool, NotEqualHalfOp);
DEFINE_BINARY_FUNC(Less, bool, bool, LessHalfOp);
DEFINE_BINARY_FUNC(LessEqual, bool, bool, LessEqualHalfOp);
DEFINE_BINARY_FUNC(Greater, bool, bool, GreaterHalfOp);
DEFINE_BINARY_FUNC(GreaterEqual, bool, bool, GreaterEqualHalfOp);
DEFINE_BINARY_FUNC(Add, float16, half, math::PlusFunctor);
DEFINE_BINARY_FUNC(Sub, float16, half, math::MinusFunctor);
DEFINE_BINARY_FUNC(Mul, float16, half, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Div, float16, half, math::DividesFunctor);
DEFINE_BINARY_FUNC(Pow, float16, half, math::PowFunctor);
DEFINE_BINARY_FUNC(Minimum, float16, half, math::MinFunctor);
DEFINE_BINARY_FUNC(Maximum, float16, half, math::MaxFunctor);
DEFINE_BINARY_FUNC(Equal, bool, bool, math::EqualFunctor);
DEFINE_BINARY_FUNC(NotEqual, bool, bool, math::NotEqualFunctor);
DEFINE_BINARY_FUNC(Less, bool, bool, math::LessFunctor);
DEFINE_BINARY_FUNC(LessEqual, bool, bool, math::LessEqualFunctor);
DEFINE_BINARY_FUNC(Greater, bool, bool, math::GreaterFunctor);
DEFINE_BINARY_FUNC(GreaterEqual, bool, bool, math::GreaterEqualFunctor);
#undef DEFINE_BINARY_FUNC
#define DEFINE_WHERE_FUNC(T1, T2) \
......
......@@ -102,6 +102,29 @@ DEFINE_UNARY_FUNC(Sign, double, [](double x) {
});
#undef DEFINE_UNARY_FUNC
template <>
#define DEFINE_NEG_FUNC(T) \
template <> \
DRAGON_API void Neg<T, CPUContext>( \
const int n, const T* x, T* y, CPUContext* ctx) { \
EigenVectorArrayMap<T>(y, n) = -ConstEigenVectorArrayMap<T>(x, n); \
}
DRAGON_API void Neg<float16, CPUContext>(
const int n,
const float16* x,
float16* y,
CPUContext* ctx) {
CPU_FP16_NOT_SUPPORTED;
}
DEFINE_NEG_FUNC(int8_t);
DEFINE_NEG_FUNC(int);
DEFINE_NEG_FUNC(int64_t);
DEFINE_NEG_FUNC(float);
DEFINE_NEG_FUNC(double);
#undef DEFINE_NEG_FUNC
/* y = value */
#define DEFINE_SET_FUNC(T) \
......
......@@ -3,6 +3,7 @@
#include "dragon/core/context_cuda.h"
#include "dragon/utils/cast.h"
#include "dragon/utils/math/elementwise.h"
#include "dragon/utils/math/functional.h"
#include "dragon/utils/math/utils.h"
namespace dragon {
......@@ -12,75 +13,79 @@ namespace math {
namespace {
/*!
* UnaryOp Wrappers
* Unary Functors
*/
#define DEFINE_UNARY_OPERATOR(name, func) \
#define DEFINE_UNARY_FUNCTOR(name, func) \
template <typename T> \
struct name##Op { \
struct name##Functor { \
inline __device__ T operator()(const T& x) const { \
return func(x); \
} \
}
DEFINE_UNARY_OPERATOR(Ceil, ceil);
DEFINE_UNARY_OPERATOR(Cos, cos);
DEFINE_UNARY_OPERATOR(Exp, exp);
DEFINE_UNARY_OPERATOR(Floor, floor);
DEFINE_UNARY_OPERATOR(Log, log);
DEFINE_UNARY_OPERATOR(Round, round);
DEFINE_UNARY_OPERATOR(Rsqrt, rsqrt);
DEFINE_UNARY_OPERATOR(Sin, sin);
DEFINE_UNARY_OPERATOR(Sqrt, sqrt);
DEFINE_UNARY_FUNCTOR(Neg, -);
DEFINE_UNARY_FUNCTOR(Ceil, ceil);
DEFINE_UNARY_FUNCTOR(Cos, cos);
DEFINE_UNARY_FUNCTOR(Exp, exp);
DEFINE_UNARY_FUNCTOR(Floor, floor);
DEFINE_UNARY_FUNCTOR(Log, log);
DEFINE_UNARY_FUNCTOR(Round, round);
DEFINE_UNARY_FUNCTOR(Rsqrt, rsqrt);
DEFINE_UNARY_FUNCTOR(Sin, sin);
DEFINE_UNARY_FUNCTOR(Sqrt, sqrt);
#if __CUDA_ARCH__ >= 530
DEFINE_UNARY_OPERATOR(CeilHalf, hceil);
DEFINE_UNARY_OPERATOR(CeilHalf2, h2ceil);
DEFINE_UNARY_OPERATOR(CosHalf, hcos);
DEFINE_UNARY_OPERATOR(CosHalf2, h2cos);
DEFINE_UNARY_OPERATOR(ExpHalf, hexp);
DEFINE_UNARY_OPERATOR(ExpHalf2, h2exp);
DEFINE_UNARY_OPERATOR(FloorHalf, hfloor);
DEFINE_UNARY_OPERATOR(FloorHalf2, h2floor);
DEFINE_UNARY_OPERATOR(InvHalf, hrcp);
DEFINE_UNARY_OPERATOR(InvHalf2, h2rcp);
DEFINE_UNARY_OPERATOR(LogHalf, hlog);
DEFINE_UNARY_OPERATOR(LogHalf2, h2log);
DEFINE_UNARY_OPERATOR(RoundHalf, hrint);
DEFINE_UNARY_OPERATOR(RoundHalf2, h2rint);
DEFINE_UNARY_OPERATOR(RsqrtHalf, hrsqrt);
DEFINE_UNARY_OPERATOR(RsqrtHalf2, h2rsqrt);
DEFINE_UNARY_OPERATOR(SinHalf, hsin);
DEFINE_UNARY_OPERATOR(SinHalf2, h2sin);
DEFINE_UNARY_OPERATOR(SqrtHalf, hsqrt);
DEFINE_UNARY_OPERATOR(SqrtHalf2, h2sqrt);
DEFINE_UNARY_FUNCTOR(NegHalf, __hneg);
DEFINE_UNARY_FUNCTOR(NegHalf2, __hneg2);
DEFINE_UNARY_FUNCTOR(CeilHalf, hceil);
DEFINE_UNARY_FUNCTOR(CeilHalf2, h2ceil);
DEFINE_UNARY_FUNCTOR(CosHalf, hcos);
DEFINE_UNARY_FUNCTOR(CosHalf2, h2cos);
DEFINE_UNARY_FUNCTOR(ExpHalf, hexp);
DEFINE_UNARY_FUNCTOR(ExpHalf2, h2exp);
DEFINE_UNARY_FUNCTOR(FloorHalf, hfloor);
DEFINE_UNARY_FUNCTOR(FloorHalf2, h2floor);
DEFINE_UNARY_FUNCTOR(InvHalf, hrcp);
DEFINE_UNARY_FUNCTOR(InvHalf2, h2rcp);
DEFINE_UNARY_FUNCTOR(LogHalf, hlog);
DEFINE_UNARY_FUNCTOR(LogHalf2, h2log);
DEFINE_UNARY_FUNCTOR(RoundHalf, hrint);
DEFINE_UNARY_FUNCTOR(RoundHalf2, h2rint);
DEFINE_UNARY_FUNCTOR(RsqrtHalf, hrsqrt);
DEFINE_UNARY_FUNCTOR(RsqrtHalf2, h2rsqrt);
DEFINE_UNARY_FUNCTOR(SinHalf, hsin);
DEFINE_UNARY_FUNCTOR(SinHalf2, h2sin);
DEFINE_UNARY_FUNCTOR(SqrtHalf, hsqrt);
DEFINE_UNARY_FUNCTOR(SqrtHalf2, h2sqrt);
#endif
#undef DEFINE_UNARY_OPERATOR
#undef DEFINE_UNARY_FUNCTOR
#define DEFINE_UNARY_OPERATOR(name, func) \
#define DEFINE_UNARY_FUNCTOR(name, func) \
template <typename T> \
struct name##Op { \
struct name##Functor { \
inline __device__ T operator()(const T& x) const { \
return __float2half(func(__half2float(x))); \
} \
}
#if __CUDA_ARCH__ < 530
DEFINE_UNARY_OPERATOR(CeilHalf, ceil);
DEFINE_UNARY_OPERATOR(CosHalf, cos);
DEFINE_UNARY_OPERATOR(ExpHalf, exp);
DEFINE_UNARY_OPERATOR(FloorHalf, floor);
DEFINE_UNARY_OPERATOR(InvHalf, __frcp_rn);
DEFINE_UNARY_OPERATOR(LogHalf, log);
DEFINE_UNARY_OPERATOR(RoundHalf, round);
DEFINE_UNARY_OPERATOR(RsqrtHalf, rsqrt);
DEFINE_UNARY_OPERATOR(SinHalf, sin);
DEFINE_UNARY_OPERATOR(SqrtHalf, sqrt);
DEFINE_UNARY_FUNCTOR(NegHalf, -);
DEFINE_UNARY_FUNCTOR(CeilHalf, ceil);
DEFINE_UNARY_FUNCTOR(CosHalf, cos);
DEFINE_UNARY_FUNCTOR(ExpHalf, exp);
DEFINE_UNARY_FUNCTOR(FloorHalf, floor);
DEFINE_UNARY_FUNCTOR(InvHalf, __frcp_rn);
DEFINE_UNARY_FUNCTOR(LogHalf, log);
DEFINE_UNARY_FUNCTOR(RoundHalf, round);
DEFINE_UNARY_FUNCTOR(RsqrtHalf, rsqrt);
DEFINE_UNARY_FUNCTOR(SinHalf, sin);
DEFINE_UNARY_FUNCTOR(SqrtHalf, sqrt);
#endif
#undef DEFINE_UNARY_OPERATOR
#undef DEFINE_UNARY_FUNCTOR
#define DEFINE_UNARY_OPERATOR(name, func) \
#define DEFINE_UNARY_FUNCTOR(name, func) \
template <typename T> \
struct name##Op { \
struct name##Functor { \
inline __device__ T operator()(const T& x) const { \
const float2 val = __half22float2(x); \
return __floats2half2_rn(func(val.x), func(val.y)); \
......@@ -88,152 +93,22 @@ DEFINE_UNARY_OPERATOR(SqrtHalf, sqrt);
}
#if __CUDA_ARCH__ < 530
DEFINE_UNARY_OPERATOR(CeilHalf2, ceil);
DEFINE_UNARY_OPERATOR(CosHalf2, cos);
DEFINE_UNARY_OPERATOR(ExpHalf2, exp);
DEFINE_UNARY_OPERATOR(FloorHalf2, floor);
DEFINE_UNARY_OPERATOR(InvHalf2, __frcp_rn);
DEFINE_UNARY_OPERATOR(LogHalf2, log);
DEFINE_UNARY_OPERATOR(RoundHalf2, round);
DEFINE_UNARY_OPERATOR(RsqrtHalf2, rsqrt);
DEFINE_UNARY_OPERATOR(SinHalf2, sin);
DEFINE_UNARY_OPERATOR(SqrtHalf2, sqrt);
DEFINE_UNARY_FUNCTOR(NegHalf2, -);
DEFINE_UNARY_FUNCTOR(CeilHalf2, ceil);
DEFINE_UNARY_FUNCTOR(CosHalf2, cos);
DEFINE_UNARY_FUNCTOR(ExpHalf2, exp);
DEFINE_UNARY_FUNCTOR(FloorHalf2, floor);
DEFINE_UNARY_FUNCTOR(InvHalf2, __frcp_rn);
DEFINE_UNARY_FUNCTOR(LogHalf2, log);
DEFINE_UNARY_FUNCTOR(RoundHalf2, round);
DEFINE_UNARY_FUNCTOR(RsqrtHalf2, rsqrt);
DEFINE_UNARY_FUNCTOR(SinHalf2, sin);
DEFINE_UNARY_FUNCTOR(SqrtHalf2, sqrt);
#endif
#undef DEFINE_UNARY_OPERATOR
#define DEFINE_BINARY_OPERATOR(name, TOut, expr) \
template <typename T> \
struct name##Op { \
inline __device__ TOut operator()(const T& a, const T& b) const { \
return a expr b; \
} \
}
#undef DEFINE_UNARY_FUNCTOR
/*!
* BinaryOp Wrappers
*/
DEFINE_BINARY_OPERATOR(Add, T, +);
DEFINE_BINARY_OPERATOR(Sub, T, -);
DEFINE_BINARY_OPERATOR(Mul, T, *);
DEFINE_BINARY_OPERATOR(Div, T, /);
DEFINE_BINARY_OPERATOR(Equal, bool, ==);
DEFINE_BINARY_OPERATOR(NotEqual, bool, !=);
DEFINE_BINARY_OPERATOR(Less, bool, <);
DEFINE_BINARY_OPERATOR(LessEqual, bool, <=);
DEFINE_BINARY_OPERATOR(Greater, bool, >);
DEFINE_BINARY_OPERATOR(GreaterEqual, bool, >=);
#undef DEFINE_BINARY_OPERATOR
#define DEFINE_BINARY_OPERATOR(name, TOut, func) \
template <typename T> \
struct name##Op { \
inline __device__ TOut operator()(const T& a, const T& b) const { \
return func(a, b); \
} \
}
DEFINE_BINARY_OPERATOR(Pow, T, pow);
DEFINE_BINARY_OPERATOR(Min, T, min);
DEFINE_BINARY_OPERATOR(Max, T, max);
#if __CUDA_ARCH__ >= 530
DEFINE_BINARY_OPERATOR(AddHalf, T, __hadd);
DEFINE_BINARY_OPERATOR(AddHalf2, T, __hadd2);
DEFINE_BINARY_OPERATOR(SubHalf, T, __hsub);
DEFINE_BINARY_OPERATOR(SubHalf2, T, __hsub2);
DEFINE_BINARY_OPERATOR(MulHalf, T, __hmul);
DEFINE_BINARY_OPERATOR(MulHalf2, T, __hmul2);
DEFINE_BINARY_OPERATOR(DivHalf, T, __hdiv);
DEFINE_BINARY_OPERATOR(EqualHalf, bool, __heq);
DEFINE_BINARY_OPERATOR(NotEqualHalf, bool, __hne);
DEFINE_BINARY_OPERATOR(LessHalf, bool, __hlt);
DEFINE_BINARY_OPERATOR(LessEqualHalf, bool, __hle);
DEFINE_BINARY_OPERATOR(GreaterHalf, bool, __hgt);
DEFINE_BINARY_OPERATOR(GreaterEqualHalf, bool, __hge);
#endif
#undef DEFINE_BINARY_OPERATOR
#define DEFINE_BINARY_OPERATOR(name, expr) \
template <typename T> \
struct name##Op { \
inline __device__ T operator()(const T& a, const T& b) const { \
return __float2half(__half2float(a) expr __half2float(b)); \
} \
}
#if __CUDA_ARCH__ < 530
DEFINE_BINARY_OPERATOR(AddHalf, +);
DEFINE_BINARY_OPERATOR(SubHalf, -);
DEFINE_BINARY_OPERATOR(MulHalf, *);
DEFINE_BINARY_OPERATOR(DivHalf, /);
#endif
#undef DEFINE_BINARY_OPERATOR
#define DEFINE_BINARY_OPERATOR(name, expr) \
template <typename T> \
struct name##Op { \
inline __device__ bool operator()(const T& a, const T& b) const { \
return __half2float(a) expr __half2float(b); \
} \
}
#if __CUDA_ARCH__ < 530
DEFINE_BINARY_OPERATOR(EqualHalf, ==);
DEFINE_BINARY_OPERATOR(NotEqualHalf, !=);
DEFINE_BINARY_OPERATOR(LessHalf, <);
DEFINE_BINARY_OPERATOR(LessEqualHalf, <=);
DEFINE_BINARY_OPERATOR(GreaterHalf, >);
DEFINE_BINARY_OPERATOR(GreaterEqualHalf, >=);
#endif
#undef DEFINE_BINARY_OPERATOR
#define DEFINE_BINARY_OPERATOR(name, func) \
template <typename T> \
struct name##Op { \
inline __device__ T operator()(const T& a, const T& b) const { \
return __float2half(func(__half2float(a), __half2float(b))); \
} \
}
DEFINE_BINARY_OPERATOR(PowHalf, pow);
DEFINE_BINARY_OPERATOR(MinHalf, min);
DEFINE_BINARY_OPERATOR(MaxHalf, max);
#undef DEFINE_BINARY_OPERATOR
#define DEFINE_BINARY_OPERATOR(name, expr) \
template <typename T> \
struct name##Op { \
inline __device__ T operator()(const T& a, const T& b) const { \
const float2 v1 = __half22float2(a); \
const float2 v2 = __half22float2(b); \
return __floats2half2_rn(v1.x expr v2.x, v1.y expr v2.y); \
} \
}
#if __CUDA_ARCH__ < 530
DEFINE_BINARY_OPERATOR(AddHalf2, +);
DEFINE_BINARY_OPERATOR(SubHalf2, -);
DEFINE_BINARY_OPERATOR(MulHalf2, *);
#endif
#undef DEFINE_BINARY_OPERATOR
#define DEFINE_BINARY_OPERATOR(name, func) \
template <typename T> \
struct name##Op { \
inline __device__ T operator()(const T& a, const T& b) const { \
const float2 v1 = __half22float2(a); \
const float2 v2 = __half22float2(b); \
return __floats2half2_rn(func(v1.x, v2.x), func(v1.y, v2.y)); \
} \
}
DEFINE_BINARY_OPERATOR(PowHalf2, pow);
DEFINE_BINARY_OPERATOR(MinHalf2, min);
DEFINE_BINARY_OPERATOR(MaxHalf2, max);
#undef DEFINE_BINARY_OPERATOR
/*!
* UnaryOp Kernels
* Unary Function Kernels
*/
template <typename T, class Operator>
......@@ -471,7 +346,7 @@ _Bias(const int n, const T beta, const Operator op, const T* x, T* y) {
}
/*!
* BinaryOp Kernels
* Binary Function Kernels
*/
template <typename TIn, typename TOut, class Operator>
......@@ -498,32 +373,37 @@ _Where(const int n, const T* a, const T* b, const bool* c, T* y) {
/* ------------------- Launcher Separator ------------------- */
#define DEFINE_UNARY_FUNC(name, T, Op) \
#define DEFINE_UNARY_FUNC(name, T, Functor) \
template <> \
DRAGON_API void name<T, CUDAContext>( \
const int n, const T* x, T* y, CUDAContext* ctx) { \
_SimpleUnaryFunc<<<CUDA_BLOCKS(n), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
n, Op<T>(), x, y); \
}
DEFINE_UNARY_FUNC(Ceil, float, CeilOp);
DEFINE_UNARY_FUNC(Ceil, double, CeilOp);
DEFINE_UNARY_FUNC(Cos, float, CosOp);
DEFINE_UNARY_FUNC(Cos, double, CosOp);
DEFINE_UNARY_FUNC(Exp, float, ExpOp);
DEFINE_UNARY_FUNC(Exp, double, ExpOp);
DEFINE_UNARY_FUNC(Floor, float, FloorOp);
DEFINE_UNARY_FUNC(Floor, double, FloorOp);
DEFINE_UNARY_FUNC(Log, float, LogOp);
DEFINE_UNARY_FUNC(Log, double, LogOp);
DEFINE_UNARY_FUNC(Round, float, RoundOp);
DEFINE_UNARY_FUNC(Round, double, RoundOp);
DEFINE_UNARY_FUNC(Rsqrt, float, RsqrtOp);
DEFINE_UNARY_FUNC(Rsqrt, double, RsqrtOp);
DEFINE_UNARY_FUNC(Sin, float, SinOp);
DEFINE_UNARY_FUNC(Sin, double, SinOp);
DEFINE_UNARY_FUNC(Sqrt, float, SqrtOp);
DEFINE_UNARY_FUNC(Sqrt, double, SqrtOp);
n, Functor<T>(), x, y); \
}
DEFINE_UNARY_FUNC(Neg, int8_t, NegFunctor);
DEFINE_UNARY_FUNC(Neg, int, NegFunctor);
DEFINE_UNARY_FUNC(Neg, int64_t, NegFunctor);
DEFINE_UNARY_FUNC(Neg, float, NegFunctor);
DEFINE_UNARY_FUNC(Neg, double, NegFunctor);
DEFINE_UNARY_FUNC(Ceil, float, CeilFunctor);
DEFINE_UNARY_FUNC(Ceil, double, CeilFunctor);
DEFINE_UNARY_FUNC(Cos, float, CosFunctor);
DEFINE_UNARY_FUNC(Cos, double, CosFunctor);
DEFINE_UNARY_FUNC(Exp, float, ExpFunctor);
DEFINE_UNARY_FUNC(Exp, double, ExpFunctor);
DEFINE_UNARY_FUNC(Floor, float, FloorFunctor);
DEFINE_UNARY_FUNC(Floor, double, FloorFunctor);
DEFINE_UNARY_FUNC(Log, float, LogFunctor);
DEFINE_UNARY_FUNC(Log, double, LogFunctor);
DEFINE_UNARY_FUNC(Round, float, RoundFunctor);
DEFINE_UNARY_FUNC(Round, double, RoundFunctor);
DEFINE_UNARY_FUNC(Rsqrt, float, RsqrtFunctor);
DEFINE_UNARY_FUNC(Rsqrt, double, RsqrtFunctor);
DEFINE_UNARY_FUNC(Sin, float, SinFunctor);
DEFINE_UNARY_FUNC(Sin, double, SinFunctor);
DEFINE_UNARY_FUNC(Sqrt, float, SqrtFunctor);
DEFINE_UNARY_FUNC(Sqrt, double, SqrtFunctor);
#undef DEFINE_UNARY_FUNC
#define DEFINE_UNARY_FUNC(name, T) \
......@@ -560,7 +440,7 @@ DEFINE_UNARY_FUNC(Square, float);
DEFINE_UNARY_FUNC(Square, double);
#undef DEFINE_UNARY_FUNC
#define DEFINE_UNARY_FUNC(name, HalfOp, Half2Op) \
#define DEFINE_UNARY_FUNC(name, HalfFunctor, Half2Functor) \
template <> \
DRAGON_API void name<float16, CUDAContext>( \
const int n, const float16* x, float16* y, CUDAContext* ctx) { \
......@@ -571,7 +451,7 @@ DEFINE_UNARY_FUNC(Square, double);
0, \
ctx->cuda_stream()>>>( \
n >> 1, \
Half2Op<half2>(), \
Half2Functor<half2>(), \
reinterpret_cast<const half2*>(x), \
reinterpret_cast<half2*>(y)); \
} else { \
......@@ -581,22 +461,23 @@ DEFINE_UNARY_FUNC(Square, double);
0, \
ctx->cuda_stream()>>>( \
n, \
HalfOp<half>(), \
HalfFunctor<half>(), \
reinterpret_cast<const half*>(x), \
reinterpret_cast<half*>(y)); \
} \
}
DEFINE_UNARY_FUNC(Ceil, CeilHalfOp, CeilHalf2Op);
DEFINE_UNARY_FUNC(Cos, CosHalfOp, CosHalf2Op);
DEFINE_UNARY_FUNC(Exp, ExpHalfOp, ExpHalf2Op);
DEFINE_UNARY_FUNC(Floor, FloorHalfOp, FloorHalf2Op);
DEFINE_UNARY_FUNC(Log, LogHalfOp, LogHalf2Op);
DEFINE_UNARY_FUNC(Inv, InvHalfOp, InvHalf2Op);
DEFINE_UNARY_FUNC(Round, RoundHalfOp, RoundHalf2Op);
DEFINE_UNARY_FUNC(Rsqrt, RsqrtHalfOp, RsqrtHalf2Op);
DEFINE_UNARY_FUNC(Sin, SinHalfOp, SinHalf2Op);
DEFINE_UNARY_FUNC(Sqrt, SqrtHalfOp, SqrtHalf2Op);
DEFINE_UNARY_FUNC(Neg, NegHalfFunctor, NegHalf2Functor);
DEFINE_UNARY_FUNC(Ceil, CeilHalfFunctor, CeilHalf2Functor);
DEFINE_UNARY_FUNC(Cos, CosHalfFunctor, CosHalf2Functor);
DEFINE_UNARY_FUNC(Exp, ExpHalfFunctor, ExpHalf2Functor);
DEFINE_UNARY_FUNC(Floor, FloorHalfFunctor, FloorHalf2Functor);
DEFINE_UNARY_FUNC(Log, LogHalfFunctor, LogHalf2Functor);
DEFINE_UNARY_FUNC(Inv, InvHalfFunctor, InvHalf2Functor);
DEFINE_UNARY_FUNC(Round, RoundHalfFunctor, RoundHalf2Functor);
DEFINE_UNARY_FUNC(Rsqrt, RsqrtHalfFunctor, RsqrtHalf2Functor);
DEFINE_UNARY_FUNC(Sin, SinHalfFunctor, SinHalf2Functor);
DEFINE_UNARY_FUNC(Sqrt, SqrtHalfFunctor, SqrtHalf2Functor);
#undef DEFINE_UNARY_FUNC
#define DEFINE_UNARY_FUNC(name) \
......@@ -843,7 +724,7 @@ DEFINE_REPLACE_NAN_FUNC(double);
const int n, const float beta, const T* x, T* y, CUDAContext* ctx) { \
if (beta == 0.f) return; \
_Bias<<<CUDA_BLOCKS(n), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
n, (T)beta, AddOp<T>(), x, y); \
n, (T)beta, math::PlusFunctor<T>(), x, y); \
}
template <>
......@@ -858,14 +739,14 @@ DRAGON_API void Bias<float16, CUDAContext>(
_Bias<<<CUDA_BLOCKS(n >> 1), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
n >> 1,
cast::to<half2>(beta),
AddHalf2Op<half2>(),
math::PlusFunctor<half2>(),
reinterpret_cast<const half2*>(x),
reinterpret_cast<half2*>(y));
} else {
_Bias<<<CUDA_BLOCKS(n), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
n,
cast::to<half>(beta),
AddHalfOp<half>(),
math::PlusFunctor<half>(),
reinterpret_cast<const half*>(x),
reinterpret_cast<half*>(y));
}
......@@ -890,90 +771,80 @@ DEFINE_BIAS_FUNC(double);
ctx->cuda_stream()>>>(n, Op<TIn>(), a, b, y); \
}
DEFINE_BINARY_FUNC(Add, int8_t, int8_t, AddOp);
DEFINE_BINARY_FUNC(Add, uint8_t, uint8_t, AddOp);
DEFINE_BINARY_FUNC(Add, int, int, AddOp);
DEFINE_BINARY_FUNC(Add, int64_t, int64_t, AddOp);
DEFINE_BINARY_FUNC(Add, float, float, AddOp);
DEFINE_BINARY_FUNC(Add, double, double, AddOp);
DEFINE_BINARY_FUNC(Sub, int8_t, int8_t, SubOp);
DEFINE_BINARY_FUNC(Sub, uint8_t, uint8_t, SubOp);
DEFINE_BINARY_FUNC(Sub, int, int, SubOp);
DEFINE_BINARY_FUNC(Sub, int64_t, int64_t, SubOp);
DEFINE_BINARY_FUNC(Sub, float, float, SubOp);
DEFINE_BINARY_FUNC(Sub, double, double, SubOp);
DEFINE_BINARY_FUNC(Mul, int8_t, int8_t, MulOp);
DEFINE_BINARY_FUNC(Mul, uint8_t, uint8_t, MulOp);
DEFINE_BINARY_FUNC(Mul, int, int, MulOp);
DEFINE_BINARY_FUNC(Mul, int64_t, int64_t, MulOp);
DEFINE_BINARY_FUNC(Mul, float, float, MulOp);
DEFINE_BINARY_FUNC(Mul, double, double, MulOp);
DEFINE_BINARY_FUNC(Div, int8_t, int8_t, DivOp);
DEFINE_BINARY_FUNC(Div, uint8_t, uint8_t, DivOp);
DEFINE_BINARY_FUNC(Div, int, int, DivOp);
DEFINE_BINARY_FUNC(Div, int64_t, int64_t, DivOp);
DEFINE_BINARY_FUNC(Div, float, float, DivOp);
DEFINE_BINARY_FUNC(Div, double, double, DivOp);
DEFINE_BINARY_FUNC(Pow, float, float, PowOp);
DEFINE_BINARY_FUNC(Pow, double, double, PowOp);
DEFINE_BINARY_FUNC(Minimum, int8_t, int8_t, MinOp);
DEFINE_BINARY_FUNC(Minimum, uint8_t, uint8_t, MinOp);
DEFINE_BINARY_FUNC(Minimum, int, int, MinOp);
DEFINE_BINARY_FUNC(Minimum, int64_t, int64_t, MinOp);
DEFINE_BINARY_FUNC(Minimum, float, float, MinOp);
DEFINE_BINARY_FUNC(Minimum, double, double, MinOp);
DEFINE_BINARY_FUNC(Maximum, int8_t, int8_t, MaxOp);
DEFINE_BINARY_FUNC(Maximum, uint8_t, uint8_t, MaxOp);
DEFINE_BINARY_FUNC(Maximum, int, int, MaxOp);
DEFINE_BINARY_FUNC(Maximum, int64_t, int64_t, MaxOp);
DEFINE_BINARY_FUNC(Maximum, float, float, MaxOp);
DEFINE_BINARY_FUNC(Maximum, double, double, MaxOp);
DEFINE_BINARY_FUNC(Equal, int8_t, bool, EqualOp);
DEFINE_BINARY_FUNC(Equal, uint8_t, bool, EqualOp);
DEFINE_BINARY_FUNC(Equal, int, bool, EqualOp);
DEFINE_BINARY_FUNC(Equal, int64_t, bool, EqualOp);
DEFINE_BINARY_FUNC(Equal, float, bool, EqualOp);
DEFINE_BINARY_FUNC(Equal, double, bool, EqualOp);
DEFINE_BINARY_FUNC(NotEqual, int8_t, bool, NotEqualOp);
DEFINE_BINARY_FUNC(NotEqual, uint8_t, bool, NotEqualOp);
DEFINE_BINARY_FUNC(NotEqual, int, bool, NotEqualOp);
DEFINE_BINARY_FUNC(NotEqual, int64_t, bool, NotEqualOp);
DEFINE_BINARY_FUNC(NotEqual, float, bool, NotEqualOp);
DEFINE_BINARY_FUNC(NotEqual, double, bool, NotEqualOp);
DEFINE_BINARY_FUNC(Less, int8_t, bool, LessOp);
DEFINE_BINARY_FUNC(Less, uint8_t, bool, LessOp);
DEFINE_BINARY_FUNC(Less, int, bool, LessOp);
DEFINE_BINARY_FUNC(Less, int64_t, bool, LessOp);
DEFINE_BINARY_FUNC(Less, float, bool, LessOp);
DEFINE_BINARY_FUNC(Less, double, bool, LessOp);
DEFINE_BINARY_FUNC(LessEqual, int8_t, bool, LessEqualOp);
DEFINE_BINARY_FUNC(LessEqual, uint8_t, bool, LessEqualOp);
DEFINE_BINARY_FUNC(LessEqual, int, bool, LessEqualOp);
DEFINE_BINARY_FUNC(LessEqual, int64_t, bool, LessEqualOp);
DEFINE_BINARY_FUNC(LessEqual, float, bool, LessEqualOp);
DEFINE_BINARY_FUNC(LessEqual, double, bool, LessEqualOp);
DEFINE_BINARY_FUNC(Greater, int8_t, bool, GreaterOp);
DEFINE_BINARY_FUNC(Greater, uint8_t, bool, GreaterOp);
DEFINE_BINARY_FUNC(Greater, int, bool, GreaterOp);
DEFINE_BINARY_FUNC(Greater, int64_t, bool, GreaterOp);
DEFINE_BINARY_FUNC(Greater, float, bool, GreaterOp);
DEFINE_BINARY_FUNC(Greater, double, bool, GreaterOp);
DEFINE_BINARY_FUNC(GreaterEqual, int8_t, bool, GreaterEqualOp);
DEFINE_BINARY_FUNC(GreaterEqual, uint8_t, bool, GreaterEqualOp);
DEFINE_BINARY_FUNC(GreaterEqual, int, bool, GreaterEqualOp);
DEFINE_BINARY_FUNC(GreaterEqual, int64_t, bool, GreaterEqualOp);
DEFINE_BINARY_FUNC(GreaterEqual, float, bool, GreaterEqualOp);
DEFINE_BINARY_FUNC(GreaterEqual, double, bool, GreaterEqualOp);
#undef DEFINE_BINARY_FUNC
#define DEFINE_BINARY_FUNC(name, T) \
template <> \
DRAGON_API void name<T, CUDAContext>( \
const int n, const T* a, const T* b, T* y, CUDAContext* ctx) { \
_##name<<<CUDA_BLOCKS(n), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
n, a, b, y); \
}
DEFINE_BINARY_FUNC(Add, int8_t, int8_t, math::PlusFunctor);
DEFINE_BINARY_FUNC(Add, uint8_t, uint8_t, math::PlusFunctor);
DEFINE_BINARY_FUNC(Add, int, int, math::PlusFunctor);
DEFINE_BINARY_FUNC(Add, int64_t, int64_t, math::PlusFunctor);
DEFINE_BINARY_FUNC(Add, float, float, math::PlusFunctor);
DEFINE_BINARY_FUNC(Add, double, double, math::PlusFunctor);
DEFINE_BINARY_FUNC(Sub, int8_t, int8_t, math::MinusFunctor);
DEFINE_BINARY_FUNC(Sub, uint8_t, uint8_t, math::MinusFunctor);
DEFINE_BINARY_FUNC(Sub, int, int, math::MinusFunctor);
DEFINE_BINARY_FUNC(Sub, int64_t, int64_t, math::MinusFunctor);
DEFINE_BINARY_FUNC(Sub, float, float, math::MinusFunctor);
DEFINE_BINARY_FUNC(Sub, double, double, math::MinusFunctor);
DEFINE_BINARY_FUNC(Mul, int8_t, int8_t, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Mul, uint8_t, uint8_t, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Mul, int, int, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Mul, int64_t, int64_t, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Mul, float, float, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Mul, double, double, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Div, int8_t, int8_t, math::DividesFunctor);
DEFINE_BINARY_FUNC(Div, uint8_t, uint8_t, math::DividesFunctor);
DEFINE_BINARY_FUNC(Div, int, int, math::DividesFunctor);
DEFINE_BINARY_FUNC(Div, int64_t, int64_t, math::DividesFunctor);
DEFINE_BINARY_FUNC(Div, float, float, math::DividesFunctor);
DEFINE_BINARY_FUNC(Div, double, double, math::DividesFunctor);
DEFINE_BINARY_FUNC(Pow, float, float, math::PowFunctor);
DEFINE_BINARY_FUNC(Pow, double, double, math::PowFunctor);
DEFINE_BINARY_FUNC(Minimum, int8_t, int8_t, math::MinFunctor);
DEFINE_BINARY_FUNC(Minimum, uint8_t, uint8_t, math::MinFunctor);
DEFINE_BINARY_FUNC(Minimum, int, int, math::MinFunctor);
DEFINE_BINARY_FUNC(Minimum, int64_t, int64_t, math::MinFunctor);
DEFINE_BINARY_FUNC(Minimum, float, float, math::MinFunctor);
DEFINE_BINARY_FUNC(Minimum, double, double, math::MinFunctor);
DEFINE_BINARY_FUNC(Maximum, int8_t, int8_t, math::MaxFunctor);
DEFINE_BINARY_FUNC(Maximum, uint8_t, uint8_t, math::MaxFunctor);
DEFINE_BINARY_FUNC(Maximum, int, int, math::MaxFunctor);
DEFINE_BINARY_FUNC(Maximum, int64_t, int64_t, math::MaxFunctor);
DEFINE_BINARY_FUNC(Maximum, float, float, math::MaxFunctor);
DEFINE_BINARY_FUNC(Maximum, double, double, math::MaxFunctor);
DEFINE_BINARY_FUNC(Equal, int8_t, bool, math::EqualFunctor);
DEFINE_BINARY_FUNC(Equal, uint8_t, bool, math::EqualFunctor);
DEFINE_BINARY_FUNC(Equal, int, bool, math::EqualFunctor);
DEFINE_BINARY_FUNC(Equal, int64_t, bool, math::EqualFunctor);
DEFINE_BINARY_FUNC(Equal, float, bool, math::EqualFunctor);
DEFINE_BINARY_FUNC(Equal, double, bool, math::EqualFunctor);
DEFINE_BINARY_FUNC(NotEqual, int8_t, bool, math::NotEqualFunctor);
DEFINE_BINARY_FUNC(NotEqual, uint8_t, bool, math::NotEqualFunctor);
DEFINE_BINARY_FUNC(NotEqual, int, bool, math::NotEqualFunctor);
DEFINE_BINARY_FUNC(NotEqual, int64_t, bool, math::NotEqualFunctor);
DEFINE_BINARY_FUNC(NotEqual, float, bool, math::NotEqualFunctor);
DEFINE_BINARY_FUNC(NotEqual, double, bool, math::NotEqualFunctor);
DEFINE_BINARY_FUNC(Less, int8_t, bool, math::LessFunctor);
DEFINE_BINARY_FUNC(Less, uint8_t, bool, math::LessFunctor);
DEFINE_BINARY_FUNC(Less, int, bool, math::LessFunctor);
DEFINE_BINARY_FUNC(Less, int64_t, bool, math::LessFunctor);
DEFINE_BINARY_FUNC(Less, float, bool, math::LessFunctor);
DEFINE_BINARY_FUNC(Less, double, bool, math::LessFunctor);
DEFINE_BINARY_FUNC(LessEqual, int8_t, bool, math::LessEqualFunctor);
DEFINE_BINARY_FUNC(LessEqual, uint8_t, bool, math::LessEqualFunctor);
DEFINE_BINARY_FUNC(LessEqual, int, bool, math::LessEqualFunctor);
DEFINE_BINARY_FUNC(LessEqual, int64_t, bool, math::LessEqualFunctor);
DEFINE_BINARY_FUNC(LessEqual, float, bool, math::LessEqualFunctor);
DEFINE_BINARY_FUNC(LessEqual, double, bool, math::LessEqualFunctor);
DEFINE_BINARY_FUNC(Greater, int8_t, bool, math::GreaterFunctor);
DEFINE_BINARY_FUNC(Greater, uint8_t, bool, math::GreaterFunctor);
DEFINE_BINARY_FUNC(Greater, int, bool, math::GreaterFunctor);
DEFINE_BINARY_FUNC(Greater, int64_t, bool, math::GreaterFunctor);
DEFINE_BINARY_FUNC(Greater, float, bool, math::GreaterFunctor);
DEFINE_BINARY_FUNC(Greater, double, bool, math::GreaterFunctor);
DEFINE_BINARY_FUNC(GreaterEqual, int8_t, bool, math::GreaterEqualFunctor);
DEFINE_BINARY_FUNC(GreaterEqual, uint8_t, bool, math::GreaterEqualFunctor);
DEFINE_BINARY_FUNC(GreaterEqual, int, bool, math::GreaterEqualFunctor);
DEFINE_BINARY_FUNC(GreaterEqual, int64_t, bool, math::GreaterEqualFunctor);
DEFINE_BINARY_FUNC(GreaterEqual, float, bool, math::GreaterEqualFunctor);
DEFINE_BINARY_FUNC(GreaterEqual, double, bool, math::GreaterEqualFunctor);
#undef DEFINE_BINARY_FUNC
#define DEFINE_BINARY_FUNC(name, T, dtype) \
......@@ -993,74 +864,74 @@ DEFINE_BINARY_FUNC(Sub, bool, uint8_t); // Xor
DEFINE_BINARY_FUNC(Mul, bool, uint8_t); // And
#undef DEFINE_BINARY_FUNC
#define DEFINE_BINARY_FUNC(name, HalfOp, Half2Op) \
template <> \
DRAGON_API void name<float16, CUDAContext>( \
const int n, \
const float16* a, \
const float16* b, \
float16* y, \
CUDAContext* ctx) { \
if ((n & 1) == 0) { \
_SimpleBinaryFunc<<< \
CUDA_BLOCKS(n >> 1), \
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>( \
n >> 1, \
Half2Op<half2>(), \
reinterpret_cast<const half2*>(a), \
reinterpret_cast<const half2*>(b), \
reinterpret_cast<half2*>(y)); \
} else { \
_SimpleBinaryFunc<<< \
CUDA_BLOCKS(n), \
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>( \
n, \
HalfOp<half>(), \
reinterpret_cast<const half*>(a), \
reinterpret_cast<const half*>(b), \
reinterpret_cast<half*>(y)); \
} \
}
DEFINE_BINARY_FUNC(Add, AddHalfOp, AddHalf2Op);
DEFINE_BINARY_FUNC(Sub, SubHalfOp, SubHalf2Op);
DEFINE_BINARY_FUNC(Mul, MulHalfOp, MulHalf2Op);
DEFINE_BINARY_FUNC(Pow, PowHalfOp, PowHalf2Op);
DEFINE_BINARY_FUNC(Minimum, MinHalfOp, MinHalf2Op);
DEFINE_BINARY_FUNC(Maximum, MaxHalfOp, MaxHalf2Op);
#define DEFINE_BINARY_FUNC(name, Functor) \
template <> \
DRAGON_API void name<float16, CUDAContext>( \
const int n, \
const float16* a, \
const float16* b, \
float16* y, \
CUDAContext* ctx) { \
if ((n & 1) == 0) { \
_SimpleBinaryFunc<<< \
CUDA_BLOCKS(n >> 1), \
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>( \
n >> 1, \
Functor<half2>(), \
reinterpret_cast<const half2*>(a), \
reinterpret_cast<const half2*>(b), \
reinterpret_cast<half2*>(y)); \
} else { \
_SimpleBinaryFunc<<< \
CUDA_BLOCKS(n), \
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>( \
n, \
Functor<half>(), \
reinterpret_cast<const half*>(a), \
reinterpret_cast<const half*>(b), \
reinterpret_cast<half*>(y)); \
} \
}
DEFINE_BINARY_FUNC(Add, math::PlusFunctor);
DEFINE_BINARY_FUNC(Sub, math::MinusFunctor);
DEFINE_BINARY_FUNC(Mul, math::MultipliesFunctor);
DEFINE_BINARY_FUNC(Div, math::DividesFunctor);
DEFINE_BINARY_FUNC(Pow, math::PowFunctor);
DEFINE_BINARY_FUNC(Minimum, math::MinFunctor);
DEFINE_BINARY_FUNC(Maximum, math::MaxFunctor);
#undef DEFINE_BINARY_FUNC
#define DEFINE_BINARY_FUNC(name, TOut1, TOut2, HalfOp) \
template <> \
DRAGON_API void name<float16, CUDAContext>( \
const int n, \
const float16* a, \
const float16* b, \
TOut1* y, \
CUDAContext* ctx) { \
_SimpleBinaryFunc<<< \
CUDA_BLOCKS(n), \
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>( \
n, \
HalfOp<half>(), \
reinterpret_cast<const half*>(a), \
reinterpret_cast<const half*>(b), \
reinterpret_cast<TOut2*>(y)); \
}
DEFINE_BINARY_FUNC(Div, float16, half, DivHalfOp);
DEFINE_BINARY_FUNC(Equal, bool, bool, EqualHalfOp);
DEFINE_BINARY_FUNC(NotEqual, bool, bool, NotEqualHalfOp);
DEFINE_BINARY_FUNC(Less, bool, bool, LessHalfOp);
DEFINE_BINARY_FUNC(LessEqual, bool, bool, LessEqualHalfOp);
DEFINE_BINARY_FUNC(Greater, bool, bool, GreaterHalfOp);
DEFINE_BINARY_FUNC(GreaterEqual, bool, bool, GreaterEqualHalfOp);
#define DEFINE_BINARY_FUNC(name, Functor) \
template <> \
DRAGON_API void name<float16, CUDAContext>( \
const int n, \
const float16* a, \
const float16* b, \
bool* y, \
CUDAContext* ctx) { \
_SimpleBinaryFunc<<< \
CUDA_BLOCKS(n), \
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>( \
n, \
Functor<half>(), \
reinterpret_cast<const half*>(a), \
reinterpret_cast<const half*>(b), \
y); \
}
DEFINE_BINARY_FUNC(Equal, math::EqualFunctor);
DEFINE_BINARY_FUNC(NotEqual, math::NotEqualFunctor);
DEFINE_BINARY_FUNC(Less, math::LessFunctor);
DEFINE_BINARY_FUNC(LessEqual, math::LessEqualFunctor);
DEFINE_BINARY_FUNC(Greater, math::GreaterFunctor);
DEFINE_BINARY_FUNC(GreaterEqual, math::GreaterEqualFunctor);
#undef DEFINE_BINARY_FUNC
#define DEFINE_WHERE_FUNC(T) \
......
......@@ -23,6 +23,9 @@ template <typename T, class Context>
DRAGON_API void Abs(const int n, const T* x, T* y, Context* ctx);
template <typename T, class Context>
DRAGON_API void Neg(const int n, const T* x, T* y, Context* ctx);
template <typename T, class Context>
DRAGON_API void Ceil(const int n, const T* x, T* y, Context* ctx);
template <typename T, class Context>
......
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_UTILS_MATH_FUNCTIONAL_H_
#define DRAGON_UTILS_MATH_FUNCTIONAL_H_
#include "dragon/core/types.h"
#include "dragon/utils/cast.h"
namespace dragon {
namespace math {
/*
* Binary Arithmetic Functors
*/
template <typename T>
struct MaxFunctor {
#if defined(__CUDACC__)
inline __host__ __device__ T operator()(const T& lhs, const T& rhs) const {
return lhs < rhs ? rhs : lhs;
}
#else
inline T operator()(const T& lhs, const T& rhs) const {
return lhs < rhs ? rhs : lhs;
}
#endif
};
template <>
struct MaxFunctor<float16> {
#if defined(__CUDACC__)
inline __host__ __device__ float16
operator()(const float16& lhs, const float16& rhs) const {
#if __CUDA_ARCH__ >= 530
return __hlt(
*reinterpret_cast<const half*>(&lhs),
*reinterpret_cast<const half*>(&rhs))
? rhs
: lhs;
#else
return __half2float(*reinterpret_cast<const half*>(&lhs)) <
__half2float(*reinterpret_cast<const half*>(&rhs))
? rhs
: lhs;
#endif
}
#else
inline float16 operator()(const float16& lhs, const float16& rhs) const {
return cast::to<float>(lhs) < cast::to<float>(rhs) ? rhs : lhs;
}
#endif
};
#if defined(__CUDACC__)
template <>
struct MaxFunctor<half> {
inline __host__ __device__ half
operator()(const half& lhs, const half& rhs) const {
#if __CUDA_ARCH__ >= 530
return __hlt(lhs, rhs) ? rhs : lhs;
#else
return __half2float(lhs) < __half2float(rhs) ? rhs : lhs;
#endif
}
};
template <>
struct MaxFunctor<half2> {
inline __host__ __device__ half2
operator()(const half2& lhs, const half2& rhs) const {
const float2 v1 = __half22float2(lhs);
const float2 v2 = __half22float2(rhs);
return __floats2half2_rn(
v1.x < v2.x ? v2.x : v1.x, v1.y < v2.y ? v2.y : v1.y);
}
};
#endif
template <typename T>
struct MinFunctor {
#if defined(__CUDACC__)
inline __host__ __device__ T operator()(const T& lhs, const T& rhs) const {
return lhs < rhs ? lhs : rhs;
}
#else
inline T operator()(const T& lhs, const T& rhs) const {
return lhs < rhs ? lhs : rhs;
}
#endif
};
template <>
struct MinFunctor<float16> {
#if defined(__CUDACC__)
inline __host__ __device__ float16
operator()(const float16& lhs, const float16& rhs) const {
#if __CUDA_ARCH__ >= 530
return __hlt(
*reinterpret_cast<const half*>(&lhs),
*reinterpret_cast<const half*>(&rhs))
? lhs
: rhs;
#else
return __half2float(*reinterpret_cast<const half*>(&lhs)) <
__half2float(*reinterpret_cast<const half*>(&rhs))
? lhs
: rhs;
#endif
}
#else
inline float16 operator()(const float16& lhs, const float16& rhs) const {
return cast::to<float>(lhs) < cast::to<float>(rhs) ? lhs : rhs;
}
#endif
};
#if defined(__CUDACC__)
template <>
struct MinFunctor<half> {
inline __host__ __device__ half
operator()(const half& lhs, const half& rhs) const {
#if __CUDA_ARCH__ >= 530
return __hlt(lhs, rhs) ? lhs : rhs;
#else
return __half2float(lhs) < __half2float(rhs) ? lhs : rhs;
#endif
}
};
template <>
struct MinFunctor<half2> {
inline __host__ __device__ half2
operator()(const half2& lhs, const half2& rhs) const {
const float2 v1 = __half22float2(lhs);
const float2 v2 = __half22float2(rhs);
return __floats2half2_rn(
v1.x < v2.x ? v1.x : v2.x, v1.y < v2.y ? v1.y : v2.y);
}
};
#endif
template <typename T>
struct PlusFunctor {
#if defined(__CUDACC__)
inline __host__ __device__ T operator()(const T& lhs, const T& rhs) const {
return lhs + rhs;
}
#else
inline T operator()(const T& lhs, const T& rhs) const {
return lhs + rhs;
}
#endif
};
template <>
struct PlusFunctor<float16> {
#if defined(__CUDACC__)
inline __host__ __device__ float16
operator()(const float16& lhs, const float16& rhs) const {
#if __CUDA_ARCH__ >= 530
half ret = __hadd(
*reinterpret_cast<const half*>(&lhs),
*reinterpret_cast<const half*>(&rhs));
#else
half ret = __float2half(
__half2float(*reinterpret_cast<const half*>(&lhs)) +
__half2float(*reinterpret_cast<const half*>(&rhs)));
#endif
return *reinterpret_cast<float16*>(&ret);
}
#else
inline float16 operator()(const float16& lhs, const float16& rhs) const {
return cast::to<float16>(cast::to<float>(lhs) + cast::to<float>(rhs));
}
#endif
};
#if defined(__CUDACC__)
template <>
struct PlusFunctor<half> {
inline __host__ __device__ half
operator()(const half& lhs, const half& rhs) const {
#if __CUDA_ARCH__ >= 530
return __hadd(lhs, rhs);
#else
return __float2half(__half2float(lhs) + __half2float(rhs));
#endif
}
};
template <>
struct PlusFunctor<half2> {
inline __host__ __device__ half2
operator()(const half2& lhs, const half2& rhs) const {
#if __CUDA_ARCH__ >= 530
return __hadd2(lhs, rhs);
#else
const float2 v1 = __half22float2(lhs);
const float2 v2 = __half22float2(rhs);
return __floats2half2_rn(v1.x + v2.x, v1.y + v2.y);
#endif
}
};
#endif
template <typename T>
struct MinusFunctor {
#if defined(__CUDACC__)
inline __host__ __device__ T operator()(const T& lhs, const T& rhs) const {
return lhs - rhs;
}
#else
inline T operator()(const T& lhs, const T& rhs) const {
return lhs - rhs;
}
#endif
};
template <>
struct MinusFunctor<float16> {
#if defined(__CUDACC__)
inline __host__ __device__ float16
operator()(const float16& lhs, const float16& rhs) const {
#if __CUDA_ARCH__ >= 530
half ret = __hsub(
*reinterpret_cast<const half*>(&lhs),
*reinterpret_cast<const half*>(&rhs));
#else
half ret = __float2half(
__half2float(*reinterpret_cast<const half*>(&lhs)) -
__half2float(*reinterpret_cast<const half*>(&rhs)));
#endif
return *reinterpret_cast<float16*>(&ret);
}
#else
inline float16 operator()(const float16& lhs, const float16& rhs) const {
return cast::to<float16>(cast::to<float>(lhs) - cast::to<float>(rhs));
}
#endif
};
#if defined(__CUDACC__)
template <>
struct MinusFunctor<half> {
inline __host__ __device__ half
operator()(const half& lhs, const half& rhs) const {
#if __CUDA_ARCH__ >= 530
return __hsub(lhs, rhs);
#else
return __float2half(__half2float(lhs) - __half2float(rhs));
#endif
}
};
template <>
struct MinusFunctor<half2> {
inline __host__ __device__ half2
operator()(const half2& lhs, const half2& rhs) const {
#if __CUDA_ARCH__ >= 530
return __hsub2(lhs, rhs);
#else
const float2 v1 = __half22float2(lhs);
const float2 v2 = __half22float2(rhs);
return __floats2half2_rn(v1.x - v2.x, v1.y - v2.y);
#endif
}
};
#endif
template <typename T>
struct MultipliesFunctor {
#if defined(__CUDACC__)
inline __host__ __device__ T operator()(const T& lhs, const T& rhs) const {
return lhs * rhs;
}
#else
inline T operator()(const T& lhs, const T& rhs) const {
return lhs * rhs;
}
#endif
};
template <>
struct MultipliesFunctor<float16> {
#if defined(__CUDACC__)
inline __host__ __device__ float16
operator()(const float16& lhs, const float16& rhs) const {
#if __CUDA_ARCH__ >= 530
half ret = __hmul(
*reinterpret_cast<const half*>(&lhs),
*reinterpret_cast<const half*>(&rhs));
#else
half ret = __float2half(
__half2float(*reinterpret_cast<const half*>(&lhs)) *
__half2float(*reinterpret_cast<const half*>(&rhs)));
#endif
return *reinterpret_cast<float16*>(&ret);
}
#else
inline float16 operator()(const float16& lhs, const float16& rhs) const {
return cast::to<float16>(cast::to<float>(lhs) * cast::to<float>(rhs));
}
#endif
};
#if defined(__CUDACC__)
template <>
struct MultipliesFunctor<half> {
inline __host__ __device__ half
operator()(const half& lhs, const half& rhs) const {
#if __CUDA_ARCH__ >= 530
return __hmul(lhs, rhs);
#else
return __float2half(__half2float(lhs) * __half2float(rhs));
#endif
}
};
template <>
struct MultipliesFunctor<half2> {
inline __host__ __device__ half2
operator()(const half2& lhs, const half2& rhs) const {
#if __CUDA_ARCH__ >= 530
return __hmul2(lhs, rhs);
#else
const float2 v1 = __half22float2(lhs);
const float2 v2 = __half22float2(rhs);
return __floats2half2_rn(v1.x * v2.x, v1.y * v2.y);
#endif
}
};
#endif
template <typename T>
struct DividesFunctor {
#if defined(__CUDACC__)
inline __host__ __device__ T operator()(const T& lhs, const T& rhs) const {
return lhs / rhs;
}
#else
inline T operator()(const T& lhs, const T& rhs) const {
return lhs / rhs;
}
#endif
};
template <>
struct DividesFunctor<float16> {
#if defined(__CUDACC__)
inline __host__ __device__ float16
operator()(const float16& lhs, const float16& rhs) const {
#if __CUDA_ARCH__ >= 530
half ret = __hdiv(
*reinterpret_cast<const half*>(&lhs),
*reinterpret_cast<const half*>(&rhs));
#else
half ret = __float2half(
__half2float(*reinterpret_cast<const half*>(&lhs)) /
__half2float(*reinterpret_cast<const half*>(&rhs)));
#endif
return *reinterpret_cast<float16*>(&ret);
}
#else
inline float16 operator()(const float16& lhs, const float16& rhs) const {
return cast::to<float16>(cast::to<float>(lhs) / cast::to<float>(rhs));
}
#endif
};
#if defined(__CUDACC__)
template <>
struct DividesFunctor<half> {
inline __host__ __device__ half
operator()(const half& lhs, const half& rhs) const {
#if __CUDA_ARCH__ >= 530
return __hdiv(lhs, rhs);
#else
return __float2half(__half2float(lhs) / __half2float(rhs));
#endif
}
};
template <>
struct DividesFunctor<half2> {
inline __host__ __device__ half2
operator()(const half2& lhs, const half2& rhs) const {
const float2 v1 = __half22float2(lhs);
const float2 v2 = __half22float2(rhs);
return __floats2half2_rn(v1.x / v2.x, v1.y / v2.y);
}
};
#endif
template <typename T>
struct PowFunctor {
#if defined(__CUDACC__)
inline __host__ __device__ T operator()(const T& lhs, const T& rhs) const {
return pow(lhs, rhs);
}
#else
inline T operator()(const T& lhs, const T& rhs) const {
return std::pow(lhs, rhs);
}
#endif
};
template <>
struct PowFunctor<float16> {
#if defined(__CUDACC__)
inline __host__ __device__ float16
operator()(const float16& lhs, const float16& rhs) const {
half ret = __float2half(
pow(__half2float(*reinterpret_cast<const half*>(&lhs)),
__half2float(*reinterpret_cast<const half*>(&rhs))));
return *reinterpret_cast<float16*>(&ret);
}
#else
inline float16 operator()(const float16& lhs, const float16& rhs) const {
return cast::to<float16>(
std::pow(cast::to<float>(lhs), cast::to<float>(rhs)));
}
#endif
};
#if defined(__CUDACC__)
template <>
struct PowFunctor<half> {
inline __host__ __device__ half
operator()(const half& lhs, const half& rhs) const {
return __float2half(pow(__half2float(lhs), __half2float(rhs)));
}
};
template <>
struct PowFunctor<half2> {
inline __host__ __device__ half2
operator()(const half2& lhs, const half2& rhs) const {
const float2 v1 = __half22float2(lhs);
const float2 v2 = __half22float2(rhs);
return __floats2half2_rn(pow(v1.x, v2.x), pow(v1.y, v2.y));
}
};
#endif
/*
* Compare Functors
*/
template <typename T>
struct EqualFunctor {
#if defined(__CUDACC__)
inline __host__ __device__ bool operator()(const T& lhs, const T& rhs) const {
return lhs == rhs;
}
#else
inline bool operator()(const T& lhs, const T& rhs) const {
return lhs == rhs;
}
#endif
};
template <>
struct EqualFunctor<float16> {
#if defined(__CUDACC__)
inline __host__ __device__ bool operator()(
const float16& lhs,
const float16& rhs) const {
#if __CUDA_ARCH__ >= 530
return __heq(
*reinterpret_cast<const half*>(&lhs),
*reinterpret_cast<const half*>(&rhs));
#else
return __half2float(*reinterpret_cast<const half*>(&lhs)) ==
__half2float(*reinterpret_cast<const half*>(&rhs));
#endif
}
#else
inline bool operator()(const float16& lhs, const float16& rhs) const {
return cast::to<float>(lhs) == cast::to<float>(rhs);
}
#endif
};
#if defined(__CUDACC__)
template <>
struct EqualFunctor<half> {
inline __host__ __device__ bool operator()(const half& lhs, const half& rhs)
const {
#if __CUDA_ARCH__ >= 530
return __heq(lhs, rhs);
#else
return __half2float(lhs) == __half2float(rhs);
#endif
}
};
#endif
template <typename T>
struct NotEqualFunctor {
#if defined(__CUDACC__)
inline __host__ __device__ bool operator()(const T& lhs, const T& rhs) const {
return lhs != rhs;
}
#else
inline bool operator()(const T& lhs, const T& rhs) const {
return lhs != rhs;
}
#endif
};
template <>
struct NotEqualFunctor<float16> {
#if defined(__CUDACC__)
inline __host__ __device__ bool operator()(
const float16& lhs,
const float16& rhs) const {
#if __CUDA_ARCH__ >= 530
return __hne(
*reinterpret_cast<const half*>(&lhs),
*reinterpret_cast<const half*>(&rhs));
#else
return __half2float(*reinterpret_cast<const half*>(&lhs)) !=
__half2float(*reinterpret_cast<const half*>(&rhs));
#endif
}
#else
inline bool operator()(const float16& lhs, const float16& rhs) const {
return cast::to<float>(lhs) != cast::to<float>(rhs);
}
#endif
};
#if defined(__CUDACC__)
template <>
struct NotEqualFunctor<half> {
inline __host__ __device__ bool operator()(const half& lhs, const half& rhs)
const {
#if __CUDA_ARCH__ >= 530
return __hne(lhs, rhs);
#else
return __half2float(lhs) != __half2float(rhs);
#endif
}
};
#endif
template <typename T>
struct GreaterFunctor {
#if defined(__CUDACC__)
inline __host__ __device__ bool operator()(const T& lhs, const T& rhs) const {
return lhs > rhs;
}
#else
inline bool operator()(const T& lhs, const T& rhs) const {
return lhs > rhs;
}
#endif
};
template <>
struct GreaterFunctor<float16> {
#if defined(__CUDACC__)
inline __host__ __device__ bool operator()(
const float16& lhs,
const float16& rhs) const {
#if __CUDA_ARCH__ >= 530
return __hgt(
*reinterpret_cast<const half*>(&lhs),
*reinterpret_cast<const half*>(&rhs));
#else
return __half2float(*reinterpret_cast<const half*>(&lhs)) >
__half2float(*reinterpret_cast<const half*>(&rhs));
#endif
}
#else
inline bool operator()(const float16& lhs, const float16& rhs) const {
return cast::to<float>(lhs) > cast::to<float>(rhs);
}
#endif
};
#if defined(__CUDACC__)
template <>
struct GreaterFunctor<half> {
inline __host__ __device__ bool operator()(const half& lhs, const half& rhs)
const {
#if __CUDA_ARCH__ >= 530
return __hgt(lhs, rhs);
#else
return __half2float(lhs) > __half2float(rhs);
#endif
}
};
#endif
template <typename T>
struct LessFunctor {
#if defined(__CUDACC__)
inline __host__ __device__ bool operator()(const T& lhs, const T& rhs) const {
return lhs < rhs;
}
#else
inline bool operator()(const T& lhs, const T& rhs) const {
return lhs < rhs;
}
#endif
};
template <>
struct LessFunctor<float16> {
#if defined(__CUDACC__)
inline __host__ __device__ bool operator()(
const float16& lhs,
const float16& rhs) const {
#if __CUDA_ARCH__ >= 530
return __hlt(
*reinterpret_cast<const half*>(&lhs),
*reinterpret_cast<const half*>(&rhs));
#else
return __half2float(*reinterpret_cast<const half*>(&lhs)) <
__half2float(*reinterpret_cast<const half*>(&rhs));
#endif
}
#else
inline bool operator()(const float16& lhs, const float16& rhs) const {
return cast::to<float>(lhs) < cast::to<float>(rhs);
}
#endif
};
#if defined(__CUDACC__)
template <>
struct LessFunctor<half> {
inline __host__ __device__ bool operator()(const half& lhs, const half& rhs)
const {
#if __CUDA_ARCH__ >= 530
return __hlt(lhs, rhs);
#else
return __half2float(lhs) < __half2float(rhs);
#endif
}
};
#endif
template <typename T>
struct GreaterEqualFunctor {
#if defined(__CUDACC__)
inline __host__ __device__ bool operator()(const T& lhs, const T& rhs) const {
return lhs >= rhs;
}
#else
inline bool operator()(const T& lhs, const T& rhs) const {
return lhs >= rhs;
}
#endif
};
template <>
struct GreaterEqualFunctor<float16> {
#if defined(__CUDACC__)
inline __host__ __device__ bool operator()(
const float16& lhs,
const float16& rhs) const {
#if __CUDA_ARCH__ >= 530
return __hge(
*reinterpret_cast<const half*>(&lhs),
*reinterpret_cast<const half*>(&rhs));
#else
return __half2float(*reinterpret_cast<const half*>(&lhs)) >=
__half2float(*reinterpret_cast<const half*>(&rhs));
#endif
}
#else
inline bool operator()(const float16& lhs, const float16& rhs) const {
return cast::to<float>(lhs) >= cast::to<float>(rhs);
}
#endif
};
#if defined(__CUDACC__)
template <>
struct GreaterEqualFunctor<half> {
inline __host__ __device__ bool operator()(const half& lhs, const half& rhs)
const {
#if __CUDA_ARCH__ >= 530
return __hge(lhs, rhs);
#else
return __half2float(lhs) >= __half2float(rhs);
#endif
}
};
#endif
template <typename T>
struct LessEqualFunctor {
#if defined(__CUDACC__)
inline __host__ __device__ bool operator()(const T& lhs, const T& rhs) const {
return lhs <= rhs;
}
#else
inline bool operator()(const T& lhs, const T& rhs) const {
return lhs <= rhs;
}
#endif
};
template <>
struct LessEqualFunctor<float16> {
#if defined(__CUDACC__)
inline __host__ __device__ bool operator()(
const float16& lhs,
const float16& rhs) const {
#if __CUDA_ARCH__ >= 530
return __hle(
*reinterpret_cast<const half*>(&lhs),
*reinterpret_cast<const half*>(&rhs));
#else
return __half2float(*reinterpret_cast<const half*>(&lhs)) <
__half2float(*reinterpret_cast<const half*>(&rhs));
#endif
}
#else
inline bool operator()(const float16& lhs, const float16& rhs) const {
return cast::to<float>(lhs) <= cast::to<float>(rhs);
}
#endif
};
#if defined(__CUDACC__)
template <>
struct LessEqualFunctor<half> {
inline __host__ __device__ bool operator()(const half& lhs, const half& rhs)
const {
#if __CUDA_ARCH__ >= 530
return __hle(lhs, rhs);
#else
return __half2float(lhs) <= __half2float(rhs);
#endif
}
};
#endif
} // namespace math
} // namespace dragon
#endif // DRAGON_UTILS_MATH_FUNCTIONAL_H_
......@@ -4,6 +4,7 @@
#include "dragon/utils/device/common_cub.h"
#include "dragon/utils/device/common_thrust.h"
#include "dragon/utils/math/blas.h"
#include "dragon/utils/math/functional.h"
#include "dragon/utils/math/reduce.h"
#include "dragon/utils/math/utils.h"
......@@ -30,33 +31,7 @@ __global__ void _RowwiseReduce(
}
val = BlockReduce<T>(storage).Reduce(val, reducer);
if (threadIdx.x == 0) {
y[i] = val * scale;
}
}
}
template <class Reducer>
__global__ void _RowwiseReduce(
const int rows,
const int cols,
const Reducer reducer,
const half init,
const half scale,
const half* x,
half* y) {
__shared__ typename BlockReduce<half>::TempStorage storage;
CUDA_2D_KERNEL_LOOP1(i, cols) {
half val = init;
CUDA_2D_KERNEL_LOOP2(j, rows) {
val = reducer(val, x[j * cols + i]);
}
val = BlockReduce<half>(storage).Reduce(val, reducer);
if (threadIdx.x == 0) {
#if __CUDA_ARCH__ >= 530
y[i] = __hmul(val, scale);
#else
y[i] = __float2half(__half2float(val) * __half2float(scale));
#endif
y[i] = math::MultipliesFunctor<T>()(val, scale);
}
}
}
......@@ -78,33 +53,7 @@ __global__ void _ColwiseReduce(
}
val = BlockReduce<T>(storage).Reduce(val, reducer);
if (threadIdx.x == 0) {
y[i] = val * scale;
}
}
}
template <class Reducer>
__global__ void _ColwiseReduce(
const int rows,
const int cols,
const Reducer reducer,
const half init,
const half scale,
const half* x,
half* y) {
__shared__ typename BlockReduce<half>::TempStorage storage;
CUDA_2D_KERNEL_LOOP1(i, rows) {
half val = init;
CUDA_2D_KERNEL_LOOP2(j, cols) {
val = reducer(val, x[i * cols + j]);
}
val = BlockReduce<half>(storage).Reduce(val, reducer);
if (threadIdx.x == 0) {
#if __CUDA_ARCH__ >= 530
y[i] = __hmul(val, scale);
#else
y[i] = __float2half(__half2float(val) * __half2float(scale));
#endif
y[i] = math::MultipliesFunctor<T>()(val, scale);
}
}
}
......@@ -135,48 +84,13 @@ __global__ void _GenericReduce(
}
val = BlockReduce<T>(storage).Reduce(val, reducer);
if (threadIdx.x == 0) {
y[i] = val * scale;
}
}
}
template <class Reducer, int D>
__global__ void _GenericReduce(
const int rows,
const int cols,
const int num_dims,
const SimpleArray<int, D> x_dims,
const SimpleArray<int, D> x_strides,
const Reducer reducer,
const half init,
const half scale,
const half* x,
half* y) {
__shared__ typename BlockReduce<half>::TempStorage storage;
CUDA_2D_KERNEL_LOOP1(i, rows) {
half val = init;
CUDA_2D_KERNEL_LOOP2(j, cols) {
int xi = 0, c = i * cols + j;
for (int d = num_dims - 1; d >= 0; --d) {
int r;
FIXED_DIVISOR_DIV_MOD(x_dims.data[d], c, &c, &r);
xi += r * x_strides.data[d];
}
val = reducer(val, x[xi]);
}
val = BlockReduce<half>(storage).Reduce(val, reducer);
if (threadIdx.x == 0) {
#if __CUDA_ARCH__ >= 530
y[i] = __hmul(val, scale);
#else
y[i] = __float2half(__half2float(val) * __half2float(scale));
#endif
y[i] = math::MultipliesFunctor<T>()(val, scale);
}
}
}
#define DEFINE_REDUCE_FUNCTION(name) \
template <typename T, class Reducer> \
template <typename T, typename Reducer> \
int _Reduce##name( \
const int num_dims, \
const int* dims, \
......@@ -279,110 +193,110 @@ DEFINE_REDUCE_FUNCTION(Sum);
/* ------------------- Launcher Separator ------------------- */
#define DEFINE_KERNEL_LAUNCHER(name, Reducer, kInit) \
template <> \
void Reduce##name<float16, CUDAContext>( \
const int num_dims, \
const int* dims, \
const int num_axes, \
const int* axes, \
const float scale, \
const float16* x, \
float16* y, \
CUDAContext* ctx) { \
auto kind = _Reduce##name( \
num_dims, \
dims, \
num_axes, \
axes, \
Reducer(), \
cast::to<half>(kInit), \
scale, \
reinterpret_cast<const half*>(x), \
reinterpret_cast<half*>(y), \
ctx); \
if (kind == 0) { \
math::Scale(1, scale, y, y, ctx); \
} \
}
DEFINE_KERNEL_LAUNCHER(Max, cub::MaxHalf, -HFLT_MAX);
DEFINE_KERNEL_LAUNCHER(Min, cub::MinHalf, HFLT_MAX);
DEFINE_KERNEL_LAUNCHER(Sum, cub::SumHalf, 0.f);
#undef DEFINE_KERNEL_LAUNCHER
#define DEFINE_KERNEL_LAUNCHER(name, T, Reducer, kInit) \
template <> \
void Reduce##name<T, CUDAContext>( \
const int num_dims, \
const int* dims, \
const int num_axes, \
const int* axes, \
const float scale, \
const T* x, \
T* y, \
CUDAContext* ctx) { \
auto kind = _Reduce##name( \
num_dims, dims, num_axes, axes, Reducer(), kInit, scale, x, y, ctx); \
if (kind == 0) { \
math::Scale(1, scale, y, y, ctx); \
} \
#define DEFINE_KERNEL_LAUNCHER(name, T, Reducer, kInit) \
template <> \
void Reduce##name<T, CUDAContext>( \
const int num_dims, \
const int* dims, \
const int num_axes, \
const int* axes, \
const float scale, \
const T* x, \
T* y, \
CUDAContext* ctx) { \
auto kind = _Reduce##name( \
num_dims, \
dims, \
num_axes, \
axes, \
Reducer<T>(), \
kInit, \
scale, \
x, \
y, \
ctx); \
if (kind == 0) { \
math::Scale(1, scale, y, y, ctx); \
} \
}
DEFINE_KERNEL_LAUNCHER(
Max,
int8_t,
cub::Max,
math::MaxFunctor,
std::numeric_limits<int8_t>::lowest());
DEFINE_KERNEL_LAUNCHER(
Max,
uint8_t,
cub::Max,
math::MaxFunctor,
std::numeric_limits<uint8_t>::lowest());
DEFINE_KERNEL_LAUNCHER(Max, int, cub::Max, std::numeric_limits<int>::lowest());
DEFINE_KERNEL_LAUNCHER(
Max,
int,
math::MaxFunctor,
std::numeric_limits<int>::lowest());
DEFINE_KERNEL_LAUNCHER(
Max,
int64_t,
cub::Max,
math::MaxFunctor,
std::numeric_limits<int64_t>::lowest());
DEFINE_KERNEL_LAUNCHER(
Max,
float16,
math::MaxFunctor,
cast::to<float16>(cub::Traits<half>::Lowest()));
DEFINE_KERNEL_LAUNCHER(
Max,
float,
cub::Max,
math::MaxFunctor,
std::numeric_limits<float>::lowest());
DEFINE_KERNEL_LAUNCHER(
Max,
double,
cub::Max,
math::MaxFunctor,
std::numeric_limits<double>::lowest());
DEFINE_KERNEL_LAUNCHER(
Min,
int8_t,
cub::Min,
math::MinFunctor,
std::numeric_limits<int8_t>::max());
DEFINE_KERNEL_LAUNCHER(
Min,
uint8_t,
cub::Min,
math::MinFunctor,
std::numeric_limits<uint8_t>::max());
DEFINE_KERNEL_LAUNCHER(Min, int, cub::Min, std::numeric_limits<int>::max());
DEFINE_KERNEL_LAUNCHER(
Min,
int,
math::MinFunctor,
std::numeric_limits<int>::max());
DEFINE_KERNEL_LAUNCHER(
Min,
int64_t,
cub::Min,
math::MinFunctor,
std::numeric_limits<int64_t>::max());
DEFINE_KERNEL_LAUNCHER(Min, float, cub::Min, std::numeric_limits<float>::max());
DEFINE_KERNEL_LAUNCHER(
Min,
float16,
math::MinFunctor,
cast::to<float16>(cub::Traits<half>::Max()));
DEFINE_KERNEL_LAUNCHER(
Min,
float,
math::MinFunctor,
std::numeric_limits<float>::max());
DEFINE_KERNEL_LAUNCHER(
Min,
double,
cub::Min,
math::MinFunctor,
std::numeric_limits<double>::max());
DEFINE_KERNEL_LAUNCHER(Sum, int8_t, cub::Sum, int8_t(0));
DEFINE_KERNEL_LAUNCHER(Sum, uint8_t, cub::Sum, uint8_t(0));
DEFINE_KERNEL_LAUNCHER(Sum, int, cub::Sum, int(0));
DEFINE_KERNEL_LAUNCHER(Sum, int64_t, cub::Sum, int64_t(0));
DEFINE_KERNEL_LAUNCHER(Sum, float, cub::Sum, 0.f);
DEFINE_KERNEL_LAUNCHER(Sum, double, cub::Sum, 0.);
DEFINE_KERNEL_LAUNCHER(Sum, int8_t, math::PlusFunctor, int8_t(0));
DEFINE_KERNEL_LAUNCHER(Sum, uint8_t, math::PlusFunctor, uint8_t(0));
DEFINE_KERNEL_LAUNCHER(Sum, int, math::PlusFunctor, int(0));
DEFINE_KERNEL_LAUNCHER(Sum, int64_t, math::PlusFunctor, int64_t(0));
DEFINE_KERNEL_LAUNCHER(Sum, float16, math::PlusFunctor, cast::to<float16>(0.f));
DEFINE_KERNEL_LAUNCHER(Sum, float, math::PlusFunctor, 0.f);
DEFINE_KERNEL_LAUNCHER(Sum, double, math::PlusFunctor, 0.);
#undef DEFINE_KERNEL_LAUNCHER
#define DEFINE_SUM_FUNC(T) \
......
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_UTILS_MATH_SORT_H_
#define DRAGON_UTILS_MATH_SORT_H_
#endif // DRAGON_UTILS_MATH_SORT_H_
......@@ -16,6 +16,7 @@
#include "dragon/utils/math/blas.h"
#include "dragon/utils/math/broadcast.h"
#include "dragon/utils/math/elementwise.h"
#include "dragon/utils/math/functional.h"
#include "dragon/utils/math/random.h"
#include "dragon/utils/math/reduce.h"
#include "dragon/utils/math/utils.h"
......
......@@ -308,7 +308,7 @@ void IndexSelect(
const int outer_dim,
const int inner_dim,
const int axis_dim,
const int num_indices,
const int select_dim,
const int64_t* indices,
const T* x,
T* y,
......@@ -319,7 +319,7 @@ void IndexSelectGrad(
const int outer_dim,
const int inner_dim,
const int axis_dim,
const int num_indices,
const int select_dim,
const int64_t* index,
const T* dy,
T* dx,
......@@ -539,11 +539,11 @@ void TransposeGrad(
/* array.top_k */
template <typename T, class Context>
void TopK(
void TopSelect(
const int outer_dim,
const int inner_dim,
const int axis_dim,
const int top_k,
const int topk,
const int largest,
const T* x,
T* value,
......@@ -551,6 +551,7 @@ void TopK(
Context* ctx);
/* array.unique */
template <typename T, class Context>
void Unique(
const int dim,
......@@ -562,6 +563,7 @@ void Unique(
Context* ctx);
/* control_flow.assgin */
template <typename T, class Context>
void Assign(
const int num_dims,
......
......@@ -112,6 +112,8 @@ from dragon.vm.tensorflow.core.ops.math_ops import square
from dragon.vm.tensorflow.core.ops.math_ops import subtract
from dragon.vm.tensorflow.core.ops.math_ops import tanh
from dragon.vm.tensorflow.core.ops.gradients_impl import gradients
from dragon.vm.tensorflow.core.ops.sort_ops import argsort
from dragon.vm.tensorflow.core.ops.sort_ops import sort
from dragon.vm.tensorflow.core.ops.variables import Variable
# Attributes
......
......@@ -12,7 +12,7 @@
# <https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/array_ops.py>
#
# ------------------------------------------------------------
"""The array ops."""
"""Array ops."""
from __future__ import absolute_import
from __future__ import division
......
......@@ -12,7 +12,7 @@
# <https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/bitwise_ops.py>
#
# ------------------------------------------------------------
"""The bitwise ops."""
"""Bitwise ops."""
from __future__ import absolute_import
from __future__ import division
......
......@@ -12,7 +12,7 @@
# <https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/clip_ops.py>
#
# ------------------------------------------------------------
"""The clip ops."""
"""Clip ops."""
from __future__ import absolute_import
from __future__ import division
......
......@@ -8,7 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Grad implementation."""
"""Gradient implementation."""
from __future__ import absolute_import
from __future__ import division
......
......@@ -12,7 +12,7 @@
# <https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/init_ops.py>
#
# ------------------------------------------------------------
"""The init ops."""
"""Init ops."""
from __future__ import absolute_import
from __future__ import division
......
......@@ -12,7 +12,7 @@
# <https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/linalg_ops.py>
#
# ------------------------------------------------------------
"""The linalg ops."""
"""Linalg ops."""
from __future__ import absolute_import
from __future__ import division
......
......@@ -12,7 +12,7 @@
# <https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/math_ops.py>
#
# ------------------------------------------------------------
"""The math ops."""
"""Math ops."""
from __future__ import absolute_import
from __future__ import division
......
......@@ -8,8 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""The nn components."""
"""NN components."""
from __future__ import absolute_import
from __future__ import division
......
......@@ -8,7 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""The nn ops implementation."""
"""NN implementation."""
from __future__ import absolute_import
from __future__ import division
......
......@@ -8,7 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""The nn ops."""
"""NN ops."""
from __future__ import absolute_import
from __future__ import division
......
......@@ -8,7 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""The random ops."""
"""Random ops."""
from __future__ import absolute_import
from __future__ import division
......
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Sort ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from dragon.core.ops import array_ops
def argsort(values, axis=-1, direction='ASCENDING', name=None):
"""Return the index of sorted elements along the given axis.
By default, the last axis is chosen:
```python
x = tf.constant([[1, 2, 3], [3, 2, 1]])
index1 = tf.argsort(x)
index2 = tf.argsort(x, axis=1) # Equivalent
```
Sort in the inverse order if ``direction`` is ``DESCENDING``:
```python
x = tf.constant([1, 2, 3])
index1 = tf.argsort(-x)
index2 = tf.argsort(x, direction='DESCENDING') # Equivalent
```
Parameters
----------
values : dragon.Tensor
The input tensor.
axis : int, optional, default=-1
The axis to sort elements.
direction : {'ASCENDING', 'DESCENDING'}, optional
The sorting direction.
name : str, optional
The operation name.
Returns
-------
dragon.Tensor
The index tensor.
"""
if direction not in ('ASCENDING', 'DESCENDING'):
raise ValueError('Unknown direction: ' + direction)
value_and_index = array_ops.sort(
values,
axis=axis,
descending=direction == 'DESCENDING',
name=name,
)
return value_and_index[1]
def sort(values, axis=-1, direction='ASCENDING', name=None):
"""Return the sorted elements along the given axis.
By default, the last axis is chosen:
```python
x = tf.constant([[1, 2, 3], [3, 2, 1]])
value1 = tf.sort(x)
value2 = tf.sort(x, axis=1) # Equivalent
```
Sort in the inverse order if ``direction`` is ``DESCENDING``:
```python
x = tf.constant([1, 2, 3])
value1 = -tf.sort(-x)
value2 = tf.sort(x, direction='DESCENDING') # Equivalent
```
Parameters
----------
values : dragon.Tensor
The input tensor.
axis : int, optional, default=-1
The axis to sort elements.
direction : {'ASCENDING', 'DESCENDING'}, optional
The sorting direction.
name : str, optional
The operation name.
Returns
-------
dragon.Tensor
The value tensor.
"""
if direction not in ('ASCENDING', 'DESCENDING'):
raise ValueError('Unknown direction: ' + direction)
value_and_index = array_ops.sort(
values,
axis=axis,
descending=direction == 'DESCENDING',
name=name,
)
return value_and_index[0]
......@@ -8,7 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""The standard ops."""
"""Standard ops."""
from __future__ import absolute_import
from __future__ import division
......
......@@ -8,7 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""The Variable class."""
"""Variable class."""
from __future__ import absolute_import
from __future__ import division
......
......@@ -810,6 +810,28 @@ class TestArrayOps(OpTestCase):
with dragon.device('cuda'):
self.test_slice()
def test_sort(self):
entries = [(None, True),
(0, True),
(-1, True),
(0, False),
(-1, False)]
for execution in ('EAGER_MODE', 'GRAPH_MODE'):
with execution_context().mode(execution):
for axis, descending in entries:
data = uniform((5, 10))
x = new_tensor(data)
y = dragon.sort(x, axis=axis, descending=descending)
axis = axis if axis is not None else -1
result = np.argsort(-data if descending else data, axis=axis)
result = np.take(result, np.arange(data.shape[axis]), axis=axis)
self.assertEqual(y[1], result)
@unittest.skipIf(not TEST_CUDA, 'CUDA unavailable')
def test_sort_cuda(self):
with dragon.device('cuda'):
self.test_sort()
def test_split(self):
entries = [(2, 1, None), ((2, 1), 1, None), (2, 1, (2,))]
for execution in ('EAGER_MODE', 'GRAPH_MODE'):
......
......@@ -483,6 +483,21 @@ class TestTensorOps(OpTestCase):
x = new_tensor(data)
self.assertEqual(x.sin(), np.sin(data))
def test_sort(self):
entries = [(None, True),
(0, True),
(-1, True),
(0, False),
(-1, False)]
for axis, descending in entries:
data = uniform((5, 10))
x = new_tensor(data)
y = x.sort(axis, descending)[1]
axis = axis if axis is not None else -1
result = np.argsort(-data if descending else data, axis=axis)
result = np.take(result, np.arange(data.shape[axis]), axis=axis)
self.assertEqual(y, result)
def test_sqrt(self):
data = np.array([4., 9., 16], 'float32')
x = new_tensor(data)
......
Subproject commit c3cceac115c072fb63df1836ff46d8c60d9eb304
Subproject commit a3ee304a1f8e22f278df10600df2e4b333012592
......@@ -69,6 +69,7 @@ from dragon.vm.torch.core.ops.array.functional import one_hot
from dragon.vm.torch.core.ops.array.functional import permute
from dragon.vm.torch.core.ops.array.functional import repeat
from dragon.vm.torch.core.ops.array.functional import reshape
from dragon.vm.torch.core.ops.array.functional import sort
from dragon.vm.torch.core.ops.array.functional import split
from dragon.vm.torch.core.ops.array.functional import squeeze
from dragon.vm.torch.core.ops.array.functional import stack
......
......@@ -428,6 +428,26 @@ class Slice(function.Function):
)
class Sort(function.Function):
def __init__(self, key, dev, **kwargs):
super(Sort, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', -1)
self.descending = kwargs.get('descending', False)
def attributes(self):
return {
'op_type': 'Sort',
'arguments': {
'axis': self.axis,
'descending': self.descending,
}
}
def forward(self, input, outputs=(None, None)):
outputs = [self.alloc(outputs[0]), self.alloc(outputs[1])]
return self.dispatch([input], outputs, no_grad=True)
class Split(function.Function):
def __init__(self, key, dev, **kwargs):
super(Split, self).__init__(key, dev, **kwargs)
......@@ -546,7 +566,7 @@ class TopK(function.Function):
def __init__(self, key, dev, **kwargs):
super(TopK, self).__init__(key, dev, **kwargs)
self.k = kwargs.get('k', 1)
self.axis = kwargs.get('axis', None)
self.axis = kwargs.get('axis', -1)
self.largest = kwargs.get('largest', True)
self.sorted = kwargs.get('sorted', True)
......
......@@ -795,6 +795,50 @@ def slice(input, starts, sizes):
.apply(input, starts, sizes)
def sort(input, dim=-1, descending=False, out=None):
"""Return the sorted elements along the given dimension.
By default, the last dimension is chosen:
```python
x = torch.tensor([[1, 2, 3], [3, 2, 1]])
value1, index1 = torch.sort(x)
value2, index2 = torch.sort(x, dim=1) # Equivalent
```
Sort in the descending order if ``descending`` is ``True``:
```python
x = torch.tensor([1, 2, 3])
_, index1 = torch.sort(-x)
_, index2 = torch.sort(x, descending=True) # Equivalent
```
Parameters
----------
input : dragon.vm.torch.Tensor
The input tensor.
dim : int, optional, default=-1
The dimension to sort elements.
descending : bool, optional, default=False
Sort in the descending order or not.
out : Sequence[dragon.vm.torch.Tensor], optional
The optional output value and index.
Returns
-------
Sequence[dragon.vm.torch.Tensor]
The value and index tensor.
"""
return _functions.Sort \
.instantiate(
input.device,
axis=dim,
descending=descending,
).apply(input, out if out else (None, None))
def split(tensor, split_size_or_sections, dim=0):
"""Split input into chunks along the given dimension.
......@@ -960,10 +1004,10 @@ def sum(input, dim=None, keepdim=False, out=None):
return _reduce(input, 'Sum', dim, keepdim, out)
def topk(input, k, dim=None, largest=True, sorted=True, out=None):
def topk(input, k, dim=-1, largest=True, sorted=True, out=None):
"""Return the top-K largest or smallest elements along the given dimension.
If ``dim`` is not given, the last dimension is chosen:
By default, the last dimension is chosen:
```python
x = torch.tensor([[1, 2, 3], [3, 2, 1]])
......@@ -974,9 +1018,9 @@ def topk(input, k, dim=None, largest=True, sorted=True, out=None):
If ``largest`` is ``False``, the k smallest elements are returned:
```python
x = torch.tensor([[1, 2, 3], [3, 2, 1]])
_, index1 = torch.topk(x, 1, largest=False)
_, index2 = torch.topk(-x, 1, largest=True) # Equivalent
x = torch.tensor([1, 2, 3])
_, index1 = torch.topk(-x, 1)
_, index2 = torch.topk(x, 1, largest=False) # Equivalent
```
Parameters
......@@ -985,8 +1029,8 @@ def topk(input, k, dim=None, largest=True, sorted=True, out=None):
The input tensor.
k : int
The number of top elements to select.
dim : int, optional
The dimension to reduce.
dim : int, optional, default=-1
The dimension to select elements.
largest : bool, optional
Return largest or smallest elements.
sorted : bool, optional
......@@ -1000,8 +1044,6 @@ def topk(input, k, dim=None, largest=True, sorted=True, out=None):
The value and index tensor.
"""
if dim is None:
dim = input.ndimension() - 1
return _functions.TopK \
.instantiate(
input.device,
......
......@@ -1513,6 +1513,29 @@ def sin(self):
return math_funcs.sin(self)
def sort(self, dim=-1, descending=False):
"""Return the sorted elements.
Parameters
----------
dim : int, optional, default=-1
The dimension to sort elements.
descending : bool, optional, default=False
Sort in the descending order or not.
Returns
-------
Sequence[dragon.vm.torch.Tensor]
The value and index tensor.
See Also
--------
`torch.sort(...)`_
"""
return array_funcs.sort(self, dim, descending)
def sqrt(self):
r"""Compute the square root.
......@@ -1660,18 +1683,18 @@ def sub_(self, other):
return math_funcs.sub(self, other, self)
def topk(self, k, dim=None, largest=True, sorted=True):
def topk(self, k, dim=-1, largest=True, sorted=True):
"""Return the top-K largest or smallest elements.
Parameters
----------
k : int
The number of top elements to select.
dim : int, optional
The dimension to reduce.
largest : bool, optional
dim : int, optional, default=-1
The dimension to select elements.
largest : bool, optional, default=True
Return largest or smallest elements.
sorted : bool, optional
sorted : bool, optional, default=True
Whether to return in the sorted order.
Returns
......@@ -1939,6 +1962,7 @@ Tensor.rsqrt_ = rsqrt_
Tensor.sign = sign
Tensor.sign_ = sign_
Tensor.sin = sin
Tensor.sort = sort
Tensor.sqrt = sqrt
Tensor.sqrt_ = sqrt_
Tensor.squeeze = squeeze
......
......@@ -1668,6 +1668,27 @@ class Tensor(object):
s = cpp.Size(self._impl.dims)
return s[axis] if axis is not None else s
def sort(self, dim=-1, descending=False):
"""Return the sorted elements.
Parameters
----------
dim : int, optional, default=-1
The dimension to sort elements.
descending : bool, optional, default=False
Sort in the descending order or not.
Returns
-------
Sequence[dragon.vm.torch.Tensor]
The value and index tensor.
See Also
--------
`torch.sort(...)`_
"""
def sqrt(self):
r"""Compute the square root.
......@@ -1849,15 +1870,15 @@ class Tensor(object):
return self.type(dtype)
return self
def topk(self, k, dim=None, largest=True, sorted=True):
def topk(self, k, dim=-1, largest=True, sorted=True):
"""Return the top-K largest or smallest elements.
Parameters
----------
k : int
The number of top elements to select.
dim : int, optional
The dimension to reduce.
dim : int, optional, default=-1
The dimension to select elements.
largest : bool, optional
Return largest or smallest elements.
sorted : bool, optional
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!