Commit e83c407a by Ting PAN

Add LinSpace Operator

Summary:
This commit adds the linspace op for dragon, torch and tensorflow.
And, a workaround for truncated int interval is made to range/linspace (up to 2**57).
1 parent 5cbbef4b
...@@ -21,6 +21,9 @@ dragon ...@@ -21,6 +21,9 @@ dragon
Functions Functions
--------- ---------
`argsort(...) <dragon/argsort.html>`_
: Return the index of sorted elements along the given axis.
`assign(...) <dragon/assign.html>`_ `assign(...) <dragon/assign.html>`_
: Assign the value to input. : Assign the value to input.
...@@ -90,6 +93,9 @@ dragon ...@@ -90,6 +93,9 @@ dragon
`index_select(...) <dragon/index_select.html>`_ `index_select(...) <dragon/index_select.html>`_
: Select the elements according to the index along the given axis. : Select the elements according to the index along the given axis.
`linspace(...) <dragon/linspace.html>`_
: Generate evenly spaced values within intervals along the given axis.
`load_library(...) <dragon/load_library.html>`_ `load_library(...) <dragon/load_library.html>`_
: Load a shared library. : Load a shared library.
...@@ -178,6 +184,7 @@ dragon ...@@ -178,6 +184,7 @@ dragon
dragon/GradientTape dragon/GradientTape
dragon/Tensor dragon/Tensor
dragon/Workspace dragon/Workspace
dragon/argsort
dragon/assign dragon/assign
dragon/broadcast_to dragon/broadcast_to
dragon/cast dragon/cast
...@@ -201,6 +208,7 @@ dragon ...@@ -201,6 +208,7 @@ dragon
dragon/gradients dragon/gradients
dragon/graph_mode dragon/graph_mode
dragon/index_select dragon/index_select
dragon/linspace
dragon/load_library dragon/load_library
dragon/masked_assign dragon/masked_assign
dragon/masked_select dragon/masked_select
......
argsort
=======
.. autofunction:: dragon.argsort
.. raw:: html
<style>
h1:before {
content: "dragon.";
color: #103d3e;
}
</style>
...@@ -39,6 +39,7 @@ dragon.cuda ...@@ -39,6 +39,7 @@ dragon.cuda
.. toctree:: .. toctree::
:hidden: :hidden:
cuda/Stream
cuda/current_device cuda/current_device
cuda/enable_cudnn cuda/enable_cudnn
cuda/get_device_capability cuda/get_device_capability
...@@ -46,7 +47,6 @@ dragon.cuda ...@@ -46,7 +47,6 @@ dragon.cuda
cuda/memory_allocated cuda/memory_allocated
cuda/set_default_device cuda/set_default_device
cuda/set_device cuda/set_device
cuda/Stream
cuda/synchronize cuda/synchronize
.. raw:: html .. raw:: html
......
linspace
========
.. autofunction:: dragon.linspace
.. raw:: html
<style>
h1:before {
content: "dragon.";
color: #103d3e;
}
</style>
...@@ -60,6 +60,9 @@ vm.tensorflow ...@@ -60,6 +60,9 @@ vm.tensorflow
`identity(...) <tensorflow/identity.html>`_ `identity(...) <tensorflow/identity.html>`_
: Return a new tensor copying the content of input. : Return a new tensor copying the content of input.
`linspace(...) <tensorflow/linspace.html>`_
: Generate evenly spaced values within intervals along the given axis.
`name_scope(...) <tensorflow/name_scope.html>`_ `name_scope(...) <tensorflow/name_scope.html>`_
: Context-manager to nest the name as prefix for operations. : Context-manager to nest the name as prefix for operations.
...@@ -131,6 +134,7 @@ vm.tensorflow ...@@ -131,6 +134,7 @@ vm.tensorflow
tensorflow/gather tensorflow/gather
tensorflow/gradients tensorflow/gradients
tensorflow/identity tensorflow/identity
tensorflow/linspace
tensorflow/name_scope tensorflow/name_scope
tensorflow/ones tensorflow/ones
tensorflow/ones_like tensorflow/ones_like
......
linspace
========
.. autofunction:: dragon.vm.tensorflow.linspace
.. raw:: html
<style>
h1:before {
content: "tf.";
color: #103d3e;
}
</style>
...@@ -45,6 +45,9 @@ vm.torch ...@@ -45,6 +45,9 @@ vm.torch
`argmin(...) <torch/argmin.html>`_ `argmin(...) <torch/argmin.html>`_
: Return the index of minimum elements along the given dimension. : Return the index of minimum elements along the given dimension.
`argsort(...) <torch/argsort.html>`_
: Return the index of sorted elements along the given dimension.
`axpby(...) <torch/axpby.html>`_ `axpby(...) <torch/axpby.html>`_
: Compute the element-wise addition from input to output. : Compute the element-wise addition from input to output.
...@@ -124,6 +127,9 @@ vm.torch ...@@ -124,6 +127,9 @@ vm.torch
`le(...) <torch/le.html>`_ `le(...) <torch/le.html>`_
: Compute the element-wise less-equal comparison. : Compute the element-wise less-equal comparison.
`linspace(...) <torch/linspace.html>`_
: Generate evenly spaced values within intervals along the given axis.
`log(...) <torch/log.html>`_ `log(...) <torch/log.html>`_
: Compute the natural logarithm of input. : Compute the natural logarithm of input.
...@@ -241,6 +247,9 @@ vm.torch ...@@ -241,6 +247,9 @@ vm.torch
`topk(...) <torch/topk.html>`_ `topk(...) <torch/topk.html>`_
: Return the top-K largest or smallest elements along the given dimension. : Return the top-K largest or smallest elements along the given dimension.
`transpose(...) <torch/transpose.html>`_
: Return a new tensor with two dimensions swapped.
`unique(...) <torch/unique.html>`_ `unique(...) <torch/unique.html>`_
: Return the unique elements of input. : Return the unique elements of input.
...@@ -266,6 +275,7 @@ vm.torch ...@@ -266,6 +275,7 @@ vm.torch
torch/arange torch/arange
torch/argmax torch/argmax
torch/argmin torch/argmin
torch/argsort
torch/axpby torch/axpby
torch/bitwise_not torch/bitwise_not
torch/bitwise_xor torch/bitwise_xor
...@@ -295,6 +305,7 @@ vm.torch ...@@ -295,6 +305,7 @@ vm.torch
torch/isinf torch/isinf
torch/isnan torch/isnan
torch/le torch/le
torch/linspace
torch/log torch/log
torch/logsumexp torch/logsumexp
torch/lt torch/lt
...@@ -336,6 +347,7 @@ vm.torch ...@@ -336,6 +347,7 @@ vm.torch
torch/sum torch/sum
torch/tensor torch/tensor
torch/topk torch/topk
torch/transpose
torch/unique torch/unique
torch/unsqueeze torch/unsqueeze
torch/where torch/where
......
...@@ -61,6 +61,10 @@ argmin ...@@ -61,6 +61,10 @@ argmin
###### ######
.. automethod:: dragon.vm.torch.Tensor.argmin .. automethod:: dragon.vm.torch.Tensor.argmin
argsort
#######
.. automethod:: dragon.vm.torch.Tensor.argsort
backward backward
######## ########
.. automethod:: dragon.vm.torch.Tensor.backward .. automethod:: dragon.vm.torch.Tensor.backward
...@@ -437,6 +441,10 @@ topk ...@@ -437,6 +441,10 @@ topk
#### ####
.. automethod:: dragon.vm.torch.Tensor.topk .. automethod:: dragon.vm.torch.Tensor.topk
transpose
#########
.. automethod:: dragon.vm.torch.Tensor.transpose
type type
#### ####
.. automethod:: dragon.vm.torch.Tensor.type .. automethod:: dragon.vm.torch.Tensor.type
...@@ -481,6 +489,7 @@ zero\_ ...@@ -481,6 +489,7 @@ zero\_
.. _torch.add(...): add.html .. _torch.add(...): add.html
.. _torch.argmax(...): argmax.html .. _torch.argmax(...): argmax.html
.. _torch.argmin(...): argmin.html .. _torch.argmin(...): argmin.html
.. _torch.argsort(...): argsort.html
.. _torch.bitwise_not(...): bitwise_not.html .. _torch.bitwise_not(...): bitwise_not.html
.. _torch.bitwise_xor(...): bitwise_xor.html .. _torch.bitwise_xor(...): bitwise_xor.html
.. _torch.ceil(...): ceil.html .. _torch.ceil(...): ceil.html
...@@ -513,6 +522,7 @@ zero\_ ...@@ -513,6 +522,7 @@ zero\_
.. _torch.sub(...): sub.html .. _torch.sub(...): sub.html
.. _torch.sum(...): sum.html .. _torch.sum(...): sum.html
.. _torch.topk(...): topk.html .. _torch.topk(...): topk.html
.. _torch.transpose(...): transpose.html
.. _torch.unique(...): unique.html .. _torch.unique(...): unique.html
.. _torch.unsqueeze(...): unsqueeze.html .. _torch.unsqueeze(...): unsqueeze.html
.. _torch.where(...): where.html .. _torch.where(...): where.html
......
argsort
=======
.. autofunction:: dragon.vm.torch.argsort
.. raw:: html
<style>
h1:before {
content: "torch.";
color: #103d3e;
}
</style>
linspace
========
.. autofunction:: dragon.vm.torch.linspace
.. raw:: html
<style>
h1:before {
content: "torch.";
color: #103d3e;
}
</style>
transpose
=========
.. autofunction:: dragon.vm.torch.transpose
.. raw:: html
<style>
h1:before {
content: "torch.";
color: #103d3e;
}
</style>
#include "dragon/utils/cast.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
namespace kernel {
namespace {
template <typename T>
void _RowwiseLinSpace(
const int rows,
const int cols,
const double* start,
const double* stop,
T* y) {
for (int i = 0; i < cols; ++i) {
const auto delta = (stop[i] - start[i]) / double(rows - 1);
y[i] = cast::to<T>(start[i]);
if (rows > 1) {
y[i + (rows - 1) * cols] = cast::to<T>(stop[i]);
}
for (int j = 1; j < rows - 1; ++j) {
y[i + j * cols] = cast::to<T>(start[i] + double(j) * delta);
}
}
}
template <typename T>
void _ColwiseLinSpace(
const int rows,
const int cols,
const double* start,
const double* stop,
T* y) {
for (int i = 0; i < rows; ++i) {
const auto delta = (stop[i] - start[i]) / double(cols - 1);
auto* offset_y = y + i * cols;
offset_y[0] = cast::to<T>(start[i]);
if (cols > 1) {
offset_y[cols - 1] = cast::to<T>(stop[i]);
}
for (int j = 1; j < cols - 1; ++j) {
offset_y[j] = cast::to<T>(start[i] + double(j) * delta);
}
}
}
} // namespace
/* ------------------- Launcher Separator ------------------- */
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \
void LinSpace<T, CPUContext>( \
const int rows, \
const int cols, \
const int axis, \
const double* start, \
const double* end, \
T* y, \
CPUContext* ctx) { \
if (axis == 0) { \
_RowwiseLinSpace(rows, cols, start, end, y); \
} else { \
_ColwiseLinSpace(rows, cols, start, end, y); \
} \
}
DEFINE_KERNEL_LAUNCHER(float16);
DEFINE_KERNEL_LAUNCHER(int8_t);
DEFINE_KERNEL_LAUNCHER(uint8_t);
DEFINE_KERNEL_LAUNCHER(int);
DEFINE_KERNEL_LAUNCHER(int64_t);
DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER
} // namespace kernel
} // namespace dragon
#ifdef USE_CUDA
#include "dragon/core/context_cuda.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
namespace kernel {
namespace {
template <typename T, int D>
__global__ void _RowwiseLinSpace(
const int nthreads,
const int rows,
const int cols,
const SimpleArray<double, D> start,
const SimpleArray<double, D> stop,
T* y) {
CUDA_1D_KERNEL_LOOP(yi, nthreads) {
const int i = yi % cols;
const int j = yi / cols;
if (j == rows - 1 && j > 0) {
y[yi] = stop.data[i];
} else {
y[yi] = start.data[i] +
j * ((stop.data[i] - start.data[i]) / double(rows - 1));
}
}
}
template <int D>
__global__ void _RowwiseLinSpace(
const int nthreads,
const int rows,
const int cols,
const SimpleArray<double, D> start,
const SimpleArray<double, D> stop,
half* y) {
CUDA_1D_KERNEL_LOOP(yi, nthreads) {
const int i = yi % cols;
const int j = yi / cols;
if (j == rows - 1 && j > 0) {
y[yi] = __float2half(float(stop.data[i]));
} else {
y[yi] = __float2half(float(
start.data[i] +
j * ((stop.data[i] - start.data[i]) / double(rows - 1))));
}
}
}
template <typename T, int D>
__global__ void _ColwiseLinSpace(
const int nthreads,
const int cols,
const SimpleArray<double, D> start,
const SimpleArray<double, D> stop,
T* y) {
CUDA_1D_KERNEL_LOOP(yi, nthreads) {
const int i = yi / cols;
const int j = yi % cols;
if (j == cols - 1 && j > 0) {
y[yi] = stop.data[i];
} else {
y[yi] = start.data[i] +
j * ((stop.data[i] - start.data[i]) / double(cols - 1));
}
}
}
template <int D>
__global__ void _ColwiseLinSpace(
const int nthreads,
const int cols,
const SimpleArray<double, D> start,
const SimpleArray<double, D> stop,
half* y) {
CUDA_1D_KERNEL_LOOP(yi, nthreads) {
const int i = yi / cols;
const int j = yi % cols;
if (j == cols - 1 && j > 0) {
y[yi] = __float2half(float(stop.data[i]));
} else {
y[yi] = __float2half(float(
start.data[i] +
j * ((stop.data[i] - start.data[i]) / double(cols - 1))));
}
}
}
} // namespace
/* ------------------- Launcher Separator ------------------- */
template <>
void LinSpace<float16, CUDAContext>(
const int rows,
const int cols,
const int axis,
const double* start,
const double* stop,
float16* y,
CUDAContext* ctx) {
CUDA_TENSOR_DIMS_CHECK((axis == 0 ? cols : rows));
const auto nthreads = rows * cols;
SimpleArray<double, CUDA_TENSOR_MAX_DIMS> Y_start;
SimpleArray<double, CUDA_TENSOR_MAX_DIMS> Y_stop;
for (int i = 0; i < (axis == 0 ? cols : rows); ++i) {
Y_start.data[i] = start[i];
Y_stop.data[i] = stop[i];
}
if (axis == 0) {
_RowwiseLinSpace<<<
CUDA_BLOCKS(nthreads),
CUDA_THREADS,
0,
ctx->cuda_stream()>>>(
nthreads, rows, cols, Y_start, Y_stop, reinterpret_cast<half*>(y));
} else {
_ColwiseLinSpace<<<
CUDA_BLOCKS(nthreads),
CUDA_THREADS,
0,
ctx->cuda_stream()>>>(
nthreads, cols, Y_start, Y_stop, reinterpret_cast<half*>(y));
}
}
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \
void LinSpace<T, CUDAContext>( \
const int rows, \
const int cols, \
const int axis, \
const double* start, \
const double* stop, \
T* y, \
CUDAContext* ctx) { \
CUDA_TENSOR_DIMS_CHECK((axis == 0 ? cols : rows)); \
const auto nthreads = rows * cols; \
SimpleArray<double, CUDA_TENSOR_MAX_DIMS> Y_start; \
SimpleArray<double, CUDA_TENSOR_MAX_DIMS> Y_stop; \
for (int i = 0; i < (axis == 0 ? cols : rows); ++i) { \
Y_start.data[i] = start[i]; \
Y_stop.data[i] = stop[i]; \
} \
if (axis == 0) { \
_RowwiseLinSpace<<< \
CUDA_BLOCKS(nthreads), \
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>(nthreads, rows, cols, Y_start, Y_stop, y); \
} else { \
_ColwiseLinSpace<<< \
CUDA_BLOCKS(nthreads), \
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>(nthreads, cols, Y_start, Y_stop, y); \
} \
}
DEFINE_KERNEL_LAUNCHER(int8_t);
DEFINE_KERNEL_LAUNCHER(uint8_t);
DEFINE_KERNEL_LAUNCHER(int);
DEFINE_KERNEL_LAUNCHER(int64_t);
DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER
} // namespace kernel
} // namespace dragon
#endif // USE_CUDA
...@@ -9,12 +9,12 @@ namespace kernel { ...@@ -9,12 +9,12 @@ namespace kernel {
namespace { namespace {
template <typename T> template <typename T>
void _Range(const int count, const float start, const float delta, T* y) { void _Range(const int count, const double start, const double delta, T* y) {
#ifdef USE_OPENMP #ifdef USE_OPENMP
#pragma omp parallel for num_threads(OMP_THREADS(count)) #pragma omp parallel for num_threads(OMP_THREADS(count))
#endif #endif
for (int i = 0; i < count; ++i) { for (int i = 0; i < count; ++i) {
y[i] = static_cast<T>(start + i * delta); y[i] = cast::to<T>(start + double(i) * delta);
} }
} }
...@@ -22,27 +22,12 @@ void _Range(const int count, const float start, const float delta, T* y) { ...@@ -22,27 +22,12 @@ void _Range(const int count, const float start, const float delta, T* y) {
/* ------------------- Launcher Separator ------------------- */ /* ------------------- Launcher Separator ------------------- */
template <>
void Range<float16, CPUContext>(
const int count,
const float start,
const float delta,
float16* y,
CPUContext* ctx) {
#ifdef USE_OPENMP
#pragma omp parallel for num_threads(OMP_THREADS(count))
#endif
for (int i = 0; i < count; ++i) {
y[i] = cast::to<float16>(start + (float)i * delta);
}
}
#define DEFINE_KERNEL_LAUNCHER(T) \ #define DEFINE_KERNEL_LAUNCHER(T) \
template <> \ template <> \
void Range<T, CPUContext>( \ void Range<T, CPUContext>( \
const int count, \ const int count, \
const float start, \ const double start, \
const float delta, \ const double delta, \
T* y, \ T* y, \
CPUContext* ctx) { \ CPUContext* ctx) { \
_Range(count, start, delta, y); \ _Range(count, start, delta, y); \
...@@ -52,6 +37,7 @@ DEFINE_KERNEL_LAUNCHER(int8_t); ...@@ -52,6 +37,7 @@ DEFINE_KERNEL_LAUNCHER(int8_t);
DEFINE_KERNEL_LAUNCHER(uint8_t); DEFINE_KERNEL_LAUNCHER(uint8_t);
DEFINE_KERNEL_LAUNCHER(int); DEFINE_KERNEL_LAUNCHER(int);
DEFINE_KERNEL_LAUNCHER(int64_t); DEFINE_KERNEL_LAUNCHER(int64_t);
DEFINE_KERNEL_LAUNCHER(float16);
DEFINE_KERNEL_LAUNCHER(float); DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double); DEFINE_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
......
...@@ -11,20 +11,20 @@ namespace { ...@@ -11,20 +11,20 @@ namespace {
template <typename T> template <typename T>
__global__ void __global__ void
_Range(const int nthreads, const float start, const float delta, T* y) { _Range(const int nthreads, const double start, const double delta, T* y) {
CUDA_1D_KERNEL_LOOP(i, nthreads) { CUDA_1D_KERNEL_LOOP(i, nthreads) {
y[i] = T(start + (float)i * delta); y[i] = T(start + double(i) * delta);
} }
} }
template <> template <>
__global__ void _Range<half>( __global__ void _Range<half>(
const int nthreads, const int nthreads,
const float start, const double start,
const float delta, const double delta,
half* y) { half* y) {
CUDA_1D_KERNEL_LOOP(i, nthreads) { CUDA_1D_KERNEL_LOOP(i, nthreads) {
y[i] = __float2half(start + (float)i * delta); y[i] = __float2half(float(start + double(i) * delta));
} }
} }
...@@ -35,8 +35,8 @@ __global__ void _Range<half>( ...@@ -35,8 +35,8 @@ __global__ void _Range<half>(
template <> template <>
void Range<float16, CUDAContext>( void Range<float16, CUDAContext>(
const int count, const int count,
const float start, const double start,
const float delta, const double delta,
float16* y, float16* y,
CUDAContext* ctx) { CUDAContext* ctx) {
_Range<<<CUDA_BLOCKS(count), CUDA_THREADS, 0, ctx->cuda_stream()>>>( _Range<<<CUDA_BLOCKS(count), CUDA_THREADS, 0, ctx->cuda_stream()>>>(
...@@ -47,8 +47,8 @@ void Range<float16, CUDAContext>( ...@@ -47,8 +47,8 @@ void Range<float16, CUDAContext>(
template <> \ template <> \
void Range<T, CUDAContext>( \ void Range<T, CUDAContext>( \
const int count, \ const int count, \
const float start, \ const double start, \
const float delta, \ const double delta, \
T* y, \ T* y, \
CUDAContext* ctx) { \ CUDAContext* ctx) { \
_Range<<<CUDA_BLOCKS(count), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \ _Range<<<CUDA_BLOCKS(count), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
......
...@@ -61,8 +61,7 @@ void ChannelAffineGradientOp<Context>::DoRunWithType() { ...@@ -61,8 +61,7 @@ void ChannelAffineGradientOp<Context>::DoRunWithType() {
CANONICALIZE_AXES_WITH_TENSOR(X); CANONICALIZE_AXES_WITH_TENSOR(X);
// Reduce parameters for weight and bias // Reduce parameters for weight and bias
vec32_t dims = { vec32_t dims = {(int)X.count(0, axis),
(int)X.count(0, axis),
(int)X.count(axis, axis + num_axes), (int)X.count(axis, axis + num_axes),
(int)X.count(axis + num_axes)}; (int)X.count(axis + num_axes)};
vec32_t axes = {0, 2}; vec32_t axes = {0, 2};
......
...@@ -55,7 +55,7 @@ template <class Context> ...@@ -55,7 +55,7 @@ template <class Context>
class RangeOp final : public Operator<Context> { class RangeOp final : public Operator<Context> {
public: public:
RangeOp(const OperatorDef& def, Workspace* ws) : Operator<Context>(def, ws) { RangeOp(const OperatorDef& def, Workspace* ws) : Operator<Context>(def, ws) {
INIT_OP_REPEATED_ARG_WITH_DESC(float, slice); INIT_OP_REPEATED_ARG_WITH_DESC(double, slice);
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
...@@ -65,7 +65,27 @@ class RangeOp final : public Operator<Context> { ...@@ -65,7 +65,27 @@ class RangeOp final : public Operator<Context> {
void DoRunWithType(); void DoRunWithType();
protected: protected:
DECLARE_OP_REPEATED_ARG_WITH_DESC(float, slice); DECLARE_OP_REPEATED_ARG_WITH_DESC(double, slice);
};
template <class Context>
class LinSpaceOp final : public InitializeOp<Context> {
public:
LinSpaceOp(const OperatorDef& def, Workspace* ws)
: InitializeOp<Context>(def, ws) {
INIT_OP_REPEATED_ARG_WITH_DESC(double, start);
INIT_OP_REPEATED_ARG_WITH_DESC(double, stop);
}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T>
void DoRunWithType();
protected:
DECLARE_OP_REPEATED_ARG_WITH_DESC(double, start);
DECLARE_OP_REPEATED_ARG_WITH_DESC(double, stop);
}; };
template <class Context> template <class Context>
...@@ -265,7 +285,9 @@ class GlorotUniformOp final : public InitializeOp<Context> { ...@@ -265,7 +285,9 @@ class GlorotUniformOp final : public InitializeOp<Context> {
DEFINE_OP_SINGLE_ARG_WITH_DESC(int64_t, PermutationOp, limit); DEFINE_OP_SINGLE_ARG_WITH_DESC(int64_t, PermutationOp, limit);
DEFINE_OP_REPEATED_ARG_WITH_DESC(int64_t, InitializeOp, dims); DEFINE_OP_REPEATED_ARG_WITH_DESC(int64_t, InitializeOp, dims);
DEFINE_OP_REPEATED_ARG_WITH_DESC(float, RangeOp, slice); DEFINE_OP_REPEATED_ARG_WITH_DESC(double, RangeOp, slice);
DEFINE_OP_REPEATED_ARG_WITH_DESC(double, LinSpaceOp, start);
DEFINE_OP_REPEATED_ARG_WITH_DESC(double, LinSpaceOp, stop);
} // namespace dragon } // namespace dragon
......
#include "dragon/core/workspace.h"
#include "dragon/operators/array/initialize_ops.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
template <class Context>
template <typename T>
void LinSpaceOp<Context>::DoRunWithType() {
auto* Y = Output(0);
CANONICALIZE_AXIS_WITH_TENSOR((*Y));
// Determine the generating range
// Values are in a interval: [start, stop]
int num_starts;
start(0, &num_starts);
vector<double> starts(num_starts), stops(num_starts);
for (int i = 0; i < num_starts; ++i) {
starts[i] = start(i);
stops[i] = stop(i);
CHECK_GT(stops[i], starts[i])
<< "\nInvalid generating range: "
<< "[" << starts[i] << ", " << stops[i] << "].";
}
kernel::LinSpace(
Y->dim(0),
Y->ndim() > 1 ? Y->dim(1) : 1,
axis,
starts.data(),
stops.data(),
Y->template mutable_data<T, Context>(),
ctx());
}
template <class Context>
void LinSpaceOp<Context>::RunOnDevice() {
InitializeOp<Context>::RunOnDevice();
DispatchHelper<NumericalTensorTypes>::Call(this);
}
DEPLOY_CPU_OPERATOR(LinSpace);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(LinSpace);
#endif
OPERATOR_SCHEMA(LinSpace).NumInputs(0).NumOutputs(1);
NO_GRADIENT(LinSpace);
} // namespace dragon
...@@ -9,7 +9,7 @@ template <typename T> ...@@ -9,7 +9,7 @@ template <typename T>
void RangeOp<Context>::DoRunWithType() { void RangeOp<Context>::DoRunWithType() {
// Determine the slice arguments // Determine the slice arguments
int num_args; int num_args;
float start = 0.f, limit, delta; double start = 0., limit, delta;
slice(0, &num_args); slice(0, &num_args);
if (num_args == 2) { if (num_args == 2) {
limit = slice(0), delta = slice(1); limit = slice(0), delta = slice(1);
......
...@@ -54,6 +54,7 @@ from dragon.core.framework.workspace import get_workspace ...@@ -54,6 +54,7 @@ from dragon.core.framework.workspace import get_workspace
from dragon.core.framework.workspace import reset_workspace from dragon.core.framework.workspace import reset_workspace
from dragon.core.ops import tensorbind_eager as _ from dragon.core.ops import tensorbind_eager as _
from dragon.core.ops import tensorbind_symbol as _ from dragon.core.ops import tensorbind_symbol as _
from dragon.core.ops.array_ops import argsort
from dragon.core.ops.array_ops import broadcast_to from dragon.core.ops.array_ops import broadcast_to
from dragon.core.ops.array_ops import cast from dragon.core.ops.array_ops import cast
from dragon.core.ops.array_ops import channel_affine from dragon.core.ops.array_ops import channel_affine
...@@ -63,6 +64,7 @@ from dragon.core.ops.array_ops import concat ...@@ -63,6 +64,7 @@ from dragon.core.ops.array_ops import concat
from dragon.core.ops.array_ops import expand_dims from dragon.core.ops.array_ops import expand_dims
from dragon.core.ops.array_ops import flatten from dragon.core.ops.array_ops import flatten
from dragon.core.ops.array_ops import index_select from dragon.core.ops.array_ops import index_select
from dragon.core.ops.array_ops import linspace
from dragon.core.ops.array_ops import masked_select from dragon.core.ops.array_ops import masked_select
from dragon.core.ops.array_ops import nonzero from dragon.core.ops.array_ops import nonzero
from dragon.core.ops.array_ops import one_hot from dragon.core.ops.array_ops import one_hot
......
...@@ -491,6 +491,14 @@ def is_spec(args, inputs, outputs): ...@@ -491,6 +491,14 @@ def is_spec(args, inputs, outputs):
return outputs return outputs
@register('LinSpace')
def linspace_spec(args, inputs, outputs):
_ = locals()
outputs[0].dtype = args['dtype']
outputs[0].shape = args['dims']
return outputs
@register('MaskedSelect') @register('MaskedSelect')
def masked_select_spec(args, inputs, outputs): def masked_select_spec(args, inputs, outputs):
_ = locals() _ = locals()
......
...@@ -118,6 +118,54 @@ def argmin(inputs, axis=None, keep_dims=False, **kwargs): ...@@ -118,6 +118,54 @@ def argmin(inputs, axis=None, keep_dims=False, **kwargs):
@OpSchema.num_inputs(1) @OpSchema.num_inputs(1)
def argsort(inputs, axis=-1, descending=False, **kwargs):
"""Return the index of sorted elements along the given axis.
By default, the last axis is chosen:
```python
x = dragon.constant([[1, 2, 3], [3, 2, 1]])
index1 = dragon.argsort(x)
index2 = dragon.argsort(x, axis=1) # Equivalent
```
Sort in the inverse order if ``descending`` is **True**:
```python
x = dragon.constant([1, 2, 3])
index1 = dragon.argsort(-x)
index2 = dragon.argsort(x, descending=True) # Equivalent
```
Parameters
----------
inputs : dragon.Tensor
The input tensor.
axis : int, optional, default=-1
The axis to sort elements.
descending : bool, optional, default=False
Sort in the descending order or not.
Returns
-------
dragon.Tensor
The output tensor.
"""
args = parse_args(locals())
op_lib = array_ops_lib.Sort
if context.executing_eagerly():
return op_lib \
.instantiate(
axis=axis,
descending=descending,
).apply([inputs])[1]
else:
args['num_outputs'] = 2
return op_lib.blend(**args)[1]
@OpSchema.num_inputs(1)
@ArgHelper.repeated_desc(name='shape', name_v2='dims') @ArgHelper.repeated_desc(name='shape', name_v2='dims')
def broadcast_to(inputs, shape, **kwargs): def broadcast_to(inputs, shape, **kwargs):
"""Broadcast input according to a given shape. """Broadcast input according to a given shape.
...@@ -597,6 +645,65 @@ def index_select(inputs, index, axis=0, **kwargs): ...@@ -597,6 +645,65 @@ def index_select(inputs, index, axis=0, **kwargs):
return op_lib.blend(**args) return op_lib.blend(**args)
def linspace(start, stop, num, dtype='int64', axis=0, **kwargs):
r"""Generate evenly spaced values within intervals along the given axis.
Interval :math:`[\text{start}, \text{stop})` is determined for ``num`` values:
```python
x = dragon.linspace(2, 4, num=3) # [2, 3, 4]
```
More than one intervals are accepted to generate N-d coordinates:
```python
x = dragon.linspace([1, 2], [3, 4], num=3, axis=0) # [[1, 2], [2, 3], [3, 4]]
y = dragon.linspace([1, 2], [3, 4], num=3, axis=1) # [[1, 2, 3], [2, 3, 4]]
```
Parameters
----------
start : Union[number, Sequence[number]]
The start(s) of interval.
stop: Union[number, Sequence[number]]
The stop(s) of interval.
num : int
The number of values to generate.
dtype : str, optional, default='int64'
The optional data type.
axis : int, optional, default=0
The axis to generate values.
Returns
-------
dragon.Tensor
The output tensor.
"""
args = parse_args(locals())
args['dtype'] = args['dtype'].lower()
args['start'] = nest.flatten(start)
args['stop'] = nest.flatten(stop)
args.pop('num')
args['dims'] = []
if len(args['start']) > 1 or args['start'] == start:
args['dims'] = [len(args['start'])]
axis = axis if axis >= 0 else axis + len(args['dims']) + 1
args['dims'].insert(axis, num)
op_lib = array_ops_lib.LinSpace
trainable = args.pop('trainable') if 'trainable' in args else False
if context.executing_eagerly():
return op_lib \
.instantiate(
ndim=len(args['dims']),
num_intervals=len(args['start']),
dtype=dtype,
axis=axis,
).apply(args['dims'], args['start'], args['stop'], trainable=trainable)
else:
return op_lib.blend(**args)
@OpSchema.num_inputs(2) @OpSchema.num_inputs(2)
def masked_select(inputs, **kwargs): def masked_select(inputs, **kwargs):
"""Select the elements of input where mask is 1. """Select the elements of input where mask is 1.
...@@ -1303,7 +1410,7 @@ def sort(inputs, axis=-1, descending=False, **kwargs): ...@@ -1303,7 +1410,7 @@ def sort(inputs, axis=-1, descending=False, **kwargs):
Returns Returns
------- -------
Sequence[dragon.vm.torch.Tensor] Sequence[dragon.Tensor]
The value and index tensor. The value and index tensor.
""" """
......
...@@ -258,6 +258,56 @@ class IndexSelect(Operator): ...@@ -258,6 +258,56 @@ class IndexSelect(Operator):
return self.dispatch(inputs, [self.alloc()]) return self.dispatch(inputs, [self.alloc()])
class LinSpace(Operator):
def __init__(self, key, dev, **kwargs):
super(LinSpace, self).__init__(key, dev, **kwargs)
self.ndim = kwargs.get('ndim', 0)
self.num_intervals = kwargs.get('num_intervals', 1)
self.dtype = kwargs.get('dtype', 'int64')
self.axis = kwargs.get('axis', 0)
def attributes(self):
return {
'op_type': 'LinSpace',
'arguments': {
'axis': self.axis,
'dtype': self.dtype,
'dims_descs': [
'${{HANDLE}}/dims[{}]'.format(n)
for n in range(self.ndim)],
'start_descs': [
'${{HANDLE}}/start[{}]'
.format(n) for n in range(self.num_intervals)],
'stop_descs': [
'${{HANDLE}}/stop[{}]'
.format(n) for n in range(self.num_intervals)],
}
}
def feed(self, ws, handle, shape, starts, stops):
for i, dim in enumerate(shape):
self.feed_arg(
ws, '{}/dims[{}]'.format(handle, i),
dim, 'int64')
for i in range(len(starts)):
self.feed_arg(
ws, '{}/start[{}]'.format(handle, i),
starts[i], 'float64')
self.feed_arg(
ws, '{}/stop[{}]'.format(handle, i),
stops[i], 'float64')
def forward(self, shape, starts, stops, trainable=False):
out = self.dispatch(
[], [self.alloc()],
callback=lambda ws, handle:
self.feed(ws, handle, shape, starts, stops),
no_grad=True,
)
out._requires_grad = trainable
return out
class MaskedSelect(Operator): class MaskedSelect(Operator):
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(MaskedSelect, self).__init__(key, dev, **kwargs) super(MaskedSelect, self).__init__(key, dev, **kwargs)
...@@ -423,7 +473,7 @@ class Range(Operator): ...@@ -423,7 +473,7 @@ class Range(Operator):
for i in range(len(slice_args)): for i in range(len(slice_args)):
self.feed_arg( self.feed_arg(
ws, '{}/slice[{}]'.format(handle, i), ws, '{}/slice[{}]'.format(handle, i),
slice_args[i], 'float32') slice_args[i], 'float64')
def forward(self, slice_args, trainable=False): def forward(self, slice_args, trainable=False):
out = self.dispatch( out = self.dispatch(
......
...@@ -262,7 +262,7 @@ class CppExtension(object): ...@@ -262,7 +262,7 @@ class CppExtension(object):
libraries = kwargs.get('libraries', []) libraries = kwargs.get('libraries', [])
libraries.extend(COMMON_LINK_LIBRARIES + ['dragon']) libraries.extend(COMMON_LINK_LIBRARIES + ['dragon'])
kwargs['libraries'] = libraries kwargs['libraries'] = libraries
define_macros = kwargs.get('define_marcos', []) define_macros = kwargs.get('define_macros', [])
define_macros.append(('DRAGON_API=' + DLLIMPORT_STR, None)) define_macros.append(('DRAGON_API=' + DLLIMPORT_STR, None))
kwargs['define_macros'] = define_macros kwargs['define_macros'] = define_macros
kwargs['language'] = 'c++' kwargs['language'] = 'c++'
...@@ -282,7 +282,7 @@ class CUDAExtension(object): ...@@ -282,7 +282,7 @@ class CUDAExtension(object):
libraries = kwargs.get('libraries', []) libraries = kwargs.get('libraries', [])
libraries.extend(COMMON_LINK_LIBRARIES + ['cudart', 'dragon']) libraries.extend(COMMON_LINK_LIBRARIES + ['cudart', 'dragon'])
kwargs['libraries'] = libraries kwargs['libraries'] = libraries
define_macros = kwargs.get('define_marcos', []) define_macros = kwargs.get('define_macros', [])
define_macros.append(('USE_CUDA', None)) define_macros.append(('USE_CUDA', None))
define_macros.append(('DRAGON_API=' + DLLIMPORT_STR, None)) define_macros.append(('DRAGON_API=' + DLLIMPORT_STR, None))
kwargs['define_macros'] = define_macros kwargs['define_macros'] = define_macros
......
...@@ -107,6 +107,11 @@ inline float to<float, float16>(float16 val) { ...@@ -107,6 +107,11 @@ inline float to<float, float16>(float16 val) {
return ret; return ret;
} }
template <>
inline float16 to<float16, double>(double val) {
return to<float16>(static_cast<float>(val));
}
#ifdef USE_CUDA #ifdef USE_CUDA
template <> template <>
......
...@@ -325,6 +325,18 @@ void IndexSelectGrad( ...@@ -325,6 +325,18 @@ void IndexSelectGrad(
T* dx, T* dx,
Context* ctx); Context* ctx);
/* array.linspace */
template <typename T, class Context>
void LinSpace(
const int rows,
const int cols,
const int axis,
const double* start,
const double* end,
T* y,
Context* ctx);
/* array.masked_select */ /* array.masked_select */
template <typename IndexType, typename ValueType, class Context> template <typename IndexType, typename ValueType, class Context>
...@@ -420,8 +432,8 @@ void Permutation(const int count, T* y, uint32_t* r, Context* ctx); ...@@ -420,8 +432,8 @@ void Permutation(const int count, T* y, uint32_t* r, Context* ctx);
template <typename T, class Context> template <typename T, class Context>
void Range( void Range(
const int count, const int count,
const float start, const double start,
const float delta, const double delta,
T* y, T* y,
Context* ctx); Context* ctx);
......
...@@ -100,6 +100,7 @@ from dragon.vm.tensorflow.core.ops.math_ops import divide ...@@ -100,6 +100,7 @@ from dragon.vm.tensorflow.core.ops.math_ops import divide
from dragon.vm.tensorflow.core.ops.math_ops import equal from dragon.vm.tensorflow.core.ops.math_ops import equal
from dragon.vm.tensorflow.core.ops.math_ops import exp from dragon.vm.tensorflow.core.ops.math_ops import exp
from dragon.vm.tensorflow.core.ops.math_ops import less from dragon.vm.tensorflow.core.ops.math_ops import less
from dragon.vm.tensorflow.core.ops.math_ops import linspace
from dragon.vm.tensorflow.core.ops.math_ops import matmul from dragon.vm.tensorflow.core.ops.math_ops import matmul
from dragon.vm.tensorflow.core.ops.math_ops import multiply from dragon.vm.tensorflow.core.ops.math_ops import multiply
from dragon.vm.tensorflow.core.ops.math_ops import pow from dragon.vm.tensorflow.core.ops.math_ops import pow
......
...@@ -634,6 +634,47 @@ def less_equal(x, y, name=None): ...@@ -634,6 +634,47 @@ def less_equal(x, y, name=None):
return math_ops.less_equal([x, y], name=name) return math_ops.less_equal([x, y], name=name)
def linspace(start, stop, num, dtype='int64', name=None, axis=0):
r"""Generate evenly spaced values within intervals along the given axis.
Interval :math:`[\text{start}, \text{stop})` is determined for ``num`` values:
```python
x = tf.linspace(2, 4, num=3) # [2, 3, 4]
```
More than one intervals are accepted to generate N-d coordinates:
```python
x = tf.linspace([1, 2], [3, 4], num=3, axis=0) # [[1, 2], [2, 3], [3, 4]]
y = tf.linspace([1, 2], [3, 4], num=3, axis=1) # [[1, 2, 3], [2, 3, 4]]
```
Parameters
----------
start : Union[number, Sequence[number]]
The start(s) of interval.
stop: Union[number, Sequence[number]]
The stop(s) of interval.
num : int
The number of values to generate.
dtype : str, optional, default='int64'
The optional data type.
name : str, optional
The operation name.
axis : int, optional, default=0
The axis to generate values.
Returns
-------
dragon.Tensor
The output tensor.
"""
return array_ops.linspace(
start, stop, num, dtype=dtype, name=name, axis=axis)
def log(x, name=None): def log(x, name=None):
r"""Compute the logarithm of input. r"""Compute the logarithm of input.
......
...@@ -50,7 +50,7 @@ def argsort(values, axis=-1, direction='ASCENDING', name=None): ...@@ -50,7 +50,7 @@ def argsort(values, axis=-1, direction='ASCENDING', name=None):
Returns Returns
------- -------
dragon.Tensor dragon.Tensor
The index tensor. The output tensor.
""" """
if direction not in ('ASCENDING', 'DESCENDING'): if direction not in ('ASCENDING', 'DESCENDING'):
...@@ -97,7 +97,7 @@ def sort(values, axis=-1, direction='ASCENDING', name=None): ...@@ -97,7 +97,7 @@ def sort(values, axis=-1, direction='ASCENDING', name=None):
Returns Returns
------- -------
dragon.Tensor dragon.Tensor
The value tensor. The output tensor.
""" """
if direction not in ('ASCENDING', 'DESCENDING'): if direction not in ('ASCENDING', 'DESCENDING'):
......
...@@ -630,6 +630,26 @@ class TestArrayOps(OpTestCase): ...@@ -630,6 +630,26 @@ class TestArrayOps(OpTestCase):
with dragon.device('cuda'): with dragon.device('cuda'):
self.test_index_select() self.test_index_select()
def test_linspace(self):
entries = [([[0., 5.], [10., 40.], 5], {'axis': 0, 'dtype': 'float32'}),
([[0., 5.], [10., 40.], 5], {'axis': 1, 'dtype': 'float32'}),
([[0., 5.], [10., 40.], 5], {'axis': -1, 'dtype': 'float32'}),
([[0.], [10.], 5], {'axis': 0, 'dtype': 'float32'}),
([[0.], [10.], 5], {'axis': -1, 'dtype': 'float32'}),
([0., 10., 5], {'axis': 0, 'dtype': 'float32'}),
([0., 10., 5], {'axis': 0, 'dtype': 'int64'})]
for execution in ('EAGER_MODE', 'GRAPH_MODE'):
with execution_context().mode(execution):
for (args, kwargs) in entries:
data = np.linspace(*args, **kwargs)
x = dragon.linspace(*args, **kwargs)
self.assertEqual(x, data)
@unittest.skipIf(not TEST_CUDA, 'CUDA unavailable')
def test_linspace_cuda(self):
with dragon.device('cuda'):
self.test_linspace()
def test_masked_select(self): def test_masked_select(self):
for execution in ('EAGER_MODE', 'GRAPH_MODE'): for execution in ('EAGER_MODE', 'GRAPH_MODE'):
with execution_context().mode(execution): with execution_context().mode(execution):
...@@ -821,11 +841,16 @@ class TestArrayOps(OpTestCase): ...@@ -821,11 +841,16 @@ class TestArrayOps(OpTestCase):
for axis, descending in entries: for axis, descending in entries:
data = uniform((5, 10)) data = uniform((5, 10))
x = new_tensor(data) x = new_tensor(data)
y = dragon.sort(x, axis=axis, descending=descending) val, idx1 = dragon.sort(x, axis=axis, descending=descending)
idx2 = dragon.argsort(x, axis=axis, descending=descending)
axis = axis if axis is not None else -1 axis = axis if axis is not None else -1
result = np.argsort(-data if descending else data, axis=axis) result_val = np.sort(-data if descending else data, axis=axis)
result = np.take(result, np.arange(data.shape[axis]), axis=axis) result_val = -result_val if descending else result_val
self.assertEqual(y[1], result) result_idx = np.argsort(-data if descending else data, axis=axis)
result_idx = np.take(result_idx, np.arange(data.shape[axis]), axis=axis)
self.assertEqual(val, result_val)
self.assertEqual(idx1, result_idx)
self.assertEqual(idx2, result_idx)
@unittest.skipIf(not TEST_CUDA, 'CUDA unavailable') @unittest.skipIf(not TEST_CUDA, 'CUDA unavailable')
def test_sort_cuda(self): def test_sort_cuda(self):
......
...@@ -397,6 +397,13 @@ class TestTensorOps(OpTestCase): ...@@ -397,6 +397,13 @@ class TestTensorOps(OpTestCase):
self.assertEqual(x.permute(), np.transpose(data)) self.assertEqual(x.permute(), np.transpose(data))
else: else:
self.assertEqual(x.permute(*perm), np.transpose(data, perm)) self.assertEqual(x.permute(*perm), np.transpose(data, perm))
entries = [(0, 1), (0, 2), (1, 2)]
for dim0, dim1 in entries:
data = arange((2, 3, 4))
x = new_tensor(data)
perm = list(range(len(data.shape)))
perm[dim0], perm[dim1] = perm[dim1], perm[dim0]
self.assertEqual(x.transpose(dim0, dim1), np.transpose(data, perm))
def test_pow(self): def test_pow(self):
for a_shape, b_shape in self.binary_test_shapes: for a_shape, b_shape in self.binary_test_shapes:
...@@ -492,11 +499,16 @@ class TestTensorOps(OpTestCase): ...@@ -492,11 +499,16 @@ class TestTensorOps(OpTestCase):
for axis, descending in entries: for axis, descending in entries:
data = uniform((5, 10)) data = uniform((5, 10))
x = new_tensor(data) x = new_tensor(data)
y = x.sort(axis, descending)[1] val, idx1 = x.sort(axis, descending)
idx2 = x.argsort(axis, descending)
axis = axis if axis is not None else -1 axis = axis if axis is not None else -1
result = np.argsort(-data if descending else data, axis=axis) result_val = np.sort(-data if descending else data, axis=axis)
result = np.take(result, np.arange(data.shape[axis]), axis=axis) result_val = -result_val if descending else result_val
self.assertEqual(y, result) result_idx = np.argsort(-data if descending else data, axis=axis)
result_idx = np.take(result_idx, np.arange(data.shape[axis]), axis=axis)
self.assertEqual(val, result_val)
self.assertEqual(idx1, result_idx)
self.assertEqual(idx2, result_idx)
def test_sqrt(self): def test_sqrt(self):
data = np.array([4., 9., 16], 'float32') data = np.array([4., 9., 16], 'float32')
...@@ -599,6 +611,30 @@ class TestTensorOps(OpTestCase): ...@@ -599,6 +611,30 @@ class TestTensorOps(OpTestCase):
class TestTorchOps(OpTestCase): class TestTorchOps(OpTestCase):
"""Test the builtin torch ops.""" """Test the builtin torch ops."""
def test_arange(self):
entries = [([5], {'dtype': 'int64'}),
([0, 5], {'dtype': 'int64'}),
([0, 5, 2], {'dtype': 'int64'}),
([0., 1., 0.2], {'dtype': 'float32'})]
for (args, kwargs) in entries:
data = np.arange(*args, **kwargs)
x = torch.arange(*args, **kwargs)
self.assertEqual(x, data)
def test_linspace(self):
entries = [([[0., 5.], [10., 40.], 5], {'dim': 0, 'dtype': 'float32'}),
([[0., 5.], [10., 40.], 5], {'dim': 1, 'dtype': 'float32'}),
([[0., 5.], [10., 40.], 5], {'dim': -1, 'dtype': 'float32'}),
([[0.], [10.], 5], {'dim': 0, 'dtype': 'float32'}),
([[0.], [10.], 5], {'dim': -1, 'dtype': 'float32'}),
([0., 10., 5], {'dim': 0, 'dtype': 'float32'}),
([0., 10., 5], {'dim': 0, 'dtype': 'int64'})]
for (args, kwargs) in entries:
x = torch.linspace(*args, **kwargs)
kwargs['axis'] = kwargs.pop('dim')
data = np.linspace(*args, **kwargs)
self.assertEqual(x, data)
def test_ones_like(self): def test_ones_like(self):
data = np.ones((2, 3), dtype='float32') data = np.ones((2, 3), dtype='float32')
x = new_tensor(data) x = new_tensor(data)
......
...@@ -48,6 +48,7 @@ from dragon.vm.torch.core.cpp import from_numpy ...@@ -48,6 +48,7 @@ from dragon.vm.torch.core.cpp import from_numpy
from dragon.vm.torch.core.ops import tensorbind as _ from dragon.vm.torch.core.ops import tensorbind as _
from dragon.vm.torch.core.ops.array.functional import argmax from dragon.vm.torch.core.ops.array.functional import argmax
from dragon.vm.torch.core.ops.array.functional import argmin from dragon.vm.torch.core.ops.array.functional import argmin
from dragon.vm.torch.core.ops.array.functional import argsort
from dragon.vm.torch.core.ops.array.functional import assign from dragon.vm.torch.core.ops.array.functional import assign
from dragon.vm.torch.core.ops.array.functional import cat from dragon.vm.torch.core.ops.array.functional import cat
from dragon.vm.torch.core.ops.array.functional import channel_affine from dragon.vm.torch.core.ops.array.functional import channel_affine
...@@ -75,11 +76,13 @@ from dragon.vm.torch.core.ops.array.functional import squeeze ...@@ -75,11 +76,13 @@ from dragon.vm.torch.core.ops.array.functional import squeeze
from dragon.vm.torch.core.ops.array.functional import stack from dragon.vm.torch.core.ops.array.functional import stack
from dragon.vm.torch.core.ops.array.functional import sum from dragon.vm.torch.core.ops.array.functional import sum
from dragon.vm.torch.core.ops.array.functional import topk from dragon.vm.torch.core.ops.array.functional import topk
from dragon.vm.torch.core.ops.array.functional import transpose
from dragon.vm.torch.core.ops.array.functional import unique from dragon.vm.torch.core.ops.array.functional import unique
from dragon.vm.torch.core.ops.array.functional import unsqueeze from dragon.vm.torch.core.ops.array.functional import unsqueeze
from dragon.vm.torch.core.ops.array.functional import where from dragon.vm.torch.core.ops.array.functional import where
from dragon.vm.torch.core.ops.init.functional import arange from dragon.vm.torch.core.ops.init.functional import arange
from dragon.vm.torch.core.ops.init.functional import eye from dragon.vm.torch.core.ops.init.functional import eye
from dragon.vm.torch.core.ops.init.functional import linspace
from dragon.vm.torch.core.ops.init.functional import ones from dragon.vm.torch.core.ops.init.functional import ones
from dragon.vm.torch.core.ops.init.functional import ones_like from dragon.vm.torch.core.ops.init.functional import ones_like
from dragon.vm.torch.core.ops.init.functional import rand from dragon.vm.torch.core.ops.init.functional import rand
......
...@@ -92,6 +92,43 @@ def argmin(input, dim=None, keepdim=False, out=None): ...@@ -92,6 +92,43 @@ def argmin(input, dim=None, keepdim=False, out=None):
return _arg_reduce(input, 'ArgMin', dim, keepdim, out) return _arg_reduce(input, 'ArgMin', dim, keepdim, out)
def argsort(input, dim=-1, descending=False):
"""Return the index of sorted elements along the given dimension.
By default, the last dimension is chosen:
```python
x = torch.tensor([[1, 2, 3], [3, 2, 1]])
index1 = torch.argsort(x)
index2 = torch.argsort(x, dim=1) # Equivalent
```
Sort in the descending order if ``descending`` is ``True``:
```python
x = torch.tensor([1, 2, 3])
index1 = torch.argsort(-x)
index2 = torch.argsort(x, descending=True) # Equivalent
```
Parameters
----------
input : dragon.vm.torch.Tensor
The input tensor.
dim : int, optional, default=-1
The dimension to sort elements.
descending : bool, optional, default=False
Sort in the descending order or not.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
"""
return sort(input, dim, descending)[1]
def assign(out, starts, sizes, input): def assign(out, starts, sizes, input):
if not isinstance(input, Tensor): if not isinstance(input, Tensor):
input = utils.scalar_to_tensor( input = utils.scalar_to_tensor(
...@@ -1054,6 +1091,31 @@ def topk(input, k, dim=-1, largest=True, sorted=True, out=None): ...@@ -1054,6 +1091,31 @@ def topk(input, k, dim=-1, largest=True, sorted=True, out=None):
).apply(input, out if out else (None, None)) ).apply(input, out if out else (None, None))
def transpose(input, dim0, dim1):
"""Return a new tensor with two dimensions swapped.
Parameters
----------
input : dragon.vm.torch.Tensor
The input tensor.
dim0 : int
The first dimension to be transposed.
dim1 : int
The second dimension to be transposed.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
"""
dims = list(range(input.ndimension()))
dims[dim0], dims[dim1] = dims[dim1], dims[dim0]
return _functions.Transpose \
.instantiate(input.device, ndim=len(dims)) \
.apply(input, dims)
def unique(input, return_inverse=False, return_counts=False, **kwargs): def unique(input, return_inverse=False, return_counts=False, **kwargs):
"""Return the unique elements of input. """Return the unique elements of input.
......
...@@ -74,6 +74,54 @@ class Fill(_Initializer): ...@@ -74,6 +74,54 @@ class Fill(_Initializer):
} }
class LinSpace(function.Function):
def __init__(self, key, dev, **kwargs):
super(LinSpace, self).__init__(key, dev, **kwargs)
self.ndim = kwargs.get('ndim', 0)
self.num_intervals = kwargs.get('num_intervals', 1)
self.dtype = kwargs.get('dtype', 'int64')
self.axis = kwargs.get('axis', 0)
def attributes(self):
return {
'op_type': 'LinSpace',
'arguments': {
'axis': self.axis,
'dtype': self.dtype,
'dims_descs': [
'${{HANDLE}}/dims[{}]'.format(n)
for n in range(self.ndim)],
'start_descs': [
'${{HANDLE}}/start[{}]'
.format(n) for n in range(self.num_intervals)],
'stop_descs': [
'${{HANDLE}}/stop[{}]'
.format(n) for n in range(self.num_intervals)],
}
}
def feed(self, ws, handle, shape, starts, stops):
for i, dim in enumerate(shape):
self.feed_arg(
ws, '{}/dims[{}]'.format(handle, i),
dim, 'int64')
for i in range(len(starts)):
self.feed_arg(
ws, '{}/start[{}]'.format(handle, i),
starts[i], 'float64')
self.feed_arg(
ws, '{}/stop[{}]'.format(handle, i),
stops[i], 'float64')
def forward(self, shape, starts, stops, out=None):
return self.dispatch(
[], [self.alloc(out)],
callback=lambda ws, handle:
self.feed(ws, handle, shape, starts, stops),
no_grad=True,
)
class Permutation(function.Function): class Permutation(function.Function):
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(Permutation, self).__init__(key, dev, **kwargs) super(Permutation, self).__init__(key, dev, **kwargs)
...@@ -161,7 +209,7 @@ class Range(function.Function): ...@@ -161,7 +209,7 @@ class Range(function.Function):
for i in range(len(slice_args)): for i in range(len(slice_args)):
self.feed_arg( self.feed_arg(
ws, '{}/slice[{}]'.format(handle, i), ws, '{}/slice[{}]'.format(handle, i),
slice_args[i], 'float32') slice_args[i], 'float64')
def forward(self, slice_args, out=None): def forward(self, slice_args, out=None):
return self.dispatch( return self.dispatch(
......
...@@ -14,6 +14,8 @@ from __future__ import absolute_import ...@@ -14,6 +14,8 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from dragon.core.util import nest
from dragon.vm.torch.core import cpp from dragon.vm.torch.core import cpp
from dragon.vm.torch.core.ops import utils from dragon.vm.torch.core.ops import utils
from dragon.vm.torch.core.ops.init import _functions from dragon.vm.torch.core.ops.init import _functions
...@@ -148,6 +150,75 @@ def fill_like(out, shape_like, value): ...@@ -148,6 +150,75 @@ def fill_like(out, shape_like, value):
.apply(out, [], shape_like) .apply(out, [], shape_like)
def linspace(
start,
end,
steps=100,
out=None,
dtype='int64',
dim=0,
device=None,
requires_grad=False,
):
r"""Generate evenly spaced values within intervals along the given axis.
Interval :math:`[\text{start}, \text{end})` is determined for ``steps`` values:
```python
x = torch.linspace(2, 4, steps=3) # [2, 3, 4]
```
More than one intervals are accepted to generate N-d coordinates:
```python
x = torch.linspace([1, 2], [3, 4], steps=3, dim=0) # [[1, 2], [2, 3], [3, 4]]
y = torch.linspace([1, 2], [3, 4], steps=3, dim=1) # [[1, 2, 3], [2, 3, 4]]
```
Parameters
----------
start : Union[number, Sequence[number]]
The start(s) of interval.
end: Union[number, Sequence[number]]
The ends(s) of interval.
steps : int, optional, default=100
The number of values to generate.
out : dragon.vm.torch.Tensor, optional
The optional output tensor.
dtype : str, optional, default='int64'
The optional data type.
dim : int, optional, default=0
The dimension to generate values.
device : dragon.vm.torch.device, optional
The optional device of returned tensor.
requires_grad : bool, optional, default=False
**True** to record gradient for returned tensor.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
"""
starts = nest.flatten(start)
ends = nest.flatten(end)
sizes = []
if len(starts) > 1 or starts == start:
sizes = [len(starts)]
dim = dim if dim >= 0 else dim + len(sizes) + 1
sizes.insert(dim, steps)
out = _functions.LinSpace \
.instantiate(
device if device else cpp.device(),
ndim=len(sizes),
num_intervals=len(starts),
dtype=dtype.lower(),
axis=dim,
).apply(sizes, starts, ends, out)
out.requires_grad = requires_grad
return out
def normal_fill(input, mean=0, std=1): def normal_fill(input, mean=0, std=1):
"""Fill input from the normal distribution.""" """Fill input from the normal distribution."""
shape = input.shape shape = input.shape
......
...@@ -131,6 +131,29 @@ def argmin(self, dim=None, keepdim=False): ...@@ -131,6 +131,29 @@ def argmin(self, dim=None, keepdim=False):
return array_funcs.argmin(self, dim, keepdim) return array_funcs.argmin(self, dim, keepdim)
def argsort(self, dim=-1, descending=False):
"""Return the index of sorted elements.
Parameters
----------
dim : int, optional, default=-1
The dimension to sort elements.
descending : bool, optional, default=False
Sort in the descending order or not.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.argsort(...)`_
"""
return array_funcs.argsort(self, dim, descending)
def backward(self, gradient=None, retain_graph=False): def backward(self, gradient=None, retain_graph=False):
"""Compute the derivatives of this tensor w.r.t. graph leaves. """Compute the derivatives of this tensor w.r.t. graph leaves.
...@@ -1710,6 +1733,29 @@ def topk(self, k, dim=-1, largest=True, sorted=True): ...@@ -1710,6 +1733,29 @@ def topk(self, k, dim=-1, largest=True, sorted=True):
return array_funcs.topk(self, k, dim, largest, sorted) return array_funcs.topk(self, k, dim, largest, sorted)
def transpose(self, dim0, dim1):
"""Return a new tensor with two dimensions swapped.
Parameters
----------
dim0 : int
The first dimension to be transposed.
dim1 : int
The second dimension to be transposed.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.transpose(...)`_
"""
return array_funcs.transpose(self, dim0, dim1)
def _type(self, dtype=None): def _type(self, dtype=None):
"""Return the data type. """Return the data type.
...@@ -1890,6 +1936,7 @@ Tensor.add = add ...@@ -1890,6 +1936,7 @@ Tensor.add = add
Tensor.add_ = add_ Tensor.add_ = add_
Tensor.argmax = argmax Tensor.argmax = argmax
Tensor.argmin = argmin Tensor.argmin = argmin
Tensor.argsort = argsort
Tensor.backward = backward Tensor.backward = backward
Tensor.bitwise_not = bitwise_not Tensor.bitwise_not = bitwise_not
Tensor.bitwise_not_ = bitwise_not_ Tensor.bitwise_not_ = bitwise_not_
...@@ -1971,6 +2018,7 @@ Tensor.sum = sum ...@@ -1971,6 +2018,7 @@ Tensor.sum = sum
Tensor.sub = sub Tensor.sub = sub
Tensor.sub_ = sub_ Tensor.sub_ = sub_
Tensor.topk = topk Tensor.topk = topk
Tensor.transpose = transpose
Tensor.type = _type Tensor.type = _type
Tensor.uniform_ = uniform_ Tensor.uniform_ = uniform_
Tensor.unique = unique Tensor.unique = unique
......
...@@ -309,6 +309,27 @@ class Tensor(object): ...@@ -309,6 +309,27 @@ class Tensor(object):
""" """
def argsort(self, dim=-1, descending=False):
"""Return the index of sorted elements.
Parameters
----------
dim : int, optional, default=-1
The dimension to sort elements.
descending : bool, optional, default=False
Sort in the descending order or not.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.argsort(...)`_
"""
def backward(self, gradient=None, retain_graph=False): def backward(self, gradient=None, retain_graph=False):
"""Compute the derivatives of this tensor w.r.t. graph leaves. """Compute the derivatives of this tensor w.r.t. graph leaves.
...@@ -1895,6 +1916,27 @@ class Tensor(object): ...@@ -1895,6 +1916,27 @@ class Tensor(object):
""" """
def transpose(self, dim0, dim1):
"""Return a new tensor with two dimensions swapped.
Parameters
----------
dim0 : int
The first dimension to be transposed.
dim1 : int
The second dimension to be transposed.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.transpose(...)`_
"""
def type(self, dtype=None): def type(self, dtype=None):
"""Return the data type or copied tensor with specified type. """Return the data type or copied tensor with specified type.
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!