Commit 1dd8aeef by Ting PAN

Add Unique Operator

Summary:
This commit adds the unique op for dragon, torch, tensorflow and onnx.
Besides, fixes the bug that gets the wrong workspace size in cached cudnn convolution.
1 parent 80267d8f
Showing with 1033 additions and 202 deletions
......@@ -302,7 +302,7 @@ class RandomBBoxCrop(object):
thresholds=(0.0, 0.1, 0.3, 0.5, 0.7, 0.9),
allow_no_crop=True,
num_attempts=10,
bbox_layout=None,
bbox_layout='xyXY',
**kwargs
):
"""Create a ``RandomBBoxCrop`` operator.
......@@ -316,10 +316,10 @@ class RandomBBoxCrop(object):
thresholds : Sequence[float], optional
The minimum IoU(s) to satisfy.
allow_no_crop : bool, optional, default=True
**True** to include the no-cropping as a option.
**True** to include the no-cropping as an option.
num_attempts : int, optional, default=10
The max number of sampling trails.
bbox_layout : str, optional
bbox_layout : str, optional, default='xyXY'
The optional bbox layout.
Returns
......@@ -437,7 +437,7 @@ class Resize(object):
resize_shorter=None,
resize_longer=None,
max_size=None,
interp_type='LINEAR',
interp_type=None,
mag_filter=None,
min_filter=None,
**kwargs
......
......@@ -153,6 +153,9 @@ dragon
`transpose(...) <dragon/transpose.html>`_
: Permute the dimensions of input.
`unique(...) <dragon/unique.html>`_
: Return the unique elements of input.
`where(...) <dragon/where.html>`_
: Select the elements from two branches under the condition.
......@@ -212,6 +215,7 @@ dragon
dragon/Tensor
dragon/tile
dragon/transpose
dragon/unique
dragon/where
dragon/Workspace
dragon/zeros
......
unique
======
.. autofunction:: dragon.unique
.. raw:: html
<style>
h1:before {
content: "dragon.";
color: #103d3e;
}
</style>
......@@ -161,7 +161,7 @@ Name Supported Reference
`Tile`_ |v| :func:`dragon.tile`
`TopK`_ |v| :func:`dragon.math.top_k`
`Transpose`_ |v| :func:`dragon.transpose`
`Unique`_
`Unique`_ |v| :func:`dragon.unique`
`Unsqueeze`_ |v| :func:`dragon.unsqueeze`
`Upsample`_ |v| :func:`dragon.vision.resize`
`Where`_ |v| :func:`dragon.where`
......
......@@ -93,6 +93,12 @@ vm.tensorflow
`transpose(...) <tensorflow/transpose.html>`_
: Permute the dimensions of input.
`unique(...) <tensorflow/unique.html>`_
: Return the unique elements of input.
`unique_with_counts(...) <tensorflow/unique_with_counts.html>`_
: Return the unique elements of input with counts.
`zeros(...) <tensorflow/zeros.html>`_
: Return a tensor filled with zeros.
......@@ -130,6 +136,8 @@ vm.tensorflow
tensorflow/TensorShape
tensorflow/TensorSpec
tensorflow/transpose
tensorflow/unique
tensorflow/unique_with_counts
tensorflow/zeros
tensorflow/zeros_like
......
unique
======
.. autofunction:: dragon.vm.tensorflow.unique
.. raw:: html
<style>
h1:before {
content: "tf.";
color: #103d3e;
}
</style>
unique_with_counts
==================
.. autofunction:: dragon.vm.tensorflow.unique_with_counts
.. raw:: html
<style>
h1:before {
content: "tf.";
color: #103d3e;
}
</style>
......@@ -235,6 +235,9 @@ vm.torch
`topk(...) <torch/topk.html>`_
: Return the top-K largest or smallest elements along the given dimension.
`unique(...) <torch/unique.html>`_
: Return the unique elements of input.
`unsqueeze(...) <torch/unsqueeze.html>`_
: Expand the dimensions of input with size 1.
......@@ -325,6 +328,7 @@ vm.torch
torch/Tensor_
torch/tensor
torch/topk
torch/unique
torch/unsqueeze
torch/where
torch/zeros_like
......
......@@ -425,6 +425,10 @@ sub\_
#####
.. automethod:: dragon.vm.torch.Tensor.sub_
to
##
.. automethod:: dragon.vm.torch.Tensor.to
topk
####
.. automethod:: dragon.vm.torch.Tensor.topk
......@@ -437,6 +441,10 @@ uniform\_
#########
.. automethod:: dragon.vm.torch.Tensor.uniform_
unique
######
.. automethod:: dragon.vm.torch.Tensor.unique
unsqueeze
#########
.. automethod:: dragon.vm.torch.Tensor.unsqueeze
......@@ -500,6 +508,7 @@ zero\_
.. _torch.sub(...): sub.html
.. _torch.sum(...): sum.html
.. _torch.topk(...): topk.html
.. _torch.unique(...): unique.html
.. _torch.unsqueeze(...): unsqueeze.html
.. _torch.where(...): where.html
......
......@@ -18,6 +18,10 @@ apply
#####
.. automethod:: dragon.vm.torch.nn.Module.apply
buffers
#######
.. automethod:: dragon.vm.torch.nn.Module.buffers
children
########
.. automethod:: dragon.vm.torch.nn.Module.children
......@@ -58,6 +62,10 @@ modules
#######
.. automethod:: dragon.vm.torch.nn.Module.modules
named_buffers
#############
.. automethod:: dragon.vm.torch.nn.Module.named_buffers
named_children
##############
.. automethod:: dragon.vm.torch.nn.Module.named_children
......
unique
======
.. autofunction:: dragon.vm.torch.unique
.. raw:: html
<style>
h1:before {
content: "torch.";
color: #103d3e;
}
</style>
......@@ -16,6 +16,7 @@ void GraphOptimizer::BuildDAG(const GraphDef& graph) {
reference_count_[in] += 1;
}
for (const auto& out : op.output()) {
if (out.empty()) continue;
if (op.input().empty()) {
nodes_[""].childs.push_back(out);
nodes_[out].parents.push_back("");
......
#ifdef USE_CUDA
#include "dragon/core/context_cuda.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
......
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
namespace kernel {
namespace {
template <typename T>
void _Unique(
const int dim,
const T* x,
T* y,
int64_t* inverse_index,
int64_t* counts,
int* num) {
vec32_t order(dim);
std::iota(order.begin(), order.end(), 0);
std::sort(order.begin(), order.end(), [x](const int i, const int j) {
return x[i] < x[j];
});
int n = dim, m;
for (int i = 1; i < dim; ++i) {
n -= x[order[i]] == x[order[i - 1]];
}
n = 0;
T prev = -1;
for (int i = 0; i < dim; ++i) {
if (i == 0 || prev != x[order[i]]) {
if (counts && i > 0) counts[n - 1] = m;
prev = y[n++] = x[order[i]];
m = 1;
} else {
m += 1;
}
if (inverse_index) {
inverse_index[order[i]] = n - 1;
}
}
num[0] = n;
if (counts) counts[n - 1] = m;
}
} // namespace
template <>
void Unique<float16, CPUContext>(
const int dim,
const float16* x,
float16* y,
int64_t* inverse_index,
int64_t* counts,
int* num,
CPUContext* ctx) {
CPU_FP16_NOT_SUPPORTED;
}
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \
void Unique<T, CPUContext>( \
const int dim, \
const T* x, \
T* y, \
int64_t* inverse_index, \
int64_t* counts, \
int* num, \
CPUContext* ctx) { \
_Unique(dim, x, y, inverse_index, counts, num); \
}
DEFINE_KERNEL_LAUNCHER(int8_t);
DEFINE_KERNEL_LAUNCHER(uint8_t);
DEFINE_KERNEL_LAUNCHER(int);
DEFINE_KERNEL_LAUNCHER(int64_t);
DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double);
} // namespace kernel
} // namespace dragon
#ifdef USE_CUDA
#include "dragon/core/context_cuda.h"
#include "dragon/utils/device/common_thrust.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
namespace kernel {
namespace {
__global__ void _RemapInverse(
const int dim,
const int num,
thrust::device_ptr<int> order1,
thrust::device_ptr<int> order2,
int64_t* inverse_index) {
const int yi = blockDim.x * blockIdx.x + threadIdx.x;
if (yi >= num) return;
int xi = order2[yi];
inverse_index[order1[xi]] = yi;
for (xi++; xi < dim && (yi == num - 1 || xi != order2[yi + 1]); xi++) {
inverse_index[order1[xi]] = yi;
}
}
__global__ void _ComputeCounts(
const int dim,
const int num,
thrust::device_ptr<int> order2,
int64_t* counts) {
const int yi = blockDim.x * blockIdx.x + threadIdx.x;
if (yi >= num) return;
counts[yi] = (yi == num - 1 ? dim : order2[yi + 1]) - order2[yi];
}
} // namespace
/* ------------------- Launcher Separator ------------------- */
template <>
void Unique<float16, CUDAContext>(
const int dim,
const float16* x,
float16* y,
int64_t* inverse_index,
int64_t* counts,
int* num,
CUDAContext* ctx) {
LOG(FATAL) << "FP16 is unsupported for CUDAContext.";
}
#define DEFINE_KERNEL_LAUNCHER(T) \
template <> \
void Unique<T, CUDAContext>( \
const int dim, \
const T* x, \
T* y, \
int64_t* inverse_index, \
int64_t* counts, \
int* num, \
CUDAContext* ctx) { \
math::Copy(dim, x, y, ctx); \
auto policy = thrust::cuda::par.on(ctx->cuda_stream()); \
thrust::device_vector<int> order1(dim), order2(dim); \
thrust::sequence(policy, order1.begin(), order1.end()); \
thrust::sequence(policy, order2.begin(), order2.end()); \
thrust::sort_by_key(policy, y, y + dim, order1.begin()); \
auto last = thrust::unique_by_key(policy, y, y + dim, order2.begin()); \
int n = num[0] = last.first - y; \
if (inverse_index) { \
_RemapInverse<<<CUDA_BLOCKS(n), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
dim, n, order1.data(), order2.data(), inverse_index); \
} \
if (counts) { \
_ComputeCounts<<<CUDA_BLOCKS(n), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
dim, n, order2.data(), counts); \
} \
}
DEFINE_KERNEL_LAUNCHER(int8_t);
DEFINE_KERNEL_LAUNCHER(uint8_t);
DEFINE_KERNEL_LAUNCHER(int);
DEFINE_KERNEL_LAUNCHER(int64_t);
DEFINE_KERNEL_LAUNCHER(float);
DEFINE_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER
} // namespace kernel
} // namespace dragon
#endif // USE_CUDA
#ifdef USE_CUDA
#include "dragon/core/context_cuda.h"
#include "dragon/utils/device/common_cub.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
......
......@@ -2,6 +2,7 @@
#include "dragon/core/context_cuda.h"
#include "dragon/utils/device/common_cub.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
......
#include "dragon/operators/array/unique_op.h"
#include "dragon/core/workspace.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
template <class Context>
template <typename T>
void UniqueOp<Context>::DoRunWithType() {
auto &X = Input(0), *Y = Output(0);
int num_values;
Tensor *X_index = nullptr, *Y_counts = nullptr;
int64_t *inverse_index = nullptr, *counts = nullptr;
if (OutputSize() == 2) {
if (return_inverse_) {
X_index = Output(1)->ReshapeLike(X);
inverse_index = X_index->template mutable_data<int64_t, Context>();
} else if (return_counts_) {
Y_counts = Output(1)->ReshapeLike(X);
counts = Y_counts->template mutable_data<int64_t, Context>();
}
} else if (OutputSize() == 3) {
X_index = Output(1)->ReshapeLike(X);
Y_counts = Output(2)->ReshapeLike(X);
inverse_index = X_index->template mutable_data<int64_t, Context>();
counts = Y_counts->template mutable_data<int64_t, Context>();
}
kernel::Unique(
X.count(),
X.template data<T, Context>(),
Y->ReshapeLike(X)->template mutable_data<T, Context>(),
inverse_index,
counts,
&num_values,
ctx());
// Shrink to match the number of values
Y->Reshape({num_values});
if (Y_counts) Y_counts->Reshape({num_values});
}
template <class Context>
void UniqueOp<Context>::RunOnDevice() {
DispatchHelper<NumericalTensorTypes>::Call(this, Input(0));
}
DEPLOY_CPU_OPERATOR(Unique);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(Unique);
#endif
OPERATOR_SCHEMA(Unique)
/* X */
.NumInputs(1)
/* Y, InverseIndex, Counts */
.NumOutputs(1, 3);
NO_GRADIENT(Unique);
} // namespace dragon
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_ARRAY_UNIQUE_OP_H_
#define DRAGON_OPERATORS_ARRAY_UNIQUE_OP_H_
#include "dragon/core/operator.h"
namespace dragon {
template <class Context>
class UniqueOp final : public Operator<Context> {
public:
UniqueOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
return_inverse_(OP_SINGLE_ARG(int64_t, "return_inverse", 0)),
return_counts_(OP_SINGLE_ARG(int64_t, "return_counts", 0)) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
template <typename T>
void DoRunWithType();
protected:
int64_t return_inverse_, return_counts_;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_ARRAY_UNIQUE_OP_H_
......@@ -479,7 +479,6 @@ void CuDNNConv2dGradientOp<Context>::DoRunWithType() {
// Determine the workspace size for selected algorithm
if (cudnn_ws_nbytes_ == SIZE_MAX) {
size_t bwd_filter_size = 0, bwd_data_size = 0;
if (dW->has_name()) {
CUDNN_CHECK(cudnnGetConvolutionBackwardFilterWorkspaceSize(
ctx()->cudnn_handle(),
output_desc_,
......@@ -488,8 +487,6 @@ void CuDNNConv2dGradientOp<Context>::DoRunWithType() {
filter_desc_,
bwd_filter_algo_,
&bwd_filter_size));
}
if (dX->has_name()) {
CUDNN_CHECK(cudnnGetConvolutionBackwardDataWorkspaceSize(
ctx()->cudnn_handle(),
filter_desc_,
......@@ -498,7 +495,6 @@ void CuDNNConv2dGradientOp<Context>::DoRunWithType() {
output_desc_,
bwd_data_algo_,
&bwd_data_size));
}
cudnn_ws_nbytes_ = std::max(bwd_filter_size, bwd_data_size);
}
......
......@@ -474,7 +474,6 @@ void CuDNNConvTranspose2dGradientOp<Context>::DoRunWithType() {
// Determine the workspace size for selected algorithm
if (cudnn_ws_nbytes_ == SIZE_MAX) {
size_t bwd_filter_size = 0, bwd_data_size = 0;
if (dW->has_name()) {
CUDNN_CHECK(cudnnGetConvolutionBackwardFilterWorkspaceSize(
ctx()->cudnn_handle(),
input_desc_,
......@@ -483,8 +482,6 @@ void CuDNNConvTranspose2dGradientOp<Context>::DoRunWithType() {
filter_desc_,
bwd_filter_algo_,
&bwd_filter_size));
}
if (dX->has_name()) {
CUDNN_CHECK(cudnnGetConvolutionForwardWorkspaceSize(
ctx()->cudnn_handle(),
input_desc_,
......@@ -493,7 +490,6 @@ void CuDNNConvTranspose2dGradientOp<Context>::DoRunWithType() {
output_desc_,
bwd_data_algo_,
&bwd_data_size));
}
cudnn_ws_nbytes_ = std::max(bwd_filter_size, bwd_data_size);
}
......@@ -514,7 +510,7 @@ void CuDNNConvTranspose2dGradientOp<Context>::DoRunWithType() {
db));
}
if (Output(1)->has_name()) {
if (dW->has_name()) {
x = X.template data<T, Context>();
dw = dW->template mutable_data<T, Context>();
for (int g = 0; g < cudnn_group_; g++) {
......@@ -535,7 +531,7 @@ void CuDNNConvTranspose2dGradientOp<Context>::DoRunWithType() {
}
}
if (Output(0)->has_name()) {
if (dX->has_name()) {
auto* w = W.template data<T, Context>();
auto* dx = dX->template mutable_data<T, Context>();
for (int g = 0; g < cudnn_group_; g++) {
......
......@@ -76,6 +76,7 @@ from dragon.core.ops.array_ops import squeeze
from dragon.core.ops.array_ops import stack
from dragon.core.ops.array_ops import tile
from dragon.core.ops.array_ops import transpose
from dragon.core.ops.array_ops import unique
from dragon.core.ops.array_ops import where
from dragon.core.ops.control_flow_ops import assign
from dragon.core.ops.control_flow_ops import copy
......
......@@ -1045,3 +1045,28 @@ def unchanged_spec(args, inputs, outputs):
except TypeError:
pass
return outputs
@register('Unique')
def unique_spec(args, inputs, outputs):
return_inverse = args['return_inverse']
return_counts = args['return_counts']
outputs[0].dtype = inputs[0].dtype
for i in range(1, len(outputs)):
outputs[i].dtype = 'int64'
outputs[0].shape = (None,)
if len(outputs) == 2:
if return_inverse:
try:
outputs[1].shape = inputs[0].shape[:]
except TypeError:
pass
elif return_counts:
outputs[1].shape = (None,)
elif len(outputs) == 3:
try:
outputs[1].shape = inputs[0].shape[:]
except TypeError:
pass
outputs[2].shape = (None,)
return outputs
......@@ -962,8 +962,7 @@ def pad(inputs, pads, mode='constant', value=0, **kwargs):
if len(pad) != 2:
raise ValueError(
'The tuple length of <pads> '
'should be 2, got {}.'.format(len(pad))
)
'should be 2, got {}.'.format(len(pad)))
pads_begin.append(pad[0])
pads_end.append(pad[1])
args['pads'] = pads_begin + pads_end
......@@ -1562,6 +1561,59 @@ def top_k(inputs, k=1, axis=None, largest=True, sorted=True, **kwargs):
return op_lib.blend(**args)
@OpSchema.num_inputs(1)
def unique(inputs, return_inverse=False, return_counts=False, **kwargs):
"""Return the unique elements of input.
If ``return_inverse``, return the extra index where input mapping to:
```python
x = dragon.constant([1, 2, 3, 2])
y, index = dragon.unique(x, return_inverse=True)
print(y) # [1, 2, 3]
print(index) # [0, 1, 2, 1]
```
If ``return_counts``, return the extra counts of output:
```python
x = dragon.constant([1, 2, 3, 2])
y, counts = dragon.unique(x, return_counts=True)
print(y) # [1, 2, 3]
print(counts) # [1, 2, 1]
```
Parameters
----------
inputs : dragon.Tensor
The input tensor.
return_inverse : bool, optional, default=False
Return the inverse index or not.
return_counts : bool, optional, default=False
Return the counts or not.
Returns
-------
dragon.Tensor
The output tensor.
dragon.Tensor, optional
The inverse index tensor.
dragon.Tensor, optional
The counts tensor.
"""
args = parse_args(locals())
op_lib = array_ops_lib.Unique
if context.executing_eagerly():
return op_lib.instantiate(
return_inverse=return_inverse,
return_counts=return_counts,
).apply([inputs])
else:
num_outputs = 1 + return_inverse + return_counts
return op_lib.blend(num_outputs=num_outputs, **args)
@OpSchema.num_inputs(1, 3)
def where(inputs, **kwargs):
r"""Select the elements from two branches under the condition.
......
......@@ -665,6 +665,27 @@ class TopK(Operator):
return self.dispatch(inputs, [self.alloc(), self.alloc()], no_grad=True)
class Unique(Operator):
def __init__(self, key, dev, **kwargs):
super(Unique, self).__init__(key, dev, **kwargs)
self.return_inverse = kwargs.get('return_inverse', False)
self.return_counts = kwargs.get('return_counts', False)
self.num_outputs = 1 + self.return_inverse + self.return_counts
def attributes(self):
return {
'op_type': 'Unique',
'arguments': {
'return_inverse': self.return_inverse,
'return_counts': self.return_counts,
}
}
def forward(self, inputs):
outputs = [self.alloc() for _ in range(self.num_outputs)]
return self.dispatch(inputs, outputs)
class Where(Operator):
def __init__(self, key, dev, **kwargs):
super(Where, self).__init__(key, dev, **kwargs)
......
......@@ -322,13 +322,11 @@ def _get_cuda_arch_flags(cflags=None):
for flag in cflags:
if 'arch' in flag:
return []
supported_arches = [
'3.5', '3.7',
supported_arches = ['3.5', '3.7',
'5.0', '5.2', '5.3',
'6.0', '6.1', '6.2',
'7.0', '7.2', '7.5',
'8.0',
]
'8.0']
valid_arch_strings = supported_arches + [s + "+PTX" for s in supported_arches]
capability = _cuda.get_device_capability()
arch_list = ['{}.{}'.format(capability[0], capability[1])]
......
......@@ -89,7 +89,7 @@ class DataIterator(object):
cutout_size : int, optional, default=0
The square size for the cutout algorithm.
mirror : bool, optional, default=False
Whether to mirror(flip horizontally) images.
Whether to apply the mirror (flip horizontally).
random_scales : Sequence[float], optional, default=(0.08, 1.)
The range of scales to sample a crop randomly.
random_aspect_ratios : Sequence[float], optional, default=(0.75, 1.33)
......@@ -133,7 +133,7 @@ class DataIterator(object):
if kwargs.get('random_crop_size', 0) > 0:
self._num_transformers += 1
# Add a transformer for distortion.
if kwargs.get('augment_color', False):
if kwargs.get('distort_color', False):
self._num_transformers += 1
# Initialize queues.
......
......@@ -13,19 +13,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import io
import math
import multiprocessing
import numpy
try:
import cv2
except ImportError:
cv2 = None
try:
import PIL.Image
import PIL.ImageEnhance
except ImportError:
PIL = None
import PIL.Image
import PIL.ImageEnhance
from dragon.core.framework import config
......@@ -51,7 +45,7 @@ class DataTransformer(multiprocessing.Process):
cutout_size : int, optional, default=0
The square size for the cutout algorithm.
mirror : bool, optional, default=False
Whether to mirror(flip horizontally) images.
Whether to apply the mirror (flip horizontally).
random_scales : Sequence[float], optional, default=(0.08, 1.)
The range of scales to sample a crop randomly.
random_aspect_ratios : Sequence[float], optional, default=(0.75, 1.33)
......@@ -82,10 +76,6 @@ class DataTransformer(multiprocessing.Process):
self._seed = kwargs.get('seed', config.config().random_seed)
self.q_in = self.q_out = None
self.daemon = True
if cv2 is None:
raise ImportError('Failed to import package <cv2>.')
if self._distort_color and PIL is None:
raise ImportError('Failed to import package <PIL>.')
def get(self, example):
"""Return image and labels from a serialized str.
......@@ -104,24 +94,26 @@ class DataTransformer(multiprocessing.Process):
"""
# Decode.
img = numpy.frombuffer(example['data'], numpy.uint8)
if example.get('encoded', 0) > 0:
img = cv2.imdecode(img, 1)
if example['encoded'] > 0:
img = PIL.Image.open(io.BytesIO(example['data']))
else:
img = numpy.frombuffer(example['data'], numpy.uint8)
img = img.reshape(example['shape'])
# Resizing.
if self._resize > 0:
(h, w), size = img.shape[:2], self._resize
(w, h), size = img.size, self._resize
if (w <= h and w == size) or (h <= w and h == size):
pass
else:
if w < h:
ow, oh, im_scale = size, size * h // w, float(size) / w
ow, oh = size, size * h // w
else:
oh, ow, im_scale = size, size * w // h, float(size) / h
interp = cv2.INTER_AREA if im_scale < 1 else cv2.INTER_LINEAR
img = cv2.resize(img, (ow, oh), interpolation=interp)
oh, ow = size, size * w // h
img = img.resize((ow, oh), PIL.Image.BILINEAR)
# ToArray.
img = numpy.asarray(img)
# Padding.
if self._padding > 0:
......@@ -152,12 +144,9 @@ class DataTransformer(multiprocessing.Process):
area = height * width
i = j = h = w = None
for attempt in range(10):
target_area = numpy.random.uniform(
*self._random_scales) * area
log_ratio = (
math.log(self._random_ratios[0]),
math.log(self._random_ratios[1]),
)
target_area = numpy.random.uniform(*self._random_scales) * area
log_ratio = (math.log(self._random_ratios[0]),
math.log(self._random_ratios[1]))
aspect_ratio = math.exp(numpy.random.uniform(*log_ratio))
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
......@@ -179,9 +168,8 @@ class DataTransformer(multiprocessing.Process):
j = (width - w) // 2
img = img[i:i + h, j:j + w, :]
new_size = (self._random_crop_size, self._random_crop_size)
min_scale = self._random_crop_size / max(img.shape[:2])
interp = cv2.INTER_AREA if min_scale < 1 else cv2.INTER_LINEAR
img = cv2.resize(img, new_size, interpolation=interp)
img = PIL.Image.fromarray(img)
img = numpy.asarray(img.resize(new_size, PIL.Image.BILINEAR))
# CutOut.
if self._cutout_size > 0:
......@@ -202,16 +190,14 @@ class DataTransformer(multiprocessing.Process):
# Color distortion.
if self._distort_color:
img = PIL.Image.fromarray(img)
transforms = [
PIL.ImageEnhance.Brightness,
transforms = [PIL.ImageEnhance.Brightness,
PIL.ImageEnhance.Contrast,
PIL.ImageEnhance.Color,
]
PIL.ImageEnhance.Color]
numpy.random.shuffle(transforms)
for transform in transforms:
img = transform(img)
img = img.enhance(1. + numpy.random.uniform(-.4, .4))
img = numpy.array(img)
img = numpy.asarray(img)
# Color transformation.
if self._inverse_color:
......
......@@ -251,6 +251,7 @@ class DragonFrontend(object):
for e in op_def.output:
outputs.append(e + '_%d' % blob_versions[e]
if blob_versions[e] > 0 else e)
if e != '':
blob_versions[e] += 1
blob_names[e] = outputs[-1]
op_def.ClearField('input')
......
......@@ -72,7 +72,6 @@ def softmax_exporter(op_def, shape_dict, ws):
if axis != (ndim - 1):
raise ValueError(
'Softmax axis could only be the last one.\n'
'Use Exp(LogSoftmax) to compute the softmax instead.'
)
'Use Exp(LogSoftmax) to compute the softmax instead.')
helper.add_attribute(node, 'axis', arg.i)
return node, const_tensors
......@@ -494,3 +494,27 @@ def top_k_exporter_v11(op_def, shape_dict, ws):
)
node.input.extend([k.name])
return node, [k]
@exporter.register('Unique')
def unique_exporter(op_def, shape_dict, ws):
node, const_tensors = exporter.translate(**locals())
helper.add_attribute(node, 'sorted', 1)
return_inverse = return_counts = 0
for arg in op_def.arg:
if arg.name == 'return_inverse':
return_inverse = arg.i
elif arg.name == 'return_counts':
return_counts = arg.i
outputs = [op_def.output[0]]
if len(op_def.output) > 1:
outputs.append('')
if len(op_def.output) == 2:
if return_inverse:
outputs.append(op_def.output[1])
elif return_counts:
outputs.extend(['', op_def.output[1]])
elif len(op_def.output) == 3:
outputs.extend([op_def.output[1], op_def.output[2]])
node.output[:] = outputs
return node, const_tensors
......@@ -4,6 +4,7 @@
#ifdef USE_CUDA
#include <thrust/device_ptr.h>
#include <thrust/device_vector.h>
#include <thrust/execution_policy.h>
#include <thrust/sequence.h>
#include <thrust/sort.h>
......
......@@ -544,8 +544,18 @@ void TopK(
int64_t* index,
Context* ctx);
/* control_flow.assgin */
/* array.unique */
template <typename T, class Context>
void Unique(
const int dim,
const T* x,
T* y,
int64_t* inverse_index,
int64_t* counts,
int* num,
Context* ctx);
/* control_flow.assgin */
template <typename T, class Context>
void Assign(
const int num_dims,
......
......@@ -85,6 +85,8 @@ from dragon.vm.tensorflow.core.ops.array_ops import split
from dragon.vm.tensorflow.core.ops.array_ops import squeeze
from dragon.vm.tensorflow.core.ops.array_ops import tile
from dragon.vm.tensorflow.core.ops.array_ops import transpose
from dragon.vm.tensorflow.core.ops.array_ops import unique
from dragon.vm.tensorflow.core.ops.array_ops import unique_with_counts
from dragon.vm.tensorflow.core.ops.array_ops import zeros
from dragon.vm.tensorflow.core.ops.array_ops import zeros_like
from dragon.vm.tensorflow.core.ops.clip_ops import clip_by_value
......
......@@ -31,7 +31,7 @@ class Loss(object):
reduction : {'none', 'sum', 'mean', 'valid'}, optional
The reduction method.
name : str, optional
A optional name for the operation.
The operation name.
"""
losses_utils.Reduction.validate(reduction)
......@@ -112,7 +112,7 @@ class BinaryCrossentropy(LossFunctionWrapper):
reduction : {'none', 'sum', 'mean', 'valid'}, optional
The reduction method.
name : str, optional
A optional name for the operation.
The operation name.
"""
super(BinaryCrossentropy, self).__init__(
......@@ -155,7 +155,7 @@ class CategoricalCrossentropy(LossFunctionWrapper):
reduction : {'none', 'sum', 'mean', 'valid'}, optional
The reduction method.
name : str, optional
A optional name for the operation.
The operation name.
"""
super(CategoricalCrossentropy, self).__init__(
......@@ -196,7 +196,7 @@ class MeanAbsoluteError(LossFunctionWrapper):
reduction : {'none', 'sum', 'mean'}, optional
The reduction method.
name : str, optional
A optional name for the operation.
The operation name.
"""
super(MeanAbsoluteError, self).__init__(
......@@ -236,7 +236,7 @@ class MeanSquaredError(LossFunctionWrapper):
reduction : {'none', 'sum', 'mean'}, optional
The reduction method.
name : str, optional
A optional name for the operation.
The operation name.
"""
super(MeanSquaredError, self).__init__(
......@@ -282,7 +282,7 @@ class SparseCategoricalCrossentropy(LossFunctionWrapper):
reduction : {'none', 'sum', 'mean', 'valid'}, optional
The reduction method.
name : str, optional
A optional name for the operation.
The operation name.
"""
super(SparseCategoricalCrossentropy, self).__init__(
......
......@@ -58,7 +58,7 @@ def broadcast_to(input, shape, name=None):
shape : Sequence[Union[int, dragon.Tensor]]
The output shape to broadcast to.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -95,7 +95,7 @@ def concat(values, axis, name='concat'):
axis : int
The axis to concatenate
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -129,7 +129,7 @@ def depth_to_space(input, block_size, data_format='NHWC', name=None):
data_format : {'NCHW', 'NHWC'}, optional
The optional data format.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -172,7 +172,7 @@ def expand_dims(input, axis, name=None):
axis : Union[int, Sequence[int]]
The axis to insert the new dimension(s).
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -203,7 +203,7 @@ def fill(dims, value=0, dtype=None, name=None):
dtype : str, optional
The optional data type.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -248,7 +248,7 @@ def gather(params, indices, axis=0, name=None):
axis : Union[int, Sequence[int]], optional, default=0
The axis where the indices aligned.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -280,7 +280,7 @@ def identity(input, name=None):
input : dragon.Tensor
The input tensor.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -307,7 +307,7 @@ def ones(shape, dtype='float32', name=None):
dtype : str, optional, default='float32'
The optional data type.
name : str, optional
A optional name for the operation.
The operation name.
"""
return init_ops.fill(shape, value=1, dtype=dtype, name=name)
......@@ -332,7 +332,7 @@ def ones_like(input, dtype='float32', name=None):
dtype : str, optional, default='float32'
The optional data type.
name : str, optional
A optional name for the operation.
The operation name.
"""
return init_ops.ones_like(input, dtype=dtype, name=name)
......@@ -378,7 +378,7 @@ def one_hot(
off_value : int, optional, default=0
The value for not-equal branch.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -439,7 +439,7 @@ def pad(
constant_values : int, optional, default=0
The constant value in ``CONSTANT`` mode.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -515,7 +515,7 @@ def reshape(tensor, shape, name=None):
shape : Union[Sequence[int], dragon.Tensor]
The output shape.
name : str, optional
A optional name for the operation.
The operation name.
"""
return array_ops.reshape(tensor, shape=shape, name=name)
......@@ -537,7 +537,7 @@ def shape(input, name=None):
input : dragon.Tensor
The input tensor.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -580,7 +580,7 @@ def slice(input_, begin, size, name=None):
size : Union[Sequence[int], dragon.Tensor]
The number of elements sliced from start.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -614,7 +614,7 @@ def space_to_depth(input, block_size, data_format='NHWC', name=None):
data_format : {'NCHW', 'NHWC'}, optional
The optional data format.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -665,7 +665,7 @@ def split(
axis : int, optional, default=0
The axis to split.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -704,7 +704,7 @@ def squeeze(input, axis=None, name=None):
axis : Union[int, Sequence[int]], optional
The axis to remove.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -740,7 +740,7 @@ def transpose(a, perm=None, name=None):
perm : Sequence[Union[int, dragon.Tensor]]
The output permutation.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -751,6 +751,78 @@ def transpose(a, perm=None, name=None):
return array_ops.transpose(a, perm=perm, name=name)
def unique(x, name=None, **kwargs):
"""Return the unique elements of input.
Unique elements and index where input mapping to are returned:
```python
x = tf.constant([1, 2, 3, 2])
y, index = tf.unique(x)
print(y) # [1, 2, 3]
print(index) # [0, 1, 2, 1]
```
Parameters
----------
x : dragon.Tensor
The input tensor.
name : str, optional
The operation name.
Returns
-------
dragon.Tensor
The output tensor.
dragon.Tensor
The inverse index tensor.
"""
if 'out_idx' in kwargs:
kwargs.pop('out_idx')
return array_ops.unique(x, return_inverse=True, name=name)
def unique_with_counts(x, name=None, **kwargs):
"""Return the unique elements of input with counts.
Unique elements, remapping index and counts are returned:
```python
x = tf.constant([1, 2, 3, 2])
y, index, counts = tf.unique_with_counts(x)
print(y) # [1, 2, 3]
print(index) # [0, 1, 2, 1]
print(counts) # [1, 2, 1]
```
Parameters
----------
x : dragon.Tensor
The input tensor.
name : str, optional
The operation name.
Returns
-------
dragon.Tensor
The output tensor.
dragon.Tensor
The inverse index tensor.
dragon.Tensor
The counts tensor.
"""
if 'out_idx' in kwargs:
kwargs.pop('out_idx')
return array_ops.unique(
x,
return_inverse=True,
return_counts=True,
name=name,
)
def zeros(shape, dtype='float32', name=None):
r"""Return a tensor filled with zeros.
......@@ -767,7 +839,7 @@ def zeros(shape, dtype='float32', name=None):
dtype : str, optional, default='float32'
The optional data type.
name : str, optional
A optional name for the operation.
The operation name.
"""
return init_ops.fill(shape, value=0., dtype=dtype, name=name)
......@@ -792,7 +864,7 @@ def zeros_like(input, dtype='float32', name=None):
dtype : str, optional, default='float32'
The optional data type.
name : str, optional
A optional name for the operation.
The operation name.
"""
return init_ops.zeros_like(input, dtype=dtype, name=name)
......@@ -42,7 +42,7 @@ def bitwise_and(x, y, name=None):
y : dragon.Tensor
The input2 tensor.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -74,7 +74,7 @@ def bitwise_or(x, y, name=None):
y : dragon.Tensor
The input2 tensor.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -106,7 +106,7 @@ def bitwise_xor(x, y, name=None):
y : dragon.Tensor
The input2 tensor.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -139,7 +139,7 @@ def invert(x, name=None):
x : dragon.Tensor
The input tensor.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......
......@@ -47,7 +47,7 @@ def clip_by_value(
clip_value_max : number, optional
The value to :math:`\text{high}`.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......
......@@ -42,7 +42,7 @@ def eye(num_rows, num_columns=None, dtype='float32', name=None):
dtype : str, optional, default='float32'
The optional data type.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......
......@@ -39,7 +39,7 @@ def abs(x, name=None):
x : dragon.Tensor
The input tensor.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -69,7 +69,7 @@ def add(x, y, name=None):
y : dragon.Tensor
The input2 tensor.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -97,7 +97,7 @@ def add_n(inputs, name=None):
inputs : Sequence[dragon.Tensor]
The input tensors.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -138,7 +138,7 @@ def argmax(input, axis=None, name=None):
axis : int, optional
The axis to reduce.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -173,7 +173,7 @@ def argmin(input, axis=None, name=None):
axis : int, optional
The axis to reduce.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -201,7 +201,7 @@ def cast(x, dtype, name=None):
dtype : str
The data type to cast to.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -229,7 +229,7 @@ def ceil(x, name=None):
x : dragon.Tensor
The input tensor.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -257,7 +257,7 @@ def cos(x, name=None):
x : dragon.Tensor
The input tensor.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -308,7 +308,7 @@ def cumsum(x, axis=0, exclusive=False, reverse=False, name=None):
reverse : bool, optional, default=False
**True** to compute in the reverse direction.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -346,7 +346,7 @@ def divide(x, y, name=None):
y : dragon.Tensor
The input2 tensor.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -379,7 +379,7 @@ def equal(x, y, name=None):
y : dragon.Tensor
The input2 tensor.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -407,7 +407,7 @@ def exp(x, name=None):
x : dragon.Tensor
The input tensor.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -435,7 +435,7 @@ def floor(x, name=None):
x : dragon.Tensor
The input tensor.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -468,7 +468,7 @@ def greater(x, y, name=None):
y : dragon.Tensor
The input2 tensor.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -501,7 +501,7 @@ def greater_equal(x, y, name=None):
y : dragon.Tensor
The input2 tensor.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -529,7 +529,7 @@ def is_inf(x, name=None):
x : dragon.Tensor
The input tensor.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -557,7 +557,7 @@ def is_nan(x, name=None):
x : dragon.Tensor
The input tensor.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -590,7 +590,7 @@ def less(x, y, name=None):
y : dragon.Tensor
The input2 tensor.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -623,7 +623,7 @@ def less_equal(x, y, name=None):
y : dragon.Tensor
The input2 tensor.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -651,7 +651,7 @@ def log(x, name=None):
x : dragon.Tensor
The input tensor.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -707,7 +707,7 @@ def matmul(
transpose_b : bool, optional, default=False
**True** to transpose :math:`b` before computing.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -744,7 +744,7 @@ def multiply(x, y, name=None):
y : dragon.Tensor
The input2 tensor.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -770,7 +770,7 @@ def negative(x, name=None):
x : dragon.Tensor
The input tensor.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -803,7 +803,7 @@ def not_equal(x, y, name=None):
y : dragon.Tensor
The input2 tensor.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -835,7 +835,7 @@ def pow(x, y, name=None):
y : Union[dragon.Tensor, number]
The exponent tensor.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -878,7 +878,7 @@ def range(start, limit=None, delta=1, dtype='int64', name=None):
dtype : str, optional, default='int64'
The optional data type.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -912,7 +912,7 @@ def reciprocal(x, name=None):
x : dragon.Tensor
The input tensor.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -949,7 +949,7 @@ def reduce_max(input_tensor, axis=None, keepdims=False, name=None):
keepdims : bool, optional, default=False
Keep the reduced dimensions or not.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -994,7 +994,7 @@ def reduce_mean(input_tensor, axis=None, keepdims=False, name=None):
keepdims : bool, optional, default=False
Keep the reduced dimensions or not.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -1036,7 +1036,7 @@ def reduce_min(input_tensor, axis=None, keepdims=False, name=None):
keepdims : bool, optional, default=False
Keep the reduced dimensions or not.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -1081,7 +1081,7 @@ def reduce_sum(input_tensor, axis=None, keepdims=False, name=None):
keepdims : bool, optional, default=False
Keep the reduced dimensions or not.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -1114,7 +1114,7 @@ def round(x, name=None):
x : dragon.Tensor
The input tensor.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -1142,7 +1142,7 @@ def rsqrt(x, name=None):
x : dragon.Tensor
The input tensor.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -1170,7 +1170,7 @@ def sigmoid(x, name=None, **kwargs):
x : dragon.Tensor
The input tensor.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -1204,7 +1204,7 @@ def sign(x, name=None):
x : dragon.Tensor
The input tensor.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -1232,7 +1232,7 @@ def sin(x, name=None):
x : dragon.Tensor
The input tensor.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -1260,7 +1260,7 @@ def sqrt(x, name=None):
x : dragon.Tensor
The input tensor.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -1288,7 +1288,7 @@ def square(x, name=None):
x : dragon.Tensor
The input tensor.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -1320,7 +1320,7 @@ def subtract(x, y, name=None):
y : dragon.Tensor
The input2 tensor.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -1349,7 +1349,7 @@ def tanh(x, name=None, **kwargs):
x : dragon.Tensor
The input tensor.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......
......@@ -64,7 +64,7 @@ def batch_normalization(
trainable : bool, optional, default=False
The optional training flag.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -119,7 +119,7 @@ def l2_normalize(x, axis=None, epsilon=1e-12, name=None):
epsilon : float, optional, default=1e-5
The value to :math:`\epsilon`.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -171,7 +171,7 @@ def moments(x, axes=None, keepdims=False, name=None):
keepdims : bool, optional, default=False
Keep the reduced dimensions or not.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......
......@@ -49,7 +49,7 @@ def avg_pool(
data_format : {'NCHW', 'NCDHW', 'NHWC', 'NDHWC'}, optional
The optional data format.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -109,7 +109,7 @@ def avg_pool2d(
data_format : {'NCHW', 'NCDHW', 'NHWC', 'NDHWC'}, optional
The optional data format.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -164,7 +164,7 @@ def convolution(
dilations : Sequence[int], optional
The rate(s) of dilated kernel.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -230,7 +230,7 @@ def conv_transpose(
dilations : Sequence[int], optional
The rate(s) of dilated kernel.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -299,7 +299,7 @@ def conv2d(
dilations : Sequence[int], optional
The rate(s) of dilated kernel.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -339,7 +339,7 @@ def conv2d_transpose(
dilations : Sequence[int], optional
The rate(s) of dilated kernel.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -377,7 +377,7 @@ def depthwise_conv2d(
dilations : Sequence[int], optional
The rate(s) of dilated kernel.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -410,7 +410,7 @@ def dropout(x, rate, name=None, **kwargs):
rate : Union[float, dragon.Tensor]
The dropping probability.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -441,7 +441,7 @@ def elu(features, alpha=1., name=None, **kwargs):
alpha : float, optional, default=1.
The value to :math:`\alpha`.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -475,7 +475,7 @@ def leaky_relu(features, alpha=0.2, name=None, **kwargs):
alpha : number, optional, default=0.2
The value to :math:`\alpha`.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -520,7 +520,7 @@ def local_response_normalization(
data_format : {'NCHW', 'NHWC'}, optional
The optional data format.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -561,7 +561,7 @@ def log_softmax(logits, axis=-1, name=None):
axis : int, optional, default=1
The axis to reduce.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -595,7 +595,7 @@ def max_pool(
data_format : {'NCHW', 'NCDHW', 'NHWC', 'NDHWC'}, optional
The optional data format.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -655,7 +655,7 @@ def max_pool2d(
data_format : {'NCHW', 'NCDHW', 'NHWC', 'NDHWC'}, optional
The optional data format.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -693,7 +693,7 @@ def relu(features, name=None, **kwargs):
features : dragon.Tensor
The input tensor.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -722,7 +722,7 @@ def relu6(features, name=None, **kwargs):
features : dragon.Tensor
The input tensor.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -756,7 +756,7 @@ def selu(features, name=None, **kwargs):
features : dragon.Tensor
The tensor :math:`x`.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -797,7 +797,7 @@ def softmax(logits, axis=-1, name=None, **kwargs):
axis : int, optional, default=-1
The axis to reduce.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -826,7 +826,7 @@ def softmax_cross_entropy_with_logits(labels, logits, name=None):
logits : dragon.Tensor
The logit tensor.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -860,7 +860,7 @@ def sparse_softmax_cross_entropy_with_logits(labels, logits, name=None):
logits : dragon.Tensor
The logit tensor.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -888,7 +888,7 @@ def top_k(input, k=1, sorted=True, name=None):
sorted : bool, optional
Whether to return in the sorted order.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......
......@@ -42,7 +42,7 @@ def random_normal(
seed : int, optional
The optional random seed.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -79,7 +79,7 @@ def random_uniform(
seed : int, optional
The optional random seed.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......@@ -117,7 +117,7 @@ def truncated_normal(
seed : int, optional
The optional random seed.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......
......@@ -84,7 +84,7 @@ def Input(shape, dtype='float32', name=None):
dtype : str, optional, default='float32'
The optional data type.
name : str, optional
A optional name for the operation.
The operation name.
Returns
-------
......
......@@ -47,18 +47,11 @@ class TestKPLRecord(unittest.TestCase):
self.assertEqual(len(dataset), 1)
dataset.redirect(0)
self.assertEqual(example, dataset.get())
for num_chunks, shuffle in [(0, False), (0, True), (1, True)]:
for shuffle, initial_fill in [(False, 1), (True, 1), (True, 1024)]:
reader = dragon.io.DataReader(
dataset=dragon.io.KPLRecordDataset,
source=path, num_chunks=num_chunks, shuffle=shuffle)
source=path, shuffle=shuffle, initial_fill=initial_fill)
reader._init_dataset()
for i in range(8):
self.assertEqual(reader.next_example(), example)
if reader._example_cursor >= reader._end:
if reader._num_parts > 1 or reader._shuffle:
reader.next_chunk()
else:
reader.reset()
except (OSError, PermissionError):
pass
......
......@@ -908,6 +908,31 @@ class TestArrayOps(OpTestCase):
with dragon.device('cuda'):
self.test_top_k()
def test_unique(self):
data = np.array([1, 1, 3, 5, 5, 7, 9])
entries = [(False, False),
(True, False),
(False, True),
(True, True)]
for execution in ('EAGER_MODE', 'GRAPH_MODE'):
with execution_context().mode(execution):
for return_inverse, return_counts in entries:
x = new_tensor(data)
y = dragon.unique(
x,
return_inverse=return_inverse,
return_counts=return_counts)
result = np.unique(
data,
return_inverse=return_inverse,
return_counts=return_counts)
self.assertEqual(y, result, test_symbols=False)
@unittest.skipIf(not TEST_CUDA, 'CUDA unavailable')
def test_unique_cuda(self):
with dragon.device('cuda'):
self.test_unique()
def test_where(self):
entries = [((6,), (6,))]
for execution in ('EAGER_MODE', 'GRAPH_MODE'):
......
......@@ -112,8 +112,12 @@ class TestModule(unittest.TestCase):
pass
for _, _ in m.named_parameters():
pass
for _, _ in m.named_buffers():
pass
for _ in m.parameters():
pass
for _ in m.buffers():
pass
_, _ = repr(m), repr(m.weight)
def test_sequential(self):
......
......@@ -537,14 +537,31 @@ class TestTensorOps(OpTestCase):
self.assertEqual(getattr(x, name)(), data.astype(dtype))
getattr(x, name + '_')()
self.assertEqual(x, data.astype(dtype))
x.type(dtype)
self.assertEqual(x.type(), dtype)
y = x.type(dtype)
self.assertEqual(y.type(), dtype)
def test_uniform(self):
data = arange((2, 3))
x = new_tensor(data)
x.uniform_()
def test_unique(self):
data = np.array([1, 1, 3, 5, 5, 7, 9])
entries = [(False, False),
(True, False),
(False, True),
(True, True)]
for return_inverse, return_counts in entries:
x = new_tensor(data)
y = x.unique(return_inverse=return_inverse,
return_counts=return_counts,
sorted=True)
result = np.unique(
data,
return_inverse=return_inverse,
return_counts=return_counts)
self.assertEqual(y, result)
def test_unsqueeze(self):
entries = [1, -1]
for axis in entries:
......
......@@ -83,6 +83,25 @@ class TestTensor(unittest.TestCase):
self.assertEqual(x_from_dlpack.dtype, str(data.dtype))
self.assertLessEqual(np.abs(x_from_dlpack.numpy() - data).max(), 1e-5)
def test_internal_converter(self):
data = np.array([0., 1., 2.], 'float32')
x = torch.tensor(data)
y = x.to(torch.int32)
self.assertEqual(y.dtype, 'int32')
y = x.to(torch.device('cpu'))
self.assertEqual(y.device, torch.device('cpu'))
y = x.to(torch.FloatTensor(1))
self.assertEqual(y.dtype, 'float32')
self.assertEqual(y.device, torch.device('cpu'))
try:
_ = x.to(data)
except ValueError:
pass
try:
_ = x.to(torch.device('gpu'))
except ValueError:
pass
def test_numpy_converter(self):
data = np.array([0., 1., 2.], 'float32')
x = torch.from_numpy(data)
......
......@@ -73,6 +73,7 @@ from dragon.vm.torch.core.ops.array.functional import squeeze
from dragon.vm.torch.core.ops.array.functional import stack
from dragon.vm.torch.core.ops.array.functional import sum
from dragon.vm.torch.core.ops.array.functional import topk
from dragon.vm.torch.core.ops.array.functional import unique
from dragon.vm.torch.core.ops.array.functional import unsqueeze
from dragon.vm.torch.core.ops.array.functional import where
from dragon.vm.torch.core.ops.init.functional import arange
......
......@@ -98,6 +98,23 @@ class Module(object):
fn(self)
return self
def buffers(self, recurse=True):
"""Return an iterator over all buffers.
Parameters
----------
recurse : bool, optional, default=True
Yield parameters recursively or not.
Returns
-------
Iterator
The iterator of buffer.
"""
for name, buffer in self.named_buffers(recurse=recurse):
yield buffer
def children(self):
"""Return an iterator over immediate modules.
......@@ -294,6 +311,28 @@ class Module(object):
for name, module in self.named_modules():
yield module
def named_buffers(self, prefix='', recurse=True):
"""Return an iterator over all buffers.
Parameters
----------
prefix : str, optional, default=''
The prefix added to the name.
recurse : bool, optional, default=True
Yield buffers recursively or not.
Returns
-------
Iterator
The iterator of (name, buffer).
"""
gen = self._named_members(
lambda module: module._buffers.items(),
prefix=prefix, recurse=recurse)
for name, buffer in gen:
yield name, buffer
def named_children(self):
"""Return an iterator over immediate modules, yield as *(name, module)*.
......@@ -320,7 +359,7 @@ class Module(object):
Returns
-------
Iterator
The iterator of module.
The iterator of (name, module).
"""
if memo is None:
......@@ -335,43 +374,43 @@ class Module(object):
for m in module.named_modules(memo, submodule_prefix):
yield m
def named_parameters(self, memo=None, prefix=''):
def named_parameters(self, prefix='', recurse=True):
"""Return an iterator over all parameters.
Parameters
----------
memo : Set, optional
The optional set to collect parameters.
prefix : str, optional, default=''
The prefix added to the name.
recurse : bool, optional, default=True
Yield parameters recursively or not.
Returns
-------
Iterator
The iterator of parameter.
The iterator of (name, param).
"""
if memo is None:
memo = set()
for name, p in self._parameters.items():
if p is not None and p not in memo:
memo.add(p)
yield prefix + ('.' if prefix else '') + name, p
for mname, module in self.named_children():
submodule_prefix = prefix + ('.' if prefix else '') + mname
for name, p in module.named_parameters(memo, submodule_prefix):
yield name, p
def parameters(self):
gen = self._named_members(
lambda module: module._parameters.items(),
prefix=prefix, recurse=recurse)
for name, param in gen:
yield name, param
def parameters(self, recurse=True):
"""Return an iterator over all parameters.
Parameters
----------
recurse : bool, optional, default=True
Yield parameters recursively or not.
Returns
-------
Iterator
The iterator of parameter.
The iterator of param.
"""
for name, param in self.named_parameters():
for name, param in self.named_parameters(recurse=recurse):
yield param
def register_buffer(self, name, tensor):
......@@ -535,8 +574,22 @@ class Module(object):
return self
def _get_name(self):
"""Return the class name."""
return self.__class__.__name__
def _named_members(self, getter, prefix='', recurse=True):
"""Return the named members."""
memo = set()
modules = self.named_modules(prefix=prefix) if recurse else [(prefix, self)]
for module_prefix, module in modules:
members = getter(module)
for k, v in members:
if v is None or v in memo:
continue
memo.add(v)
name = module_prefix + ('.' if module_prefix else '') + k
yield name, v
def __call__(self, *args, **kwargs):
"""Run the forward pipeline."""
outputs = self.forward(*args, **kwargs)
......
......@@ -546,6 +546,27 @@ class TopK(function.Function):
return self.dispatch([input], outputs, no_grad=True)
class Unique(function.Function):
def __init__(self, key, dev, **kwargs):
super(Unique, self).__init__(key, dev, **kwargs)
self.return_inverse = kwargs.get('return_inverse', False)
self.return_counts = kwargs.get('return_counts', False)
self.num_outputs = 1 + self.return_inverse + self.return_counts
def attributes(self):
return {
'op_type': 'Unique',
'arguments': {
'return_inverse': self.return_inverse,
'return_counts': self.return_counts,
}
}
def forward(self, input):
outputs = [self.alloc() for _ in range(self.num_outputs)]
return self.dispatch([input], outputs)
class UnSqueeze(function.Function):
def __init__(self, key, dev, **kwargs):
super(UnSqueeze, self).__init__(key, dev, **kwargs)
......
......@@ -982,6 +982,56 @@ def topk(input, k, dim=None, largest=True, sorted=True, out=None):
).apply(input, out if out else (None, None))
def unique(input, return_inverse=False, return_counts=False, **kwargs):
"""Return the unique elements of input.
If ``return_inverse``, return the extra index where input mapping to:
```python
x = torch.tensor([1, 2, 3, 2])
y, index = torch.unique(x, return_inverse=True)
print(y) # [1, 2, 3]
print(index) # [0, 1, 2, 1]
```
If ``return_counts``, return the extra counts of output:
```python
x = torch.tensor([1, 2, 3, 2])
y, counts = torch.unique(x, return_counts=True)
print(y) # [1, 2, 3]
print(counts) # [1, 2, 1]
```
Parameters
----------
input : dragon.vm.torch.Tensor
The input tensor.
return_inverse : bool, optional, default=False
Return the inverse index or not.
return_counts : bool, optional, default=False
Return the counts or not.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
dragon.vm.torch.Tensor, optional
The inverse index tensor.
dragon.vm.torch.Tensor, optional
The counts tensor.
"""
if 'sorted' in kwargs:
kwargs.pop('sorted')
return _functions.Unique \
.instantiate(
input.device,
return_inverse=return_inverse,
return_counts=return_counts,
).apply(input)
def unsqueeze(input, dim, out=None):
"""Expand the dimensions of input with size 1.
......
......@@ -1705,7 +1705,7 @@ def _type(self, dtype=None):
"""
if dtype is None:
return self.dtype
return array_funcs.cast(self, dtype, True)
return array_funcs.cast(self, dtype, False)
def uniform_(self, low=0, high=1):
......@@ -1729,6 +1729,33 @@ def uniform_(self, low=0, high=1):
return init_funcs.uniform_fill(self, low, high)
def unique(self, return_inverse=False, return_counts=False, **kwargs):
"""Return the unique elements.
Parameters
----------
return_inverse : bool, optional, default=False
Return the inverse index or not.
return_counts : bool, optional, default=False
Return the counts or not.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
dragon.vm.torch.Tensor, optional
The inverse index tensor.
dragon.vm.torch.Tensor, optional
The counts tensor.
See Also
--------
`torch.unique(...)`_
"""
return array_funcs.unique(self, return_inverse, return_counts, **kwargs)
def unsqueeze(self, dim):
"""Return a tensor with dimensions of size 1 inserted.
......@@ -1922,6 +1949,7 @@ Tensor.sub_ = sub_
Tensor.topk = topk
Tensor.type = _type
Tensor.uniform_ = uniform_
Tensor.unique = unique
Tensor.unsqueeze = unsqueeze
Tensor.unsqueeze_ = unsqueeze_
Tensor.where = where
......
......@@ -564,12 +564,10 @@ class Tensor(object):
src._impl,
proto_util.get_device_option(
self._device.type,
self._device.index
).SerializeToString(),
self._device.index).SerializeToString(),
proto_util.get_device_option(
src._device.type,
src._device.index
).SerializeToString(),
src._device.index).SerializeToString(),
)
return self
......@@ -1803,6 +1801,54 @@ class Tensor(object):
"""
def to(self, *args, **kwargs):
"""Convert to the specified data type or device.
The arguments could be ``torch.dtype`` or ``torch.device``:
```python
x = torch.FloatTensor(1)
x.to(torch.int32) # Equivalent to ``x.int()``
x.to(torch.device('cpu')) # Equivalent to ``x.cpu()``
x.to(torch.device('cuda'), torch.float32) # Convert both
```
Or ``torch.Tensor`` to provide both ``dtype`` and ``device``:
```python
a, b = torch.tensor(1.), torch.tensor(2)
print(a.to(b)) # 1
```
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
"""
dtype = kwargs.get('dtype', None)
device = kwargs.get('device', None)
for arg in args:
if isinstance(arg, cpp.dtype):
dtype = arg
elif isinstance(arg, cpp.device):
device = arg
elif isinstance(arg, Tensor):
dtype, device = arg.dtype, arg.device
break
else:
raise ValueError('Unsupported conversion target.')
if device is not None:
if device.type == 'cpu':
self.cpu()
elif device.type == 'cuda':
self.cuda(device.index)
else:
raise ValueError('Unsupported device type: ' + device.type)
if dtype is not None:
return self.type(dtype)
return self
def topk(self, k, dim=None, largest=True, sorted=True):
"""Return the top-K largest or smallest elements.
......@@ -1829,19 +1875,17 @@ class Tensor(object):
"""
def type(self, dtype=None):
"""Return the data type.
If ``dtype`` is not **None**, cast ``self`` to the new tensor.
"""Return the data type or copied tensor with specified type.
Parameters
----------
dtype : str, optional
The specified type.
The specified type to convert to.
Returns
-------
Union[str, dragon.vm.torch.Tensor]
The data type or new tensor.
The data type or copied tensor.
"""
......@@ -1864,6 +1908,31 @@ class Tensor(object):
"""
def unique(self, return_inverse=False, return_counts=False, **kwargs):
"""Return the unique elements.
Parameters
----------
return_inverse : bool, optional, default=False
Return the inverse index or not.
return_counts : bool, optional, default=False
Return the counts or not.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
dragon.vm.torch.Tensor, optional
The inverse index tensor.
dragon.vm.torch.Tensor, optional
The counts tensor.
See Also
--------
`torch.unique(...)`_
"""
def unsqueeze(self, dim):
"""Return a tensor with dimensions of size 1 inserted.
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!