Commit 1dd8aeef by Ting PAN

Add Unique Operator

Summary:
This commit adds the unique op for dragon, torch, tensorflow and onnx.
Besides, fixes the bug that gets the wrong workspace size in cached cudnn convolution.
1 parent 80267d8f
Showing with 1033 additions and 202 deletions
...@@ -302,7 +302,7 @@ class RandomBBoxCrop(object): ...@@ -302,7 +302,7 @@ class RandomBBoxCrop(object):
thresholds=(0.0, 0.1, 0.3, 0.5, 0.7, 0.9), thresholds=(0.0, 0.1, 0.3, 0.5, 0.7, 0.9),
allow_no_crop=True, allow_no_crop=True,
num_attempts=10, num_attempts=10,
bbox_layout=None, bbox_layout='xyXY',
**kwargs **kwargs
): ):
"""Create a ``RandomBBoxCrop`` operator. """Create a ``RandomBBoxCrop`` operator.
...@@ -316,10 +316,10 @@ class RandomBBoxCrop(object): ...@@ -316,10 +316,10 @@ class RandomBBoxCrop(object):
thresholds : Sequence[float], optional thresholds : Sequence[float], optional
The minimum IoU(s) to satisfy. The minimum IoU(s) to satisfy.
allow_no_crop : bool, optional, default=True allow_no_crop : bool, optional, default=True
**True** to include the no-cropping as a option. **True** to include the no-cropping as an option.
num_attempts : int, optional, default=10 num_attempts : int, optional, default=10
The max number of sampling trails. The max number of sampling trails.
bbox_layout : str, optional bbox_layout : str, optional, default='xyXY'
The optional bbox layout. The optional bbox layout.
Returns Returns
...@@ -437,7 +437,7 @@ class Resize(object): ...@@ -437,7 +437,7 @@ class Resize(object):
resize_shorter=None, resize_shorter=None,
resize_longer=None, resize_longer=None,
max_size=None, max_size=None,
interp_type='LINEAR', interp_type=None,
mag_filter=None, mag_filter=None,
min_filter=None, min_filter=None,
**kwargs **kwargs
......
...@@ -153,6 +153,9 @@ dragon ...@@ -153,6 +153,9 @@ dragon
`transpose(...) <dragon/transpose.html>`_ `transpose(...) <dragon/transpose.html>`_
: Permute the dimensions of input. : Permute the dimensions of input.
`unique(...) <dragon/unique.html>`_
: Return the unique elements of input.
`where(...) <dragon/where.html>`_ `where(...) <dragon/where.html>`_
: Select the elements from two branches under the condition. : Select the elements from two branches under the condition.
...@@ -212,6 +215,7 @@ dragon ...@@ -212,6 +215,7 @@ dragon
dragon/Tensor dragon/Tensor
dragon/tile dragon/tile
dragon/transpose dragon/transpose
dragon/unique
dragon/where dragon/where
dragon/Workspace dragon/Workspace
dragon/zeros dragon/zeros
......
unique
======
.. autofunction:: dragon.unique
.. raw:: html
<style>
h1:before {
content: "dragon.";
color: #103d3e;
}
</style>
...@@ -161,7 +161,7 @@ Name Supported Reference ...@@ -161,7 +161,7 @@ Name Supported Reference
`Tile`_ |v| :func:`dragon.tile` `Tile`_ |v| :func:`dragon.tile`
`TopK`_ |v| :func:`dragon.math.top_k` `TopK`_ |v| :func:`dragon.math.top_k`
`Transpose`_ |v| :func:`dragon.transpose` `Transpose`_ |v| :func:`dragon.transpose`
`Unique`_ `Unique`_ |v| :func:`dragon.unique`
`Unsqueeze`_ |v| :func:`dragon.unsqueeze` `Unsqueeze`_ |v| :func:`dragon.unsqueeze`
`Upsample`_ |v| :func:`dragon.vision.resize` `Upsample`_ |v| :func:`dragon.vision.resize`
`Where`_ |v| :func:`dragon.where` `Where`_ |v| :func:`dragon.where`
......
...@@ -93,6 +93,12 @@ vm.tensorflow ...@@ -93,6 +93,12 @@ vm.tensorflow
`transpose(...) <tensorflow/transpose.html>`_ `transpose(...) <tensorflow/transpose.html>`_
: Permute the dimensions of input. : Permute the dimensions of input.
`unique(...) <tensorflow/unique.html>`_
: Return the unique elements of input.
`unique_with_counts(...) <tensorflow/unique_with_counts.html>`_
: Return the unique elements of input with counts.
`zeros(...) <tensorflow/zeros.html>`_ `zeros(...) <tensorflow/zeros.html>`_
: Return a tensor filled with zeros. : Return a tensor filled with zeros.
...@@ -130,6 +136,8 @@ vm.tensorflow ...@@ -130,6 +136,8 @@ vm.tensorflow
tensorflow/TensorShape tensorflow/TensorShape
tensorflow/TensorSpec tensorflow/TensorSpec
tensorflow/transpose tensorflow/transpose
tensorflow/unique
tensorflow/unique_with_counts
tensorflow/zeros tensorflow/zeros
tensorflow/zeros_like tensorflow/zeros_like
......
unique
======
.. autofunction:: dragon.vm.tensorflow.unique
.. raw:: html
<style>
h1:before {
content: "tf.";
color: #103d3e;
}
</style>
unique_with_counts
==================
.. autofunction:: dragon.vm.tensorflow.unique_with_counts
.. raw:: html
<style>
h1:before {
content: "tf.";
color: #103d3e;
}
</style>
...@@ -235,6 +235,9 @@ vm.torch ...@@ -235,6 +235,9 @@ vm.torch
`topk(...) <torch/topk.html>`_ `topk(...) <torch/topk.html>`_
: Return the top-K largest or smallest elements along the given dimension. : Return the top-K largest or smallest elements along the given dimension.
`unique(...) <torch/unique.html>`_
: Return the unique elements of input.
`unsqueeze(...) <torch/unsqueeze.html>`_ `unsqueeze(...) <torch/unsqueeze.html>`_
: Expand the dimensions of input with size 1. : Expand the dimensions of input with size 1.
...@@ -325,6 +328,7 @@ vm.torch ...@@ -325,6 +328,7 @@ vm.torch
torch/Tensor_ torch/Tensor_
torch/tensor torch/tensor
torch/topk torch/topk
torch/unique
torch/unsqueeze torch/unsqueeze
torch/where torch/where
torch/zeros_like torch/zeros_like
......
...@@ -425,6 +425,10 @@ sub\_ ...@@ -425,6 +425,10 @@ sub\_
##### #####
.. automethod:: dragon.vm.torch.Tensor.sub_ .. automethod:: dragon.vm.torch.Tensor.sub_
to
##
.. automethod:: dragon.vm.torch.Tensor.to
topk topk
#### ####
.. automethod:: dragon.vm.torch.Tensor.topk .. automethod:: dragon.vm.torch.Tensor.topk
...@@ -437,6 +441,10 @@ uniform\_ ...@@ -437,6 +441,10 @@ uniform\_
######### #########
.. automethod:: dragon.vm.torch.Tensor.uniform_ .. automethod:: dragon.vm.torch.Tensor.uniform_
unique
######
.. automethod:: dragon.vm.torch.Tensor.unique
unsqueeze unsqueeze
######### #########
.. automethod:: dragon.vm.torch.Tensor.unsqueeze .. automethod:: dragon.vm.torch.Tensor.unsqueeze
...@@ -500,6 +508,7 @@ zero\_ ...@@ -500,6 +508,7 @@ zero\_
.. _torch.sub(...): sub.html .. _torch.sub(...): sub.html
.. _torch.sum(...): sum.html .. _torch.sum(...): sum.html
.. _torch.topk(...): topk.html .. _torch.topk(...): topk.html
.. _torch.unique(...): unique.html
.. _torch.unsqueeze(...): unsqueeze.html .. _torch.unsqueeze(...): unsqueeze.html
.. _torch.where(...): where.html .. _torch.where(...): where.html
......
...@@ -18,6 +18,10 @@ apply ...@@ -18,6 +18,10 @@ apply
##### #####
.. automethod:: dragon.vm.torch.nn.Module.apply .. automethod:: dragon.vm.torch.nn.Module.apply
buffers
#######
.. automethod:: dragon.vm.torch.nn.Module.buffers
children children
######## ########
.. automethod:: dragon.vm.torch.nn.Module.children .. automethod:: dragon.vm.torch.nn.Module.children
...@@ -58,6 +62,10 @@ modules ...@@ -58,6 +62,10 @@ modules
####### #######
.. automethod:: dragon.vm.torch.nn.Module.modules .. automethod:: dragon.vm.torch.nn.Module.modules
named_buffers
#############
.. automethod:: dragon.vm.torch.nn.Module.named_buffers
named_children named_children
############## ##############
.. automethod:: dragon.vm.torch.nn.Module.named_children .. automethod:: dragon.vm.torch.nn.Module.named_children
......
unique
======
.. autofunction:: dragon.vm.torch.unique
.. raw:: html
<style>
h1:before {
content: "torch.";
color: #103d3e;
}
</style>
...@@ -16,6 +16,7 @@ void GraphOptimizer::BuildDAG(const GraphDef& graph) { ...@@ -16,6 +16,7 @@ void GraphOptimizer::BuildDAG(const GraphDef& graph) {
reference_count_[in] += 1; reference_count_[in] += 1;
} }
for (const auto& out : op.output()) { for (const auto& out : op.output()) {
if (out.empty()) continue;
if (op.input().empty()) { if (op.input().empty()) {
nodes_[""].childs.push_back(out); nodes_[""].childs.push_back(out);
nodes_[out].parents.push_back(""); nodes_[out].parents.push_back("");
......
#ifdef USE_CUDA #ifdef USE_CUDA
#include "dragon/core/context_cuda.h"
#include "dragon/utils/op_kernels.h" #include "dragon/utils/op_kernels.h"
namespace dragon { namespace dragon {
......
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
namespace kernel {
namespace {
template <typename T>
void _Unique(
const int dim,
const T* x,
T* y,
int64_t* inverse_index,
int64_t* counts,
int* num) {
vec32_t order(dim);
std::iota(order.begin(), order.end(), 0);
std::sort(order.begin(), order.end(), [x](const int i, const int j) {
return x[i] < x[j];
});
int n = dim, m;
for (int i = 1; i < dim; ++i) {
n -= x[order[i]] == x[order[i - 1]];
}
n = 0;
T prev = -1;
for (int i = 0; i < dim; ++i) {
if (i == 0 || prev != x[order[i]]) {
if (counts && i > 0) counts[n - 1] = m;
prev = y[n++] = x[order[i]];
m = 1;
} else {
m += 1;
}
if (inverse_index) {
inverse_index[order[i]] = n - 1;
}
}
num[0] = n;
if (counts) counts[n - 1] = m;
}
} // namespace
template <>
void Unique<float16, CPUContext>(
const int dim,
const float16* x,
float16* y,
int64_t* inverse_index,
int64_t* counts,
int* num,
CPUContext* ctx) {
CPU_FP16_NOT_SUPPORTED;
}
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \
void Unique<T, CPUContext>( \
const int dim, \
const T* x, \
T* y, \
int64_t* inverse_index, \
int64_t* counts, \
int* num, \
CPUContext* ctx) { \
_Unique(dim, x, y, inverse_index, counts, num); \
}
DEFINE_KERNEL_LAUNCHER(int8_t);
DEFINE_KERNEL_LAUNCHER(uint8_t);
DEFINE_KERNEL_LAUNCHER(int);
DEFINE_KERNEL_LAUNCHER(int64_t);
DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double);
} // namespace kernel
} // namespace dragon
#ifdef USE_CUDA
#include "dragon/core/context_cuda.h"
#include "dragon/utils/device/common_thrust.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
namespace kernel {
namespace {
__global__ void _RemapInverse(
const int dim,
const int num,
thrust::device_ptr<int> order1,
thrust::device_ptr<int> order2,
int64_t* inverse_index) {
const int yi = blockDim.x * blockIdx.x + threadIdx.x;
if (yi >= num) return;
int xi = order2[yi];
inverse_index[order1[xi]] = yi;
for (xi++; xi < dim && (yi == num - 1 || xi != order2[yi + 1]); xi++) {
inverse_index[order1[xi]] = yi;
}
}
__global__ void _ComputeCounts(
const int dim,
const int num,
thrust::device_ptr<int> order2,
int64_t* counts) {
const int yi = blockDim.x * blockIdx.x + threadIdx.x;
if (yi >= num) return;
counts[yi] = (yi == num - 1 ? dim : order2[yi + 1]) - order2[yi];
}
} // namespace
/* ------------------- Launcher Separator ------------------- */
template <>
void Unique<float16, CUDAContext>(
const int dim,
const float16* x,
float16* y,
int64_t* inverse_index,
int64_t* counts,
int* num,
CUDAContext* ctx) {
LOG(FATAL) << "FP16 is unsupported for CUDAContext.";
}
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \
void Unique<T, CUDAContext>( \
const int dim, \
const T* x, \
T* y, \
int64_t* inverse_index, \
int64_t* counts, \
int* num, \
CUDAContext* ctx) { \
math::Copy(dim, x, y, ctx); \
auto policy = thrust::cuda::par.on(ctx->cuda_stream()); \
thrust::device_vector<int> order1(dim), order2(dim); \
thrust::sequence(policy, order1.begin(), order1.end()); \
thrust::sequence(policy, order2.begin(), order2.end()); \
thrust::sort_by_key(policy, y, y + dim, order1.begin()); \
auto last = thrust::unique_by_key(policy, y, y + dim, order2.begin()); \
int n = num[0] = last.first - y; \
if (inverse_index) { \
_RemapInverse<<<CUDA_BLOCKS(n), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
dim, n, order1.data(), order2.data(), inverse_index); \
} \
if (counts) { \
_ComputeCounts<<<CUDA_BLOCKS(n), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
dim, n, order2.data(), counts); \
} \
}
DEFINE_KERNEL_LAUNCHER(int8_t);
DEFINE_KERNEL_LAUNCHER(uint8_t);
DEFINE_KERNEL_LAUNCHER(int);
DEFINE_KERNEL_LAUNCHER(int64_t);
DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER
} // namespace kernel
} // namespace dragon
#endif // USE_CUDA
#ifdef USE_CUDA #ifdef USE_CUDA
#include "dragon/core/context_cuda.h" #include "dragon/core/context_cuda.h"
#include "dragon/utils/device/common_cub.h"
#include "dragon/utils/op_kernels.h" #include "dragon/utils/op_kernels.h"
namespace dragon { namespace dragon {
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include "dragon/core/context_cuda.h" #include "dragon/core/context_cuda.h"
#include "dragon/utils/device/common_cub.h" #include "dragon/utils/device/common_cub.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h" #include "dragon/utils/op_kernels.h"
namespace dragon { namespace dragon {
......
#include "dragon/operators/array/unique_op.h"
#include "dragon/core/workspace.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
template <class Context>
template <typename T>
void UniqueOp<Context>::DoRunWithType() {
auto &X = Input(0), *Y = Output(0);
int num_values;
Tensor *X_index = nullptr, *Y_counts = nullptr;
int64_t *inverse_index = nullptr, *counts = nullptr;
if (OutputSize() == 2) {
if (return_inverse_) {
X_index = Output(1)->ReshapeLike(X);
inverse_index = X_index->template mutable_data<int64_t, Context>();
} else if (return_counts_) {
Y_counts = Output(1)->ReshapeLike(X);
counts = Y_counts->template mutable_data<int64_t, Context>();
}
} else if (OutputSize() == 3) {
X_index = Output(1)->ReshapeLike(X);
Y_counts = Output(2)->ReshapeLike(X);
inverse_index = X_index->template mutable_data<int64_t, Context>();
counts = Y_counts->template mutable_data<int64_t, Context>();
}
kernel::Unique(
X.count(),
X.template data<T, Context>(),
Y->ReshapeLike(X)->template mutable_data<T, Context>(),
inverse_index,
counts,
&num_values,
ctx());
// Shrink to match the number of values
Y->Reshape({num_values});
if (Y_counts) Y_counts->Reshape({num_values});
}
template <class Context>
void UniqueOp<Context>::RunOnDevice() {
DispatchHelper<NumericalTensorTypes>::Call(this, Input(0));
}
DEPLOY_CPU_OPERATOR(Unique);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(Unique);
#endif
OPERATOR_SCHEMA(Unique)
/* X */
.NumInputs(1)
/* Y, InverseIndex, Counts */
.NumOutputs(1, 3);
NO_GRADIENT(Unique);
} // namespace dragon
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ARRAY_UNIQUE_OP_H_
#define DRAGON_OPERATORS_ARRAY_UNIQUE_OP_H_
#include "dragon/core/operator.h"
namespace dragon {
template <class Context>
class UniqueOp final : public Operator<Context> {
public:
UniqueOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
return_inverse_(OP_SINGLE_ARG(int64_t, "return_inverse", 0)),
return_counts_(OP_SINGLE_ARG(int64_t, "return_counts", 0)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T>
void DoRunWithType();
protected:
int64_t return_inverse_, return_counts_;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_ARRAY_UNIQUE_OP_H_
...@@ -479,7 +479,6 @@ void CuDNNConv2dGradientOp<Context>::DoRunWithType() { ...@@ -479,7 +479,6 @@ void CuDNNConv2dGradientOp<Context>::DoRunWithType() {
// Determine the workspace size for selected algorithm // Determine the workspace size for selected algorithm
if (cudnn_ws_nbytes_ == SIZE_MAX) { if (cudnn_ws_nbytes_ == SIZE_MAX) {
size_t bwd_filter_size = 0, bwd_data_size = 0; size_t bwd_filter_size = 0, bwd_data_size = 0;
if (dW->has_name()) {
CUDNN_CHECK(cudnnGetConvolutionBackwardFilterWorkspaceSize( CUDNN_CHECK(cudnnGetConvolutionBackwardFilterWorkspaceSize(
ctx()->cudnn_handle(), ctx()->cudnn_handle(),
output_desc_, output_desc_,
...@@ -488,8 +487,6 @@ void CuDNNConv2dGradientOp<Context>::DoRunWithType() { ...@@ -488,8 +487,6 @@ void CuDNNConv2dGradientOp<Context>::DoRunWithType() {
filter_desc_, filter_desc_,
bwd_filter_algo_, bwd_filter_algo_,
&bwd_filter_size)); &bwd_filter_size));
}
if (dX->has_name()) {
CUDNN_CHECK(cudnnGetConvolutionBackwardDataWorkspaceSize( CUDNN_CHECK(cudnnGetConvolutionBackwardDataWorkspaceSize(
ctx()->cudnn_handle(), ctx()->cudnn_handle(),
filter_desc_, filter_desc_,
...@@ -498,7 +495,6 @@ void CuDNNConv2dGradientOp<Context>::DoRunWithType() { ...@@ -498,7 +495,6 @@ void CuDNNConv2dGradientOp<Context>::DoRunWithType() {
output_desc_, output_desc_,
bwd_data_algo_, bwd_data_algo_,
&bwd_data_size)); &bwd_data_size));
}
cudnn_ws_nbytes_ = std::max(bwd_filter_size, bwd_data_size); cudnn_ws_nbytes_ = std::max(bwd_filter_size, bwd_data_size);
} }
......
...@@ -474,7 +474,6 @@ void CuDNNConvTranspose2dGradientOp<Context>::DoRunWithType() { ...@@ -474,7 +474,6 @@ void CuDNNConvTranspose2dGradientOp<Context>::DoRunWithType() {
// Determine the workspace size for selected algorithm // Determine the workspace size for selected algorithm
if (cudnn_ws_nbytes_ == SIZE_MAX) { if (cudnn_ws_nbytes_ == SIZE_MAX) {
size_t bwd_filter_size = 0, bwd_data_size = 0; size_t bwd_filter_size = 0, bwd_data_size = 0;
if (dW->has_name()) {
CUDNN_CHECK(cudnnGetConvolutionBackwardFilterWorkspaceSize( CUDNN_CHECK(cudnnGetConvolutionBackwardFilterWorkspaceSize(
ctx()->cudnn_handle(), ctx()->cudnn_handle(),
input_desc_, input_desc_,
...@@ -483,8 +482,6 @@ void CuDNNConvTranspose2dGradientOp<Context>::DoRunWithType() { ...@@ -483,8 +482,6 @@ void CuDNNConvTranspose2dGradientOp<Context>::DoRunWithType() {
filter_desc_, filter_desc_,
bwd_filter_algo_, bwd_filter_algo_,
&bwd_filter_size)); &bwd_filter_size));
}
if (dX->has_name()) {
CUDNN_CHECK(cudnnGetConvolutionForwardWorkspaceSize( CUDNN_CHECK(cudnnGetConvolutionForwardWorkspaceSize(
ctx()->cudnn_handle(), ctx()->cudnn_handle(),
input_desc_, input_desc_,
...@@ -493,7 +490,6 @@ void CuDNNConvTranspose2dGradientOp<Context>::DoRunWithType() { ...@@ -493,7 +490,6 @@ void CuDNNConvTranspose2dGradientOp<Context>::DoRunWithType() {
output_desc_, output_desc_,
bwd_data_algo_, bwd_data_algo_,
&bwd_data_size)); &bwd_data_size));
}
cudnn_ws_nbytes_ = std::max(bwd_filter_size, bwd_data_size); cudnn_ws_nbytes_ = std::max(bwd_filter_size, bwd_data_size);
} }
...@@ -514,7 +510,7 @@ void CuDNNConvTranspose2dGradientOp<Context>::DoRunWithType() { ...@@ -514,7 +510,7 @@ void CuDNNConvTranspose2dGradientOp<Context>::DoRunWithType() {
db)); db));
} }
if (Output(1)->has_name()) { if (dW->has_name()) {
x = X.template data<T, Context>(); x = X.template data<T, Context>();
dw = dW->template mutable_data<T, Context>(); dw = dW->template mutable_data<T, Context>();
for (int g = 0; g < cudnn_group_; g++) { for (int g = 0; g < cudnn_group_; g++) {
...@@ -535,7 +531,7 @@ void CuDNNConvTranspose2dGradientOp<Context>::DoRunWithType() { ...@@ -535,7 +531,7 @@ void CuDNNConvTranspose2dGradientOp<Context>::DoRunWithType() {
} }
} }
if (Output(0)->has_name()) { if (dX->has_name()) {
auto* w = W.template data<T, Context>(); auto* w = W.template data<T, Context>();
auto* dx = dX->template mutable_data<T, Context>(); auto* dx = dX->template mutable_data<T, Context>();
for (int g = 0; g < cudnn_group_; g++) { for (int g = 0; g < cudnn_group_; g++) {
......
...@@ -76,6 +76,7 @@ from dragon.core.ops.array_ops import squeeze ...@@ -76,6 +76,7 @@ from dragon.core.ops.array_ops import squeeze
from dragon.core.ops.array_ops import stack from dragon.core.ops.array_ops import stack
from dragon.core.ops.array_ops import tile from dragon.core.ops.array_ops import tile
from dragon.core.ops.array_ops import transpose from dragon.core.ops.array_ops import transpose
from dragon.core.ops.array_ops import unique
from dragon.core.ops.array_ops import where from dragon.core.ops.array_ops import where
from dragon.core.ops.control_flow_ops import assign from dragon.core.ops.control_flow_ops import assign
from dragon.core.ops.control_flow_ops import copy from dragon.core.ops.control_flow_ops import copy
......
...@@ -1045,3 +1045,28 @@ def unchanged_spec(args, inputs, outputs): ...@@ -1045,3 +1045,28 @@ def unchanged_spec(args, inputs, outputs):
except TypeError: except TypeError:
pass pass
return outputs return outputs
@register('Unique')
def unique_spec(args, inputs, outputs):
return_inverse = args['return_inverse']
return_counts = args['return_counts']
outputs[0].dtype = inputs[0].dtype
for i in range(1, len(outputs)):
outputs[i].dtype = 'int64'
outputs[0].shape = (None,)
if len(outputs) == 2:
if return_inverse:
try:
outputs[1].shape = inputs[0].shape[:]
except TypeError:
pass
elif return_counts:
outputs[1].shape = (None,)
elif len(outputs) == 3:
try:
outputs[1].shape = inputs[0].shape[:]
except TypeError:
pass
outputs[2].shape = (None,)
return outputs
...@@ -962,8 +962,7 @@ def pad(inputs, pads, mode='constant', value=0, **kwargs): ...@@ -962,8 +962,7 @@ def pad(inputs, pads, mode='constant', value=0, **kwargs):
if len(pad) != 2: if len(pad) != 2:
raise ValueError( raise ValueError(
'The tuple length of <pads> ' 'The tuple length of <pads> '
'should be 2, got {}.'.format(len(pad)) 'should be 2, got {}.'.format(len(pad)))
)
pads_begin.append(pad[0]) pads_begin.append(pad[0])
pads_end.append(pad[1]) pads_end.append(pad[1])
args['pads'] = pads_begin + pads_end args['pads'] = pads_begin + pads_end
...@@ -1562,6 +1561,59 @@ def top_k(inputs, k=1, axis=None, largest=True, sorted=True, **kwargs): ...@@ -1562,6 +1561,59 @@ def top_k(inputs, k=1, axis=None, largest=True, sorted=True, **kwargs):
return op_lib.blend(**args) return op_lib.blend(**args)
@OpSchema.num_inputs(1)
def unique(inputs, return_inverse=False, return_counts=False, **kwargs):
"""Return the unique elements of input.
If ``return_inverse``, return the extra index where input mapping to:
```python
x = dragon.constant([1, 2, 3, 2])
y, index = dragon.unique(x, return_inverse=True)
print(y) # [1, 2, 3]
print(index) # [0, 1, 2, 1]
```
If ``return_counts``, return the extra counts of output:
```python
x = dragon.constant([1, 2, 3, 2])
y, counts = dragon.unique(x, return_counts=True)
print(y) # [1, 2, 3]
print(counts) # [1, 2, 1]
```
Parameters
----------
inputs : dragon.Tensor
The input tensor.
return_inverse : bool, optional, default=False
Return the inverse index or not.
return_counts : bool, optional, default=False
Return the counts or not.
Returns
-------
dragon.Tensor
The output tensor.
dragon.Tensor, optional
The inverse index tensor.
dragon.Tensor, optional
The counts tensor.
"""
args = parse_args(locals())
op_lib = array_ops_lib.Unique
if context.executing_eagerly():
return op_lib.instantiate(
return_inverse=return_inverse,
return_counts=return_counts,
).apply([inputs])
else:
num_outputs = 1 + return_inverse + return_counts
return op_lib.blend(num_outputs=num_outputs, **args)
@OpSchema.num_inputs(1, 3) @OpSchema.num_inputs(1, 3)
def where(inputs, **kwargs): def where(inputs, **kwargs):
r"""Select the elements from two branches under the condition. r"""Select the elements from two branches under the condition.
......
...@@ -665,6 +665,27 @@ class TopK(Operator): ...@@ -665,6 +665,27 @@ class TopK(Operator):
return self.dispatch(inputs, [self.alloc(), self.alloc()], no_grad=True) return self.dispatch(inputs, [self.alloc(), self.alloc()], no_grad=True)
class Unique(Operator):
def __init__(self, key, dev, **kwargs):
super(Unique, self).__init__(key, dev, **kwargs)
self.return_inverse = kwargs.get('return_inverse', False)
self.return_counts = kwargs.get('return_counts', False)
self.num_outputs = 1 + self.return_inverse + self.return_counts
def attributes(self):
return {
'op_type': 'Unique',
'arguments': {
'return_inverse': self.return_inverse,
'return_counts': self.return_counts,
}
}
def forward(self, inputs):
outputs = [self.alloc() for _ in range(self.num_outputs)]
return self.dispatch(inputs, outputs)
class Where(Operator): class Where(Operator):
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(Where, self).__init__(key, dev, **kwargs) super(Where, self).__init__(key, dev, **kwargs)
......
...@@ -322,13 +322,11 @@ def _get_cuda_arch_flags(cflags=None): ...@@ -322,13 +322,11 @@ def _get_cuda_arch_flags(cflags=None):
for flag in cflags: for flag in cflags:
if 'arch' in flag: if 'arch' in flag:
return [] return []
supported_arches = [ supported_arches = ['3.5', '3.7',
'3.5', '3.7',
'5.0', '5.2', '5.3', '5.0', '5.2', '5.3',
'6.0', '6.1', '6.2', '6.0', '6.1', '6.2',
'7.0', '7.2', '7.5', '7.0', '7.2', '7.5',
'8.0', '8.0']
]
valid_arch_strings = supported_arches + [s + "+PTX" for s in supported_arches] valid_arch_strings = supported_arches + [s + "+PTX" for s in supported_arches]
capability = _cuda.get_device_capability() capability = _cuda.get_device_capability()
arch_list = ['{}.{}'.format(capability[0], capability[1])] arch_list = ['{}.{}'.format(capability[0], capability[1])]
......
...@@ -89,7 +89,7 @@ class DataIterator(object): ...@@ -89,7 +89,7 @@ class DataIterator(object):
cutout_size : int, optional, default=0 cutout_size : int, optional, default=0
The square size for the cutout algorithm. The square size for the cutout algorithm.
mirror : bool, optional, default=False mirror : bool, optional, default=False
Whether to mirror(flip horizontally) images. Whether to apply the mirror (flip horizontally).
random_scales : Sequence[float], optional, default=(0.08, 1.) random_scales : Sequence[float], optional, default=(0.08, 1.)
The range of scales to sample a crop randomly. The range of scales to sample a crop randomly.
random_aspect_ratios : Sequence[float], optional, default=(0.75, 1.33) random_aspect_ratios : Sequence[float], optional, default=(0.75, 1.33)
...@@ -133,7 +133,7 @@ class DataIterator(object): ...@@ -133,7 +133,7 @@ class DataIterator(object):
if kwargs.get('random_crop_size', 0) > 0: if kwargs.get('random_crop_size', 0) > 0:
self._num_transformers += 1 self._num_transformers += 1
# Add a transformer for distortion. # Add a transformer for distortion.
if kwargs.get('augment_color', False): if kwargs.get('distort_color', False):
self._num_transformers += 1 self._num_transformers += 1
# Initialize queues. # Initialize queues.
......
...@@ -13,19 +13,13 @@ from __future__ import absolute_import ...@@ -13,19 +13,13 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import io
import math import math
import multiprocessing import multiprocessing
import numpy import numpy
try: import PIL.Image
import cv2 import PIL.ImageEnhance
except ImportError:
cv2 = None
try:
import PIL.Image
import PIL.ImageEnhance
except ImportError:
PIL = None
from dragon.core.framework import config from dragon.core.framework import config
...@@ -51,7 +45,7 @@ class DataTransformer(multiprocessing.Process): ...@@ -51,7 +45,7 @@ class DataTransformer(multiprocessing.Process):
cutout_size : int, optional, default=0 cutout_size : int, optional, default=0
The square size for the cutout algorithm. The square size for the cutout algorithm.
mirror : bool, optional, default=False mirror : bool, optional, default=False
Whether to mirror(flip horizontally) images. Whether to apply the mirror (flip horizontally).
random_scales : Sequence[float], optional, default=(0.08, 1.) random_scales : Sequence[float], optional, default=(0.08, 1.)
The range of scales to sample a crop randomly. The range of scales to sample a crop randomly.
random_aspect_ratios : Sequence[float], optional, default=(0.75, 1.33) random_aspect_ratios : Sequence[float], optional, default=(0.75, 1.33)
...@@ -82,10 +76,6 @@ class DataTransformer(multiprocessing.Process): ...@@ -82,10 +76,6 @@ class DataTransformer(multiprocessing.Process):
self._seed = kwargs.get('seed', config.config().random_seed) self._seed = kwargs.get('seed', config.config().random_seed)
self.q_in = self.q_out = None self.q_in = self.q_out = None
self.daemon = True self.daemon = True
if cv2 is None:
raise ImportError('Failed to import package <cv2>.')
if self._distort_color and PIL is None:
raise ImportError('Failed to import package <PIL>.')
def get(self, example): def get(self, example):
"""Return image and labels from a serialized str. """Return image and labels from a serialized str.
...@@ -104,24 +94,26 @@ class DataTransformer(multiprocessing.Process): ...@@ -104,24 +94,26 @@ class DataTransformer(multiprocessing.Process):
""" """
# Decode. # Decode.
img = numpy.frombuffer(example['data'], numpy.uint8) if example['encoded'] > 0:
if example.get('encoded', 0) > 0: img = PIL.Image.open(io.BytesIO(example['data']))
img = cv2.imdecode(img, 1)
else: else:
img = numpy.frombuffer(example['data'], numpy.uint8)
img = img.reshape(example['shape']) img = img.reshape(example['shape'])
# Resizing. # Resizing.
if self._resize > 0: if self._resize > 0:
(h, w), size = img.shape[:2], self._resize (w, h), size = img.size, self._resize
if (w <= h and w == size) or (h <= w and h == size): if (w <= h and w == size) or (h <= w and h == size):
pass pass
else: else:
if w < h: if w < h:
ow, oh, im_scale = size, size * h // w, float(size) / w ow, oh = size, size * h // w
else: else:
oh, ow, im_scale = size, size * w // h, float(size) / h oh, ow = size, size * w // h
interp = cv2.INTER_AREA if im_scale < 1 else cv2.INTER_LINEAR img = img.resize((ow, oh), PIL.Image.BILINEAR)
img = cv2.resize(img, (ow, oh), interpolation=interp)
# ToArray.
img = numpy.asarray(img)
# Padding. # Padding.
if self._padding > 0: if self._padding > 0:
...@@ -152,12 +144,9 @@ class DataTransformer(multiprocessing.Process): ...@@ -152,12 +144,9 @@ class DataTransformer(multiprocessing.Process):
area = height * width area = height * width
i = j = h = w = None i = j = h = w = None
for attempt in range(10): for attempt in range(10):
target_area = numpy.random.uniform( target_area = numpy.random.uniform(*self._random_scales) * area
*self._random_scales) * area log_ratio = (math.log(self._random_ratios[0]),
log_ratio = ( math.log(self._random_ratios[1]))
math.log(self._random_ratios[0]),
math.log(self._random_ratios[1]),
)
aspect_ratio = math.exp(numpy.random.uniform(*log_ratio)) aspect_ratio = math.exp(numpy.random.uniform(*log_ratio))
w = int(round(math.sqrt(target_area * aspect_ratio))) w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio))) h = int(round(math.sqrt(target_area / aspect_ratio)))
...@@ -179,9 +168,8 @@ class DataTransformer(multiprocessing.Process): ...@@ -179,9 +168,8 @@ class DataTransformer(multiprocessing.Process):
j = (width - w) // 2 j = (width - w) // 2
img = img[i:i + h, j:j + w, :] img = img[i:i + h, j:j + w, :]
new_size = (self._random_crop_size, self._random_crop_size) new_size = (self._random_crop_size, self._random_crop_size)
min_scale = self._random_crop_size / max(img.shape[:2]) img = PIL.Image.fromarray(img)
interp = cv2.INTER_AREA if min_scale < 1 else cv2.INTER_LINEAR img = numpy.asarray(img.resize(new_size, PIL.Image.BILINEAR))
img = cv2.resize(img, new_size, interpolation=interp)
# CutOut. # CutOut.
if self._cutout_size > 0: if self._cutout_size > 0:
...@@ -202,16 +190,14 @@ class DataTransformer(multiprocessing.Process): ...@@ -202,16 +190,14 @@ class DataTransformer(multiprocessing.Process):
# Color distortion. # Color distortion.
if self._distort_color: if self._distort_color:
img = PIL.Image.fromarray(img) img = PIL.Image.fromarray(img)
transforms = [ transforms = [PIL.ImageEnhance.Brightness,
PIL.ImageEnhance.Brightness,
PIL.ImageEnhance.Contrast, PIL.ImageEnhance.Contrast,
PIL.ImageEnhance.Color, PIL.ImageEnhance.Color]
]
numpy.random.shuffle(transforms) numpy.random.shuffle(transforms)
for transform in transforms: for transform in transforms:
img = transform(img) img = transform(img)
img = img.enhance(1. + numpy.random.uniform(-.4, .4)) img = img.enhance(1. + numpy.random.uniform(-.4, .4))
img = numpy.array(img) img = numpy.asarray(img)
# Color transformation. # Color transformation.
if self._inverse_color: if self._inverse_color:
......
...@@ -251,6 +251,7 @@ class DragonFrontend(object): ...@@ -251,6 +251,7 @@ class DragonFrontend(object):
for e in op_def.output: for e in op_def.output:
outputs.append(e + '_%d' % blob_versions[e] outputs.append(e + '_%d' % blob_versions[e]
if blob_versions[e] > 0 else e) if blob_versions[e] > 0 else e)
if e != '':
blob_versions[e] += 1 blob_versions[e] += 1
blob_names[e] = outputs[-1] blob_names[e] = outputs[-1]
op_def.ClearField('input') op_def.ClearField('input')
......
...@@ -72,7 +72,6 @@ def softmax_exporter(op_def, shape_dict, ws): ...@@ -72,7 +72,6 @@ def softmax_exporter(op_def, shape_dict, ws):
if axis != (ndim - 1): if axis != (ndim - 1):
raise ValueError( raise ValueError(
'Softmax axis could only be the last one.\n' 'Softmax axis could only be the last one.\n'
'Use Exp(LogSoftmax) to compute the softmax instead.' 'Use Exp(LogSoftmax) to compute the softmax instead.')
)
helper.add_attribute(node, 'axis', arg.i) helper.add_attribute(node, 'axis', arg.i)
return node, const_tensors return node, const_tensors
...@@ -494,3 +494,27 @@ def top_k_exporter_v11(op_def, shape_dict, ws): ...@@ -494,3 +494,27 @@ def top_k_exporter_v11(op_def, shape_dict, ws):
) )
node.input.extend([k.name]) node.input.extend([k.name])
return node, [k] return node, [k]
@exporter.register('Unique')
def unique_exporter(op_def, shape_dict, ws):
node, const_tensors = exporter.translate(**locals())
helper.add_attribute(node, 'sorted', 1)
return_inverse = return_counts = 0
for arg in op_def.arg:
if arg.name == 'return_inverse':
return_inverse = arg.i
elif arg.name == 'return_counts':
return_counts = arg.i
outputs = [op_def.output[0]]
if len(op_def.output) > 1:
outputs.append('')
if len(op_def.output) == 2:
if return_inverse:
outputs.append(op_def.output[1])
elif return_counts:
outputs.extend(['', op_def.output[1]])
elif len(op_def.output) == 3:
outputs.extend([op_def.output[1], op_def.output[2]])
node.output[:] = outputs
return node, const_tensors
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#ifdef USE_CUDA #ifdef USE_CUDA
#include <thrust/device_ptr.h> #include <thrust/device_ptr.h>
#include <thrust/device_vector.h>
#include <thrust/execution_policy.h> #include <thrust/execution_policy.h>
#include <thrust/sequence.h> #include <thrust/sequence.h>
#include <thrust/sort.h> #include <thrust/sort.h>
......
...@@ -544,8 +544,18 @@ void TopK( ...@@ -544,8 +544,18 @@ void TopK(
int64_t* index, int64_t* index,
Context* ctx); Context* ctx);
/* control_flow.assgin */ /* array.unique */
template <typename T, class Context>
void Unique(
const int dim,
const T* x,
T* y,
int64_t* inverse_index,
int64_t* counts,
int* num,
Context* ctx);
/* control_flow.assgin */
template <typename T, class Context> template <typename T, class Context>
void Assign( void Assign(
const int num_dims, const int num_dims,
......
...@@ -85,6 +85,8 @@ from dragon.vm.tensorflow.core.ops.array_ops import split ...@@ -85,6 +85,8 @@ from dragon.vm.tensorflow.core.ops.array_ops import split
from dragon.vm.tensorflow.core.ops.array_ops import squeeze from dragon.vm.tensorflow.core.ops.array_ops import squeeze
from dragon.vm.tensorflow.core.ops.array_ops import tile from dragon.vm.tensorflow.core.ops.array_ops import tile
from dragon.vm.tensorflow.core.ops.array_ops import transpose from dragon.vm.tensorflow.core.ops.array_ops import transpose
from dragon.vm.tensorflow.core.ops.array_ops import unique
from dragon.vm.tensorflow.core.ops.array_ops import unique_with_counts
from dragon.vm.tensorflow.core.ops.array_ops import zeros from dragon.vm.tensorflow.core.ops.array_ops import zeros
from dragon.vm.tensorflow.core.ops.array_ops import zeros_like from dragon.vm.tensorflow.core.ops.array_ops import zeros_like
from dragon.vm.tensorflow.core.ops.clip_ops import clip_by_value from dragon.vm.tensorflow.core.ops.clip_ops import clip_by_value
......
...@@ -31,7 +31,7 @@ class Loss(object): ...@@ -31,7 +31,7 @@ class Loss(object):
reduction : {'none', 'sum', 'mean', 'valid'}, optional reduction : {'none', 'sum', 'mean', 'valid'}, optional
The reduction method. The reduction method.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
""" """
losses_utils.Reduction.validate(reduction) losses_utils.Reduction.validate(reduction)
...@@ -112,7 +112,7 @@ class BinaryCrossentropy(LossFunctionWrapper): ...@@ -112,7 +112,7 @@ class BinaryCrossentropy(LossFunctionWrapper):
reduction : {'none', 'sum', 'mean', 'valid'}, optional reduction : {'none', 'sum', 'mean', 'valid'}, optional
The reduction method. The reduction method.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
""" """
super(BinaryCrossentropy, self).__init__( super(BinaryCrossentropy, self).__init__(
...@@ -155,7 +155,7 @@ class CategoricalCrossentropy(LossFunctionWrapper): ...@@ -155,7 +155,7 @@ class CategoricalCrossentropy(LossFunctionWrapper):
reduction : {'none', 'sum', 'mean', 'valid'}, optional reduction : {'none', 'sum', 'mean', 'valid'}, optional
The reduction method. The reduction method.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
""" """
super(CategoricalCrossentropy, self).__init__( super(CategoricalCrossentropy, self).__init__(
...@@ -196,7 +196,7 @@ class MeanAbsoluteError(LossFunctionWrapper): ...@@ -196,7 +196,7 @@ class MeanAbsoluteError(LossFunctionWrapper):
reduction : {'none', 'sum', 'mean'}, optional reduction : {'none', 'sum', 'mean'}, optional
The reduction method. The reduction method.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
""" """
super(MeanAbsoluteError, self).__init__( super(MeanAbsoluteError, self).__init__(
...@@ -236,7 +236,7 @@ class MeanSquaredError(LossFunctionWrapper): ...@@ -236,7 +236,7 @@ class MeanSquaredError(LossFunctionWrapper):
reduction : {'none', 'sum', 'mean'}, optional reduction : {'none', 'sum', 'mean'}, optional
The reduction method. The reduction method.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
""" """
super(MeanSquaredError, self).__init__( super(MeanSquaredError, self).__init__(
...@@ -282,7 +282,7 @@ class SparseCategoricalCrossentropy(LossFunctionWrapper): ...@@ -282,7 +282,7 @@ class SparseCategoricalCrossentropy(LossFunctionWrapper):
reduction : {'none', 'sum', 'mean', 'valid'}, optional reduction : {'none', 'sum', 'mean', 'valid'}, optional
The reduction method. The reduction method.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
""" """
super(SparseCategoricalCrossentropy, self).__init__( super(SparseCategoricalCrossentropy, self).__init__(
......
...@@ -58,7 +58,7 @@ def broadcast_to(input, shape, name=None): ...@@ -58,7 +58,7 @@ def broadcast_to(input, shape, name=None):
shape : Sequence[Union[int, dragon.Tensor]] shape : Sequence[Union[int, dragon.Tensor]]
The output shape to broadcast to. The output shape to broadcast to.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -95,7 +95,7 @@ def concat(values, axis, name='concat'): ...@@ -95,7 +95,7 @@ def concat(values, axis, name='concat'):
axis : int axis : int
The axis to concatenate The axis to concatenate
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -129,7 +129,7 @@ def depth_to_space(input, block_size, data_format='NHWC', name=None): ...@@ -129,7 +129,7 @@ def depth_to_space(input, block_size, data_format='NHWC', name=None):
data_format : {'NCHW', 'NHWC'}, optional data_format : {'NCHW', 'NHWC'}, optional
The optional data format. The optional data format.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -172,7 +172,7 @@ def expand_dims(input, axis, name=None): ...@@ -172,7 +172,7 @@ def expand_dims(input, axis, name=None):
axis : Union[int, Sequence[int]] axis : Union[int, Sequence[int]]
The axis to insert the new dimension(s). The axis to insert the new dimension(s).
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -203,7 +203,7 @@ def fill(dims, value=0, dtype=None, name=None): ...@@ -203,7 +203,7 @@ def fill(dims, value=0, dtype=None, name=None):
dtype : str, optional dtype : str, optional
The optional data type. The optional data type.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -248,7 +248,7 @@ def gather(params, indices, axis=0, name=None): ...@@ -248,7 +248,7 @@ def gather(params, indices, axis=0, name=None):
axis : Union[int, Sequence[int]], optional, default=0 axis : Union[int, Sequence[int]], optional, default=0
The axis where the indices aligned. The axis where the indices aligned.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -280,7 +280,7 @@ def identity(input, name=None): ...@@ -280,7 +280,7 @@ def identity(input, name=None):
input : dragon.Tensor input : dragon.Tensor
The input tensor. The input tensor.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -307,7 +307,7 @@ def ones(shape, dtype='float32', name=None): ...@@ -307,7 +307,7 @@ def ones(shape, dtype='float32', name=None):
dtype : str, optional, default='float32' dtype : str, optional, default='float32'
The optional data type. The optional data type.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
""" """
return init_ops.fill(shape, value=1, dtype=dtype, name=name) return init_ops.fill(shape, value=1, dtype=dtype, name=name)
...@@ -332,7 +332,7 @@ def ones_like(input, dtype='float32', name=None): ...@@ -332,7 +332,7 @@ def ones_like(input, dtype='float32', name=None):
dtype : str, optional, default='float32' dtype : str, optional, default='float32'
The optional data type. The optional data type.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
""" """
return init_ops.ones_like(input, dtype=dtype, name=name) return init_ops.ones_like(input, dtype=dtype, name=name)
...@@ -378,7 +378,7 @@ def one_hot( ...@@ -378,7 +378,7 @@ def one_hot(
off_value : int, optional, default=0 off_value : int, optional, default=0
The value for not-equal branch. The value for not-equal branch.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -439,7 +439,7 @@ def pad( ...@@ -439,7 +439,7 @@ def pad(
constant_values : int, optional, default=0 constant_values : int, optional, default=0
The constant value in ``CONSTANT`` mode. The constant value in ``CONSTANT`` mode.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -515,7 +515,7 @@ def reshape(tensor, shape, name=None): ...@@ -515,7 +515,7 @@ def reshape(tensor, shape, name=None):
shape : Union[Sequence[int], dragon.Tensor] shape : Union[Sequence[int], dragon.Tensor]
The output shape. The output shape.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
""" """
return array_ops.reshape(tensor, shape=shape, name=name) return array_ops.reshape(tensor, shape=shape, name=name)
...@@ -537,7 +537,7 @@ def shape(input, name=None): ...@@ -537,7 +537,7 @@ def shape(input, name=None):
input : dragon.Tensor input : dragon.Tensor
The input tensor. The input tensor.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -580,7 +580,7 @@ def slice(input_, begin, size, name=None): ...@@ -580,7 +580,7 @@ def slice(input_, begin, size, name=None):
size : Union[Sequence[int], dragon.Tensor] size : Union[Sequence[int], dragon.Tensor]
The number of elements sliced from start. The number of elements sliced from start.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -614,7 +614,7 @@ def space_to_depth(input, block_size, data_format='NHWC', name=None): ...@@ -614,7 +614,7 @@ def space_to_depth(input, block_size, data_format='NHWC', name=None):
data_format : {'NCHW', 'NHWC'}, optional data_format : {'NCHW', 'NHWC'}, optional
The optional data format. The optional data format.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -665,7 +665,7 @@ def split( ...@@ -665,7 +665,7 @@ def split(
axis : int, optional, default=0 axis : int, optional, default=0
The axis to split. The axis to split.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -704,7 +704,7 @@ def squeeze(input, axis=None, name=None): ...@@ -704,7 +704,7 @@ def squeeze(input, axis=None, name=None):
axis : Union[int, Sequence[int]], optional axis : Union[int, Sequence[int]], optional
The axis to remove. The axis to remove.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -740,7 +740,7 @@ def transpose(a, perm=None, name=None): ...@@ -740,7 +740,7 @@ def transpose(a, perm=None, name=None):
perm : Sequence[Union[int, dragon.Tensor]] perm : Sequence[Union[int, dragon.Tensor]]
The output permutation. The output permutation.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -751,6 +751,78 @@ def transpose(a, perm=None, name=None): ...@@ -751,6 +751,78 @@ def transpose(a, perm=None, name=None):
return array_ops.transpose(a, perm=perm, name=name) return array_ops.transpose(a, perm=perm, name=name)
def unique(x, name=None, **kwargs):
"""Return the unique elements of input.
Unique elements and index where input mapping to are returned:
```python
x = tf.constant([1, 2, 3, 2])
y, index = tf.unique(x)
print(y) # [1, 2, 3]
print(index) # [0, 1, 2, 1]
```
Parameters
----------
x : dragon.Tensor
The input tensor.
name : str, optional
The operation name.
Returns
-------
dragon.Tensor
The output tensor.
dragon.Tensor
The inverse index tensor.
"""
if 'out_idx' in kwargs:
kwargs.pop('out_idx')
return array_ops.unique(x, return_inverse=True, name=name)
def unique_with_counts(x, name=None, **kwargs):
"""Return the unique elements of input with counts.
Unique elements, remapping index and counts are returned:
```python
x = tf.constant([1, 2, 3, 2])
y, index, counts = tf.unique_with_counts(x)
print(y) # [1, 2, 3]
print(index) # [0, 1, 2, 1]
print(counts) # [1, 2, 1]
```
Parameters
----------
x : dragon.Tensor
The input tensor.
name : str, optional
The operation name.
Returns
-------
dragon.Tensor
The output tensor.
dragon.Tensor
The inverse index tensor.
dragon.Tensor
The counts tensor.
"""
if 'out_idx' in kwargs:
kwargs.pop('out_idx')
return array_ops.unique(
x,
return_inverse=True,
return_counts=True,
name=name,
)
def zeros(shape, dtype='float32', name=None): def zeros(shape, dtype='float32', name=None):
r"""Return a tensor filled with zeros. r"""Return a tensor filled with zeros.
...@@ -767,7 +839,7 @@ def zeros(shape, dtype='float32', name=None): ...@@ -767,7 +839,7 @@ def zeros(shape, dtype='float32', name=None):
dtype : str, optional, default='float32' dtype : str, optional, default='float32'
The optional data type. The optional data type.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
""" """
return init_ops.fill(shape, value=0., dtype=dtype, name=name) return init_ops.fill(shape, value=0., dtype=dtype, name=name)
...@@ -792,7 +864,7 @@ def zeros_like(input, dtype='float32', name=None): ...@@ -792,7 +864,7 @@ def zeros_like(input, dtype='float32', name=None):
dtype : str, optional, default='float32' dtype : str, optional, default='float32'
The optional data type. The optional data type.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
""" """
return init_ops.zeros_like(input, dtype=dtype, name=name) return init_ops.zeros_like(input, dtype=dtype, name=name)
...@@ -42,7 +42,7 @@ def bitwise_and(x, y, name=None): ...@@ -42,7 +42,7 @@ def bitwise_and(x, y, name=None):
y : dragon.Tensor y : dragon.Tensor
The input2 tensor. The input2 tensor.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -74,7 +74,7 @@ def bitwise_or(x, y, name=None): ...@@ -74,7 +74,7 @@ def bitwise_or(x, y, name=None):
y : dragon.Tensor y : dragon.Tensor
The input2 tensor. The input2 tensor.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -106,7 +106,7 @@ def bitwise_xor(x, y, name=None): ...@@ -106,7 +106,7 @@ def bitwise_xor(x, y, name=None):
y : dragon.Tensor y : dragon.Tensor
The input2 tensor. The input2 tensor.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -139,7 +139,7 @@ def invert(x, name=None): ...@@ -139,7 +139,7 @@ def invert(x, name=None):
x : dragon.Tensor x : dragon.Tensor
The input tensor. The input tensor.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
......
...@@ -47,7 +47,7 @@ def clip_by_value( ...@@ -47,7 +47,7 @@ def clip_by_value(
clip_value_max : number, optional clip_value_max : number, optional
The value to :math:`\text{high}`. The value to :math:`\text{high}`.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
......
...@@ -42,7 +42,7 @@ def eye(num_rows, num_columns=None, dtype='float32', name=None): ...@@ -42,7 +42,7 @@ def eye(num_rows, num_columns=None, dtype='float32', name=None):
dtype : str, optional, default='float32' dtype : str, optional, default='float32'
The optional data type. The optional data type.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
......
...@@ -39,7 +39,7 @@ def abs(x, name=None): ...@@ -39,7 +39,7 @@ def abs(x, name=None):
x : dragon.Tensor x : dragon.Tensor
The input tensor. The input tensor.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -69,7 +69,7 @@ def add(x, y, name=None): ...@@ -69,7 +69,7 @@ def add(x, y, name=None):
y : dragon.Tensor y : dragon.Tensor
The input2 tensor. The input2 tensor.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -97,7 +97,7 @@ def add_n(inputs, name=None): ...@@ -97,7 +97,7 @@ def add_n(inputs, name=None):
inputs : Sequence[dragon.Tensor] inputs : Sequence[dragon.Tensor]
The input tensors. The input tensors.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -138,7 +138,7 @@ def argmax(input, axis=None, name=None): ...@@ -138,7 +138,7 @@ def argmax(input, axis=None, name=None):
axis : int, optional axis : int, optional
The axis to reduce. The axis to reduce.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -173,7 +173,7 @@ def argmin(input, axis=None, name=None): ...@@ -173,7 +173,7 @@ def argmin(input, axis=None, name=None):
axis : int, optional axis : int, optional
The axis to reduce. The axis to reduce.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -201,7 +201,7 @@ def cast(x, dtype, name=None): ...@@ -201,7 +201,7 @@ def cast(x, dtype, name=None):
dtype : str dtype : str
The data type to cast to. The data type to cast to.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -229,7 +229,7 @@ def ceil(x, name=None): ...@@ -229,7 +229,7 @@ def ceil(x, name=None):
x : dragon.Tensor x : dragon.Tensor
The input tensor. The input tensor.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -257,7 +257,7 @@ def cos(x, name=None): ...@@ -257,7 +257,7 @@ def cos(x, name=None):
x : dragon.Tensor x : dragon.Tensor
The input tensor. The input tensor.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -308,7 +308,7 @@ def cumsum(x, axis=0, exclusive=False, reverse=False, name=None): ...@@ -308,7 +308,7 @@ def cumsum(x, axis=0, exclusive=False, reverse=False, name=None):
reverse : bool, optional, default=False reverse : bool, optional, default=False
**True** to compute in the reverse direction. **True** to compute in the reverse direction.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -346,7 +346,7 @@ def divide(x, y, name=None): ...@@ -346,7 +346,7 @@ def divide(x, y, name=None):
y : dragon.Tensor y : dragon.Tensor
The input2 tensor. The input2 tensor.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -379,7 +379,7 @@ def equal(x, y, name=None): ...@@ -379,7 +379,7 @@ def equal(x, y, name=None):
y : dragon.Tensor y : dragon.Tensor
The input2 tensor. The input2 tensor.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -407,7 +407,7 @@ def exp(x, name=None): ...@@ -407,7 +407,7 @@ def exp(x, name=None):
x : dragon.Tensor x : dragon.Tensor
The input tensor. The input tensor.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -435,7 +435,7 @@ def floor(x, name=None): ...@@ -435,7 +435,7 @@ def floor(x, name=None):
x : dragon.Tensor x : dragon.Tensor
The input tensor. The input tensor.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -468,7 +468,7 @@ def greater(x, y, name=None): ...@@ -468,7 +468,7 @@ def greater(x, y, name=None):
y : dragon.Tensor y : dragon.Tensor
The input2 tensor. The input2 tensor.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -501,7 +501,7 @@ def greater_equal(x, y, name=None): ...@@ -501,7 +501,7 @@ def greater_equal(x, y, name=None):
y : dragon.Tensor y : dragon.Tensor
The input2 tensor. The input2 tensor.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -529,7 +529,7 @@ def is_inf(x, name=None): ...@@ -529,7 +529,7 @@ def is_inf(x, name=None):
x : dragon.Tensor x : dragon.Tensor
The input tensor. The input tensor.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -557,7 +557,7 @@ def is_nan(x, name=None): ...@@ -557,7 +557,7 @@ def is_nan(x, name=None):
x : dragon.Tensor x : dragon.Tensor
The input tensor. The input tensor.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -590,7 +590,7 @@ def less(x, y, name=None): ...@@ -590,7 +590,7 @@ def less(x, y, name=None):
y : dragon.Tensor y : dragon.Tensor
The input2 tensor. The input2 tensor.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -623,7 +623,7 @@ def less_equal(x, y, name=None): ...@@ -623,7 +623,7 @@ def less_equal(x, y, name=None):
y : dragon.Tensor y : dragon.Tensor
The input2 tensor. The input2 tensor.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -651,7 +651,7 @@ def log(x, name=None): ...@@ -651,7 +651,7 @@ def log(x, name=None):
x : dragon.Tensor x : dragon.Tensor
The input tensor. The input tensor.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -707,7 +707,7 @@ def matmul( ...@@ -707,7 +707,7 @@ def matmul(
transpose_b : bool, optional, default=False transpose_b : bool, optional, default=False
**True** to transpose :math:`b` before computing. **True** to transpose :math:`b` before computing.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -744,7 +744,7 @@ def multiply(x, y, name=None): ...@@ -744,7 +744,7 @@ def multiply(x, y, name=None):
y : dragon.Tensor y : dragon.Tensor
The input2 tensor. The input2 tensor.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -770,7 +770,7 @@ def negative(x, name=None): ...@@ -770,7 +770,7 @@ def negative(x, name=None):
x : dragon.Tensor x : dragon.Tensor
The input tensor. The input tensor.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -803,7 +803,7 @@ def not_equal(x, y, name=None): ...@@ -803,7 +803,7 @@ def not_equal(x, y, name=None):
y : dragon.Tensor y : dragon.Tensor
The input2 tensor. The input2 tensor.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -835,7 +835,7 @@ def pow(x, y, name=None): ...@@ -835,7 +835,7 @@ def pow(x, y, name=None):
y : Union[dragon.Tensor, number] y : Union[dragon.Tensor, number]
The exponent tensor. The exponent tensor.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -878,7 +878,7 @@ def range(start, limit=None, delta=1, dtype='int64', name=None): ...@@ -878,7 +878,7 @@ def range(start, limit=None, delta=1, dtype='int64', name=None):
dtype : str, optional, default='int64' dtype : str, optional, default='int64'
The optional data type. The optional data type.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -912,7 +912,7 @@ def reciprocal(x, name=None): ...@@ -912,7 +912,7 @@ def reciprocal(x, name=None):
x : dragon.Tensor x : dragon.Tensor
The input tensor. The input tensor.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -949,7 +949,7 @@ def reduce_max(input_tensor, axis=None, keepdims=False, name=None): ...@@ -949,7 +949,7 @@ def reduce_max(input_tensor, axis=None, keepdims=False, name=None):
keepdims : bool, optional, default=False keepdims : bool, optional, default=False
Keep the reduced dimensions or not. Keep the reduced dimensions or not.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -994,7 +994,7 @@ def reduce_mean(input_tensor, axis=None, keepdims=False, name=None): ...@@ -994,7 +994,7 @@ def reduce_mean(input_tensor, axis=None, keepdims=False, name=None):
keepdims : bool, optional, default=False keepdims : bool, optional, default=False
Keep the reduced dimensions or not. Keep the reduced dimensions or not.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -1036,7 +1036,7 @@ def reduce_min(input_tensor, axis=None, keepdims=False, name=None): ...@@ -1036,7 +1036,7 @@ def reduce_min(input_tensor, axis=None, keepdims=False, name=None):
keepdims : bool, optional, default=False keepdims : bool, optional, default=False
Keep the reduced dimensions or not. Keep the reduced dimensions or not.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -1081,7 +1081,7 @@ def reduce_sum(input_tensor, axis=None, keepdims=False, name=None): ...@@ -1081,7 +1081,7 @@ def reduce_sum(input_tensor, axis=None, keepdims=False, name=None):
keepdims : bool, optional, default=False keepdims : bool, optional, default=False
Keep the reduced dimensions or not. Keep the reduced dimensions or not.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -1114,7 +1114,7 @@ def round(x, name=None): ...@@ -1114,7 +1114,7 @@ def round(x, name=None):
x : dragon.Tensor x : dragon.Tensor
The input tensor. The input tensor.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -1142,7 +1142,7 @@ def rsqrt(x, name=None): ...@@ -1142,7 +1142,7 @@ def rsqrt(x, name=None):
x : dragon.Tensor x : dragon.Tensor
The input tensor. The input tensor.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -1170,7 +1170,7 @@ def sigmoid(x, name=None, **kwargs): ...@@ -1170,7 +1170,7 @@ def sigmoid(x, name=None, **kwargs):
x : dragon.Tensor x : dragon.Tensor
The input tensor. The input tensor.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -1204,7 +1204,7 @@ def sign(x, name=None): ...@@ -1204,7 +1204,7 @@ def sign(x, name=None):
x : dragon.Tensor x : dragon.Tensor
The input tensor. The input tensor.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -1232,7 +1232,7 @@ def sin(x, name=None): ...@@ -1232,7 +1232,7 @@ def sin(x, name=None):
x : dragon.Tensor x : dragon.Tensor
The input tensor. The input tensor.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -1260,7 +1260,7 @@ def sqrt(x, name=None): ...@@ -1260,7 +1260,7 @@ def sqrt(x, name=None):
x : dragon.Tensor x : dragon.Tensor
The input tensor. The input tensor.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -1288,7 +1288,7 @@ def square(x, name=None): ...@@ -1288,7 +1288,7 @@ def square(x, name=None):
x : dragon.Tensor x : dragon.Tensor
The input tensor. The input tensor.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -1320,7 +1320,7 @@ def subtract(x, y, name=None): ...@@ -1320,7 +1320,7 @@ def subtract(x, y, name=None):
y : dragon.Tensor y : dragon.Tensor
The input2 tensor. The input2 tensor.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -1349,7 +1349,7 @@ def tanh(x, name=None, **kwargs): ...@@ -1349,7 +1349,7 @@ def tanh(x, name=None, **kwargs):
x : dragon.Tensor x : dragon.Tensor
The input tensor. The input tensor.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
......
...@@ -64,7 +64,7 @@ def batch_normalization( ...@@ -64,7 +64,7 @@ def batch_normalization(
trainable : bool, optional, default=False trainable : bool, optional, default=False
The optional training flag. The optional training flag.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -119,7 +119,7 @@ def l2_normalize(x, axis=None, epsilon=1e-12, name=None): ...@@ -119,7 +119,7 @@ def l2_normalize(x, axis=None, epsilon=1e-12, name=None):
epsilon : float, optional, default=1e-5 epsilon : float, optional, default=1e-5
The value to :math:`\epsilon`. The value to :math:`\epsilon`.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -171,7 +171,7 @@ def moments(x, axes=None, keepdims=False, name=None): ...@@ -171,7 +171,7 @@ def moments(x, axes=None, keepdims=False, name=None):
keepdims : bool, optional, default=False keepdims : bool, optional, default=False
Keep the reduced dimensions or not. Keep the reduced dimensions or not.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
......
...@@ -49,7 +49,7 @@ def avg_pool( ...@@ -49,7 +49,7 @@ def avg_pool(
data_format : {'NCHW', 'NCDHW', 'NHWC', 'NDHWC'}, optional data_format : {'NCHW', 'NCDHW', 'NHWC', 'NDHWC'}, optional
The optional data format. The optional data format.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -109,7 +109,7 @@ def avg_pool2d( ...@@ -109,7 +109,7 @@ def avg_pool2d(
data_format : {'NCHW', 'NCDHW', 'NHWC', 'NDHWC'}, optional data_format : {'NCHW', 'NCDHW', 'NHWC', 'NDHWC'}, optional
The optional data format. The optional data format.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -164,7 +164,7 @@ def convolution( ...@@ -164,7 +164,7 @@ def convolution(
dilations : Sequence[int], optional dilations : Sequence[int], optional
The rate(s) of dilated kernel. The rate(s) of dilated kernel.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -230,7 +230,7 @@ def conv_transpose( ...@@ -230,7 +230,7 @@ def conv_transpose(
dilations : Sequence[int], optional dilations : Sequence[int], optional
The rate(s) of dilated kernel. The rate(s) of dilated kernel.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -299,7 +299,7 @@ def conv2d( ...@@ -299,7 +299,7 @@ def conv2d(
dilations : Sequence[int], optional dilations : Sequence[int], optional
The rate(s) of dilated kernel. The rate(s) of dilated kernel.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -339,7 +339,7 @@ def conv2d_transpose( ...@@ -339,7 +339,7 @@ def conv2d_transpose(
dilations : Sequence[int], optional dilations : Sequence[int], optional
The rate(s) of dilated kernel. The rate(s) of dilated kernel.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -377,7 +377,7 @@ def depthwise_conv2d( ...@@ -377,7 +377,7 @@ def depthwise_conv2d(
dilations : Sequence[int], optional dilations : Sequence[int], optional
The rate(s) of dilated kernel. The rate(s) of dilated kernel.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -410,7 +410,7 @@ def dropout(x, rate, name=None, **kwargs): ...@@ -410,7 +410,7 @@ def dropout(x, rate, name=None, **kwargs):
rate : Union[float, dragon.Tensor] rate : Union[float, dragon.Tensor]
The dropping probability. The dropping probability.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -441,7 +441,7 @@ def elu(features, alpha=1., name=None, **kwargs): ...@@ -441,7 +441,7 @@ def elu(features, alpha=1., name=None, **kwargs):
alpha : float, optional, default=1. alpha : float, optional, default=1.
The value to :math:`\alpha`. The value to :math:`\alpha`.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -475,7 +475,7 @@ def leaky_relu(features, alpha=0.2, name=None, **kwargs): ...@@ -475,7 +475,7 @@ def leaky_relu(features, alpha=0.2, name=None, **kwargs):
alpha : number, optional, default=0.2 alpha : number, optional, default=0.2
The value to :math:`\alpha`. The value to :math:`\alpha`.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -520,7 +520,7 @@ def local_response_normalization( ...@@ -520,7 +520,7 @@ def local_response_normalization(
data_format : {'NCHW', 'NHWC'}, optional data_format : {'NCHW', 'NHWC'}, optional
The optional data format. The optional data format.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -561,7 +561,7 @@ def log_softmax(logits, axis=-1, name=None): ...@@ -561,7 +561,7 @@ def log_softmax(logits, axis=-1, name=None):
axis : int, optional, default=1 axis : int, optional, default=1
The axis to reduce. The axis to reduce.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -595,7 +595,7 @@ def max_pool( ...@@ -595,7 +595,7 @@ def max_pool(
data_format : {'NCHW', 'NCDHW', 'NHWC', 'NDHWC'}, optional data_format : {'NCHW', 'NCDHW', 'NHWC', 'NDHWC'}, optional
The optional data format. The optional data format.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -655,7 +655,7 @@ def max_pool2d( ...@@ -655,7 +655,7 @@ def max_pool2d(
data_format : {'NCHW', 'NCDHW', 'NHWC', 'NDHWC'}, optional data_format : {'NCHW', 'NCDHW', 'NHWC', 'NDHWC'}, optional
The optional data format. The optional data format.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -693,7 +693,7 @@ def relu(features, name=None, **kwargs): ...@@ -693,7 +693,7 @@ def relu(features, name=None, **kwargs):
features : dragon.Tensor features : dragon.Tensor
The input tensor. The input tensor.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -722,7 +722,7 @@ def relu6(features, name=None, **kwargs): ...@@ -722,7 +722,7 @@ def relu6(features, name=None, **kwargs):
features : dragon.Tensor features : dragon.Tensor
The input tensor. The input tensor.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -756,7 +756,7 @@ def selu(features, name=None, **kwargs): ...@@ -756,7 +756,7 @@ def selu(features, name=None, **kwargs):
features : dragon.Tensor features : dragon.Tensor
The tensor :math:`x`. The tensor :math:`x`.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -797,7 +797,7 @@ def softmax(logits, axis=-1, name=None, **kwargs): ...@@ -797,7 +797,7 @@ def softmax(logits, axis=-1, name=None, **kwargs):
axis : int, optional, default=-1 axis : int, optional, default=-1
The axis to reduce. The axis to reduce.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -826,7 +826,7 @@ def softmax_cross_entropy_with_logits(labels, logits, name=None): ...@@ -826,7 +826,7 @@ def softmax_cross_entropy_with_logits(labels, logits, name=None):
logits : dragon.Tensor logits : dragon.Tensor
The logit tensor. The logit tensor.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -860,7 +860,7 @@ def sparse_softmax_cross_entropy_with_logits(labels, logits, name=None): ...@@ -860,7 +860,7 @@ def sparse_softmax_cross_entropy_with_logits(labels, logits, name=None):
logits : dragon.Tensor logits : dragon.Tensor
The logit tensor. The logit tensor.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -888,7 +888,7 @@ def top_k(input, k=1, sorted=True, name=None): ...@@ -888,7 +888,7 @@ def top_k(input, k=1, sorted=True, name=None):
sorted : bool, optional sorted : bool, optional
Whether to return in the sorted order. Whether to return in the sorted order.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
......
...@@ -42,7 +42,7 @@ def random_normal( ...@@ -42,7 +42,7 @@ def random_normal(
seed : int, optional seed : int, optional
The optional random seed. The optional random seed.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -79,7 +79,7 @@ def random_uniform( ...@@ -79,7 +79,7 @@ def random_uniform(
seed : int, optional seed : int, optional
The optional random seed. The optional random seed.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
...@@ -117,7 +117,7 @@ def truncated_normal( ...@@ -117,7 +117,7 @@ def truncated_normal(
seed : int, optional seed : int, optional
The optional random seed. The optional random seed.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
......
...@@ -84,7 +84,7 @@ def Input(shape, dtype='float32', name=None): ...@@ -84,7 +84,7 @@ def Input(shape, dtype='float32', name=None):
dtype : str, optional, default='float32' dtype : str, optional, default='float32'
The optional data type. The optional data type.
name : str, optional name : str, optional
A optional name for the operation. The operation name.
Returns Returns
------- -------
......
...@@ -47,18 +47,11 @@ class TestKPLRecord(unittest.TestCase): ...@@ -47,18 +47,11 @@ class TestKPLRecord(unittest.TestCase):
self.assertEqual(len(dataset), 1) self.assertEqual(len(dataset), 1)
dataset.redirect(0) dataset.redirect(0)
self.assertEqual(example, dataset.get()) self.assertEqual(example, dataset.get())
for num_chunks, shuffle in [(0, False), (0, True), (1, True)]: for shuffle, initial_fill in [(False, 1), (True, 1), (True, 1024)]:
reader = dragon.io.DataReader( reader = dragon.io.DataReader(
dataset=dragon.io.KPLRecordDataset, dataset=dragon.io.KPLRecordDataset,
source=path, num_chunks=num_chunks, shuffle=shuffle) source=path, shuffle=shuffle, initial_fill=initial_fill)
reader._init_dataset() reader._init_dataset()
for i in range(8):
self.assertEqual(reader.next_example(), example)
if reader._example_cursor >= reader._end:
if reader._num_parts > 1 or reader._shuffle:
reader.next_chunk()
else:
reader.reset()
except (OSError, PermissionError): except (OSError, PermissionError):
pass pass
......
...@@ -908,6 +908,31 @@ class TestArrayOps(OpTestCase): ...@@ -908,6 +908,31 @@ class TestArrayOps(OpTestCase):
with dragon.device('cuda'): with dragon.device('cuda'):
self.test_top_k() self.test_top_k()
def test_unique(self):
data = np.array([1, 1, 3, 5, 5, 7, 9])
entries = [(False, False),
(True, False),
(False, True),
(True, True)]
for execution in ('EAGER_MODE', 'GRAPH_MODE'):
with execution_context().mode(execution):
for return_inverse, return_counts in entries:
x = new_tensor(data)
y = dragon.unique(
x,
return_inverse=return_inverse,
return_counts=return_counts)
result = np.unique(
data,
return_inverse=return_inverse,
return_counts=return_counts)
self.assertEqual(y, result, test_symbols=False)
@unittest.skipIf(not TEST_CUDA, 'CUDA unavailable')
def test_unique_cuda(self):
with dragon.device('cuda'):
self.test_unique()
def test_where(self): def test_where(self):
entries = [((6,), (6,))] entries = [((6,), (6,))]
for execution in ('EAGER_MODE', 'GRAPH_MODE'): for execution in ('EAGER_MODE', 'GRAPH_MODE'):
......
...@@ -112,8 +112,12 @@ class TestModule(unittest.TestCase): ...@@ -112,8 +112,12 @@ class TestModule(unittest.TestCase):
pass pass
for _, _ in m.named_parameters(): for _, _ in m.named_parameters():
pass pass
for _, _ in m.named_buffers():
pass
for _ in m.parameters(): for _ in m.parameters():
pass pass
for _ in m.buffers():
pass
_, _ = repr(m), repr(m.weight) _, _ = repr(m), repr(m.weight)
def test_sequential(self): def test_sequential(self):
......
...@@ -537,14 +537,31 @@ class TestTensorOps(OpTestCase): ...@@ -537,14 +537,31 @@ class TestTensorOps(OpTestCase):
self.assertEqual(getattr(x, name)(), data.astype(dtype)) self.assertEqual(getattr(x, name)(), data.astype(dtype))
getattr(x, name + '_')() getattr(x, name + '_')()
self.assertEqual(x, data.astype(dtype)) self.assertEqual(x, data.astype(dtype))
x.type(dtype) y = x.type(dtype)
self.assertEqual(x.type(), dtype) self.assertEqual(y.type(), dtype)
def test_uniform(self): def test_uniform(self):
data = arange((2, 3)) data = arange((2, 3))
x = new_tensor(data) x = new_tensor(data)
x.uniform_() x.uniform_()
def test_unique(self):
data = np.array([1, 1, 3, 5, 5, 7, 9])
entries = [(False, False),
(True, False),
(False, True),
(True, True)]
for return_inverse, return_counts in entries:
x = new_tensor(data)
y = x.unique(return_inverse=return_inverse,
return_counts=return_counts,
sorted=True)
result = np.unique(
data,
return_inverse=return_inverse,
return_counts=return_counts)
self.assertEqual(y, result)
def test_unsqueeze(self): def test_unsqueeze(self):
entries = [1, -1] entries = [1, -1]
for axis in entries: for axis in entries:
......
...@@ -83,6 +83,25 @@ class TestTensor(unittest.TestCase): ...@@ -83,6 +83,25 @@ class TestTensor(unittest.TestCase):
self.assertEqual(x_from_dlpack.dtype, str(data.dtype)) self.assertEqual(x_from_dlpack.dtype, str(data.dtype))
self.assertLessEqual(np.abs(x_from_dlpack.numpy() - data).max(), 1e-5) self.assertLessEqual(np.abs(x_from_dlpack.numpy() - data).max(), 1e-5)
def test_internal_converter(self):
data = np.array([0., 1., 2.], 'float32')
x = torch.tensor(data)
y = x.to(torch.int32)
self.assertEqual(y.dtype, 'int32')
y = x.to(torch.device('cpu'))
self.assertEqual(y.device, torch.device('cpu'))
y = x.to(torch.FloatTensor(1))
self.assertEqual(y.dtype, 'float32')
self.assertEqual(y.device, torch.device('cpu'))
try:
_ = x.to(data)
except ValueError:
pass
try:
_ = x.to(torch.device('gpu'))
except ValueError:
pass
def test_numpy_converter(self): def test_numpy_converter(self):
data = np.array([0., 1., 2.], 'float32') data = np.array([0., 1., 2.], 'float32')
x = torch.from_numpy(data) x = torch.from_numpy(data)
......
...@@ -73,6 +73,7 @@ from dragon.vm.torch.core.ops.array.functional import squeeze ...@@ -73,6 +73,7 @@ from dragon.vm.torch.core.ops.array.functional import squeeze
from dragon.vm.torch.core.ops.array.functional import stack from dragon.vm.torch.core.ops.array.functional import stack
from dragon.vm.torch.core.ops.array.functional import sum from dragon.vm.torch.core.ops.array.functional import sum
from dragon.vm.torch.core.ops.array.functional import topk from dragon.vm.torch.core.ops.array.functional import topk
from dragon.vm.torch.core.ops.array.functional import unique
from dragon.vm.torch.core.ops.array.functional import unsqueeze from dragon.vm.torch.core.ops.array.functional import unsqueeze
from dragon.vm.torch.core.ops.array.functional import where from dragon.vm.torch.core.ops.array.functional import where
from dragon.vm.torch.core.ops.init.functional import arange from dragon.vm.torch.core.ops.init.functional import arange
......
...@@ -98,6 +98,23 @@ class Module(object): ...@@ -98,6 +98,23 @@ class Module(object):
fn(self) fn(self)
return self return self
def buffers(self, recurse=True):
"""Return an iterator over all buffers.
Parameters
----------
recurse : bool, optional, default=True
Yield parameters recursively or not.
Returns
-------
Iterator
The iterator of buffer.
"""
for name, buffer in self.named_buffers(recurse=recurse):
yield buffer
def children(self): def children(self):
"""Return an iterator over immediate modules. """Return an iterator over immediate modules.
...@@ -294,6 +311,28 @@ class Module(object): ...@@ -294,6 +311,28 @@ class Module(object):
for name, module in self.named_modules(): for name, module in self.named_modules():
yield module yield module
def named_buffers(self, prefix='', recurse=True):
"""Return an iterator over all buffers.
Parameters
----------
prefix : str, optional, default=''
The prefix added to the name.
recurse : bool, optional, default=True
Yield buffers recursively or not.
Returns
-------
Iterator
The iterator of (name, buffer).
"""
gen = self._named_members(
lambda module: module._buffers.items(),
prefix=prefix, recurse=recurse)
for name, buffer in gen:
yield name, buffer
def named_children(self): def named_children(self):
"""Return an iterator over immediate modules, yield as *(name, module)*. """Return an iterator over immediate modules, yield as *(name, module)*.
...@@ -320,7 +359,7 @@ class Module(object): ...@@ -320,7 +359,7 @@ class Module(object):
Returns Returns
------- -------
Iterator Iterator
The iterator of module. The iterator of (name, module).
""" """
if memo is None: if memo is None:
...@@ -335,43 +374,43 @@ class Module(object): ...@@ -335,43 +374,43 @@ class Module(object):
for m in module.named_modules(memo, submodule_prefix): for m in module.named_modules(memo, submodule_prefix):
yield m yield m
def named_parameters(self, memo=None, prefix=''): def named_parameters(self, prefix='', recurse=True):
"""Return an iterator over all parameters. """Return an iterator over all parameters.
Parameters Parameters
---------- ----------
memo : Set, optional
The optional set to collect parameters.
prefix : str, optional, default='' prefix : str, optional, default=''
The prefix added to the name. The prefix added to the name.
recurse : bool, optional, default=True
Yield parameters recursively or not.
Returns Returns
------- -------
Iterator Iterator
The iterator of parameter. The iterator of (name, param).
""" """
if memo is None: gen = self._named_members(
memo = set() lambda module: module._parameters.items(),
for name, p in self._parameters.items(): prefix=prefix, recurse=recurse)
if p is not None and p not in memo: for name, param in gen:
memo.add(p) yield name, param
yield prefix + ('.' if prefix else '') + name, p
for mname, module in self.named_children(): def parameters(self, recurse=True):
submodule_prefix = prefix + ('.' if prefix else '') + mname
for name, p in module.named_parameters(memo, submodule_prefix):
yield name, p
def parameters(self):
"""Return an iterator over all parameters. """Return an iterator over all parameters.
Parameters
----------
recurse : bool, optional, default=True
Yield parameters recursively or not.
Returns Returns
------- -------
Iterator Iterator
The iterator of parameter. The iterator of param.
""" """
for name, param in self.named_parameters(): for name, param in self.named_parameters(recurse=recurse):
yield param yield param
def register_buffer(self, name, tensor): def register_buffer(self, name, tensor):
...@@ -535,8 +574,22 @@ class Module(object): ...@@ -535,8 +574,22 @@ class Module(object):
return self return self
def _get_name(self): def _get_name(self):
"""Return the class name."""
return self.__class__.__name__ return self.__class__.__name__
def _named_members(self, getter, prefix='', recurse=True):
"""Return the named members."""
memo = set()
modules = self.named_modules(prefix=prefix) if recurse else [(prefix, self)]
for module_prefix, module in modules:
members = getter(module)
for k, v in members:
if v is None or v in memo:
continue
memo.add(v)
name = module_prefix + ('.' if module_prefix else '') + k
yield name, v
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
"""Run the forward pipeline.""" """Run the forward pipeline."""
outputs = self.forward(*args, **kwargs) outputs = self.forward(*args, **kwargs)
......
...@@ -546,6 +546,27 @@ class TopK(function.Function): ...@@ -546,6 +546,27 @@ class TopK(function.Function):
return self.dispatch([input], outputs, no_grad=True) return self.dispatch([input], outputs, no_grad=True)
class Unique(function.Function):
def __init__(self, key, dev, **kwargs):
super(Unique, self).__init__(key, dev, **kwargs)
self.return_inverse = kwargs.get('return_inverse', False)
self.return_counts = kwargs.get('return_counts', False)
self.num_outputs = 1 + self.return_inverse + self.return_counts
def attributes(self):
return {
'op_type': 'Unique',
'arguments': {
'return_inverse': self.return_inverse,
'return_counts': self.return_counts,
}
}
def forward(self, input):
outputs = [self.alloc() for _ in range(self.num_outputs)]
return self.dispatch([input], outputs)
class UnSqueeze(function.Function): class UnSqueeze(function.Function):
def __init__(self, key, dev, **kwargs): def __init__(self, key, dev, **kwargs):
super(UnSqueeze, self).__init__(key, dev, **kwargs) super(UnSqueeze, self).__init__(key, dev, **kwargs)
......
...@@ -982,6 +982,56 @@ def topk(input, k, dim=None, largest=True, sorted=True, out=None): ...@@ -982,6 +982,56 @@ def topk(input, k, dim=None, largest=True, sorted=True, out=None):
).apply(input, out if out else (None, None)) ).apply(input, out if out else (None, None))
def unique(input, return_inverse=False, return_counts=False, **kwargs):
"""Return the unique elements of input.
If ``return_inverse``, return the extra index where input mapping to:
```python
x = torch.tensor([1, 2, 3, 2])
y, index = torch.unique(x, return_inverse=True)
print(y) # [1, 2, 3]
print(index) # [0, 1, 2, 1]
```
If ``return_counts``, return the extra counts of output:
```python
x = torch.tensor([1, 2, 3, 2])
y, counts = torch.unique(x, return_counts=True)
print(y) # [1, 2, 3]
print(counts) # [1, 2, 1]
```
Parameters
----------
input : dragon.vm.torch.Tensor
The input tensor.
return_inverse : bool, optional, default=False
Return the inverse index or not.
return_counts : bool, optional, default=False
Return the counts or not.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
dragon.vm.torch.Tensor, optional
The inverse index tensor.
dragon.vm.torch.Tensor, optional
The counts tensor.
"""
if 'sorted' in kwargs:
kwargs.pop('sorted')
return _functions.Unique \
.instantiate(
input.device,
return_inverse=return_inverse,
return_counts=return_counts,
).apply(input)
def unsqueeze(input, dim, out=None): def unsqueeze(input, dim, out=None):
"""Expand the dimensions of input with size 1. """Expand the dimensions of input with size 1.
......
...@@ -1705,7 +1705,7 @@ def _type(self, dtype=None): ...@@ -1705,7 +1705,7 @@ def _type(self, dtype=None):
""" """
if dtype is None: if dtype is None:
return self.dtype return self.dtype
return array_funcs.cast(self, dtype, True) return array_funcs.cast(self, dtype, False)
def uniform_(self, low=0, high=1): def uniform_(self, low=0, high=1):
...@@ -1729,6 +1729,33 @@ def uniform_(self, low=0, high=1): ...@@ -1729,6 +1729,33 @@ def uniform_(self, low=0, high=1):
return init_funcs.uniform_fill(self, low, high) return init_funcs.uniform_fill(self, low, high)
def unique(self, return_inverse=False, return_counts=False, **kwargs):
"""Return the unique elements.
Parameters
----------
return_inverse : bool, optional, default=False
Return the inverse index or not.
return_counts : bool, optional, default=False
Return the counts or not.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
dragon.vm.torch.Tensor, optional
The inverse index tensor.
dragon.vm.torch.Tensor, optional
The counts tensor.
See Also
--------
`torch.unique(...)`_
"""
return array_funcs.unique(self, return_inverse, return_counts, **kwargs)
def unsqueeze(self, dim): def unsqueeze(self, dim):
"""Return a tensor with dimensions of size 1 inserted. """Return a tensor with dimensions of size 1 inserted.
...@@ -1922,6 +1949,7 @@ Tensor.sub_ = sub_ ...@@ -1922,6 +1949,7 @@ Tensor.sub_ = sub_
Tensor.topk = topk Tensor.topk = topk
Tensor.type = _type Tensor.type = _type
Tensor.uniform_ = uniform_ Tensor.uniform_ = uniform_
Tensor.unique = unique
Tensor.unsqueeze = unsqueeze Tensor.unsqueeze = unsqueeze
Tensor.unsqueeze_ = unsqueeze_ Tensor.unsqueeze_ = unsqueeze_
Tensor.where = where Tensor.where = where
......
...@@ -564,12 +564,10 @@ class Tensor(object): ...@@ -564,12 +564,10 @@ class Tensor(object):
src._impl, src._impl,
proto_util.get_device_option( proto_util.get_device_option(
self._device.type, self._device.type,
self._device.index self._device.index).SerializeToString(),
).SerializeToString(),
proto_util.get_device_option( proto_util.get_device_option(
src._device.type, src._device.type,
src._device.index src._device.index).SerializeToString(),
).SerializeToString(),
) )
return self return self
...@@ -1803,6 +1801,54 @@ class Tensor(object): ...@@ -1803,6 +1801,54 @@ class Tensor(object):
""" """
def to(self, *args, **kwargs):
"""Convert to the specified data type or device.
The arguments could be ``torch.dtype`` or ``torch.device``:
```python
x = torch.FloatTensor(1)
x.to(torch.int32) # Equivalent to ``x.int()``
x.to(torch.device('cpu')) # Equivalent to ``x.cpu()``
x.to(torch.device('cuda'), torch.float32) # Convert both
```
Or ``torch.Tensor`` to provide both ``dtype`` and ``device``:
```python
a, b = torch.tensor(1.), torch.tensor(2)
print(a.to(b)) # 1
```
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
"""
dtype = kwargs.get('dtype', None)
device = kwargs.get('device', None)
for arg in args:
if isinstance(arg, cpp.dtype):
dtype = arg
elif isinstance(arg, cpp.device):
device = arg
elif isinstance(arg, Tensor):
dtype, device = arg.dtype, arg.device
break
else:
raise ValueError('Unsupported conversion target.')
if device is not None:
if device.type == 'cpu':
self.cpu()
elif device.type == 'cuda':
self.cuda(device.index)
else:
raise ValueError('Unsupported device type: ' + device.type)
if dtype is not None:
return self.type(dtype)
return self
def topk(self, k, dim=None, largest=True, sorted=True): def topk(self, k, dim=None, largest=True, sorted=True):
"""Return the top-K largest or smallest elements. """Return the top-K largest or smallest elements.
...@@ -1829,19 +1875,17 @@ class Tensor(object): ...@@ -1829,19 +1875,17 @@ class Tensor(object):
""" """
def type(self, dtype=None): def type(self, dtype=None):
"""Return the data type. """Return the data type or copied tensor with specified type.
If ``dtype`` is not **None**, cast ``self`` to the new tensor.
Parameters Parameters
---------- ----------
dtype : str, optional dtype : str, optional
The specified type. The specified type to convert to.
Returns Returns
------- -------
Union[str, dragon.vm.torch.Tensor] Union[str, dragon.vm.torch.Tensor]
The data type or new tensor. The data type or copied tensor.
""" """
...@@ -1864,6 +1908,31 @@ class Tensor(object): ...@@ -1864,6 +1908,31 @@ class Tensor(object):
""" """
def unique(self, return_inverse=False, return_counts=False, **kwargs):
"""Return the unique elements.
Parameters
----------
return_inverse : bool, optional, default=False
Return the inverse index or not.
return_counts : bool, optional, default=False
Return the counts or not.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
dragon.vm.torch.Tensor, optional
The inverse index tensor.
dragon.vm.torch.Tensor, optional
The counts tensor.
See Also
--------
`torch.unique(...)`_
"""
def unsqueeze(self, dim): def unsqueeze(self, dim):
"""Return a tensor with dimensions of size 1 inserted. """Return a tensor with dimensions of size 1 inserted.
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!