Commit b7e2298f by Ting PAN

Add Im2col operator

Summary:
This commit adds im2col operator to unfold input to depth.
1 parent 9f583556
...@@ -12,6 +12,9 @@ dragon.vision ...@@ -12,6 +12,9 @@ dragon.vision
Functions Functions
--------- ---------
`extract_patches(...) <vision/extract_patches.html>`_
: Extract the sliding patches from input.
`resize(...) <vision/resize.html>`_ `resize(...) <vision/resize.html>`_
: Resize input via interpolating neighborhoods. : Resize input via interpolating neighborhoods.
...@@ -27,6 +30,7 @@ dragon.vision ...@@ -27,6 +30,7 @@ dragon.vision
:hidden: :hidden:
vision/DataIterator vision/DataIterator
vision/extract_patches
vision/resize vision/resize
vision/roi_align vision/roi_align
vision/roi_pool vision/roi_pool
......
extract_patches
===============
.. autofunction:: dragon.vision.extract_patches
.. raw:: html
<style>
h1:before {
content: "dragon.vision.";
color: #103d3e;
}
</style>
...@@ -279,6 +279,9 @@ vm.torch.nn ...@@ -279,6 +279,9 @@ vm.torch.nn
: Apply the sync batch normalization over input. : Apply the sync batch normalization over input.
`[Ioffe & Szegedy, 2015] <https://arxiv.org/abs/1502.03167>`_. `[Ioffe & Szegedy, 2015] <https://arxiv.org/abs/1502.03167>`_.
`class Unfold <nn/Unfold.html>`_
: Extract the sliding blocks.
`class Upsample <nn/Upsample.html>`_ `class Upsample <nn/Upsample.html>`_
: Upsample input via interpolating neighborhoods. : Upsample input via interpolating neighborhoods.
...@@ -374,6 +377,7 @@ vm.torch.nn ...@@ -374,6 +377,7 @@ vm.torch.nn
nn/TransformerEncoder nn/TransformerEncoder
nn/TransformerEncoderLayer nn/TransformerEncoderLayer
nn/SyncBatchNorm nn/SyncBatchNorm
nn/Unfold
nn/Upsample nn/Upsample
nn/UpsamplingBilinear2d nn/UpsamplingBilinear2d
nn/UpsamplingNearest2d nn/UpsamplingNearest2d
......
Unfold
======
.. autoclass:: dragon.vm.torch.nn.Unfold
__init__
--------
.. automethod:: dragon.vm.torch.nn.Unfold.__init__
.. _torch.nn.functional.unfold(...): functional/unfold.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.";
color: #103d3e;
}
</style>
...@@ -196,6 +196,9 @@ vm.torch.nn.functional ...@@ -196,6 +196,9 @@ vm.torch.nn.functional
`tanh(...) <functional/tanh.html>`_ `tanh(...) <functional/tanh.html>`_
: Apply the tanh function to input. : Apply the tanh function to input.
`unfold(...) <functional/unfold.html>`_
: Extract the sliding blocks from input.
`upsample(...) <functional/upsample.html>`_ `upsample(...) <functional/upsample.html>`_
: Upsample input via interpolating neighborhoods. : Upsample input via interpolating neighborhoods.
...@@ -264,6 +267,7 @@ vm.torch.nn.functional ...@@ -264,6 +267,7 @@ vm.torch.nn.functional
functional/softmax functional/softmax
functional/sync_batch_norm functional/sync_batch_norm
functional/tanh functional/tanh
functional/unfold
functional/upsample functional/upsample
functional/upsample_bilinear functional/upsample_bilinear
functional/upsample_nearest functional/upsample_nearest
......
unfold
======
.. autofunction:: dragon.vm.torch.nn.functional.unfold
.. _torch.nn.Unfold(...): ../Unfold.html
.. raw:: html
<style>
h1:before {
content: "torch.nn.functional.";
color: #103d3e;
}
</style>
...@@ -46,7 +46,7 @@ __global__ void _ClipGrad( ...@@ -46,7 +46,7 @@ __global__ void _ClipGrad(
const T* x, \ const T* x, \
T* y, \ T* y, \
CUDAContext* ctx) { \ CUDAContext* ctx) { \
_Clip<T, AccT><<<CUDA_BLOCKS(N), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \ _Clip<<<CUDA_BLOCKS(N), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
N, low, high, x, y); \ N, low, high, x, y); \
} }
...@@ -60,8 +60,7 @@ __global__ void _ClipGrad( ...@@ -60,8 +60,7 @@ __global__ void _ClipGrad(
const T* x, \ const T* x, \
T* dx, \ T* dx, \
CUDAContext* ctx) { \ CUDAContext* ctx) { \
_ClipGrad<T, AccT> \ _ClipGrad<<<CUDA_BLOCKS(N), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
<<<CUDA_BLOCKS(N), CUDA_THREADS, 0, ctx->cuda_stream()>>>( \
N, low, high, dy, x, dx); \ N, low, high, dy, x, dx); \
} }
......
...@@ -36,8 +36,9 @@ void _Im2Col2dNCHW( ...@@ -36,8 +36,9 @@ void _Im2Col2dNCHW(
} else { } else {
int w = -pad_w + w_k * dilation_w; int w = -pad_w + w_k * dilation_w;
for (int w_out = 0; w_out < out_w; ++w_out) { for (int w_out = 0; w_out < out_w; ++w_out) {
*(col++) = *(col++) = !math::utils::IsAGeZeroAndALtB(w, W)
!math::utils::IsAGeZeroAndALtB(w, W) ? T(0) : im[h * W + w]; ? convert::To<T>(0.f)
: im[h * W + w];
w += stride_w; w += stride_w;
} }
} }
...@@ -299,7 +300,7 @@ void _Im2ColNdNHWC( ...@@ -299,7 +300,7 @@ void _Im2ColNdNHWC(
T* y, \ T* y, \
CPUContext* ctx) { \ CPUContext* ctx) { \
if (kTransposed) { \ if (kTransposed) { \
math::Set(C* H* W, T(0), y, ctx); \ math::Set(C* H* W, convert::To<T>(0.f), y, ctx); \
} \ } \
DISPATCH_CONV_KERNEL( \ DISPATCH_CONV_KERNEL( \
_##name, \ _##name, \
...@@ -320,6 +321,29 @@ void _Im2ColNdNHWC( ...@@ -320,6 +321,29 @@ void _Im2ColNdNHWC(
y); \ y); \
} }
template <>
void Col2Im2d<float16, CPUContext>(
const int C,
const int H,
const int W,
const int out_h,
const int out_w,
const int kernel_h,
const int kernel_w,
const int stride_h,
const int stride_w,
const int pad_h,
const int pad_w,
const int dilation_h,
const int dilation_w,
const string& data_format,
const float16* x,
float16* y,
CPUContext* ctx) {
CPU_FP16_NOT_SUPPORTED;
}
DEFINE_KERNEL_LAUNCHER(Im2Col2d, false, float16);
DEFINE_KERNEL_LAUNCHER(Im2Col2d, false, float); DEFINE_KERNEL_LAUNCHER(Im2Col2d, false, float);
DEFINE_KERNEL_LAUNCHER(Im2Col2d, false, double); DEFINE_KERNEL_LAUNCHER(Im2Col2d, false, double);
DEFINE_KERNEL_LAUNCHER(Col2Im2d, true, float); DEFINE_KERNEL_LAUNCHER(Col2Im2d, true, float);
...@@ -379,6 +403,28 @@ DEFINE_KERNEL_LAUNCHER(Col2ImNd, true, double); ...@@ -379,6 +403,28 @@ DEFINE_KERNEL_LAUNCHER(Col2ImNd, true, double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
#undef DISPATCH_CONV_KERNEL #undef DISPATCH_CONV_KERNEL
#define DEFINE_KERNEL_LAUNCHER(name, kTransposed, T) \
template <> \
void name<T, CPUContext>( \
const int num_dims, \
const int channels, \
const int* in_shape, \
const int* out_shape, \
const int* kshape, \
const int* strides, \
const int* pads, \
const int* dilations, \
const string& data_format, \
const T* x, \
T* y, \
CPUContext* ctx) { \
CPU_FP16_NOT_SUPPORTED; \
}
DEFINE_KERNEL_LAUNCHER(Im2ColNd, false, float16);
DEFINE_KERNEL_LAUNCHER(Col2ImNd, true, float16);
#undef DEFINE_KERNEL_LAUNCHER
} // namespace kernels } // namespace kernels
} // namespace dragon } // namespace dragon
...@@ -283,7 +283,7 @@ __global__ void _Im2ColNdNHWC( ...@@ -283,7 +283,7 @@ __global__ void _Im2ColNdNHWC(
} }
im_idx = im_idx * channel_dim + i % channel_dim; im_idx = im_idx * channel_dim + i % channel_dim;
if (!Transposed) { if (!Transposed) {
y[col_idx] = is_padding ? T(0) : __ldg(x + im_idx); y[col_idx] = is_padding ? convert::To<T>(0.f) : __ldg(x + im_idx);
} else if (!is_padding) { } else if (!is_padding) {
math::utils::AtomicAdd(y + im_idx, x[col_idx]); math::utils::AtomicAdd(y + im_idx, x[col_idx]);
} }
...@@ -343,12 +343,14 @@ __global__ void _Im2ColNdNHWC( ...@@ -343,12 +343,14 @@ __global__ void _Im2ColNdNHWC(
pad_w, \ pad_w, \
dilation_h, \ dilation_h, \
dilation_w, \ dilation_w, \
x, \ reinterpret_cast<const math::ScalarType<T>::type*>(x), \
y); \ reinterpret_cast<math::ScalarType<T>::type*>(y)); \
} }
DEFINE_KERNEL_LAUNCHER(Im2Col2d, false, float16);
DEFINE_KERNEL_LAUNCHER(Im2Col2d, false, float); DEFINE_KERNEL_LAUNCHER(Im2Col2d, false, float);
DEFINE_KERNEL_LAUNCHER(Im2Col2d, false, double); DEFINE_KERNEL_LAUNCHER(Im2Col2d, false, double);
DEFINE_KERNEL_LAUNCHER(Col2Im2d, true, float16);
DEFINE_KERNEL_LAUNCHER(Col2Im2d, true, float); DEFINE_KERNEL_LAUNCHER(Col2Im2d, true, float);
DEFINE_KERNEL_LAUNCHER(Col2Im2d, true, double); DEFINE_KERNEL_LAUNCHER(Col2Im2d, true, double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
...@@ -407,12 +409,12 @@ DEFINE_KERNEL_LAUNCHER(Col2Im2d, true, double); ...@@ -407,12 +409,12 @@ DEFINE_KERNEL_LAUNCHER(Col2Im2d, true, double);
if (kTransposed) { \ if (kTransposed) { \
const auto in_dim = std::accumulate( \ const auto in_dim = std::accumulate( \
in_shape, in_shape + num_dims, 1, std::multiplies<int>()); \ in_shape, in_shape + num_dims, 1, std::multiplies<int>()); \
math::Set(channels* in_dim, T(0), y, ctx); \ math::Set(channels* in_dim, convert::To<T>(0.f), y, ctx); \
} \ } \
DISPATCH_CONV_KERNEL( \ DISPATCH_CONV_KERNEL( \
_Im2ColNd, \ _Im2ColNd, \
kTransposed, \ kTransposed, \
T, \ math::ScalarType<T>::type, \
CUDA_2D_BLOCKS(outer_dim), \ CUDA_2D_BLOCKS(outer_dim), \
CUDA_THREADS, \ CUDA_THREADS, \
channels, \ channels, \
...@@ -426,12 +428,14 @@ DEFINE_KERNEL_LAUNCHER(Col2Im2d, true, double); ...@@ -426,12 +428,14 @@ DEFINE_KERNEL_LAUNCHER(Col2Im2d, true, double);
strides_arr, \ strides_arr, \
pads_arr, \ pads_arr, \
dilations_arr, \ dilations_arr, \
x, \ reinterpret_cast<const math::ScalarType<T>::type*>(x), \
y); \ reinterpret_cast<math::ScalarType<T>::type*>(y)); \
} }
DEFINE_KERNEL_LAUNCHER(Im2ColNd, false, float16);
DEFINE_KERNEL_LAUNCHER(Im2ColNd, false, float); DEFINE_KERNEL_LAUNCHER(Im2ColNd, false, float);
DEFINE_KERNEL_LAUNCHER(Im2ColNd, false, double); DEFINE_KERNEL_LAUNCHER(Im2ColNd, false, double);
DEFINE_KERNEL_LAUNCHER(Col2ImNd, true, float16);
DEFINE_KERNEL_LAUNCHER(Col2ImNd, true, float); DEFINE_KERNEL_LAUNCHER(Col2ImNd, true, float);
DEFINE_KERNEL_LAUNCHER(Col2ImNd, true, double); DEFINE_KERNEL_LAUNCHER(Col2ImNd, true, double);
#undef DEFINE_KERNEL_LAUNCHER #undef DEFINE_KERNEL_LAUNCHER
......
...@@ -118,6 +118,7 @@ OPERATOR_SCHEMA(GroupNormGradient) ...@@ -118,6 +118,7 @@ OPERATOR_SCHEMA(GroupNormGradient)
.NumOutputs(3); .NumOutputs(3);
namespace { namespace {
class GradientMaker final : public GradientMakerBase { class GradientMaker final : public GradientMakerBase {
public: public:
GRADIENT_MAKER_CTOR(GradientMaker); GRADIENT_MAKER_CTOR(GradientMaker);
......
...@@ -31,7 +31,7 @@ void ConvOp<Context>::RunOnDevice() { ...@@ -31,7 +31,7 @@ void ConvOp<Context>::RunOnDevice() {
// You really need the CuDNN to help you -:) // You really need the CuDNN to help you -:)
LOG(FATAL) << "GroupConv(NHWC) is not supported."; LOG(FATAL) << "GroupConv(NHWC) is not supported.";
} }
DispatchHelper<dtypes::TypesBase<float, double>>::Call(this, Input(0)); DispatchHelper<dtypes::Floating>::Call(this, Input(0));
} }
template <class Context> template <class Context>
...@@ -72,7 +72,7 @@ void ConvGradientOp<Context>::RunOnDevice() { ...@@ -72,7 +72,7 @@ void ConvGradientOp<Context>::RunOnDevice() {
// You really need the CuDNN to help you -:) // You really need the CuDNN to help you -:)
LOG(FATAL) << "GroupConv(NHWC) is not supported."; LOG(FATAL) << "GroupConv(NHWC) is not supported.";
} }
DispatchHelper<dtypes::TypesBase<float, double>>::Call(this, Input(0)); DispatchHelper<dtypes::Floating>::Call(this, Input(0));
} }
DEPLOY_CPU_OPERATOR(Conv); DEPLOY_CPU_OPERATOR(Conv);
......
...@@ -24,7 +24,7 @@ void ConvOpBase<Context>::GetBaseArguments() { ...@@ -24,7 +24,7 @@ void ConvOpBase<Context>::GetBaseArguments() {
num_axes_ = (int64_t)kshape.size(); num_axes_ = (int64_t)kshape.size();
CHECK_GT(num_axes_, 0) << "\nInvalid size of <kernel_shape>."; CHECK_GT(num_axes_, 0) << "\nInvalid size of <kernel_shape>.";
for (int i = 0; i < num_axes_; i++) { for (int i = 0; i < num_axes_; ++i) {
kshape_.push_back(i < kshape.size() ? kshape[i] : kshape[0]); kshape_.push_back(i < kshape.size() ? kshape[i] : kshape[0]);
dilations_.push_back(i < dilations.size() ? dilations[i] : dilations[0]); dilations_.push_back(i < dilations.size() ? dilations[i] : dilations[0]);
strides_.push_back(i < strides.size() ? strides[i] : strides[0]); strides_.push_back(i < strides.size() ? strides[i] : strides[0]);
...@@ -32,7 +32,7 @@ void ConvOpBase<Context>::GetBaseArguments() { ...@@ -32,7 +32,7 @@ void ConvOpBase<Context>::GetBaseArguments() {
} }
if ((int64_t)pads.size() == (num_axes_ * 2)) { if ((int64_t)pads.size() == (num_axes_ * 2)) {
for (int i = 0; i < num_axes_; i++) { for (int i = 0; i < num_axes_; ++i) {
pads_end_.push_back(pads[num_axes_ + i]); pads_end_.push_back(pads[num_axes_ + i]);
} }
} else { } else {
...@@ -40,7 +40,7 @@ void ConvOpBase<Context>::GetBaseArguments() { ...@@ -40,7 +40,7 @@ void ConvOpBase<Context>::GetBaseArguments() {
} }
bool skip_flag = true; bool skip_flag = true;
for (int i = 0; i < num_axes_; i++) { for (int i = 0; i < num_axes_; ++i) {
skip_flag &= (kshape_[i] == 1 && strides_[i] == 1); skip_flag &= (kshape_[i] == 1 && strides_[i] == 1);
skip_flag &= (pads_begin_[i] == 0 && pads_end_[i] == 0); skip_flag &= (pads_begin_[i] == 0 && pads_end_[i] == 0);
if (!skip_flag) break; if (!skip_flag) break;
...@@ -54,7 +54,7 @@ void ConvOpBase<Context>::ComputeOutShape() { ...@@ -54,7 +54,7 @@ void ConvOpBase<Context>::ComputeOutShape() {
vec64_t X_dims = Input(0).dims(); vec64_t X_dims = Input(0).dims();
int64_t in_size, out_size, k_size, pad_size; int64_t in_size, out_size, k_size, pad_size;
if (!Transposed()) { if (!Transposed()) {
for (int i = 0; i < num_axes_; i++) { for (int i = 0; i < num_axes_; ++i) {
in_size = X_dims[axis_ + i]; in_size = X_dims[axis_ + i];
k_size = dilations_[i] * (kshape_[i] - 1) + 1; k_size = dilations_[i] * (kshape_[i] - 1) + 1;
if (!str::find(padding_, "SAME")) { // Explicit pads if (!str::find(padding_, "SAME")) { // Explicit pads
...@@ -74,7 +74,7 @@ void ConvOpBase<Context>::ComputeOutShape() { ...@@ -74,7 +74,7 @@ void ConvOpBase<Context>::ComputeOutShape() {
CHECK(num_output_padding == 0 || num_output_padding == num_axes_) CHECK(num_output_padding == 0 || num_output_padding == num_axes_)
<< "\nExcepted 0 or " << num_axes_ << " ints for <output_padding>."; << "\nExcepted 0 or " << num_axes_ << " ints for <output_padding>.";
if (!str::find(padding_, "SAME")) { // Explicit pads if (!str::find(padding_, "SAME")) { // Explicit pads
for (int i = 0; i < num_axes_; i++) { for (int i = 0; i < num_axes_; ++i) {
in_size = X_dims[axis_ + i]; in_size = X_dims[axis_ + i];
k_size = dilations_[i] * (kshape_[i] - 1) + 1; k_size = dilations_[i] * (kshape_[i] - 1) + 1;
pad_size = pads_begin_[i] + pads_end_[i]; pad_size = pads_begin_[i] + pads_end_[i];
...@@ -88,7 +88,7 @@ void ConvOpBase<Context>::ComputeOutShape() { ...@@ -88,7 +88,7 @@ void ConvOpBase<Context>::ComputeOutShape() {
output_shape(0, &num_output_shape); output_shape(0, &num_output_shape);
CHECK(num_output_shape == num_axes_) CHECK(num_output_shape == num_axes_)
<< "\nExcepted " << num_axes_ << " ints for <output_shape>."; << "\nExcepted " << num_axes_ << " ints for <output_shape>.";
for (int i = 0; i < num_axes_; i++) { for (int i = 0; i < num_axes_; ++i) {
in_size = X_dims[axis_ + i]; in_size = X_dims[axis_ + i];
k_size = dilations_[i] * (kshape_[i] - 1) + 1; k_size = dilations_[i] * (kshape_[i] - 1) + 1;
out_size = output_shape(i); out_size = output_shape(i);
...@@ -118,7 +118,7 @@ void ConvOpBase<Context>::Reshape(bool backward) { ...@@ -118,7 +118,7 @@ void ConvOpBase<Context>::Reshape(bool backward) {
if (out_channels_ <= 0) { if (out_channels_ <= 0) {
// Infer the output channels from the weights shape // Infer the output channels from the weights shape
out_channels_ = W.count() / (in_channels_ / group_); out_channels_ = W.count() / (in_channels_ / group_);
for (int i = 0; i < num_axes_; i++) { for (int i = 0; i < num_axes_; ++i) {
out_channels_ /= kshape_[i]; out_channels_ /= kshape_[i];
} }
CHECK_GT(out_channels_, 0) << "\nFailed to infer the out channels " CHECK_GT(out_channels_, 0) << "\nFailed to infer the out channels "
...@@ -136,7 +136,7 @@ void ConvOpBase<Context>::Reshape(bool backward) { ...@@ -136,7 +136,7 @@ void ConvOpBase<Context>::Reshape(bool backward) {
// Weight shape is assumed as NCHW format // Weight shape is assumed as NCHW format
// whatever to compute the fans correctly // whatever to compute the fans correctly
w_shape_ = {conv_out_channels_, conv_in_channels_ / group_}; w_shape_ = {conv_out_channels_, conv_in_channels_ / group_};
for (int i = 0; i < num_axes_; i++) { for (int i = 0; i < num_axes_; ++i) {
w_shape_.push_back(kshape_[i]); w_shape_.push_back(kshape_[i]);
} }
b_shape_ = {out_channels_}; b_shape_ = {out_channels_};
...@@ -151,11 +151,11 @@ void ConvOpBase<Context>::Reshape(bool backward) { ...@@ -151,11 +151,11 @@ void ConvOpBase<Context>::Reshape(bool backward) {
vec64_t Y_dims{X.dim(0)}; vec64_t Y_dims{X.dim(0)};
if (data_format() == "NCHW") { if (data_format() == "NCHW") {
Y_dims.push_back(out_channels_); Y_dims.push_back(out_channels_);
for (int i = 0; i < num_axes_; i++) { for (int i = 0; i < num_axes_; ++i) {
Y_dims.push_back(out_shape_[i]); Y_dims.push_back(out_shape_[i]);
} }
} else if (data_format() == "NHWC") { } else if (data_format() == "NHWC") {
for (int i = 0; i < num_axes_; i++) { for (int i = 0; i < num_axes_; ++i) {
Y_dims.push_back(out_shape_[i]); Y_dims.push_back(out_shape_[i]);
} }
Y_dims.push_back(out_channels_); Y_dims.push_back(out_channels_);
...@@ -185,7 +185,7 @@ void ConvOpBase<Context>::Reshape(bool backward) { ...@@ -185,7 +185,7 @@ void ConvOpBase<Context>::Reshape(bool backward) {
X_stride_ = X.stride(0); X_stride_ = X.stride(0);
Y_stride_ = Y_ref->stride(0); Y_stride_ = Y_ref->stride(0);
kernel_dim_ = conv_in_channels_ / group_; kernel_dim_ = conv_in_channels_ / group_;
for (int i = 0; i < num_axes_; i++) { for (int i = 0; i < num_axes_; ++i) {
kernel_dim_ *= kshape_[i]; kernel_dim_ *= kshape_[i];
} }
col_stride_ = kernel_dim_ * conv_out_dim_; col_stride_ = kernel_dim_ * conv_out_dim_;
...@@ -194,7 +194,7 @@ void ConvOpBase<Context>::Reshape(bool backward) { ...@@ -194,7 +194,7 @@ void ConvOpBase<Context>::Reshape(bool backward) {
// Compute im2col arguments // Compute im2col arguments
in_shape_.clear(); in_shape_.clear();
for (int i = 0; i < num_axes_; i++) { for (int i = 0; i < num_axes_; ++i) {
if (Transposed()) { if (Transposed()) {
in_shape_.push_back(Y_ref->dim(axis_ + i)); in_shape_.push_back(Y_ref->dim(axis_ + i));
out_shape_[i] = X.dim(axis_ + i); out_shape_[i] = X.dim(axis_ + i);
...@@ -203,7 +203,7 @@ void ConvOpBase<Context>::Reshape(bool backward) { ...@@ -203,7 +203,7 @@ void ConvOpBase<Context>::Reshape(bool backward) {
} }
} }
col_dim_ = kernel_dim_ * group_; col_dim_ = kernel_dim_ * group_;
for (int i = 0; i < num_axes_; i++) { for (int i = 0; i < num_axes_; ++i) {
col_dim_ *= out_shape_[i]; col_dim_ *= out_shape_[i];
} }
} }
......
...@@ -39,7 +39,9 @@ class ConvOpBase : public Operator<Context> { ...@@ -39,7 +39,9 @@ class ConvOpBase : public Operator<Context> {
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
protected: protected:
virtual bool HasBias() = 0; virtual bool HasBias() {
return false;
}
virtual bool Transposed() { virtual bool Transposed() {
return false; return false;
...@@ -47,9 +49,17 @@ class ConvOpBase : public Operator<Context> { ...@@ -47,9 +49,17 @@ class ConvOpBase : public Operator<Context> {
void GetBaseArguments(); void GetBaseArguments();
void ComputeOutShape();
void Reshape(bool backward = false); void Reshape(bool backward = false);
template <typename T> template <typename T>
void Im2Col(const T* im, T* col);
template <typename T>
void Col2Im(const T* col, T* im);
template <typename T>
void WeightedX(const T* x, const T* w, T* y); void WeightedX(const T* x, const T* w, T* y);
template <typename T> template <typename T>
...@@ -74,22 +84,14 @@ class ConvOpBase : public Operator<Context> { ...@@ -74,22 +84,14 @@ class ConvOpBase : public Operator<Context> {
int64_t axis_, num_axes_; int64_t axis_, num_axes_;
int64_t in_channels_, out_channels_; int64_t in_channels_, out_channels_;
int64_t conv_in_channels_, conv_out_channels_; int64_t conv_in_channels_, conv_out_channels_;
int64_t kernel_dim_;
int64_t X_stride_, W_stride_, Y_stride_; int64_t X_stride_, W_stride_, Y_stride_;
DECLARE_OP_REPEATED_ARG(int64_t, output_shape); DECLARE_OP_REPEATED_ARG(int64_t, output_shape);
DECLARE_OP_REPEATED_ARG(int64_t, output_padding); DECLARE_OP_REPEATED_ARG(int64_t, output_padding);
private: private:
void ComputeOutShape();
template <typename T>
void Im2Col(const T* im, T* col);
template <typename T>
void Col2Im(const T* col, T* im);
int64_t skip_im2col_; int64_t skip_im2col_;
int64_t kernel_dim_;
int64_t col_dim_, col_stride_; int64_t col_dim_, col_stride_;
int64_t out_dim_, conv_out_dim_; int64_t out_dim_, conv_out_dim_;
int64_t Y_stride1_; int64_t Y_stride1_;
......
...@@ -62,6 +62,9 @@ template <class Context> ...@@ -62,6 +62,9 @@ template <class Context>
template <typename T> template <typename T>
void ConvOpBase<Context>::Col2Im(const T* col, T* im) { void ConvOpBase<Context>::Col2Im(const T* col, T* im) {
if (num_axes_ == 1 || num_axes_ == 2) { if (num_axes_ == 1 || num_axes_ == 2) {
// std::cout << conv_in_channels_ << std::endl;
// std::cout << in_shape_[0] << " " << in_shape_[1] << std::endl;
// std::cout << out_shape_[0] << " " << out_shape_[1] << std::endl;
kernels::Col2Im2d( kernels::Col2Im2d(
conv_in_channels_, conv_in_channels_,
in_shape_[0], in_shape_[0],
......
...@@ -31,7 +31,7 @@ void ConvTransposeOp<Context>::RunOnDevice() { ...@@ -31,7 +31,7 @@ void ConvTransposeOp<Context>::RunOnDevice() {
// You really need the CuDNN to help you -:) // You really need the CuDNN to help you -:)
LOG(FATAL) << "GroupConv(NHWC) is not supported."; LOG(FATAL) << "GroupConv(NHWC) is not supported.";
} }
DispatchHelper<dtypes::TypesBase<float, double>>::Call(this, Input(0)); DispatchHelper<dtypes::Floating>::Call(this, Input(0));
} }
template <class Context> template <class Context>
...@@ -72,7 +72,7 @@ void ConvTransposeGradientOp<Context>::RunOnDevice() { ...@@ -72,7 +72,7 @@ void ConvTransposeGradientOp<Context>::RunOnDevice() {
// You really need the CuDNN to help you -:) // You really need the CuDNN to help you -:)
LOG(FATAL) << "GroupConv(NHWC) is not supported."; LOG(FATAL) << "GroupConv(NHWC) is not supported.";
} }
DispatchHelper<dtypes::TypesBase<float, double>>::Call(this, Input(0)); DispatchHelper<dtypes::Floating>::Call(this, Input(0));
} }
DEPLOY_CPU_OPERATOR(ConvTranspose); DEPLOY_CPU_OPERATOR(ConvTranspose);
......
#include "dragon/operators/vision/im2col_op.h"
#include "dragon/core/workspace.h"
#include "dragon/operators/vision/conv_op_impl.h"
namespace dragon {
template <class Context>
template <typename T>
void Im2ColOp<Context>::DoRunWithType() {
auto &X = Input(0), *Y = Output(0);
ConvOpBase<Context>::ComputeOutShape();
if (str::find(this->padding_, "SAME")) SET_INPUT_SPEC(0);
vec64_t Y_dims(X.dims());
auto im_channels_iter = Y_dims.begin() + 1;
if (data_format() == "NHWC") im_channels_iter += num_axes_;
conv_in_channels_ = *im_channels_iter;
in_shape_.clear();
for (int i = 0; i < num_axes_; ++i) {
in_shape_.push_back(X.dim(axis_ + i));
Y_dims[axis_ + i] = out_shape_[i];
*im_channels_iter *= kshape_[i];
}
auto* x = X.template data<T, Context>();
auto* y = Y->Reshape(Y_dims)->template mutable_data<T, Context>();
for (int i = 0; i < X.dim(0); ++i) {
this->Im2Col(x + i * X.stride(0), y + i * Y->stride(0));
}
}
template <class Context>
template <typename T>
void Col2ImOp<Context>::DoRunWithType() {
auto &X = Input(0), *Y = Output(0);
if (str::find(this->padding_, "SAME")) {
auto* Y_ref = workspace()->TryGetTensor(handle() + "/X_spec:0");
if (Y_ref != nullptr) {
// Get output shape from the spec if computed.
this->output_shape_.resize(num_axes_);
for (int i = 0; i < num_axes_; ++i) {
this->output_shape_[i] = Y_ref->dim(axis_ + i);
}
}
}
ConvOpBase<Context>::ComputeOutShape();
vec64_t Y_dims(X.dims());
auto im_channels_iter = Y_dims.begin() + 1;
if (data_format() == "NHWC") im_channels_iter += num_axes_;
in_shape_.clear();
for (int i = 0; i < num_axes_; ++i) {
in_shape_.push_back(X.dim(axis_ + i));
Y_dims[axis_ + i] = out_shape_[i];
*im_channels_iter /= kshape_[i];
}
conv_in_channels_ = *im_channels_iter;
std::swap(in_shape_, out_shape_);
auto* x = X.template data<T, Context>();
auto* y = Y->Reshape(Y_dims)->template mutable_data<T, Context>();
for (int i = 0; i < X.dim(0); ++i) {
this->Col2Im(x + i * X.stride(0), y + i * Y->stride(0));
}
}
DEPLOY_CPU_OPERATOR(Im2Col);
REGISTER_CPU_OPERATOR(Im2ColGradient, Col2ImOp<CPUContext>);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(Im2Col);
REGISTER_CUDA_OPERATOR(Im2ColGradient, Col2ImOp<CUDAContext>);
#endif
DEPLOY_CPU_OPERATOR(Col2Im);
REGISTER_CPU_OPERATOR(Col2ImGradient, Im2ColOp<CPUContext>);
#ifdef USE_CUDA
DEPLOY_CUDA_OPERATOR(Col2Im);
REGISTER_CUDA_OPERATOR(Col2ImGradient, Im2ColOp<CUDAContext>);
#endif
OPERATOR_SCHEMA(Im2Col).NumInputs(1).NumOutputs(1);
OPERATOR_SCHEMA(Im2ColGradient).NumInputs(1).NumOutputs(1);
OPERATOR_SCHEMA(Col2Im).NumInputs(1).NumOutputs(1);
OPERATOR_SCHEMA(Col2ImGradient).NumInputs(1).NumOutputs(1);
REGISTER_GRADIENT(Im2Col, SimpleGradientMaker);
REGISTER_GRADIENT(Col2Im, SimpleGradientMaker);
} // namespace dragon
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_VISION_IM2COL_OP_H_
#define DRAGON_OPERATORS_VISION_IM2COL_OP_H_
#include "dragon/operators/vision/conv_op_base.h"
#include "dragon/operators/vision/conv_op_cache.h"
namespace dragon {
template <class Context>
class Im2ColOp final : public ConvOpBase<Context> {
public:
Im2ColOp(const OperatorDef& def, Workspace* ws)
: ConvOpBase<Context>(def, ws) {
GetBaseArguments();
}
USE_OPERATOR_FUNCTIONS;
USE_CONV_FUNCTIONS;
void RunOnDevice() override {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
template <typename T>
void DoRunWithType();
};
template <class Context>
class Col2ImOp final : public ConvOpBase<Context> {
public:
Col2ImOp(const OperatorDef& def, Workspace* ws)
: ConvOpBase<Context>(def, ws) {
GetBaseArguments();
}
USE_OPERATOR_FUNCTIONS;
USE_CONV_FUNCTIONS;
void RunOnDevice() override {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
template <typename T>
void DoRunWithType();
protected:
bool Transposed() override {
return true;
}
};
} // namespace dragon
#endif // DRAGON_OPERATORS_VISION_IM2COL_OP_H_
...@@ -17,6 +17,7 @@ from __future__ import print_function as _print_function ...@@ -17,6 +17,7 @@ from __future__ import print_function as _print_function
from dragon.utils.vision import DataIterator from dragon.utils.vision import DataIterator
# Functions # Functions
from dragon.core.ops.vision_ops import extract_patches
from dragon.core.ops.vision_ops import resize from dragon.core.ops.vision_ops import resize
from dragon.core.ops.vision_ops import roi_align from dragon.core.ops.vision_ops import roi_align
from dragon.core.ops.vision_ops import roi_pool from dragon.core.ops.vision_ops import roi_pool
......
...@@ -125,7 +125,7 @@ def concat_args(**kwargs): ...@@ -125,7 +125,7 @@ def concat_args(**kwargs):
return {'axis': kwargs.get('axis', 0)} return {'axis': kwargs.get('axis', 0)}
@register(['Conv', 'DepthwiseConv']) @register(['Conv', 'DepthwiseConv', 'Im2Col'])
def conv_args(**kwargs): def conv_args(**kwargs):
return { return {
'kernel_shape': kwargs.get('kernel_shape', 1), 'kernel_shape': kwargs.get('kernel_shape', 1),
...@@ -138,7 +138,7 @@ def conv_args(**kwargs): ...@@ -138,7 +138,7 @@ def conv_args(**kwargs):
} }
@register('ConvTranspose') @register(['Col2Im', 'ConvTranspose'])
def conv_transpose_args(**kwargs): def conv_transpose_args(**kwargs):
return {**conv_args(**kwargs), **{ return {**conv_args(**kwargs), **{
'output_padding': kwargs.get('output_padding', None), 'output_padding': kwargs.get('output_padding', None),
......
...@@ -309,6 +309,37 @@ def expand_spec(args, inputs, outputs): ...@@ -309,6 +309,37 @@ def expand_spec(args, inputs, outputs):
return outputs return outputs
@register('Im2Col')
def im2col_spec(args, inputs, outputs):
outputs[0]._dtype = inputs[0].dtype
try:
out_shape = list(inputs[0].shape[:])
num_axes = len(out_shape) - 2
channel_axis = 1 if args['data_format'] == 'NCHW' else -1
spatial_axis = 2 if args['data_format'] == 'NCHW' else 1
for i in range(num_axes):
try:
k = args['kernel_shape'][i]
s = args['strides'][i]
d = args['dilations'][i]
in_size = out_shape[i + spatial_axis]
k_size = d * (k - 1) + 1
if 'SAME' not in args['padding']:
pad_size = args['pads'][i] + args['pads'][i + num_axes]
out_size = (in_size + pad_size - k_size) // s + 1
else:
out_size = (in_size + s - 1) // s
if out_shape[channel_axis] is not None:
out_shape[channel_axis] *= k
except (IndexError, TypeError):
out_size = None
out_shape[i + spatial_axis] = out_size
outputs[0]._shape = tuple(out_shape)
except (TypeError, IndexError):
outputs[0]._shape = None
return outputs
@register([ @register([
'Eye', 'Eye',
'Fill', 'Fill',
......
...@@ -48,7 +48,10 @@ class Workspace(object): ...@@ -48,7 +48,10 @@ class Workspace(object):
def release(self, handle): def release(self, handle):
"""Release a created handle.""" """Release a created handle."""
key, index = handle.rsplit('_', 1) key, index = handle.rsplit('_', 1)
try:
heapq.heappush(self._handles[key], int(index)) heapq.heappush(self._handles[key], int(index))
except AttributeError:
pass
def __init__(self): def __init__(self):
"""Create a ``Workspace``.""" """Create a ``Workspace``."""
......
...@@ -871,6 +871,89 @@ def depth_to_space(inputs, block_size, data_format='NCHW', **kwargs): ...@@ -871,6 +871,89 @@ def depth_to_space(inputs, block_size, data_format='NCHW', **kwargs):
@OpSchema.num_inputs(1) @OpSchema.num_inputs(1)
def extract_patches(
inputs,
kernel_shape=(3, 3),
strides=1,
pads=0,
dilations=1,
padding='VALID',
data_format='NCHW',
**kwargs
):
r"""Extract the sliding patches from input.
* If :attr:`data_format` is ``'NCHW'``, excepts input shape
:math:`(N, C, D1, D2, ...)`, and output shape is
:math:`(N, C \times \prod(\text{kernel\_shape}), D1_{\text{out}}, D2_{\text{out}}, ...)`.
* If :attr:`data_format` is ``'NHWC'``, excepts input shape
:math:`(N, D1, D2, ..., C)`, and output shape is
:math:`(N, D1_{\text{out}}, D2_{\text{out}}, ..., \prod(\text{kernel\_shape}) \times C)`.
* If :attr:`padding` is ``'VALID'``, :attr:`pads` controls the explicit padding size.
Otherwise, size are computed automatically use the given method.
Examples:
```python
for i in range(3):
ndim = i + 1
x = dragon.ones((1, 2) + (2,) * ndim)
y = dragon.vision.extract_patches(x, kernel_shape=(2,) * ndim)
assert y.shape == (1, 2 * (2 ** ndim)) + (1,) * ndim
```
Parameters
----------
inputs : dragon.Tensor
The input tensor.
kernel_shape : Sequence[int], optional, default=(3, 3)
The shape of sliding window.
strides : Union[int, Sequence[int]], optional, default=1
The stride of sliding window.
pads : Union[int, Sequence[int]], optional, default=0
The zero padding size.
dilations : Union[int, Sequence[int]], optional, default=1
The dilated rate of sliding window.
padding : str, optional, default='VALID'
``'VALID'``, ``'SAME'``, ``'SAME_UPPER'`` or ``'SAME_LOWER'``.
data_format : str, optional, default='NCHW'
``'NCHW'`` or ``'NHWC'``.
Returns
-------
dragon.Tensor
The output tensor.
"""
if padding not in ('VALID', 'SAME', 'SAME_UPPER', 'SAME_LOWER'):
raise ValueError('Unsupported padding algorithm: %s' % padding)
if data_format not in ('NCHW', 'NHWC'):
raise ValueError('Unsupported data format: %s' % data_format)
kernel_shape = nest.flatten(kernel_shape)
if context.executing_eagerly():
return OpLib.execute(
'Im2Col',
inputs,
kernel_shape=kernel_shape,
strides=_normalize_tuple(strides, len(kernel_shape)),
pads=_normalize_pads(pads, len(kernel_shape)),
dilations=_normalize_tuple(dilations, len(kernel_shape)),
padding=padding,
data_format=data_format,
)
return OpLib.add('Im2Col', inputs,
kernel_shape=kernel_shape,
strides=_normalize_tuple(strides, len(kernel_shape)),
pads=_normalize_pads(pads, len(kernel_shape)),
dilations=_normalize_tuple(dilations, len(kernel_shape)),
padding=padding,
data_format=data_format,
**kwargs)
@OpSchema.num_inputs(1)
def pool( def pool(
inputs, inputs,
kernel_shape, kernel_shape,
......
...@@ -217,6 +217,12 @@ class TestOpSpec(unittest.TestCase): ...@@ -217,6 +217,12 @@ class TestOpSpec(unittest.TestCase):
self.assertEqual(dragon.expand_dims( self.assertEqual(dragon.expand_dims(
self.sym3, axis=(0, 3, 5)).shape, (1, 1, None, 1)) self.sym3, axis=(0, 3, 5)).shape, (1, 1, None, 1))
def test_extract_patches(self):
with dragon.graph_mode():
self.assertEqual(dragon.vision.extract_patches(self.sym1).shape, None)
self.assertEqual(dragon.vision.extract_patches(self.sym4).shape,
(self.sym4.shape[0], None, None, None))
def test_init_ops(self): def test_init_ops(self):
init_funcs_v1 = [dragon.fill, init_funcs_v1 = [dragon.fill,
dragon.ones, dragon.ones,
......
...@@ -3894,6 +3894,51 @@ class TestVisionOps(OpTestCase): ...@@ -3894,6 +3894,51 @@ class TestVisionOps(OpTestCase):
with dragon.device('cuda'): with dragon.device('cuda'):
self.test_depth_to_space() self.test_depth_to_space()
def test_extract_patches(self):
entries = [((2, 1, 2, 2), (2, 2), 1, 1, 1, 'NCHW'),
((2, 2, 2, 1), (2, 2), 1, 1, 1, 'NHWC')]
results = [[[[[0., 0., 0.], [0., 0., 0.1], [0., 0.2, 0.3]],
[[0., 0., 0.], [0., 0.1, 0.], [0.2, 0.3, 0.]],
[[0., 0., 0.1], [0., 0.2, 0.3], [0., 0., 0.]],
[[0., 0.1, 0.], [0.2, 0.3, 0.], [0., 0., 0.]]],
[[[0., 0., 0.], [0., 0.4, 0.5], [0., 0.6, 0.7]],
[[0., 0., 0.], [0.4, 0.5, 0.], [0.6, 0.7, 0.]],
[[0., 0.4, 0.5], [0., 0.6, 0.7], [0., 0., 0.]],
[[0.4, 0.5, 0.], [0.6, 0.7, 0.], [0., 0., 0.]]]],
[[[[0., 0., 0., 0.], [0., 0., 0., 0.1], [0., 0., 0.1, 0.]],
[[0., 0., 0., 0.2], [0., 0.1, 0.2, 0.3], [0.1, 0., 0.3, 0.]],
[[0., 0.2, 0., 0.], [0.2, 0.3, 0., 0.], [0.3, 0., 0., 0.]]],
[[[0., 0., 0., 0.4], [0., 0., 0.4, 0.5], [0., 0., 0.5, 0.]],
[[0., 0.4, 0., 0.6], [0.4, 0.5, 0.6, 0.7], [0.5, 0., 0.7, 0.]],
[[0., 0.6, 0., 0.], [0.6, 0.7, 0., 0.], [0.7, 0., 0., 0.]]]]]
grads = [[[[[6.2, 6.6], [7.4, 7.8]]], [[[20.6, 21.], [21.8, 22.2]]]],
[[[[3.8], [5.4]], [[8.6], [10.2]]], [[[18.2], [19.8]], [[23.], [24.6]]]]]
for execution in ('EAGER_MODE', 'GRAPH_MODE'):
with execution_context().mode(execution):
for (x_shape, kernel_shape, strides, pads, dilations, data_format), \
result, grad in zip(entries, results, grads):
data = arange(x_shape) * .1
x = new_tensor(data)
with dragon.GradientTape() as tape:
tape.watch(x)
y = dragon.vision.extract_patches(
x,
kernel_shape=kernel_shape,
strides=strides,
pads=pads,
dilations=dilations,
data_format=data_format,
)
data2 = arange(y.shape) * .1
dy = new_tensor(data2)
dx = tape.gradient(y, x, output_gradients=[dy])
self.assertEqual([y, dx], [np.array(result), np.array(grad)])
@unittest.skipIf(not TEST_CUDA, 'CUDA unavailable')
def test_extract_patches_cuda(self):
with dragon.device('cuda'):
self.test_extract_patches()
def test_pool1d(self): def test_pool1d(self):
entries = [((2, 2, 2), (2,), 2, 1, 'max', 'NCHW'), entries = [((2, 2, 2), (2,), 2, 1, 'max', 'NCHW'),
((2, 2, 2), (2,), 2, 1, 'avg', 'NCHW'), ((2, 2, 2), (2,), 2, 1, 'avg', 'NCHW'),
......
...@@ -828,6 +828,24 @@ class TestModules(OpTestCase): ...@@ -828,6 +828,24 @@ class TestModules(OpTestCase):
y, _ = m(x), repr(m) y, _ = m(x), repr(m)
self.assertEqual(y, np.tanh(data)) self.assertEqual(y, np.tanh(data))
def test_unfold(self):
entries = [((2, 1, 2, 2), 2, 1, 1, 1)]
results = [[[[0., 0., 0., 0., 0., 0.1, 0., 0.2, 0.3],
[0., 0., 0., 0., 0.1, 0., 0.2, 0.3, 0.],
[0., 0., 0.1, 0., 0.2, 0.3, 0., 0., 0.],
[0., 0.1, 0., 0.2, 0.3, 0., 0., 0., 0.]],
[[0., 0., 0., 0., 0.4, 0.5, 0., 0.6, 0.7],
[0., 0., 0., 0.4, 0.5, 0., 0.6, 0.7, 0.],
[0., 0.4, 0.5, 0., 0.6, 0.7, 0., 0., 0.],
[0.4, 0.5, 0., 0.6, 0.7, 0., 0., 0., 0.]]]]
for (x_shape, kernel_shape, strides, pads, dilations), \
result in zip(entries, results):
data = arange(x_shape) * .1
x = new_tensor(data)
m = torch.nn.Unfold(kernel_shape, dilations, pads, strides)
y, _ = m(x), repr(m)
self.assertEqual(y, np.array(result))
def test_upsample(self): def test_upsample(self):
entries = [((2, 2, 1, 1), (2, 2), 'nearest'), entries = [((2, 2, 1, 1), (2, 2), 'nearest'),
((2, 2, 1, 1), (2, 2), 'bilinear'), ((2, 2, 1, 1), (2, 2), 'bilinear'),
......
...@@ -54,6 +54,7 @@ from dragon.vm.torch.core.nn.modules.dropout import DropBlock2d ...@@ -54,6 +54,7 @@ from dragon.vm.torch.core.nn.modules.dropout import DropBlock2d
from dragon.vm.torch.core.nn.modules.dropout import Dropout from dragon.vm.torch.core.nn.modules.dropout import Dropout
from dragon.vm.torch.core.nn.modules.dropout import DropPath from dragon.vm.torch.core.nn.modules.dropout import DropPath
from dragon.vm.torch.core.nn.modules.flatten import Flatten from dragon.vm.torch.core.nn.modules.flatten import Flatten
from dragon.vm.torch.core.nn.modules.fold import Unfold
from dragon.vm.torch.core.nn.modules.linear import Identity from dragon.vm.torch.core.nn.modules.linear import Identity
from dragon.vm.torch.core.nn.modules.linear import Linear from dragon.vm.torch.core.nn.modules.linear import Linear
from dragon.vm.torch.core.nn.modules.loss import CTCLoss from dragon.vm.torch.core.nn.modules.loss import CTCLoss
......
...@@ -71,6 +71,7 @@ from dragon.vm.torch.core.nn.functional import smooth_l1_loss ...@@ -71,6 +71,7 @@ from dragon.vm.torch.core.nn.functional import smooth_l1_loss
from dragon.vm.torch.core.nn.functional import softmax from dragon.vm.torch.core.nn.functional import softmax
from dragon.vm.torch.core.nn.functional import sync_batch_norm from dragon.vm.torch.core.nn.functional import sync_batch_norm
from dragon.vm.torch.core.nn.functional import tanh from dragon.vm.torch.core.nn.functional import tanh
from dragon.vm.torch.core.nn.functional import unfold
from dragon.vm.torch.core.nn.functional import upsample from dragon.vm.torch.core.nn.functional import upsample
from dragon.vm.torch.core.nn.functional import upsample_bilinear from dragon.vm.torch.core.nn.functional import upsample_bilinear
from dragon.vm.torch.core.nn.functional import upsample_nearest from dragon.vm.torch.core.nn.functional import upsample_nearest
......
...@@ -2164,6 +2164,44 @@ def tanh(input, inplace=False): ...@@ -2164,6 +2164,44 @@ def tanh(input, inplace=False):
outputs=[input if inplace else None]) outputs=[input if inplace else None])
def unfold(input, kernel_size, dilation=1, padding=0, stride=1):
"""Extract the sliding blocks from input.
Parameters
----------
input : dragon.vm.torch.Tensor
The input tensor.
kernel_size : Union[int, Sequence[int]]
The size of sliding window.
dilation : Union[int, Sequence[int]], optional, default=1
The dilated rate of sliding window.
padding : Union[int, Sequence[int]], optional, default=0
The zero padding size.
stride : Union[int, Sequence[int]], optional, default=1
The stride of sliding window.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
See Also
--------
`torch.nn.Unfold(...)`_
"""
nd_util = utils._ntuple(input.ndimension() - 2)
out = FunctionLib.apply(
'Im2Col',
input.device,
[input],
kernel_shape=nd_util(kernel_size),
strides=nd_util(stride),
pads=nd_util(padding),
dilations=nd_util(dilation))
return out.flatten_(2)
def upsample( def upsample(
input, input,
size=None, size=None,
......
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Fold modules."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from dragon.vm.torch.core.nn import functional as F
from dragon.vm.torch.core.nn.modules.module import Module
class Unfold(Module):
r"""Extract the sliding blocks.
This module excepts the input size :math:`(N, C, D1, D2, ...)`,
and output size is :math:`(N, C \times \prod(\text{kernel\_size}), L)`,
where :math:`N` is the batch size, :math:`C` is the number of channels,
:math:`L` is :math:`\prod(D_{\text{out}})`.
Examples:
```python
m = torch.nn.Unfold(3, padding=1)
x = torch.ones(2, 2, 2, 2)
y = m(x)
```
See Also
--------
`torch.nn.functional.unfold(...)`_
"""
def __init__(self, kernel_size, dilation=1, padding=0, stride=1):
"""Create a ``Unfold`` module.
Parameters
----------
kernel_size : Union[int, Sequence[int]]
The size of sliding window.
dilation : Union[int, Sequence[int]], optional, default=1
The dilated rate of sliding convolution.
stride : Union[int, Sequence[int]], optional, default=1
The stride of sliding window.
padding : Union[int, Sequence[int]], optional, default=0
The zero padding size.
"""
super(Unfold, self).__init__()
self.kernel_size = kernel_size
self.dilation = dilation
self.padding = padding
self.stride = stride
def extra_repr(self):
return 'kernel_size={kernel_size}, ' \
'dilation={dilation}, ' \
'padding={padding}, ' \
'stride={stride}' \
.format(**self.__dict__)
def forward(self, input):
return F.unfold(input, self.kernel_size, self.dilation,
self.padding, self.stride)
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!