Commit 58284aa4 by Ting PAN

Refactor Vision Module

1 parent 771e3d5a
Showing with 3077 additions and 1727 deletions
......@@ -37,7 +37,7 @@ class CPUContext {
inline static void* New(size_t nbytes) {
void* data;
#ifdef WITH_CUDA_HOST_MEN
#ifdef WITH_CUDA_HOST_MEM
CUDA_CHECK(cudaMallocHost(&data, nbytes));
#else
data = malloc(nbytes);
......
......@@ -18,14 +18,14 @@ namespace dragon {
#define MAX_GPUS 8
/**************************************************************************
* cuXXX libraries wrapper "Context" as "Handle"
* it's well known that each "Context" binds to some "Devices" in OpenCL
* so, we must create different handles to associate different devices
* or the computations will be dispatched to the same GPU
* read more: http://docs.nvidia.com/cuda/cublas/, section 2.1.2
* also, "Handle" is thread safe
* it seems not necessary to create handles for different threads
/**************************************************************************
* cuXXX libraries wrapper "Context" as "Handle".
* It's well known that each "Context" binds to some "Devices" in OpenCL.
* So, we must create different handles to associate different devices or
the computations will be dispatched to the same GPU.
* Read more: http://docs.nvidia.com/cuda/cublas/, Sec 2.1.2.
* Also, "Handle" is thread safe,
it seems not necessary to create handles for different threads
*************************************************************************/
class CUDAObject {
......
......@@ -128,7 +128,7 @@ class Operator : public OperatorBase {
#ifndef WITH_MPI
return true;
#else
vector<int> allow_ranks = Operator::GetRepeatedArg<int>("mpi_rank");
vector<int> allow_ranks = Operator::GetRepeatedArg<int>("mpi_ranks");
if (allow_ranks.empty()) return true;
int cur_rank;
MPI_Comm_rank(MPI_COMM_WORLD, &cur_rank);
......
......@@ -105,7 +105,7 @@ class Tensor {
MixedMemory* memory() const { return own_mem_ ? memory_.get() : ex_memory_; }
MixedMemory::State memory_state() const {
MixedMemory* mem = memory();
CHECK(mem) << "Memory access before allowcating.";
CHECK(mem) << "\nMemory access before allowcating.";
return memory()->state();
}
......
......@@ -19,8 +19,7 @@ class BiasAddOp : public Operator<Context> {
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {}
void RunOnDevice() override;
template <typename T> void NCHWRunWithType();
template <typename T> void NHWCRunWithType();
template <typename T> void RunWithType();
protected:
TIndex outer_dim, dim, inner_dim;
......@@ -36,8 +35,7 @@ class BiasAddGradientOp final : public Operator<Context> {
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {}
void RunOnDevice() override;
template <typename T> void NCHWRunWithType();
template <typename T> void NHWCRunWithType();
template <typename T> void RunWithType();
protected:
int outer_dim, dim, inner_dim;
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_MISC_IMAGE_DATA_OP_H_
#define DRAGON_OPERATORS_MISC_IMAGE_DATA_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class ImageDataOp final : public Operator<Context> {
public:
ImageDataOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
dtype(OperatorBase::GetSingleArg<string>("dtype", "FLOAT32")),
mean_values(OperatorBase::GetRepeatedArg<float>("mean_values")),
std_values(OperatorBase::GetRepeatedArg<float>("std_values")),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {
if (mean_values.size() > 0) {
CHECK_EQ((int)mean_values.size(), 3)
<< "The mean values should be a list with length 3.";
mean.Reshape(vector<TIndex>(1, 3));
for (int i = 0; i < 3; i++)
mean.mutable_data<float, CPUContext>()[i] = mean_values[i];
}
if (std_values.size() > 0) {
CHECK_EQ((int)std_values.size(), 3)
<< "The std values should be a list with length 3.";
std.Reshape(vector<TIndex>(1, 3));
for (int i = 0; i < 3; i++)
std.mutable_data<float, CPUContext>()[i] = std_values[i];
}
}
void RunOnDevice() override;
template <typename Tx, typename Ty> void RunWithType();
protected:
string dtype, data_format;
vector<float> mean_values, std_values;
TIndex n, c, h, w;
Tensor mean, std;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_MISC_IMAGE_DATA_OP_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_MISC_MEMORY_DATA_OP_H_
#define DRAGON_OPERATORS_MISC_MEMORY_DATA_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class MemoryDataOp final : public Operator<Context> {
public:
MemoryDataOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws) {
int DATA_TYPE = OperatorBase::GetSingleArg<int>("dtype", 1);
data_type = TensorProto_DataType(DATA_TYPE);
}
void RunOnDevice() override;
template <typename Tx, typename Ty> void RunWithType();
protected:
TensorProto_DataType data_type;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_MISC_MEMORY_DATA_OP_H_
\ No newline at end of file
......@@ -19,8 +19,9 @@ class ModelMPIBase : public Operator<Context> {
public:
ModelMPIBase(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
comm((MPI_Comm)OperatorBase::GetSingleArg<int>("comm", 0)),
group((MPI_Group)OperatorBase::GetSingleArg<int>("group", 0)) {
comm((MPI_Comm)OperatorBase::GetSingleArg<int64_t>("comm", 0)),
group((MPI_Group)OperatorBase::GetSingleArg<int64_t>("group", 0)),
dtype(OperatorBase::GetSingleArg<string>("dtype", "FLOAT32")) {
if (comm == MPI_COMM_NULL) return;
MPI_Comm_size(MPI_COMM_WORLD, &world_size);
......@@ -36,11 +37,18 @@ class ModelMPIBase : public Operator<Context> {
CHECK(comm_root != MPI_UNDEFINED) << "MPI root is not included in layer group.";
}
MPI_Datatype mpi_dtype() {
if (dtype == "FLOAT32") return MPI_FLOAT;
else LOG(FATAL) << "Unsupported input type: " << dtype;
return MPI_DATATYPE_NULL;
}
protected:
MPI_Comm comm;
MPI_Group group;
int comm_size, comm_rank, comm_root;
int world_size, world_rank;
string dtype;
};
} // namespace dragon
......
......@@ -19,26 +19,37 @@ class BilinearResizeOp : public Operator<Context> {
static_dsize(OperatorBase::GetRepeatedArg<int>("static_dsize")),
dynamic_dsize(OperatorBase::GetRepeatedArg<string>("dynamic_dsize")),
fy(OperatorBase::GetSingleArg<float>("fy", -1.0)),
fx(OperatorBase::GetSingleArg<float>("fx", -1.0)) {}
fx(OperatorBase::GetSingleArg<float>("fx", -1.0)),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {
if (data_format == "NCHW") spatial_axis = 2;
else if (data_format == "NHWC") spatial_axis = 1;
else LOG(FATAL) << "Unknown data format: " << data_format;
}
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
vector<int> static_dsize;
vector<string> dynamic_dsize;
float fy, fx;
string data_format;
TIndex n, c, h, w, out_h, out_w, spatial_axis;
vector<TIndex> dims;
float h_scale, w_scale, fy, fx;
};
template <class Context>
class BilinearResizeGradientOp : public Operator<Context> {
public:
BilinearResizeGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws) {}
: Operator<Context>(op_def, ws),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {}
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
string data_format;
TIndex n, c, h, w, out_h, out_w;
};
} // namespace dragon
......
......@@ -12,23 +12,28 @@
namespace dragon {
template <class Context>
class ConvOp : public ConvOpBase<Context> {
class Conv2dOp : public ConvOpBase<Context> {
public:
ConvOp(const OperatorDef& def, Workspace* ws)
: ConvOpBase<Context>(def, ws) {}
Conv2dOp(const OperatorDef& def, Workspace* ws)
: ConvOpBase<Context>(def, ws) {
this->num_spatial_axes = 2;
Setup();
}
void ComputeOutputShape() override;
bool ReverseDimensions() override { return false; }
virtual bool HasBias() { return InputSize() > 2; }
void RunOnDevice() override;
template <typename T> void RunWithType();
};
template <class Context>
class ConvGradientOp : public ConvOp<Context> {
class Conv2dGradientOp : public Conv2dOp<Context> {
public:
ConvGradientOp(const OperatorDef& def, Workspace* ws)
: ConvOp<Context>(def, ws) {}
Conv2dGradientOp(const OperatorDef& def, Workspace* ws)
: Conv2dOp<Context>(def, ws) {}
bool HasBias() override { return output(2)->name() != "ignore"; }
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -39,10 +44,10 @@ class ConvGradientOp : public ConvOp<Context> {
#include "utils/cudnn_device.h"
template <class Context>
class CuDNNConvOp : public ConvOp<Context> {
class CuDNNConv2dOp : public Conv2dOp<Context> {
public:
CuDNNConvOp(const OperatorDef& def, Workspace* ws)
: ConvOp<Context>(def, ws) {
CuDNNConv2dOp(const OperatorDef& def, Workspace* ws)
: Conv2dOp<Context>(def, ws) {
handle = new cudnnHandle_t[this->group];
stream = new cudaStream_t[this->group];
ctx().SwitchToDevice();
......@@ -55,8 +60,10 @@ class CuDNNConvOp : public ConvOp<Context> {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
CUDNN_CHECK(cudnnCreateConvolutionDescriptor(&conv_desc));
if (InputSize() > 2)
CUDNN_CHECK(cudnnCreateTensorDescriptor(&bias_desc));
if (HasBias()) CUDNN_CHECK(cudnnCreateTensorDescriptor(&bias_desc));
if (this->data_format == "NCHW") format = CUDNN_TENSOR_NCHW;
else if (this->data_format == "NHWC") format = CUDNN_TENSOR_NHWC;
else LOG(FATAL) << "Unknown data format: " << this->data_format;
}
void RunOnDevice() override;
......@@ -65,19 +72,20 @@ class CuDNNConvOp : public ConvOp<Context> {
protected:
cudnnHandle_t* handle;
cudaStream_t* stream;
cudnnTensorFormat_t format;
cudnnConvolutionFwdAlgo_t fwd_algo;
cudnnTensorDescriptor_t input_desc, output_desc, bias_desc;
cudnnConvolutionDescriptor_t conv_desc;
cudnnFilterDescriptor_t filter_desc;
size_t workspace_fwd_data_size;
int bias_offset;
TIndex bias_offset;
};
template <class Context>
class CuDNNConvGradientOp : public ConvGradientOp<Context> {
class CuDNNConv2dGradientOp : public Conv2dGradientOp<Context> {
public:
CuDNNConvGradientOp(const OperatorDef& def, Workspace* ws)
: ConvGradientOp<Context>(def, ws) {
CuDNNConv2dGradientOp(const OperatorDef& def, Workspace* ws)
: Conv2dGradientOp<Context>(def, ws) {
handle = new cudnnHandle_t[this->group * 3];
stream = new cudaStream_t[this->group * 3];
for (int g = 0; g < this->group * 3; g++) {
......@@ -89,8 +97,10 @@ class CuDNNConvGradientOp : public ConvGradientOp<Context> {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
CUDNN_CHECK(cudnnCreateConvolutionDescriptor(&conv_desc));
if (InputSize() > 2)
CUDNN_CHECK(cudnnCreateTensorDescriptor(&bias_desc));
if (HasBias()) CUDNN_CHECK(cudnnCreateTensorDescriptor(&bias_desc));
if (this->data_format == "NCHW") format = CUDNN_TENSOR_NCHW;
else if (this->data_format == "NHWC") format = CUDNN_TENSOR_NHWC;
else LOG(FATAL) << "Unknown data format: " << this->data_format;
}
void RunOnDevice() override;
......@@ -99,6 +109,7 @@ class CuDNNConvGradientOp : public ConvGradientOp<Context> {
protected:
cudnnHandle_t* handle;
cudaStream_t* stream;
cudnnTensorFormat_t format;
cudnnConvolutionBwdFilterAlgo_t bwd_filter_algo;
cudnnConvolutionBwdDataAlgo_t bwd_data_algo;
cudnnTensorDescriptor_t input_desc, output_desc, bias_desc;
......
......@@ -18,53 +18,38 @@ class ConvOpBase : public Operator<Context> {
public:
ConvOpBase(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
num_output(OperatorBase::GetSingleArg<int>("num_output", 1)),
group(OperatorBase::GetSingleArg<int>("group", 1)) {
channel_axis = 1, num_spatial_axes = 2; // Conv2D support only Now
vector<TIndex> spatial_shape(1, num_spatial_axes);
vector<int> ks = OperatorBase::GetRepeatedArg<int>("kernel_size");
for (int i = 0; i < num_spatial_axes; i++)
kernel_size.push_back(i < ks.size() ? ks[i]: ks[0]);
vector<int> s = OperatorBase::GetRepeatedArg<int>("stride");
for (int i = 0; i < num_spatial_axes; i++)
stride.push_back(i < s.size() ? s[i] : s[0]);
vector<int> p = OperatorBase::GetRepeatedArg<int>("pad");
for (int i = 0; i < num_spatial_axes; i++)
pad.push_back(i < p.size() ? p[i] : p[0]);
vector<int> d = OperatorBase::GetRepeatedArg<int>("dilation");
for (int i = 0; i < num_spatial_axes; i++)
dilation.push_back(i < d.size() ? d[i] : d[0]);
is_1x1 = true;
for (int i = 0; i < num_spatial_axes; i++) {
is_1x1 &= (kernel_size[i] == 1 &&
stride[i] == 1 &&
pad[i] == 0);
if (!is_1x1) break;
}
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")),
padding(OperatorBase::GetSingleArg<string>("padding", "VALID")),
num_output(OperatorBase::GetSingleArg<int>("num_output", 1)),
group(OperatorBase::GetSingleArg<int>("group", 1)),
static_dsize(OperatorBase::GetRepeatedArg<int>("static_dsize")),
dynamic_dsize(OperatorBase::GetRepeatedArg<string>("dynamic_dsize")) {
if (data_format == "NCHW") spatial_axis = 2;
else if (data_format == "NHWC") spatial_axis = 1;
else LOG(FATAL) << "Unknown data format: " << data_format;
num_spatial_axes = -1; // unknown
}
protected:
vector<TIndex> kernel_size, stride, pad, dilation;
vector<TIndex> input_shape, output_shape, bottom_shape, col_buffer_shape;
string data_format, padding;
vector<TIndex> input_shape, output_shape, bottom_shape, top_shape, col_shape;
vector<TIndex> weight_shape, bias_shape;
Tensor* col_buffer, *bias_multiplier;
TIndex num_output, group;
TIndex channel_axis, num_spatial_axes;
TIndex spatial_axis, num_spatial_axes;
TIndex channels, out_spatial_dim;
TIndex conv_in_channels, conv_out_channels;
TIndex conv_out_spatial_dim, kernel_dim;
TIndex col_offset, output_offset, weight_offset, x_offset, y_offset;
vector<int> static_dsize;
vector<string> dynamic_dsize;
bool is_1x1;
void Setup();
void Reshape();
void GradientReshape();
virtual void ComputeOutputShape() = 0;
virtual void ComputeOutputShape();
virtual bool ReverseDimensions() = 0;
template <typename T> void Wx(const T* x, const T* weights, T* y, bool skip_im2col = false);
......@@ -74,25 +59,33 @@ class ConvOpBase : public Operator<Context> {
template <typename T> void Db(const T* dy, T* db);
private:
template <typename T> void Im2Col(const T* im, T* col_buffer) {
kernel::Im2Col<T, Context>(conv_in_channels,
input_shape[0], input_shape[1],
kernel_size[0], kernel_size[1],
stride[0], stride[1],
pad[0], pad[1],
dilation[0], dilation[1],
im,
col_buffer);
template <typename T> void Im2Col(const T* im, T* col) {
if (input(0).ndim() == 4) {
kernel::Im2Col2d<T, Context>(conv_in_channels,
input_shape[0], input_shape[1],
output_shape[0], output_shape[1],
kernel_size[0], kernel_size[1],
stride[0], stride[1],
pad[0], pad[1],
dilation[0], dilation[1],
data_format,
im,
col);
} else LOG(FATAL) << "ConvNd has not been implemented yet";
}
template <typename T> void Col2Im(const T* col_buffer, T* im) {
kernel::Col2Im<T, Context>(conv_in_channels,
input_shape[0], input_shape[1],
kernel_size[0], kernel_size[1],
stride[0], stride[1],
pad[0], pad[1],
dilation[0], dilation[1],
col_buffer,
im);
template <typename T> void Col2Im(const T* col, T* im) {
if (input(0).ndim() == 4) {
kernel::Col2Im2d<T, Context>(conv_in_channels,
input_shape[0], input_shape[1],
output_shape[0], output_shape[1],
kernel_size[0], kernel_size[1],
stride[0], stride[1],
pad[0], pad[1],
dilation[0], dilation[1],
data_format,
col,
im);
} else LOG(FATAL) << "ConvNd has not been implemented yet";
}
};
......
......@@ -4,32 +4,40 @@
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_VISION_DECONV_OP_H_
#define DRAGON_OPERATORS_VISION_DECONV_OP_H_
#ifndef DRAGON_OPERATORS_VISION_CONV_TRANSPOSE_OP_H_
#define DRAGON_OPERATORS_VISION_CONV_TRANSPOSE_OP_H_
#include "operators/vision/conv_op_base.h"
namespace dragon {
template <class Context>
class DeConvOp: public ConvOpBase<Context> {
class Conv2dTransposeOp: public ConvOpBase<Context> {
public:
DeConvOp(const OperatorDef& def, Workspace* ws)
: ConvOpBase<Context>(def, ws) {}
Conv2dTransposeOp(const OperatorDef& def, Workspace* ws)
: ConvOpBase<Context>(def, ws) {
this->num_spatial_axes = 2;
Setup();
}
void ComputeOutputShape() override;
bool ReverseDimensions() override { return true; }
virtual bool HasBias() { return InputSize() > 2; }
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
vector<int> static_dsize;
vector<string> dynamic_dsize;
};
template <class Context>
class DeConvGradientOp : public DeConvOp<Context> {
class Conv2dTransposeGradientOp : public Conv2dTransposeOp<Context> {
public:
DeConvGradientOp(const OperatorDef& def, Workspace* ws) :
DeConvOp<Context>(def, ws) {}
Conv2dTransposeGradientOp(const OperatorDef& def, Workspace* ws)
: Conv2dTransposeOp<Context>(def, ws) {}
bool HasBias() override { return output(2)->name() != "ignore"; }
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -40,10 +48,10 @@ class DeConvGradientOp : public DeConvOp<Context> {
#include "utils/cudnn_device.h"
template <class Context>
class CuDNNDeConvOp : public DeConvOp<Context> {
class CuDNNConv2dTransposeOp : public Conv2dTransposeOp<Context> {
public:
CuDNNDeConvOp(const OperatorDef& def, Workspace* ws)
: DeConvOp<Context>(def, ws) {
CuDNNConv2dTransposeOp(const OperatorDef& def, Workspace* ws)
: Conv2dTransposeOp<Context>(def, ws) {
handle = new cudnnHandle_t[this->group];
stream = new cudaStream_t[this->group];
for (int g = 0; g < this->group; g++) {
......@@ -55,8 +63,10 @@ class CuDNNDeConvOp : public DeConvOp<Context> {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
CUDNN_CHECK(cudnnCreateConvolutionDescriptor(&conv_desc));
if (InputSize() > 2)
CUDNN_CHECK(cudnnCreateTensorDescriptor(&bias_desc));
if (HasBias()) CUDNN_CHECK(cudnnCreateTensorDescriptor(&bias_desc));
if (this->data_format == "NCHW") format = CUDNN_TENSOR_NCHW;
else if (this->data_format == "NHWC") format = CUDNN_TENSOR_NHWC;
else LOG(FATAL) << "Unknown data format: " << this->data_format;
}
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -64,6 +74,7 @@ class CuDNNDeConvOp : public DeConvOp<Context> {
protected:
cudnnHandle_t* handle;
cudaStream_t* stream;
cudnnTensorFormat_t format;
cudnnConvolutionBwdDataAlgo_t fwd_algo;
cudnnTensorDescriptor_t input_desc, output_desc, bias_desc;
cudnnConvolutionDescriptor_t conv_desc;
......@@ -73,10 +84,10 @@ class CuDNNDeConvOp : public DeConvOp<Context> {
};
template <class Context>
class CuDNNDeConvGradientOp : public DeConvGradientOp<Context> {
class CuDNNConv2dTransposeGradientOp : public Conv2dTransposeGradientOp<Context> {
public:
CuDNNDeConvGradientOp(const OperatorDef& def, Workspace* ws)
: DeConvGradientOp<Context>(def, ws) {
CuDNNConv2dTransposeGradientOp(const OperatorDef& def, Workspace* ws)
: Conv2dTransposeGradientOp<Context>(def, ws) {
handle = new cudnnHandle_t[this->group * 3];
stream = new cudaStream_t[this->group * 3];
for (int g = 0; g < this->group * 3; g++) {
......@@ -88,8 +99,10 @@ public:
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
CUDNN_CHECK(cudnnCreateConvolutionDescriptor(&conv_desc));
if (InputSize() > 2)
CUDNN_CHECK(cudnnCreateTensorDescriptor(&bias_desc));
if (HasBias()) CUDNN_CHECK(cudnnCreateTensorDescriptor(&bias_desc));
if (this->data_format == "NCHW") format = CUDNN_TENSOR_NCHW;
else if (this->data_format == "NHWC") format = CUDNN_TENSOR_NHWC;
else LOG(FATAL) << "Unknown data format: " << this->data_format;
}
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -97,6 +110,7 @@ public:
protected:
cudnnHandle_t* handle;
cudaStream_t* stream;
cudnnTensorFormat_t format;
cudnnConvolutionBwdFilterAlgo_t bwd_filter_algo;
cudnnConvolutionFwdAlgo_t bwd_data_algo;
cudnnTensorDescriptor_t input_desc, output_desc, bias_desc;
......@@ -110,4 +124,4 @@ public:
} // namespace dragon
#endif // DRAGON_OPERATORS_VISION_DECONV_OP_H_
\ No newline at end of file
#endif // DRAGON_OPERATORS_VISION_CONV_TRANSPOSE_OP_H_
\ No newline at end of file
......@@ -19,7 +19,12 @@ class NNResizeOp : public Operator<Context> {
static_dsize(OperatorBase::GetRepeatedArg<int>("static_dsize")),
dynamic_dsize(OperatorBase::GetRepeatedArg<string>("dynamic_dsize")),
fy(OperatorBase::GetSingleArg<float>("fy", -1.0)),
fx(OperatorBase::GetSingleArg<float>("fx", -1.0)) {}
fx(OperatorBase::GetSingleArg<float>("fx", -1.0)),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {
if (data_format == "NCHW") spatial_axis = 2;
else if (data_format == "NHWC") spatial_axis = 1;
else LOG(FATAL) << "Unknown data format: " << data_format;
}
void RunOnDevice() override;
template <typename T> void RunWithType();
......@@ -27,18 +32,24 @@ class NNResizeOp : public Operator<Context> {
protected:
vector<int> static_dsize;
vector<string> dynamic_dsize;
vector<TIndex> dims;
float h_scale, w_scale, fy, fx;
float fy, fx;
string data_format;
TIndex n, c, h, w, out_h, out_w, spatial_axis;
};
template <class Context>
class NNResizeGradientOp : public Operator<Context> {
public:
NNResizeGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws) {}
: Operator<Context>(op_def, ws),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {}
void RunOnDevice() override;
template <typename T> void RunWithType();
protected:
string data_format;
TIndex n, c, h, w, out_h, out_w;
};
} // namespace dragon
......
......@@ -11,14 +11,14 @@
namespace dragon {
enum PoolingMode { MAX_POOLING, AVG_POOLING };
template <class Context>
class PoolingOp: public Operator <Context> {
class Pooling2dOp: public Operator <Context> {
public:
PoolingOp(const OperatorDef& op_def, Workspace* ws)
Pooling2dOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
mode(PoolingMode(OperatorBase::GetSingleArg<int>("mode", MAX_POOLING))),
mode(OperatorBase::GetSingleArg<string>("mode", "MAX")),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")),
padding(OperatorBase::GetSingleArg<string>("padding", "VALID")),
global_pooling(OperatorBase::GetSingleArg<bool>("global_pooling", false)) {
vector<int> ks = OperatorBase::GetRepeatedArg<int>("kernel_size");
vector<int> s = OperatorBase::GetRepeatedArg<int>("stride");
......@@ -38,24 +38,25 @@ class PoolingOp: public Operator <Context> {
void Reshape();
void RunOnDevice() override;
template <typename T> void MaxRunWithType();
template <typename T> void AvgRunWithType();
template <typename T> void MAXRunWithType();
template <typename T> void AVGRunWithType();
protected:
vector<TIndex> kernel_size, stride, pad;
Tensor* mask;
PoolingMode mode;
TIndex num, channels, height, width;
TIndex pool_height, pool_width;
string mode, data_format, padding;
TIndex n, c, h, w, pool_h, pool_w;
bool global_pooling;
};
template <class Context>
class PoolingGradientOp: public Operator<Context> {
class Pooling2dGradientOp: public Operator<Context> {
public:
PoolingGradientOp(const OperatorDef& op_def, Workspace* ws)
Pooling2dGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
mode(PoolingMode(OperatorBase::GetSingleArg<int>("mode", MAX_POOLING))),
mode(OperatorBase::GetSingleArg<string>("mode", "MAX")),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")),
padding(OperatorBase::GetSingleArg<string>("padding", "VALID")),
global_pooling(OperatorBase::GetSingleArg<bool>("global_pooling", false)) {
vector<int> ks = OperatorBase::GetRepeatedArg<int>("kernel_size");
vector<int> s = OperatorBase::GetRepeatedArg<int>("stride");
......@@ -75,46 +76,36 @@ class PoolingGradientOp: public Operator<Context> {
void Reshape();
void RunOnDevice() override;
template <typename T> void MaxRunWithType();
template <typename T> void AvgRunWithType();
template <typename T> void MAXRunWithType();
template <typename T> void AVGRunWithType();
protected:
vector<TIndex> kernel_size, stride, pad;
Tensor* mask;
PoolingMode mode;
TIndex num, channels, height, width;
TIndex pool_height, pool_width;
string mode, data_format, padding;
TIndex n, c, h, w, pool_h, pool_w;
bool global_pooling;
};
#ifdef WITH_CUDNN
template <class Context>
class CuDNNPoolingOp final : public PoolingOp<Context> {
class CuDNNPooling2dOp final : public Pooling2dOp<Context> {
public:
CuDNNPoolingOp(const OperatorDef& op_def, Workspace* ws)
: PoolingOp<Context>(op_def, ws) {
CuDNNPooling2dOp(const OperatorDef& op_def, Workspace* ws)
: Pooling2dOp<Context>(op_def, ws) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
CUDNN_CHECK(cudnnCreatePoolingDescriptor(&pool_desc));
pool_mode = this->mode == MAX_POOLING ?
CUDNN_POOLING_MAX :
CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
#if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnSetPooling2dDescriptor(pool_desc,
pool_mode,
CUDNN_PROPAGATE_NAN,
this->kernel_size[0], this->kernel_size[1],
this->pad[0], this->pad[1],
this->stride[0], this->stride[1]));
if (this->mode == "MAX") {
#if CUDNN_VERSION_MIN(6,0,0)
pool_mode = CUDNN_POOLING_MAX_DETERMINISTIC;
#else
CUDNN_CHECK(cudnnSetPooling2dDescriptor_v4(pool_desc,
pool_mode,
CUDNN_PROPAGATE_NAN,
this->kernel_size[0], this->kernel_size[1],
this->pad[0], this->pad[1],
this->stride[0], this->stride[1]));
pool_mode = CUDNN_POOLING_MAX;
#endif
} else if (this->mode == "AVG") {
pool_mode = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
} else LOG(FATAL) << "Unsupported pooling mode: " << this->mode;
}
void RunOnDevice() override;
......@@ -122,34 +113,40 @@ class CuDNNPoolingOp final : public PoolingOp<Context> {
protected:
cudnnTensorDescriptor_t input_desc, output_desc;
cudnnPoolingDescriptor_t pool_desc;
cudnnPoolingMode_t pool_mode;
cudnnPoolingDescriptor_t pool_desc;
cudnnPoolingMode_t pool_mode;
};
template <class Context>
class CuDNNPoolingGradientOp final : public PoolingGradientOp<Context> {
class CuDNNPooling2dGradientOp final : public Pooling2dGradientOp<Context> {
public:
CuDNNPoolingGradientOp(const OperatorDef& op_def, Workspace* ws)
: PoolingGradientOp<Context>(op_def, ws) {
CuDNNPooling2dGradientOp(const OperatorDef& op_def, Workspace* ws)
: Pooling2dGradientOp<Context>(op_def, ws) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
CUDNN_CHECK(cudnnCreatePoolingDescriptor(&pool_desc));
pool_mode = this->mode == MAX_POOLING ?
CUDNN_POOLING_MAX :
CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
if (this->mode == "MAX") {
#if CUDNN_VERSION_MIN(6,0,0)
pool_mode = CUDNN_POOLING_MAX_DETERMINISTIC;
#else
pool_mode = CUDNN_POOLING_MAX;
#endif
} else if (this->mode == "AVG") {
pool_mode = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
} else LOG(FATAL) << "Unsupported pooling mode: " << this->mode;
#if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnSetPooling2dDescriptor(pool_desc,
CUDNN_CHECK(cudnnSetPooling2dDescriptor(pool_desc,
pool_mode,
CUDNN_PROPAGATE_NAN,
CUDNN_PROPAGATE_NAN,
this->kernel_size[0], this->kernel_size[1],
this->pad[0], this->pad[1],
this->pad[0], this->pad[1],
this->stride[0], this->stride[1]));
#else
CUDNN_CHECK(cudnnSetPooling2dDescriptor_v4(pool_desc,
CUDNN_CHECK(cudnnSetPooling2dDescriptor_v4(pool_desc,
pool_mode,
CUDNN_PROPAGATE_NAN,
CUDNN_PROPAGATE_NAN,
this->kernel_size[0], this->kernel_size[1],
this->pad[0], this->pad[1],
this->pad[0], this->pad[1],
this->stride[0], this->stride[1]));
#endif
}
......@@ -159,8 +156,8 @@ class CuDNNPoolingGradientOp final : public PoolingGradientOp<Context> {
protected:
cudnnTensorDescriptor_t input_desc, output_desc;
cudnnPoolingDescriptor_t pool_desc;
cudnnPoolingMode_t pool_mode;
cudnnPoolingDescriptor_t pool_desc;
cudnnPoolingMode_t pool_mode;
};
#endif // WITH_CUDNN
......
......@@ -61,11 +61,29 @@ template <typename T>
void cudnnSetTensorDesc(cudnnTensorDescriptor_t* desc, Tensor* tensor);
template <typename T>
void cudnnSetTensor4dDesc(cudnnTensorDescriptor_t* desc, const string& data_format, Tensor* tensor);
template <typename T>
void cudnnSetTensor5dDesc(cudnnTensorDescriptor_t* desc, const string& data_format, Tensor* tensor);
template <typename T>
void cudnnSetTensor3dDesc(cudnnTensorDescriptor_t* desc, const string& data_format, Tensor* tensor);
template <typename T>
void cudnnSetTensorDesc(cudnnTensorDescriptor_t* desc, const std::vector<int64_t>& dims);
template <typename T>
void cudnnSetTensorDesc(cudnnTensorDescriptor_t* desc,
const std::vector<int64_t>& dims,
void cudnnSetTensor4dDesc(cudnnTensorDescriptor_t* desc, const string& data_format, const std::vector<int64_t>& dims);
template <typename T>
void cudnnSetTensor5dDesc(cudnnTensorDescriptor_t* desc, const string& data_format, const std::vector<int64_t>& dims);
template <typename T>
void cudnnSetTensor3dDesc(cudnnTensorDescriptor_t* desc, const string& data_format, const std::vector<int64_t>& dims);
template <typename T>
void cudnnSetTensorDesc(cudnnTensorDescriptor_t* desc,
const std::vector<int64_t>& dims,
const std::vector<int64_t>& strides);
}
......
......@@ -54,7 +54,7 @@ class TruncatedNormalFiller final : public Filler < T, Context > {
public:
TruncatedNormalFiller(const TensorFiller& filler): Filler<T, Context>(filler) {}
void Fill(Tensor* tensor) override {
// implement of gpu is diffcult
// implement it on gpu is difficult
math::RandomTruncatedNormal<T, CPUContext>(tensor->count(),
filler().mean(),
filler().std(),
......
......@@ -148,12 +148,12 @@ void TanhGrad(const int count, const T* dy, const T* y, T* dx);
/******************** arithmetic.bias_add ********************/
template <typename T, class Context>
void BiasAdd(const int count,
const int outer_dim,
const int dim,
void BiasAdd(const int count,
const int outer_dim,
const int dim,
const int inner_dim,
const string& format,
const T* bias,
const string& data_format,
const T* bias,
const T* bias_multiplier,
T* y);
......@@ -270,16 +270,19 @@ void SparseSoftmaxFocalLossGrad(const int count,
Tensor* ignore,
T* dx);
/******************** misc.memory_data ********************/
/******************** misc.image_data ********************/
template <typename Tx, typename Ty, class Context>
void MemoryData(const int count,
const int num,
const int channels,
const int height,
const int width,
const Tx* x,
Ty* y);
void ImageData(const int count,
const int N,
const int C,
const int H,
const int W,
const float* mean_values,
const float* std_values,
const string& data_format,
const Tx* x,
Ty* y);
/******************** ndarray.arange ********************/
......@@ -369,7 +372,8 @@ void Crop1D(const int count,
const int inner_dim,
const int start,
const T* x,
T* y);
T* y,
Context* context);
template <typename T, class Context>
void Crop1DGrad(const int count,
......@@ -379,7 +383,8 @@ void Crop1DGrad(const int count,
const int start,
const int end,
const T* dy,
T* dx);
T* dx,
Context* context);
/******************** ndarray.pad ********************/
......@@ -391,7 +396,8 @@ void ConstPad1D(const int count,
const int pad_l,
const float value,
const T* x,
T* y);
T* y,
Context* context);
template <typename T, class Context>
void ReflectPad1D(const int count,
......@@ -400,7 +406,8 @@ void ReflectPad1D(const int count,
const int inner_dim,
const int pad_l,
const T* x,
T* y);
T* y,
Context* context);
template <typename T, class Context>
void EdgePad1D(const int count,
......@@ -409,7 +416,8 @@ void EdgePad1D(const int count,
const int inner_dim,
const int pad_l,
const T* x,
T* y);
T* y,
Context* context);
template <typename T, class Context>
void ConstPad1DGrad(const int count,
......@@ -418,7 +426,8 @@ void ConstPad1DGrad(const int count,
const int inner_dim,
const int pad_l,
const T* dy,
T* dx);
T* dx,
Context* context);
template <typename T, class Context>
void ReflectPad1DGrad(const int count,
......@@ -436,7 +445,8 @@ void EdgePad1DGrad(const int count,
const int inner_dim,
const int pad_l,
const T* dy,
T* dx);
T* dx,
Context* context);
/******************** ndarray.one_hot ********************/
......@@ -612,154 +622,168 @@ void RMSPropUpdate(const int count,
/******************** vision.bilinear_resize ********************/
template <typename T, class Context>
void BilinearResize(const int count,
const int num,
const int channels,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
const T* x,
T* y);
void BilinearResize(const int count,
const int N,
const int C,
const int H,
const int W,
const int out_h,
const int out_w,
const string& data_format,
const T* x,
T* y);
template <typename T, class Context>
void BilinearResizeGrad(const int count,
const int num,
const int channels,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
const T* dy,
const int N,
const int C,
const int H,
const int W,
const int out_h,
const int out_w,
const string& data_format,
const T* dy,
T* dx);
/******************** vision.conv ********************/
template <typename T, class Context>
void Im2Col(const int channels,
const int height,
const int width,
const int kernel_h,
const int kernel_w,
const int stride_h,
const int stride_w,
const int pad_h,
const int pad_w,
const int dilation_h,
const int dilation_w,
const T* im,
T* col);
template <typename T, class Context>
void Col2Im(const int channels,
const int height,
const int width,
const int kernel_h,
const int kernel_w,
const int stride_h,
const int stride_w,
const int pad_h,
const int pad_w,
const int dilation_h,
const int dilation_w,
const T* col,
T* im);
void Im2Col2d(const int C,
const int H,
const int W,
const int col_h,
const int col_w,
const int kernel_h,
const int kernel_w,
const int stride_h,
const int stride_w,
const int pad_h,
const int pad_w,
const int dilation_h,
const int dilation_w,
const string& data_format,
const T* im,
T* col);
template <typename T, class Context>
void Col2Im2d(const int C,
const int H,
const int W,
const int col_h,
const int col_w,
const int kernel_h,
const int kernel_w,
const int stride_h,
const int stride_w,
const int pad_h,
const int pad_w,
const int dilation_h,
const int dilation_w,
const string& data_format,
const T* col,
T* im);
/******************** vision.nn_resize ********************/
template <typename T, class Context>
void NNResize(const int count,
const int num,
const int channels,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
void NNResize(const int count,
const int N,
const int C,
const int H,
const int W,
const int out_h,
const int out_w,
const string& data_format,
const T* x,
T* y);
template <typename T, class Context>
void NNResizeGrad(const int count,
const int num,
const int channels,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
void NNResizeGrad(const int count,
const int N,
const int C,
const int H,
const int W,
const int out_h,
const int out_w,
const string& data_format,
const T* dy,
T* dx);
/******************** vision.pooling ********************/
template <typename T, class Context>
void MAXPooling(const int count,
const int num,
const int channels,
const int height,
const int width,
const int pool_height,
const int pool_width,
const int kernel_h,
const int kernel_w,
const int stride_h,
const int stride_w,
const int pad_h,
const int pad_w,
const T* x,
int* mask,
T* y);
template <typename T, class Context>
void AVEPooling(const int count,
const int num,
const int channels,
const int height,
const int width,
const int pool_height,
const int pool_width,
const int kernel_h,
const int kernel_w,
const int stride_h,
const int stride_w,
const int pad_h,
const int pad_w,
const T* x,
T* y);
template <typename T, class Context>
void MAXPoolingGrad(const int count,
const int num,
const int channels,
const int height,
const int width,
const int pool_height,
const int pool_width,
const int kernel_h,
const int kernel_w,
const int stride_h,
const int stride_w,
const int pad_h,
const int pad_w,
const T* dy,
const int* mask,
T* dx);
template <typename T, class Context>
void AVEPoolingGrad(const int count,
const int num,
const int channels,
const int height,
const int width,
const int pool_height,
const int pool_width,
const int kernel_h,
const int kernel_w,
const int stride_h,
const int stride_w,
const int pad_h,
const int pad_w,
const T* dy,
T* dx);
void MAXPooling2d(const int count,
const int N,
const int C,
const int H,
const int W,
const int pool_h,
const int pool_w,
const int kernel_h,
const int kernel_w,
const int stride_h,
const int stride_w,
const int pad_h,
const int pad_w,
const string& data_format,
const T* x,
int* mask,
T* y);
template <typename T, class Context>
void AVGPooling2d(const int count,
const int N,
const int C,
const int H,
const int W,
const int pool_h,
const int pool_w,
const int kernel_h,
const int kernel_w,
const int stride_h,
const int stride_w,
const int pad_h,
const int pad_w,
const string& data_format,
const T* x,
T* y);
template <typename T, class Context>
void MAXPooling2dGrad(const int count,
const int N,
const int C,
const int H,
const int W,
const int pool_h,
const int pool_w,
const int kernel_h,
const int kernel_w,
const int stride_h,
const int stride_w,
const int pad_h,
const int pad_w,
const string& data_format,
const T* dy,
const int* mask,
T* dx);
template <typename T, class Context>
void AVGPooling2dGrad(const int count,
const int N,
const int C,
const int H,
const int W,
const int pool_h,
const int pool_w,
const int kernel_h,
const int kernel_w,
const int stride_h,
const int stride_w,
const int pad_h,
const int pad_w,
const string& data_format,
const T* dy,
T* dx);
/******************** vision.roi_pooling ********************/
......
......@@ -809,7 +809,7 @@ class Tensor(object):
if self.shape is not None:
output.shape = input_shape[:]
output.shape.insert(axis, 1L)
output.shape.insert(axis, np.long(1))
return output
......
......@@ -35,6 +35,7 @@ else:
argument.name = key
if type(value) is float: argument.f = value
elif type(value) is int: argument.i = value
elif type(value) is long: argument.i = value
elif type(value) is np.int64: argument.i64 = int(value)
elif type(value) is str: argument.s = value
elif type(value) is unicode: argument.s = value
......@@ -42,6 +43,7 @@ else:
elif isinstance(value, Message): argument.s = value.SerializeToString()
elif all(type(v) is float for v in value): argument.floats.extend(value)
elif all(type(v) is int for v in value): argument.ints.extend(value)
elif all(type(v) is long for v in value): argument.ints.extend(value)
elif all(type(v) is str for v in value): argument.strings.extend(value)
elif all(type(v) is unicode or type(v) is str for v in value):
argument.strings.extend(value)
......
......@@ -269,7 +269,6 @@ def FeedTensor(tensor, ndarray, force_cpu=False, dtype=None):
format(preset_dtype, dtype))
auto_dtype = preset_dtype
ndarray = np.array(ndarray, dtype=auto_dtype)
if hasattr(tensor, 'shape'): tensor.shape = list(ndarray.shape)
FeedTensorCC(name, ndarray, _stringify_proto(dev))
......
......@@ -11,7 +11,7 @@ Data
List Brief
============== ========================================================================
`LMDBData`_ Prefetch Image data with LMDB database.
`MemoryData`_ Perform ``NHWC <-> NCHW``, ``Mean Subtraction`` and ``Type Converting``.
`ImageData`_ Process the images from 4D raw data.
============== ========================================================================
Initializer
......@@ -185,7 +185,7 @@ List Brief
.. _LMDBData: operators/data.html#dragon.operators.data.LMDBData
.. _MemoryData: operators/data.html#dragon.operators.data.MemoryData
.. _ImageData: operators/data.html#dragon.operators.data.ImageData
.. _Fill: operators/initializer.html#dragon.operators.initializer.Fill
.. _RandomUniform: operators/initializer.html#dragon.operators.initializer.RandomUniform
......
......@@ -74,25 +74,39 @@ def LMDBData(**kwargs):
return Run([], param_str=str(kwargs), nout=2, **arguments)
def MemoryData(inputs, dtype=np.float32, **kwargs):
"""Perform ``NHWC <-> NCHW``, ``Mean Subtraction`` and ``Type Converting``.
def ImageData(inputs, mean_values=None, std_values=None,
dtype='FLOAT32', data_format='NCHW', **kwargs):
"""Process the images from 4D raw data.
Note that we assume the data format of raw data is **NHWC**.
Parameters
----------
inputs : Tensor
The input tensor, with type of uint8 or float32.
dtype : np.float32 or np.float16
The dtype of output tensor.
The input tensor, with type of **uint8** or **float32**.
mean_values : list of float or None
The optional mean values to subtract.
std_values : list of float or None
The optional std values to divide.
dtype : str
The type of output. ``FLOAT32`` or ``FLOAT16``.
data_format : str
The data format of output. ``NCHW`` or ``NHWC``.
Returns
-------
Tensor
The post-processing Tensor.
The output tensor.
"""
arguments = ParseArguments(locals())
if dtype is np.float32: arguments['dtype'] = 1
elif dtype is np.float16: arguments['dtype'] = 12
else: raise TypeError('Unsupported data type.')
return Tensor.CreateOperator(nout=1, op_type='MemoryData', **arguments)
\ No newline at end of file
if mean_values is not None:
if len(mean_values) != 3:
raise ValueError('The length of mean values should be 3.')
arguments['mean_values'] = [float(v) for v in mean_values]
if std_values is not None:
if len(std_values) != 3:
raise ValueError('The length of std values should be 3.')
arguments['std_values'] = [float(v) for v in std_values]
return Tensor.CreateOperator(nout=1, op_type='ImageData', **arguments)
\ No newline at end of file
......@@ -18,7 +18,7 @@ def SparseSoftmaxCrossEntropy(inputs, axis=1, normalization='VALID', ignore_labe
axis : int
The axis of softmax function.
normalization : str
The normalization, ``UINT``, ``FULL``, ``VALID``, ``BATCH_SIZE`` or ``NONE``.
The normalization, ``UNIT``, ``FULL``, ``VALID``, ``BATCH_SIZE`` or ``NONE``.
ignore_label : tuple or list
The label id to ignore. Default is ``empty``.
......@@ -29,7 +29,7 @@ def SparseSoftmaxCrossEntropy(inputs, axis=1, normalization='VALID', ignore_labe
Notes
-----
Set the normalization to ``UINT`` will return unreduced losses.
Set the normalization to ``UNIT`` will return unreduced losses.
"""
CheckInputs(inputs, 2)
......@@ -56,7 +56,7 @@ def SigmoidCrossEntropy(inputs, normalization='FULL', **kwargs):
inputs : list of Tensor
The inputs, represent [input, labels].
normalization : str
The normalization, ``UINT``, ``FULL``, ``BATCH_SIZE`` or ``NONE``.
The normalization, ``UNIT``, ``FULL``, ``BATCH_SIZE`` or ``NONE``.
Returns
-------
......@@ -65,7 +65,7 @@ def SigmoidCrossEntropy(inputs, normalization='FULL', **kwargs):
Notes
-----
Set the normalization to ``UINT`` will return unreduced losses.
Set the normalization to ``UNIT`` will return unreduced losses.
"""
CheckInputs(inputs, 2)
......@@ -90,7 +90,7 @@ def SoftmaxCrossEntropy(inputs, axis=1, normalization='FULL', **kwargs):
axis : int
The axis of softmax function.
normalization : str
The normalization, ``UINT``, ``FULL``, ``BATCH_SIZE`` or ``NONE``.
The normalization, ``UNIT``, ``FULL``, ``BATCH_SIZE`` or ``NONE``.
Returns
-------
......@@ -99,7 +99,7 @@ def SoftmaxCrossEntropy(inputs, axis=1, normalization='FULL', **kwargs):
Notes
-----
Set the normalization to ``UINT`` will return unreduced losses.
Set the normalization to ``UNIT`` will return unreduced losses.
"""
CheckInputs(inputs, 2)
......@@ -213,13 +213,13 @@ def SparseSoftmaxFocalLoss(inputs, axis=1, normalization='VALID', ignore_labels=
axis : int
The axis of softmax function.
normalization : str
The normalization, ``UINT``, ``FULL``, ``VALID``, ``BATCH_SIZE`` or ``NONE``.
The normalization, ``UNIT``, ``FULL``, ``VALID``, ``BATCH_SIZE`` or ``NONE``.
ignore_label : tuple or list
The label id to ignore. Default is ``empty``.
alpha : float
The scale factor on the rare class. Default is ``0.5``.
gamma : float
The exponetial decay factor on the easy examples. Default is ``2.0``.
The exponential decay factor on the easy examples. Default is ``2.0``.
eps : float
The eps.
neg_id : int
......@@ -232,7 +232,7 @@ def SparseSoftmaxFocalLoss(inputs, axis=1, normalization='VALID', ignore_labels=
Notes
-----
Set the normalization to ``UINT`` will return unreduced losses.
Set the normalization to ``UNIT`` will return unreduced losses.
"""
CheckInputs(inputs, 2)
......
......@@ -80,7 +80,7 @@ def MPIGather(inputs, root, mpi_ranks=None, **kwargs):
if mpi_ranks is None:
num_nodes = mpi.Size()
mpi_rank = [i for i in xrange(0, num_nodes)]
mpi_ranks = [i for i in xrange(0, num_nodes)]
if not isinstance(mpi_ranks, list): mpi_ranks = [mpi_ranks]
comm, group = mpi.CreateGroup(root, incl=mpi_ranks)
......
......@@ -9,8 +9,9 @@ from six.moves import range as xrange
from . import *
def Conv2D(inputs, num_output, kernel_size,
stride=1, pad=0, dilation=1, group=1, **kwargs):
def Conv2d(inputs, num_output, kernel_size,
stride=1, pad=0, dilation=1, group=1,
padding='VALID', data_format='NCHW', **kwargs):
"""2D Convolution.
The number of inputs vary from ``2`` to ``3`` (Without or With ``bias``).
......@@ -19,6 +20,8 @@ def Conv2D(inputs, num_output, kernel_size,
|conv_output_dim|
Set ``padding`` to **VALID** will use the value of ``pad``.
Parameters
----------
inputs : list of Tensor
......@@ -35,21 +38,25 @@ def Conv2D(inputs, num_output, kernel_size,
The dilation multiple(s) of convolution. Default is ``1``.
group : int
The group size of convolution. Default is ``1``.
padding : str
The padding algorithm. ``VALID`` or ``SAME``.
data_format : str
The data format. ``NCHW`` or ``NHWC``.
Returns
-------
Tensor
The tensor of 2d convolution.
The output tensor.
Examples
--------
>>> input = Tensor().Variable()
>>> weights = Tensor().Normal(std=0.001)
>>> biases = Tensor().Constant(value=0)
>>> conv1 = Conv2D([input, weights, biases], num_output=64, kernel_size=3)
>>> conv1 = Conv2d([input, weights, biases], num_output=64, kernel_size=3)
>>> weights = Tensor().Gaussian(std=0.001)
>>> conv2 = Conv2D([conv1, weights], num_output=128, kernel_size=3, stride=1)
>>> conv2 = Conv2d([conv1, weights], num_output=128, kernel_size=3, stride=1)
"""
CheckInputs(inputs, 2, 3)
......@@ -63,7 +70,7 @@ def Conv2D(inputs, num_output, kernel_size,
if not isinstance(arguments['dilation'], list):
arguments['dilation'] = [arguments['dilation']]
output = Tensor.CreateOperator(nout=1, op_type='Conv', **arguments)
output = Tensor.CreateOperator(nout=1, op_type='Conv2d', **arguments)
if inputs[0].shape is not None:
output.shape = inputs[0].shape[:]
......@@ -83,8 +90,9 @@ def Conv2D(inputs, num_output, kernel_size,
return output
def Deconv2D(inputs, num_output, kernel_size,
stride=1, pad=0, dilation=1, group=1, **kwargs):
def Conv2dTranspose(inputs, num_output, kernel_size,
stride=1, pad=0, dilation=1, group=1, output_shape=None,
padding='VALID', data_format='NCHW', **kwargs):
"""2D Deconvolution.
The number of inputs vary from ``2`` to ``3`` (Without or With ``bias``).
......@@ -93,6 +101,10 @@ def Deconv2D(inputs, num_output, kernel_size,
|deconv_output_dim|
Set ``padding`` to **VALID** will use the value of ``pad``.
Provide ``output_shape`` if using **SAME** padding.
Parameters
----------
inputs : list of Tensor
......@@ -109,26 +121,46 @@ def Deconv2D(inputs, num_output, kernel_size,
The dilation multiple(s) of deconvolution. Default is ``1``.
group : int
The group size of deconvolution. Default is ``1``.
output_shape : list of int or None
The deterministic output shape for **SAME** padding.
padding : str
The padding algorithm. ``VALID`` or ``SAME``.
data_format : str
The data format. ``NCHW`` or ``NHWC``.
Returns
-------
Tensor
The tensor of 2d deconvolution.
The output tensor.
Examples
--------
>>> input = Tensor().Variable()
>>> weights = Tensor().Normal(std=0.001)
>>> biases = Tensor().Constant(value=0)
>>> deconv1 = Deconv2D([input, weights, biases], num_output=64, kernel_size=3)
>>> deconv1 = Conv2dTranspose([input, weights, biases], num_output=64, kernel_size=3)
>>> weights = Tensor().Gaussian(std=0.001)
>>> deconv2 = Deconv2D([conv1, weights], num_output=128, kernel_size=3, stride=1)
>>> deconv2 = Conv2dTranspose([deconv1, weights], num_output=128, kernel_size=3, stride=1)
"""
CheckInputs(inputs, 2, 3)
arguments = ParseArguments(locals())
arguments['output_shape'] = None
if output_shape is not None:
if not isinstance(output_shape, list):
raise TypeError('The output shape should be a list.')
if isinstance(output_shape[0], Tensor):
arguments['dynamic_dsize'] = []
arguments['extra_inputs'] = list(output_shape)
for dim in output_shape:
arguments['dynamic_dsize'].append(dim)
else:
arguments['static_dsize'] = []
for dim in output_shape:
arguments['static_dsize'].append(int(dim))
if not isinstance(arguments['kernel_size'], list):
arguments['kernel_size'] = [arguments['kernel_size']]
......@@ -141,44 +173,48 @@ def Deconv2D(inputs, num_output, kernel_size,
if not isinstance(arguments['dilation'], list):
arguments['dilation'] = [arguments['dilation']]
return Tensor.CreateOperator(nout=1, op_type='DeConv', **arguments)
return Tensor.CreateOperator(nout=1, op_type='Conv2dTranspose', **arguments)
def Pool2D(inputs, kernel_size, stride, pad=0,
mode='MAX_POOLING', global_pooling=False, **kwargs):
def Pool2d(inputs, kernel_size, stride, pad=0, padding='VALID',
mode='MAX', data_format='NCHW', global_pooling=False, **kwargs):
"""2D Pooling, MAX or AVG.
The spatial output dimension of pooling can be computed as follows:
|pooling_output_dim|
Set ``padding`` to **VALID** will use the value of ``pad``.
If use ``global_pooling``, the stride and pad will be set to ``1`` and ``0`` automatically.
Parameters
----------
inputs : Tensor
The tensor to down-sample.
The input tensor.
kernel_size : int or list
The kernel size(s) of pooling.
stride : int or list
The stride(s) of of pooling,
pad : int or list
The zero padding size(s) of pooling. Default is ``0``.
padding : str
The padding algorithm. ``VALID`` or ``SAME``.
mode : str
The mode, ``MAX_POOLING`` or ``AVG_POOLING``.
The mode, ``MAX`` or ``AVG``.
data_format : str
The data format, ``NCHW`` or ``NHWC``.
global_pooling : boolean
Whether to use global pooling.
Returns
-------
Tensor
The down-sampled tensor.
The output tensor.
"""
CheckInputs(inputs, 1)
arguments = ParseArguments(locals())
SUPPORT_MODES = {'MAX_POOLING': 0, 'AVG_POOLING': 1}
arguments['mode'] = SUPPORT_MODES[mode]
if not isinstance(arguments['kernel_size'], list):
arguments['kernel_size'] = [arguments['kernel_size']]
if not isinstance(arguments['stride'], list):
......@@ -186,10 +222,11 @@ def Pool2D(inputs, kernel_size, stride, pad=0,
if not isinstance(arguments['pad'], list):
arguments['pad'] = [arguments['pad']]
output = Tensor.CreateOperator(nout=1, op_type='Pooling', **arguments)
output = Tensor.CreateOperator(nout=1, op_type='Pooling2d', **arguments)
if inputs.shape is not None:
output.shape = inputs.shape[:]
axis = 2 if data_format == 'NCHW' else 1
for i in xrange(2):
k = arguments['kernel_size'][i] if i < len(arguments['kernel_size']) \
else arguments['kernel_size'][-1]
......@@ -197,10 +234,17 @@ def Pool2D(inputs, kernel_size, stride, pad=0,
else arguments['stride'][-1]
p = arguments['pad'][i] if i < len(arguments['pad']) \
else arguments['pad'][-1]
if padding == 'SAME':
input_size = output.shape[i + axis]
output_size = (input_size + s - 1) / float(s)
padding_needed = max(0, (output_size - 1) * s + k - input_size)
p_l = padding_needed / 2
p_r = padding_needed - p_l
p = min(p_l, p_r)
if not global_pooling:
output.shape[i + 2] = int(math.ceil(float(output.shape[i + 2] + 2 * p - k) / s) + 1)
output.shape[i + axis] = int(math.ceil(float(output.shape[i + axis] + 2 * p - k) / s) + 1)
else:
output.shape[i + 2] = 1
output.shape[i + axis] = 1
return output
......@@ -296,7 +340,7 @@ def LRN(inputs, local_size=5, alpha=0.0001, beta=0.75, k=2.0, mode='ACROSS_CHANN
return output
def NNResize(inputs, dsize, fy=-1.0, fx=-1.0, **kwargs):
def NNResize(inputs, dsize, fy=-1.0, fx=-1.0, data_format='NCHW', **kwargs):
"""Resize the image with Nearest-Neighbor method.
Set ``dsize`` to None if you want to use ``fy`` and ``fx``.
......@@ -306,16 +350,18 @@ def NNResize(inputs, dsize, fy=-1.0, fx=-1.0, **kwargs):
inputs : Tensor
The input tenosr.
dsize : tuple, list, Tensor or None
The dest output size.
The output size.
fy : float
The scale factor based on src height. Default is ``-1.0`` (Discarded).
fx : float
The scale factor based on src width. Default is ``-1.0`` (Discarded).
data_format : str
The data_format. ``NCHW`` or ``NHWC``.
Returns
-------
Tensor
The resized tensor.
The output tensor.
"""
CheckInputs(inputs, 1)
......@@ -337,7 +383,7 @@ def NNResize(inputs, dsize, fy=-1.0, fx=-1.0, **kwargs):
return output
def BilinearResize(inputs, dsize, fy=-1.0, fx=-1.0, **kwargs):
def BilinearResize(inputs, dsize, fy=-1.0, fx=-1.0, data_format='NCHW', **kwargs):
"""Resize the image with Bi-linear method.
Set ``dsize`` to None if you want to use ``fy`` and ``fx``.
......@@ -352,11 +398,13 @@ def BilinearResize(inputs, dsize, fy=-1.0, fx=-1.0, **kwargs):
The scale factor based on src height. Default is ``-1.0`` (Discarded).
fx : float
The scale factor based on src width. Default is ``-1.0`` (Discarded).
data_format : str
The data_format. ``NCHW`` or ``NHWC``.
Returns
-------
Tensor
The resized tensor.
The output tensor.
"""
CheckInputs(inputs, 1)
......@@ -383,7 +431,7 @@ def BiasAdd(inputs, data_format='NCHW', **kwargs):
Parameters
----------
inputs : Tensor
inputs : list of Tensor
The inputs, represent [input, bias].
data_format : str
The data format, ``NCHW`` or ``NHWC``.
......@@ -394,7 +442,7 @@ def BiasAdd(inputs, data_format='NCHW', **kwargs):
The bias-added tensor.
"""
CheckInputs(inputs, 1)
CheckInputs(inputs, 2)
arguments = ParseArguments(locals())
output = Tensor.CreateOperator(nout=1, op_type='BiasAdd', **arguments)
......
......@@ -20,7 +20,7 @@ from .operators import recurrent
# data
LMDBData = data.LMDBData
MemoryData = data.MemoryData
ImageData = data.ImageData
# init
Fill = init.Fill
......@@ -31,9 +31,10 @@ GlorotUniform = init.GlorotUniform
GlorotNormal = init.GlorotNormal
# vision
Conv2D = vision.Conv2D
Deconv2D = vision.Deconv2D
Pool2D = vision.Pool2D
Conv2d = vision.Conv2d
Conv2dTranspose = vision.Conv2dTranspose
Deconv2d = vision.Conv2dTranspose
Pool2d = vision.Pool2d
ROIPooling = vision.ROIPooling
ROIAlign = vision.ROIAlign
LRN = vision.LRN
......
......@@ -514,7 +514,7 @@ class NormalizeLayer(Layer):
scale = Tensor(LayerParameter.name + '@param0')
if param.HasField('scale_filler'):
self.Fill(scale, param, 'scale_filler')
else: scale.Contant(value=1.0)
else: scale.Constant(value=1.0)
self.scale_blobs = [{'data': scale, 'diff': Tensor(scale.name + '_grad')}]
self._blobs.extend(self.scale_blobs)
......
......@@ -48,22 +48,22 @@ class DataLayer(Layer):
super(DataLayer, self).__init__(LayerParameter)
param = LayerParameter.data_param
transformer_param = LayerParameter.transform_param
transform_param = LayerParameter.transform_param
parallel_param = LayerParameter.parallel_param
self._param = {'source': param.source,
'prefetch': param.prefetch,
'batch_size': param.batch_size,
'phase': {0: 'TRAIN', 1: 'TEST'}[int(LayerParameter.phase)],
'scale': transformer_param.scale,
'mirror': transformer_param.mirror,
'crop_size': transformer_param.crop_size,
'mean_values': [float(element) for element in transformer_param.mean_value],
'force_color': transformer_param.force_color,
'color_augmentation': transformer_param.color_augmentation,
'padding': transformer_param.padding,
'min_random_scale': transformer_param.min_random_scale,
'max_random_scale': transformer_param.max_random_scale,
'scale': transform_param.scale,
'mirror': transform_param.mirror,
'crop_size': transform_param.crop_size,
'mean_values': [float(element) for element in transform_param.mean_value],
'force_color': transform_param.force_color,
'color_augmentation': transform_param.color_augmentation,
'padding': transform_param.padding,
'min_random_scale': transform_param.min_random_scale,
'max_random_scale': transform_param.max_random_scale,
'shuffle': parallel_param.shuffle,
'node_step': parallel_param.node_step,
'partition': parallel_param.partition}
......@@ -76,20 +76,25 @@ class DataLayer(Layer):
class MemoryDataLayer(Layer):
"""The implementation of ``MemoryDataLayer``.
We extend it with ``FP16`` and ``NHWC <=> NCHW``.
We extend it with ``FP16`` and ``NHWC => NCHW``.
Parameters
----------
dtype : caffe_pb2.MemoryDataParameter.DataType
The dest data type. ``FLOAT32`` or ``FLOAT16``.
mean_value : list of float
The mean of each channel. Refer `TransformationParameter.mean_value`_.
"""
def __init__(self, LayerParameter):
super(MemoryDataLayer, self).__init__(LayerParameter)
param = LayerParameter.memory_data_param
import numpy as np
self._param = {'dtype': {0: np.float32, 1: np.float16}[param.dtype]}
transform_param = LayerParameter.transform_param
self._param = {'dtype': {0: 'FLOAT32', 1: 'FLOAT16'}[param.dtype]}
if len(transform_param.mean_value) > 0:
self._param['mean_values'] = \
[float(element) for element in transform_param.mean_value]
def Setup(self, bottom):
super(MemoryDataLayer, self).Setup(bottom)
return ops.MemoryData(bottom[0], **self._param)
\ No newline at end of file
return ops.ImageData(bottom[0], **self._param)
\ No newline at end of file
......@@ -42,7 +42,9 @@ class ConvolutionLayer(Layer):
'stride': [int(element) for element in param.stride] if len(param.stride) > 0 else [1],
'pad': [int(element) for element in param.pad] if len(param.pad) > 0 else [0],
'dilation': [int(element) for element in param.dilation] if len(param.dilation) > 0 else [1],
'group': int(param.group)}
'group': int(param.group),
'padding': 'VALID',
'data_format': 'NCHW'}
if param.HasField('kernel_h'):
assert param.HasField('kernel_w')
self._param['kernel_size'] = [param.kernel_h, param.kernel_w]
......@@ -69,7 +71,7 @@ class ConvolutionLayer(Layer):
def Setup(self, bottom):
super(ConvolutionLayer, self).Setup(bottom)
return ops.Conv2D(bottom + [blob['data'] for blob in self._blobs], **self._param)
return ops.Conv2d(bottom + [blob['data'] for blob in self._blobs], **self._param)
class DeconvolutionLayer(ConvolutionLayer):
......@@ -102,7 +104,7 @@ class DeconvolutionLayer(ConvolutionLayer):
def Setup(self, bottom):
super(DeconvolutionLayer, self).Setup(bottom)
return ops.Deconv2D(bottom + [blob['data'] for blob in self._blobs], **self._param)
return ops.Deconv2d(bottom + [blob['data'] for blob in self._blobs], **self._param)
class PoolingLayer(Layer):
......@@ -135,7 +137,8 @@ class PoolingLayer(Layer):
def __init__(self, LayerParameter):
super(PoolingLayer, self).__init__(LayerParameter)
param = LayerParameter.pooling_param
self._param = {'mode': {0: 'MAX_POOLING', 1: 'AVG_POOLING'}[param.pool],
self._param = {'mode': {0: 'MAX', 1: 'AVG'}[param.pool],
'data_format': 'NCHW',
'global_pooling': param.global_pooling}
if not param.HasField('kernel_h'): self._param['kernel_size'] = [param.kernel_size]
......@@ -150,7 +153,7 @@ class PoolingLayer(Layer):
def Setup(self, bottom):
input = bottom[0] if isinstance(bottom, list) else bottom
super(PoolingLayer, self).Setup(bottom)
return ops.Pool2D(input, **self._param)
return ops.Pool2d(input, **self._param)
class ROIPoolingLayer(Layer):
......@@ -253,7 +256,8 @@ class NNResizeLayer(Layer):
if param.HasField('shape') else []
self._param = {'dsize': dsize,
'fx': float(param.fx),
'fy': float(param.fy)}
'fy': float(param.fy),
'data_format': 'NCHW'}
def Setup(self, bottom):
super(NNResizeLayer, self).Setup(bottom)
......@@ -284,7 +288,8 @@ class BilinearResizeLayer(Layer):
if param.HasField('shape') else []
self._param = {'dsize': dsize,
'fx': float(param.fx),
'fy': float(param.fy)}
'fy': float(param.fy),
'data_format': 'NCHW'}
def Setup(self, bottom):
super(BilinearResizeLayer, self).Setup(bottom)
......@@ -292,4 +297,4 @@ class BilinearResizeLayer(Layer):
if isinstance(bottom, list) and len(bottom) > 1:
dshape = ops.Shape(bottom[1])
self._param['dsize'] = (dshape[2], dshape[3])
return ops.BilinearResize(input, **self._param)
return ops.BilinearResize(input, **self._param)
\ No newline at end of file
......@@ -354,6 +354,25 @@ class Net(object):
return lambda net = self, net_outputs = self.outputs \
: GetOutputs(net, net_outputs)
def forward_v2(self, **kwargs):
"""Forward pass while silencing all net outputs.
Parameters
----------
inputs : dict or None
The blobs to feed before.
Returns
-------
None
"""
if kwargs:
for name, blob in kwargs.items():
ws.FeedTensor(self._inputs_to_tensors[name], blob)
self.function()(return_outputs=False, stage='forward')
return None
def backward(self, **kwargs):
"""Backward pass. [**PyCaffe Style**]
......
......@@ -9,5 +9,4 @@ from .compile import (
scan,
shared)
from .configdefaults import config
import gradient
\ No newline at end of file
from .configdefaults import config
\ No newline at end of file
......@@ -17,6 +17,7 @@ from dragon.core.gradient_maker import GraphGradientMaker
from dragon.core.scope import GetOperatorName, GetTensorName
from dragon.core.tensor import Tensor
def GraphDef_Grad(meta_graph, targets):
"""Inject the gradient targets into GraphDef.
......@@ -67,7 +68,8 @@ def GraphDef_Phase(meta_graph, targets):
"""
phase = 'TEST'
from dragon.core.scope import _PHASE_SCOPE
if _PHASE_SCOPE != '': phase = _PHASE_SCOPE.upper()
if _PHASE_SCOPE != '':
phase = _PHASE_SCOPE.upper()
else:
for target in targets:
if len(target.grad_wrts) > 0:
......@@ -101,7 +103,7 @@ def GraphDef_Update(meta_graph, updater):
parallel_arguments = {}
# wrap hyper-parameters as Tensor for CC
for k,v in updater._hyper_params.items():
for k, v in updater._hyper_params.items():
ws.FeedTensor(updater._prefix + k, np.array([v], dtype=np.float32))
# check data parallel if necessary
......@@ -116,7 +118,8 @@ def GraphDef_Update(meta_graph, updater):
meta_graph.arg.add().CopyFrom(MakeArgument(k, v))
for tuple in updater._tuples:
tensors = tuple[0]; arguments = tuple[1]
tensors = tuple[0];
arguments = tuple[1]
kwargs = dict(arguments, **extra_arguments)
u_target = pb.UpdateTarget()
u_target.type = updater._type
......@@ -226,16 +229,21 @@ def function(inputs=None, outputs=None, givens=None, updater=None):
"""
if not isinstance(inputs, list):
if inputs is None: inputs = []
else: inputs = [inputs]
if inputs is None:
inputs = []
else:
inputs = [inputs]
if not isinstance(outputs, list):
if outputs is None: outputs = []
else: outputs = [outputs]
if outputs is None:
outputs = []
else:
outputs = [outputs]
if len(outputs) > 0 and updater is not None:
raise RuntimeError('You can specific either outputs or updater, not both.')
all_exprs = {}; all_extra_targets = set()
all_exprs = {};
all_extra_targets = set()
if not isinstance(outputs, list): outputs = [outputs]
meta_graph = pb.GraphDef()
......@@ -256,8 +264,8 @@ def function(inputs=None, outputs=None, givens=None, updater=None):
for extra_target in all_extra_targets: meta_graph.target.extend([extra_target])
# we should sort out the topology of these operators before using
all_exprs = sorted(all_exprs.items(), key=lambda d:d[0])
forward_ops = copy.deepcopy([v for k,v in all_exprs])
all_exprs = sorted(all_exprs.items(), key=lambda d: d[0])
forward_ops = copy.deepcopy([v for k, v in all_exprs])
# handle givens
if givens is not None:
......@@ -271,12 +279,13 @@ def function(inputs=None, outputs=None, givens=None, updater=None):
external_input_exprs = OrderedDict(external_input_exprs, **new_tensor.expressions)
else:
external_input_exprs = dict(external_input_exprs, **new_tensor.expressions)
elif isinstance(new_tensor, np.ndarray): ws.FeedTensor(new_tensor, GetTensorName())
external_input_ops = [v for k,v in external_input_exprs.items()]
elif isinstance(new_tensor, np.ndarray):
ws.FeedTensor(new_tensor, GetTensorName())
external_input_ops = [v for k, v in external_input_exprs.items()]
for op in forward_ops:
op.input.extend([name_dict[input] if input in name_dict
else input for input in op.input])
del op.input[:int(len(op.input)/2)]
else input for input in op.input])
del op.input[:int(len(op.input) / 2)]
forward_ops = external_input_ops + forward_ops
......@@ -285,7 +294,8 @@ def function(inputs=None, outputs=None, givens=None, updater=None):
targets = [output.name for output in outputs]
targets.extend(all_extra_targets)
forward_ops, grad_ops = GraphGradientMaker.Make(forward_ops, targets)
else: grad_ops = []
else:
grad_ops = []
meta_graph.op.extend(forward_ops + grad_ops)
if len(outputs) > 0:
......@@ -304,4 +314,36 @@ def function(inputs=None, outputs=None, givens=None, updater=None):
# return a lambda point to run this graph
return lambda *args, **kwargs: \
ws.RunGraph(meta_graph.name, (inputs, args), outputs, **kwargs)
\ No newline at end of file
ws.RunGraph(meta_graph.name, (inputs, args), outputs, **kwargs)
def eval(self, feed_dict=None):
if not hasattr(self, '_eval_func'):
if feed_dict is not None:
self._eval_func = function(inputs=feed_dict.keys(), outputs=self)
else:
self._eval_func = function(outputs=self)
# cond.1: run by feeding
if feed_dict is not None:
# checking
for key, value in feed_dict.items():
if not isinstance(key, Tensor):
raise TypeError('The key of feed_dict key should be a Tensor.')
if key.shape is not None:
if len(key.shape) != len(value.shape):
raise RuntimeError('The Tensor({}) was limited to {} dimensions, \
while feed a value with {} dimensions.'.
format(key.name, len(key.shape), len(value.shape)))
for i in xrange(len(key.shape)):
if key.shape[i] is None: continue
if key.shape[i] != value.shape[i]:
raise RuntimeError('The shape of Tensor({}) was limited as ('.format(key.name) +
','.join([str(dim) for dim in key.shape]) + '), ' +
'while feed a value with (' + ','.join([str(dim) for dim in value.shape]) + ').')
return self._eval_func(*feed_dict.values())
else:
# cond.2: run without feeding
return self._eval_func()
Tensor.eval = eval
......@@ -37,7 +37,7 @@ def grad(cost, wrt, **kwargs):
if not isinstance(wrt, list): wrt = [wrt]
for w in wrt:
cost.grad_wrts.append(w.name)
w.grad_objs.append(cost.name)
w.grad_objs.append(cost)
w_grad = Tensor(w.name + '_grad')
w_grad.extra_targets.add(cost.name)
w_grad.expressions = cost.expressions
......
......@@ -34,7 +34,7 @@ void PReluOp<Context>::RunOnDevice() {
dim = input(0).count(2);
} else {
channels = input(0).dim(-1);
dim = input(0).count() / channels;
dim = input(0).count(1) / channels;
}
output(0)->ReshapeLike(input(0));
......@@ -95,7 +95,7 @@ void PReluGradientOp<Context>::RunOnDevice() {
dim = input(0).count(2);
} else {
channels = input(0).dim(-1);
dim = input(0).count() / channels;
dim = input(0).count(1) / channels;
}
output(0)->ReshapeLike(input(0));
......
......@@ -6,41 +6,35 @@
namespace dragon {
template <class Context> template <typename T>
void BiasAddOp<Context>::NCHWRunWithType() {
outer_dim = input(0).dim(0);
dim = input(0).dim(1);
inner_dim = input(0).count(2);
void BiasAddOp<Context>::RunWithType() {
TENSOR_FILL(input(1), vector<TIndex>(1, dim));
INIT_MULTIPLIER(bias_multiplier, inner_dim);
auto* Bdata = input(1).template data<T, Context>();
auto* BMul_data = bias_multiplier->template data<T, Context>();
auto* Ydata = output(0)->template mutable_data<T, Context>();
kernel::BiasAdd<T, Context>(output(0)->count(), outer_dim, input(1).count(),
inner_dim, data_format, Bdata, BMul_data, Ydata);
}
template <class Context> template <typename T>
void BiasAddOp<Context>::NHWCRunWithType() {
NOT_IMPLEMENTED;
kernel::BiasAdd<T, Context>(output(0)->count(), outer_dim, dim, inner_dim,
data_format,
Bdata,
BMul_data,
Ydata);
}
template <class Context>
void BiasAddOp<Context>::RunOnDevice() {
if (data_format == "NCHW") {
outer_dim = input(0).dim(0);
dim = input(0).dim(1);
inner_dim = input(0).count(2);
} else if (data_format == "NHWC") {
outer_dim = input(0).dim(0);
dim = input(0).dim(-1);
inner_dim = input(0).count(1) / dim;
} else LOG(FATAL) << "Unknown data format: " << data_format;
output(0)->ReshapeLike(input(0));
output(0)->Share(input(0));
if (data_format == "NCHW") {
if (input(0).template IsType<float>()) NCHWRunWithType<float>();
else LOG(FATAL) << "Unsupported input types.";
}
else if (data_format == "NHWC") {
if (input(0).template IsType<float>()) NHWCRunWithType<float>();
else LOG(FATAL) << "Unsupported input types.";
}
else {
LOG(FATAL) << "Unknown data format: " << data_format;
}
if (input(0).template IsType<float>()) RunWithType<float>();
else LOG(FATAL) << "Unsupported input types.";
}
DEPLOY_CPU(BiasAdd);
......@@ -50,49 +44,52 @@ DEPLOY_CUDA(BiasAdd);
OPERATOR_SCHEMA(BiasAdd).NumInputs(2).NumOutputs(1);
template <class Context> template <typename T>
void BiasAddGradientOp<Context>::NCHWRunWithType() {
void BiasAddGradientOp<Context>::RunWithType() {
if (output(1)->name() != "ignore") {
outer_dim = input(0).dim(0);
dim = input(0).dim(1);
inner_dim = input(0).count(2);
output(1)->ReshapeLike(input(1));
INIT_MULTIPLIER(bias_multiplier, inner_dim);
auto* BMul_data = this->bias_multiplier->template data<T, Context>();
auto* dYdata = input(-1).template data<T, Context>();
auto* dBias = output(1)->template mutable_data<T, Context>();
const int y_offset = dim * inner_dim;
for (int n = 0; n < outer_dim; n++) {
math::Gemv<T, Context>(CblasNoTrans, dim, inner_dim,
1.0, dYdata, BMul_data, 1.0, dBias);
if (data_format == "NCHW") {
math::Gemv<T, Context>(CblasNoTrans, dim, inner_dim,
1.0,
dYdata, BMul_data,
1.0,
dBias);
} else if (data_format == "NHWC") {
math::Gemv<T, Context>(CblasTrans, inner_dim, dim,
1.0,
dYdata, BMul_data,
1.0,
dBias);
}
dYdata += y_offset;
}
}
}
template <class Context> template <typename T>
void BiasAddGradientOp<Context>::NHWCRunWithType() {
NOT_IMPLEMENTED;
if (output(0)->name() != "ignore") {
output(0)->ReshapeLike(input(-1));
output(0)->Share(input(-1));
}
}
template <class Context>
void BiasAddGradientOp<Context>::RunOnDevice() {
if (data_format == "NCHW") {
if (input(0).template IsType<float>()) NCHWRunWithType<float>();
else LOG(FATAL) << "Unsupported input types.";
}
else if (data_format == "NHWC") {
if (input(0).template IsType<float>()) NHWCRunWithType<float>();
else LOG(FATAL) << "Unsupported input types.";
}
else {
LOG(FATAL) << "Unknown data format: " << data_format;
}
outer_dim = input(0).dim(0);
dim = input(0).dim(1);
inner_dim = input(0).count(2);
} else if (data_format == "NHWC") {
outer_dim = input(0).dim(0);
dim = input(0).dim(-1);
inner_dim = input(0).count(1) / dim;
} else LOG(FATAL) << "Unknown data format: " << data_format;
output(1)->ReshapeLike(input(1));
if (output(0)->name() != "ignore") {
output(0)->ReshapeLike(input(-1));
output(0)->Share(input(-1));
}
if (input(0).template IsType<float>()) RunWithType<float>();
else LOG(FATAL) << "Unsupported input types.";
}
DEPLOY_CPU(BiasAddGradient);
......
#include "operators/cast/float2half_op.h"
#include "core/workspace.h"
#include "utils/op_kernel.h"
namespace dragon {
#ifdef WITH_CUDA_FP16
template <class Context>
void FloatToHalfOp<Context>::RunOnDevice() {
CHECK(input(0).template IsType<float>())
<< "The type of tensor should be float32.";
output(0)->ReshapeLike(input(0));
// cast
auto* Xdata = input(0).template data<float, Context>();
auto* Ydata = output(0)->template mutable_data<float16, Context>();
kernel::Float2Half<float, Context>(output(0)->count(), Xdata, Ydata);
// release & share
input(0).Reset();
input(0).ReshapeLike(*output(0));
input(0).Share(*output(0));
}
#ifdef WITH_CUDA
DEPLOY_CUDA(FloatToHalf);
#endif
OPERATOR_SCHEMA(FloatToHalf).NumInputs(1).NumOutputs(1);
NO_GRADIENT(FloatToHalf);
#endif
} // namespace dragon
\ No newline at end of file
......@@ -17,11 +17,11 @@ void SoftmaxCrossEntropyOp<Context>::RunWithType() {
if (normalization == "UNIT") {
output(0)->Reshape(vector<TIndex>(1, outer_dim * inner_dim));
auto* Ydata = output(0)->template mutable_data<T, Context>();
kernel::Sum<T, Context>(losses.count(),
input(0).dim(axis),
inner_dim,
Ldata,
Ydata);
kernel::Sum<T, Context>(outer_dim * inner_dim,
input(0).dim(axis),
inner_dim,
Ldata,
Ydata);
return;
}
......@@ -65,12 +65,12 @@ void SoftmaxCrossEntropyGradientOp<Context>::RunWithType() {
if (normalization == "UNIT") {
auto* dYdata = input(-1).template data<T, Context>();
kernel::SumGrad<T, Context>(input(0).count() / input(0).dim(axis),
input(0).dim(axis),
inner_dim,
1.0,
dYdata,
Pdata);
kernel::SumGrad<T, Context>(outer_dim * inner_dim,
input(0).dim(axis),
inner_dim,
1.0,
dYdata,
Pdata);
math::Mul<T, Context>(output(0)->count(), Pdata, dXdata, dXdata);
return;
}
......
#include "operators/misc/image_data_op.h"
#include "utils/op_kernel.h"
namespace dragon {
template <class Context> template <typename Tx, typename Ty>
void ImageDataOp<Context>::RunWithType() {
auto* Xdata = input(0).template data<Tx, Context>();
auto* Mdata = mean.count() > 0 ? mean.template data<float, Context>() : nullptr;
auto* Sdata = std.count() > 0 ? std.template data<float, Context>() : nullptr;
auto* Ydata = output(0)->template mutable_data<Ty, Context>();
kernel::ImageData<Tx, Ty, Context>(output(0)->count(),
n, c, h, w,
Mdata, Sdata,
data_format,
Xdata,
Ydata);
}
template <class Context>
void ImageDataOp<Context>::RunOnDevice() {
n = input(0).dim(0);
c = input(0).dim(3);
h = input(0).dim(1);
w = input(0).dim(2);
if (data_format == "NCHW") {
output(0)->Reshape(vector<TIndex>({ n, c, h, w }));
} else if (data_format == "NHWC") {
output(0)->ReshapeLike(input(0));
} else LOG(FATAL) << "Unknown data format: " << data_format;
if (input(0).template IsType<float>()) {
if (dtype == "FLOAT32") RunWithType<float, float>();
#ifdef WITH_CUDA_FP16
else if (dtype == "FLOAT16") RunWithType<float, float16>();
#endif
else LOG(FATAL) << "Unsupported output type: " << dtype;
} else if (input(0).template IsType<uint8_t>()) {
if (dtype == "FLOAT32") RunWithType<uint8_t, float>();
#ifdef WITH_CUDA_FP16
else if (dtype == "FLOAT16") RunWithType<uint8_t, float16>();
#endif
else LOG(FATAL) << "Unsupported output type: " << dtype;
}
else { LOG(FATAL) << "Unsupported input types."; }
}
DEPLOY_CPU(ImageData);
#ifdef WITH_CUDA
DEPLOY_CUDA(ImageData);
#endif
OPERATOR_SCHEMA(ImageData).NumInputs(1).NumOutputs(1);
NO_GRADIENT(ImageData);
} // namespace dragon
\ No newline at end of file
#include "operators/misc/memory_data_op.h"
#include "utils/op_kernel.h"
namespace dragon {
template <class Context> template <typename Tx, typename Ty>
void MemoryDataOp<Context>::RunWithType() {
auto* Xdata = input(0).template data<Tx, Context>();
auto* Ydata = output(0)->template mutable_data<Ty, Context>();
kernel::MemoryData<Tx, Ty, Context>(output(0)->count(),
output(0)->dim(0),
output(0)->dim(1),
output(0)->dim(2),
output(0)->dim(3),
Xdata, Ydata);
}
template <class Context>
void MemoryDataOp<Context>::RunOnDevice() {
vector<TIndex> dims({ input(0).dim(0), input(0).dim(3),
input(0).dim(1), input(0).dim(2) });
output(0)->Reshape(dims);
if (input(0).template IsType<float>()) {
if (data_type == TensorProto_DataType_FLOAT) RunWithType<float, float>();
#ifdef WITH_CUDA_FP16
else if (data_type == TensorProto_DataType_FLOAT16) RunWithType<float, float16>();
#endif
else LOG(FATAL) << "Unsupported input types.";
}
else if (input(0).template IsType<uint8_t>()) {
if (data_type == TensorProto_DataType_FLOAT) RunWithType<uint8_t, float>();
#ifdef WITH_CUDA_FP16
if (data_type == TensorProto_DataType_FLOAT16) RunWithType<uint8_t, float16>();
#endif
}
else { LOG(FATAL) << "Unsupported input types."; }
}
DEPLOY_CPU(MemoryData);
#ifdef WITH_CUDA
DEPLOY_CUDA(MemoryData);
#endif
OPERATOR_SCHEMA(MemoryData).NumInputs(1).NumOutputs(1);
NO_GRADIENT(MemoryData);
} // namespace dragon
\ No newline at end of file
......@@ -13,7 +13,7 @@ void MPIBroadcastOp<Context>::RunWithType() {
#else
auto* Xdata = input(0).template mutable_data<T, CPUContext>();
#endif
MPI_Bcast(Xdata, input(0).count(), MPI_FLOAT, this->comm_root, this->comm);
MPI_Bcast(Xdata, input(0).count(), mpi_dtype(), this->comm_root, this->comm);
output(0)->Share(input(0));
} else {
#ifdef WITH_MPI_CUDA
......@@ -21,7 +21,7 @@ void MPIBroadcastOp<Context>::RunWithType() {
#else
auto* Ydata = output(0)->template mutable_data<T, CPUContext>();
#endif
MPI_Bcast(Ydata, output(0)->count(), MPI_FLOAT, this->comm_root, this->comm);
MPI_Bcast(Ydata, output(0)->count(), mpi_dtype(), this->comm_root, this->comm);
}
}
......@@ -41,13 +41,13 @@ void MPIBroadcastOp<Context>::RunOnDevice() {
}
MPI_Bcast(ndim, 1, MPI_UNSIGNED_LONG_LONG, this->comm_root, this->comm);
if (dims == nullptr) dims = new TIndex[ndim[0]];
MPI_Bcast(dims, 4, MPI_LONG_LONG, this->comm_root, this->comm);
MPI_Bcast(dims, (int)ndim[0], MPI_LONG_LONG, this->comm_root, this->comm);
vector<TIndex> _dims;
for (int i = 0; i < ndim[0]; i++) _dims.push_back(dims[i]);
for (int i = 0; i < (int)ndim[0]; i++) _dims.push_back(dims[i]);
output(0)->Reshape(_dims);
if (input(0).template IsType<float>()) RunWithType<float>();
else LOG(FATAL) << "Unsupported input types.";
if (this->dtype == "FLOAT32") RunWithType<float>();
else LOG(FATAL) << "Unsupported input type: " << this->dtype;
}
DEPLOY_CPU(MPIBroadcast);
......@@ -71,7 +71,7 @@ void MPIBroadcastGradientOp<Context>::RunWithType() {
#endif
for (int i = 0; i < this->comm_size; i++) {
if (i == this->comm_root) continue;
MPI_Recv(dYdata, output(0)->count(), MPI_FLOAT, i, 0, this->comm, MPI_STATUS_IGNORE);
MPI_Recv(dYdata, output(0)->count(), mpi_dtype(), i, 0, this->comm, MPI_STATUS_IGNORE);
#ifdef WITH_MPI_CUDA
math::Add<T, Context>(output(0)->count(), dYdata, dXdata, dXdata);
#else
......@@ -85,7 +85,7 @@ void MPIBroadcastGradientOp<Context>::RunWithType() {
#else
auto* dYdata = input(-1).template data<T, CPUContext>();
#endif
MPI_Send(dYdata, input(-1).count(), MPI_FLOAT, this->comm_root, 0, this->comm);
MPI_Send(dYdata, input(-1).count(), mpi_dtype(), this->comm_root, 0, this->comm);
}
}
......@@ -93,10 +93,10 @@ template <class Context>
void MPIBroadcastGradientOp<Context>::RunOnDevice() {
output(0)->ReshapeLike(input(-1));
if (input(0).template IsType<float>()) RunWithType<float>();
else LOG(FATAL) << "Unsupported input types.";
if (this->dtype == "FLOAT32") RunWithType<float>();
else LOG(FATAL) << "Unsupported input type: " << this->dtype;
}
DEPLOY_CPU(MPIBroadcastGradient);
#ifdef WITH_CUDA
DEPLOY_CUDA(MPIBroadcastGradient);
......
......@@ -16,29 +16,50 @@ void MPIGatherOp<Context>::RunWithType() {
#else
auto* Ydata = output(i)->template mutable_data<T, CPUContext>();
#endif
MPI_Recv(Ydata, output(i)->count(), MPI_FLOAT, i, 0, this->comm, MPI_STATUS_IGNORE);
MPI_Recv(Ydata, output(i)->count(), mpi_dtype(), i, 0, this->comm, MPI_STATUS_IGNORE);
}
}
else{
else {
#ifdef WITH_MPI_CUDA
auto* Xdata = input(0).template data<T, Context>();
#else
auto* Xdata = input(0).template data<T, CPUContext>();
#endif
MPI_Send(Xdata, input(0).count(), MPI_FLOAT, this->comm_root, 0, this->comm);
MPI_Send(Xdata, input(0).count(), mpi_dtype(), this->comm_root, 0, this->comm);
}
}
template <class Context>
void MPIGatherOp<Context>::RunOnDevice() {
if (this->comm_rank == this->comm_root) {
CHECK_EQ(this->comm_size, OutputSize());
for (int i = 0; i < OutputSize(); i++)
output(i)->ReshapeLike(input(0));
CHECK_EQ(this->comm_size, OutputSize());
// reshape from root
if (this->comm_rank == this->comm_root)
output(0)->ReshapeLike(input(0));
// reshape from others
size_t* all_ndim = new size_t[this->comm_size];
size_t ndim[1];
if (this->comm_rank != this->comm_root) {
ndim[0] = input(0).ndim();
MPI_Send(ndim, 1, MPI_UNSIGNED_LONG_LONG, this->comm_root, 0, this->comm);
} else {
for (int i = 1; i < this->comm_size; i++)
MPI_Recv(all_ndim + i, 1, MPI_UNSIGNED_LONG_LONG, i, 0, this->comm, MPI_STATUS_IGNORE);
}
if (this->comm_rank != this->comm_root) {
MPI_Send(input(0).dims().data(), (int)ndim[0], MPI_LONG_LONG, this->comm_root, 0, this->comm);
} else {
for (int i = 1; i < this->comm_size; i++) {
TIndex* dims = new TIndex[all_ndim[i]];
MPI_Recv(dims, (int)all_ndim[i], MPI_LONG_LONG, i, 0, this->comm, MPI_STATUS_IGNORE);
vector<TIndex> dims_;
for (int j = 0; j < (int)all_ndim[i]; j++) dims_.push_back(dims[j]);
output(i)->Reshape(dims_);
}
}
if (input(0).template IsType<float>()) RunWithType<float>();
else LOG(FATAL) << "Unsupported input types.";
if (this->dtype == "FLOAT32") RunWithType<float>();
else LOG(FATAL) << "Unsupported input type: " << this->dtype;
}
DEPLOY_CPU(MPIGather);
......@@ -58,7 +79,7 @@ void MPIGatherGradientOp<Context>::RunWithType() {
#else
auto* dYdata = input(this->comm_rank + 1).template data<T, CPUContext>();
#endif
MPI_Send(dYdata, input(this->comm_rank + 1).count(), MPI_FLOAT, i, 0, this->comm);
MPI_Send(dYdata, input(this->comm_rank + 1).count(), mpi_dtype(), i, 0, this->comm);
}
}
else{
......@@ -67,7 +88,7 @@ void MPIGatherGradientOp<Context>::RunWithType() {
#else
auto* dXdata = output(0)->template mutable_data<T, CPUContext>();
#endif
MPI_Recv(dXdata, output(0)->count(), MPI_FLOAT, this->comm_root, 0, this->comm, MPI_STATUS_IGNORE);
MPI_Recv(dXdata, output(0)->count(), mpi_dtype(), this->comm_root, 0, this->comm, MPI_STATUS_IGNORE);
}
}
......@@ -75,8 +96,8 @@ template <class Context>
void MPIGatherGradientOp<Context>::RunOnDevice() {
output(0)->ReshapeLike(input(0));
if (input(0).template IsType<float>()) RunWithType<float>();
else LOG(FATAL) << "Unsupported input types.";
if (this->dtype == "FLOAT32") RunWithType<float>();
else LOG(FATAL) << "Unsupported input type: " << this->dtype;
}
DEPLOY_CPU(MPIGatherGradient);
......
......@@ -15,7 +15,8 @@ void CropOp<Context>::RunWithType() {
inner_dim,
starts[axis],
Xdata,
Ydata);
Ydata,
&ctx());
}
template <class Context>
......@@ -219,7 +220,6 @@ template <class Context> template <typename T>
void CropGradientOp<Context>::RunWithType() {
auto* dYdata = source->template data<T, Context>();
auto* dXdata = dest->template mutable_data<T, Context>();
math::Set<T, Context>(dest->count(), 0, dXdata);
kernel::Crop1DGrad<T, Context>(dest->count(),
input(0).dim(axis),
dim,
......@@ -227,7 +227,8 @@ void CropGradientOp<Context>::RunWithType() {
starts[axis],
ends[axis],
dYdata,
dXdata);
dXdata,
&ctx());
}
template <class Context>
......
......@@ -16,7 +16,8 @@ void PadOp<Context>::ConstRunWithType() {
pad_l[axis],
value,
Xdata,
Ydata);
Ydata,
&ctx());
}
template <class Context> template <typename T>
......@@ -29,7 +30,8 @@ void PadOp<Context>::ReflectRunWithType() {
inner_dim,
pad_l[axis],
Xdata,
Ydata);
Ydata,
&ctx());
}
template <class Context> template <typename T>
......@@ -42,7 +44,8 @@ void PadOp<Context>::EdgeRunWithType() {
inner_dim,
pad_l[axis],
Xdata,
Ydata);
Ydata,
&ctx());
}
template <class Context>
......@@ -109,14 +112,14 @@ template <class Context> template <typename T>
void PadGradientOp<Context>::ConstRunWithType() {
auto* dYdata = source->template data<T, Context>();
auto* dXdata = dest->template mutable_data<T, Context>();
math::Set<T, Context>(dest->count(), 0, dXdata);
kernel::ConstPad1DGrad<T, Context>(dest->count(),
dim - pad_l[axis] - pad_r[axis],
dim,
inner_dim,
pad_l[axis],
dYdata,
dXdata);
dXdata,
&ctx());
}
template <class Context> template <typename T>
......@@ -144,7 +147,8 @@ void PadGradientOp<Context>::EdgeRunWithType() {
inner_dim,
pad_l[axis],
dYdata,
dXdata);
dXdata,
&ctx());
}
template <class Context>
......
......@@ -7,13 +7,28 @@ namespace dragon {
template <class Context> template <typename T>
void BilinearResizeOp<Context>::RunWithType() {
if (data_format == "NCHW") {
n = dims[0];
c = dims[1];
h = input(0).dim(2);
w = input(0).dim(3);
out_h = dims[2];
out_w = dims[3];
} else if (data_format == "NHWC") {
n = dims[0];
h = input(0).dim(1);
w = input(0).dim(2);
out_h = dims[1];
out_w = dims[2];
c = dims[3];
}
auto* Xdata = input(0).template data<T, Context>();
auto* Ydata = output(0)->template mutable_data<T, Context>();
kernel::BilinearResize<T, Context>(output(0)->count(), dims[0], dims[1],
input(0).dim(2), input(0).dim(3),
dims[2], dims[3],
Xdata,
Ydata);
kernel::BilinearResize<T, Context>(output(0)->count(), n, c, h, w,
out_h, out_w,
data_format,
Xdata,
Ydata);
}
template <class Context>
......@@ -25,9 +40,9 @@ void BilinearResizeOp<Context>::RunOnDevice() {
for (int i = 0; i < 2; i++) {
Tensor* t = ws()->GetTensor(dynamic_dsize[i]);
if (t->IsType<int>()) {
dims[2 + i] = t->template data<int, CPUContext>()[0];
dims[spatial_axis + i] = t->template data<int, CPUContext>()[0];
} else if (t->IsType<float>()) {
dims[2 + i] = t->template data<float, CPUContext>()[0];
dims[spatial_axis + i] = t->template data<float, CPUContext>()[0];
} else {
LOG(FATAL) << "Unsupported types of dsize.";
}
......@@ -35,12 +50,12 @@ void BilinearResizeOp<Context>::RunOnDevice() {
} else if (static_dsize.size() > 0) {
CHECK_EQ(static_dsize.size(), 2)
<< "\nThe dsize should be a scalar with 2 elements.";
for (int i = 0; i < 2; i++) dims[2 + i] = static_dsize[i];
for (int i = 0; i < 2; i++) dims[spatial_axis + i] = static_dsize[i];
} else {
CHECK(fy != -1.0 && fx != -1.0)
<< "\nThe fx and fy should be set.";
dims[2] = int(dims[2] * fy);
dims[3] = int(dims[3] * fx);
dims[spatial_axis] = int(dims[spatial_axis] * fy);
dims[spatial_axis + 1] = int(dims[spatial_axis + 1] * fx);
}
output(0)->Reshape(dims);
......@@ -56,14 +71,28 @@ OPERATOR_SCHEMA(BilinearResize).NumInputs(1).NumOutputs(1);
template <class Context> template <typename T>
void BilinearResizeGradientOp<Context>::RunWithType() {
if (data_format == "NCHW") {
n = input(0).dim(0);
c = input(0).dim(1);
h = input(0).dim(2);
w = input(0).dim(3);
out_h = input(-1).dim(2);
out_w = input(-1).dim(3);
} else if (data_format == "NHWC") {
n = input(0).dim(0);
h = input(0).dim(1);
w = input(0).dim(2);
c = input(0).dim(3);
out_h = input(-1).dim(1);
out_w = input(-1).dim(2);
}
auto* dYdata = input(-1).template data<T, Context>();
auto* dXdata = output(0)->template mutable_data<T, Context>();
math::Set<T, Context>(output(0)->count(), 0, dXdata);
kernel::BilinearResizeGrad<T, Context>(input(-1).count(), input(0).dim(0), input(0).dim(1),
input(-1).dim(2), input(-1).dim(3),
output(0)->dim(2), output(0)->dim(3),
dYdata,
dXdata);
kernel::BilinearResizeGrad<T, Context>(input(-1).count(), n, c, h, w,
out_h, out_w,
data_format,
dYdata,
dXdata);
}
template <class Context>
......
......@@ -4,35 +4,24 @@
namespace dragon {
template <class Context>
void ConvOp<Context>::ComputeOutputShape() {
this->output_shape.clear();
for (int i = 0; i < this->num_spatial_axes; i++) {
const int input_dim = this->bottom_shape[this->channel_axis + i + 1];
const int dilated_kernel = this->dilation[i] * (this->kernel_size[i] - 1) + 1;
const int output_dim = (input_dim + 2 * this->pad[i] - dilated_kernel) / this->stride[i] + 1;
this->output_shape.push_back(output_dim);
}
}
template <class Context> template <typename T>
void ConvOp<Context>::RunWithType() {
void Conv2dOp<Context>::RunWithType() {
// get buffer
this->col_buffer = ws()->GetBuffer();
this->col_buffer->Reshape(this->col_buffer_shape);
this->col_buffer->Reshape(this->col_shape);
auto* Xdata = input(0).template data<T, Context>();
auto* Ydata = output(0)->template mutable_data<T, Context>();
TENSOR_FILL(input(1), this->weight_shape);
auto* Wdata = input(1).template data<T, Context>();
if (InputSize() > 2) {
if (HasBias()) {
TENSOR_FILL(input(2), this->bias_shape);
INIT_MULTIPLIER(this->bias_multiplier, this->out_spatial_dim);
}
for (int n = 0; n < input(0).dim(0); n++) {
Wx(Xdata + n * this->x_offset, Wdata, Ydata + n * this->y_offset);
if (InputSize() > 2) {
if (HasBias()) {
auto* Bdata = input(2).template data<T, Context>();
Pb(Bdata, Ydata + n * this->y_offset);
}
......@@ -43,28 +32,28 @@ void ConvOp<Context>::RunWithType() {
}
template <class Context>
void ConvOp<Context>::RunOnDevice() {
void Conv2dOp<Context>::RunOnDevice() {
Reshape();
if (input(0).template IsType<float>()) RunWithType<float>();
else LOG(FATAL) << "Unsupported input types.";
}
DEPLOY_CPU(Conv);
DEPLOY_CPU(Conv2d);
#ifdef WITH_CUDA
DEPLOY_CUDA(Conv);
DEPLOY_CUDA(Conv2d);
#endif
OPERATOR_SCHEMA(Conv).NumInputs(2, 3).NumOutputs(1);
OPERATOR_SCHEMA(Conv2d).NumInputs(2, 3).NumOutputs(1);
template <class Context> template <typename T>
void ConvGradientOp<Context>::RunWithType() {
void Conv2dGradientOp<Context>::RunWithType() {
// get buffer
this->col_buffer = ws()->GetBuffer();
this->col_buffer->Reshape(this->col_buffer_shape);
this->col_buffer->Reshape(this->col_shape);
auto* dYdata = input(-1).template data<T, Context>();
if (output(2)->name() != "ignore") {
if (HasBias()) {
INIT_MULTIPLIER(this->bias_multiplier, this->out_spatial_dim);
T* dBdata = output(2)->template mutable_data<T, Context>();
for (int n = 0; n < input(2).dim(0); n++)
......@@ -89,28 +78,28 @@ void ConvGradientOp<Context>::RunWithType() {
}
template <class Context>
void ConvGradientOp<Context>::RunOnDevice() {
void Conv2dGradientOp<Context>::RunOnDevice() {
GradientReshape();
if (input(0).template IsType<float>()) RunWithType<float>();
else LOG(FATAL) << "Unsupported input types.";
}
DEPLOY_CPU(ConvGradient);
DEPLOY_CPU(Conv2dGradient);
#ifdef WITH_CUDA
DEPLOY_CUDA(ConvGradient);
DEPLOY_CUDA(Conv2dGradient);
#endif
OPERATOR_SCHEMA(ConvGradient).NumInputs(3).NumOutputs(3);
OPERATOR_SCHEMA(Conv2dGradient).NumInputs(3).NumOutputs(3);
class GetConvGradient final : public GradientMakerBase {
class GetConv2dGradient final : public GradientMakerBase {
public:
GRADIENT_MAKER_CTOR(GetConvGradient);
GRADIENT_MAKER_CTOR(GetConv2dGradient);
vector<OperatorDef> MakeDefs() override {
return SingleDef(def.type() + "Gradient", "",
vector<string> {I(0), I(1), GO(0)},
vector<string> {GI(0), GI(1), GI(2)});
}
};
REGISTER_GRADIENT(Conv, GetConvGradient);
REGISTER_GRADIENT(Conv2d, GetConv2dGradient);
} // namespace dragon
\ No newline at end of file
#include "operators/vision/deconv_op.h"
#include "operators/vision/conv_transpose_op.h"
#include "core/workspace.h"
#include "utils/filler.h"
namespace dragon {
template <class Context>
void DeConvOp<Context>::ComputeOutputShape() {
this->output_shape.clear();
for (int i = 0; i < this->num_spatial_axes; i++) {
const int input_dim = this->bottom_shape[this->channel_axis + i + 1];
const int dilated_kernel = this->dilation[i] * (this->kernel_size[i] - 1) + 1;
const int output_dim = this->stride[i] * (input_dim - 1) + dilated_kernel - 2 * this->pad[i];
this->output_shape.push_back(output_dim);
}
}
template <class Context> template <typename T>
void DeConvOp<Context>::RunWithType() {
void Conv2dTransposeOp<Context>::RunWithType() {
// get buffer
this->col_buffer = ws()->GetBuffer();
this->col_buffer->Reshape(this->col_buffer_shape);
this->col_buffer->Reshape(this->col_shape);
auto* Xdata = input(0).template data<T, Context>();
auto* Ydata = output(0)->template mutable_data<T, Context>();
......@@ -43,24 +32,27 @@ void DeConvOp<Context>::RunWithType() {
}
template <class Context>
void DeConvOp<Context>::RunOnDevice() {
void Conv2dTransposeOp<Context>::RunOnDevice() {
Reshape();
// fix the output shape for im2col/col2im
for (int i = 0; i < this->num_spatial_axes; i++)
this->output_shape[i] = input(0).dim(this->spatial_axis + i);
if (input(0).template IsType<float>()) RunWithType<float>();
else LOG(FATAL) << "Unsupported input types.";
}
DEPLOY_CPU(DeConv);
DEPLOY_CPU(Conv2dTranspose);
#ifdef WITH_CUDA
DEPLOY_CUDA(DeConv);
DEPLOY_CUDA(Conv2dTranspose);
#endif
OPERATOR_SCHEMA(DeConv).NumInputs(2, 3).NumOutputs(1);
OPERATOR_SCHEMA(Conv2dTranspose).NumInputs(2, 3).NumOutputs(1);
template <class Context> template <typename T>
void DeConvGradientOp<Context>::RunWithType() {
void Conv2dTransposeGradientOp<Context>::RunWithType() {
// get buffer
this->col_buffer = ws()->GetBuffer();
this->col_buffer->Reshape(this->col_buffer_shape);
this->col_buffer->Reshape(this->col_shape);
auto* dYdata = input(-1).template data<T, Context>();
......@@ -90,28 +82,31 @@ void DeConvGradientOp<Context>::RunWithType() {
}
template <class Context>
void DeConvGradientOp<Context>::RunOnDevice() {
void Conv2dTransposeGradientOp<Context>::RunOnDevice() {
GradientReshape();
// fix the output shape for im2col/col2im
for (int i = 0; i < this->num_spatial_axes; i++)
this->output_shape[i] = input(0).dim(this->spatial_axis + i);
if (input(0).template IsType<float>()) RunWithType<float>();
else LOG(FATAL) << "Unsupported input types.";
}
DEPLOY_CPU(DeConvGradient);
DEPLOY_CPU(Conv2dTransposeGradient);
#ifdef WITH_CUDA
DEPLOY_CUDA(DeConvGradient);
DEPLOY_CUDA(Conv2dTransposeGradient);
#endif
OPERATOR_SCHEMA(DeConvGradient).NumInputs(3).NumOutputs(3);
OPERATOR_SCHEMA(Conv2dTransposeGradient).NumInputs(3).NumOutputs(3);
class GetDeConvGradient final : public GradientMakerBase {
class GetConv2dTransposeGradient final : public GradientMakerBase {
public:
GRADIENT_MAKER_CTOR(GetDeConvGradient);
GRADIENT_MAKER_CTOR(GetConv2dTransposeGradient);
vector<OperatorDef> MakeDefs() override {
return SingleDef(def.type() + "Gradient", "",
vector<string> {I(0), I(1), GO(0)},
vector<string> {GI(0), GI(1), GI(2)});
}
};
REGISTER_GRADIENT(DeConv, GetDeConvGradient);
REGISTER_GRADIENT(Conv2dTranspose, GetConv2dTransposeGradient);
} // namespace dragon
\ No newline at end of file
#include "operators/vision/conv_op.h"
#include "operators/vision/conv_op_base.h"
#include "core/workspace.h"
#include "utils/filler.h"
namespace dragon {
template <class Context>
void ConvOpBase<Context>::ComputeOutputShape() {
output_shape.clear();
for (int i = 0; i < num_spatial_axes; i++) {
if (!ReverseDimensions()) {
const TIndex input_dim = bottom_shape[spatial_axis + i];
const TIndex dilated_kernel = dilation[i] * (kernel_size[i] - 1) + 1;
if (padding != "SAME") {
const TIndex output_dim = (input_dim + 2 * pad[i] - dilated_kernel) / stride[i] + 1;
output_shape.push_back(output_dim);
} else {
TIndex output_dim = (input_dim + stride[i] - 1) / (float)stride[i];
TIndex padding_needed = std::max(TIndex(0), (output_dim - 1) * stride[i] + dilated_kernel - input_dim);
TIndex pad_l = padding_needed / 2;
TIndex pad_r = padding_needed - pad_l;
pad[i] = pad_l;
output_shape.push_back(output_dim);
}
} else {
const TIndex input_dim = bottom_shape[spatial_axis + i];
const TIndex dilated_kernel = dilation[i] * (kernel_size[i] - 1) + 1;
if (padding != "SAME") {
const TIndex output_dim = stride[i] * (input_dim - 1) + dilated_kernel - 2 * pad[i];
output_shape.push_back(output_dim);
} else {
TIndex output_dim = -1;
if (dynamic_dsize.size() > 0) {
NOT_IMPLEMENTED;
} else if (static_dsize.size() > 0) {
if ((int)static_dsize.size() != num_spatial_axes + 2)
LOG(FATAL) << "The len of output shape should be " << num_spatial_axes + 2
<< ", but got " << static_dsize.size();
output_dim = static_dsize[spatial_axis + i];
} else LOG(FATAL) << "The output shape must be specified if using SAME padding algorithm.";
TIndex padding_needed = stride[i] * (input_dim - 1) + dilated_kernel - output_dim;
CHECK_GE(padding_needed, 0)
<< "\nThe output shape is incorrect."
<< "\nWith the given stride and kernel, dimension of axis " << spatial_axis + i
<< " can be at most " << stride[i] * (input_dim - 1) + dilated_kernel << ".";
TIndex pad_l = padding_needed / 2;
TIndex pad_r = padding_needed - pad_l;
pad[i] = pad_l;
output_shape.push_back(output_dim);
}
}
}
}
template <class Context> template <typename T>
void ConvOpBase<Context>::Wx(const T* x, const T* weights, T* y, bool skip_im2col) {
const T* col_buff_ = x;
......@@ -12,25 +60,45 @@ void ConvOpBase<Context>::Wx(const T* x, const T* weights, T* y, bool skip_im2co
col_buff_ = col_buffer->data<T, Context>();
}
for (int g = 0; g < group; g++) {
math::Gemm<T, Context>(CblasNoTrans, CblasNoTrans,
conv_out_channels / group,
conv_out_spatial_dim,
kernel_dim,
1.0, weights + weight_offset * g,
col_buff_ + col_offset * g,
0.0, y + output_offset * g);
if (data_format == "NCHW") {
math::Gemm<T, Context>(CblasNoTrans, CblasNoTrans,
conv_out_channels / group,
conv_out_spatial_dim,
kernel_dim,
1.0, weights + weight_offset * g,
col_buff_ + col_offset * g,
0.0, y + output_offset * g);
} else if (data_format == "NHWC") {
math::Gemm<T, Context>(CblasNoTrans, CblasNoTrans,
conv_out_spatial_dim,
conv_out_channels / group,
kernel_dim,
1.0, col_buff_ + col_offset * g,
weights + weight_offset * g,
0.0, y + output_offset * g);
}
}
}
template <class Context> template <typename T>
void ConvOpBase<Context>::Pb(const T* bias, T* y) {
math::Gemm<T, Context>(CblasNoTrans, CblasNoTrans,
num_output,
out_spatial_dim,
1,
1.0, bias,
bias_multiplier->template data<T, Context>(),
1.0, y);
if (data_format == "NCHW") {
math::Gemm<T, Context>(CblasNoTrans, CblasNoTrans,
num_output,
out_spatial_dim,
1,
1.0, bias,
bias_multiplier->template data<T, Context>(),
1.0, y);
} else if (data_format == "NHWC") {
math::Gemm<T, Context>(CblasNoTrans, CblasNoTrans,
out_spatial_dim,
num_output,
1,
1.0, bias_multiplier->template data<T, Context>(),
bias,
1.0, y);
}
}
template <class Context> template <typename T>
......@@ -38,13 +106,23 @@ void ConvOpBase<Context>::Dx(const T* dy, const T* weights, T* dx) {
T* col_buff_ = col_buffer->template mutable_data<T, Context>();
if (is_1x1) col_buff_ = dx;
for (int g = 0; g < group; g++) {
math::Gemm<T, Context>(CblasTrans, CblasNoTrans,
kernel_dim,
conv_out_spatial_dim,
conv_out_channels / group,
1.0, weights + weight_offset * g,
dy + output_offset * g,
0.0, col_buff_ + col_offset * g);
if (data_format == "NCHW") {
math::Gemm<T, Context>(CblasTrans, CblasNoTrans,
kernel_dim,
conv_out_spatial_dim,
conv_out_channels / group,
1.0, weights + weight_offset * g,
dy + output_offset * g,
0.0, col_buff_ + col_offset * g);
} else if (data_format == "NHWC") {
math::Gemm<T, Context>(CblasNoTrans, CblasTrans,
conv_out_spatial_dim,
kernel_dim,
conv_out_channels / group,
1.0, dy + output_offset * g,
weights + weight_offset * g,
0.0, col_buff_ + col_offset * g);
}
}
if (!is_1x1) Col2Im(col_buff_, dx);
}
......@@ -57,27 +135,71 @@ void ConvOpBase<Context>::Dw(const T* dy, const T* x, T *dw) {
col_buff_ = col_buffer->template data<T, Context>();
}
for (int g = 0; g < group; g++) {
math::Gemm<T, Context>(CblasNoTrans, CblasTrans,
conv_out_channels / group,
kernel_dim,
conv_out_spatial_dim,
1.0, dy + output_offset * g,
col_buff_ + col_offset * g,
1.0, dw + weight_offset * g);
if (data_format == "NCHW") {
math::Gemm<T, Context>(CblasNoTrans, CblasTrans,
conv_out_channels / group,
kernel_dim,
conv_out_spatial_dim,
1.0, dy + output_offset * g,
col_buff_ + col_offset * g,
1.0, dw + weight_offset * g);
} else if (data_format == "NHWC") {
math::Gemm<T, Context>(CblasTrans, CblasNoTrans,
kernel_dim,
conv_out_channels / group,
conv_out_spatial_dim,
1.0, col_buff_ + col_offset * g,
dy + output_offset * g,
1.0, dw + weight_offset * g);
}
}
}
template <class Context> template <typename T>
void ConvOpBase<Context>::Db(const T* dy, T* db) {
math::Gemv<T, Context>(CblasNoTrans, num_output, out_spatial_dim,
1.0, dy,
bias_multiplier->template data<T, Context>(),
1.0, db);
if (data_format == "NCHW") {
math::Gemv<T, Context>(CblasNoTrans, num_output, out_spatial_dim,
1.0, dy,
bias_multiplier->template data<T, Context>(),
1.0, db);
} else if (data_format == "NHWC") {
math::Gemv<T, Context>(CblasTrans, out_spatial_dim, num_output,
1.0, dy,
bias_multiplier->template data<T, Context>(),
1.0, db);
}
}
template <class Context>
void ConvOpBase<Context>::Setup() {
vector<int> ks = OperatorBase::GetRepeatedArg<int>("kernel_size");
for (int i = 0; i < num_spatial_axes; i++)
kernel_size.push_back(i < ks.size() ? ks[i] : ks[0]);
vector<int> s = OperatorBase::GetRepeatedArg<int>("stride");
for (int i = 0; i < num_spatial_axes; i++)
stride.push_back(i < s.size() ? s[i] : s[0]);
vector<int> p = OperatorBase::GetRepeatedArg<int>("pad");
for (int i = 0; i < num_spatial_axes; i++)
pad.push_back(i < p.size() ? p[i] : p[0]);
vector<int> d = OperatorBase::GetRepeatedArg<int>("dilation");
for (int i = 0; i < num_spatial_axes; i++)
dilation.push_back(i < d.size() ? d[i] : d[0]);
is_1x1 = true;
for (int i = 0; i < num_spatial_axes; i++) {
is_1x1 &= (kernel_size[i] == 1 &&
stride[i] == 1 &&
pad[i] == 0);
if (!is_1x1) break;
}
}
template <class Context>
void ConvOpBase<Context>::Reshape() {
channels = input(0).dim(channel_axis);
channels = data_format == "NCHW" ? input(0).dim(1) : input(0).dim(-1);
if (ReverseDimensions()) {
conv_out_channels = channels;
conv_in_channels = num_output;
......@@ -85,61 +207,91 @@ void ConvOpBase<Context>::Reshape() {
conv_out_channels = num_output;
conv_in_channels = channels;
}
weight_shape.assign({ conv_out_channels,
conv_in_channels / group,
kernel_size[0],
kernel_size[1]});
// determine the weight and bias shape
if (data_format == "NCHW") {
weight_shape.assign({ conv_out_channels,
conv_in_channels / group });
for (int i = 0; i < num_spatial_axes; i++)
weight_shape.push_back(kernel_size[i]);
} else if (data_format == "NHWC") {
weight_shape.clear();
for (int i = 0; i < num_spatial_axes; i++)
weight_shape.push_back(kernel_size[i]);
weight_shape.push_back(conv_in_channels / group);
weight_shape.push_back(conv_out_channels);
}
bias_shape.assign(1, num_output);
// compute bottom and top shape
// determine the bottom and top shape
bottom_shape = input(0).dims();
ComputeOutputShape();
vector<TIndex> top_shape({input(0).dim(0),
num_output,
output_shape[0],
output_shape[1]});
output(0)->Reshape(top_shape);
if (ReverseDimensions()) {
conv_out_spatial_dim = input(0).count(channel_axis + 1);
} else {
conv_out_spatial_dim = output(0)->count(channel_axis + 1);
if (data_format == "NCHW") {
top_shape.assign({ input(0).dim(0), num_output });
for (int i = 0; i < num_spatial_axes; i++)
top_shape.push_back(output_shape[i]);
} else if (data_format == "NHWC") {
top_shape.assign({ input(0).dim(0) });
for (int i = 0; i < num_spatial_axes; i++)
top_shape.push_back(output_shape[i]);
top_shape.push_back(num_output);
}
output(0)->Reshape(top_shape);
// compute input shape
// determine the input shape for im2col/col2im
input_shape.clear();
for (int i = 0; i < num_spatial_axes; i++) {
if (ReverseDimensions()) {
input_shape.push_back(output(0)->dim(channel_axis + i + 1));
input_shape.push_back(output(0)->dim(spatial_axis + i));
} else {
input_shape.push_back(input(0).dim(channel_axis + i + 1));
input_shape.push_back(input(0).dim(spatial_axis + i));
}
}
kernel_dim = conv_in_channels / group * kernel_size[0] * kernel_size[1];
out_spatial_dim = output(0)->count(channel_axis + 1);
// determine the out spatial dim
if (data_format == "NCHW") {
if (ReverseDimensions()) {
conv_out_spatial_dim = input(0).count(spatial_axis);
} else {
conv_out_spatial_dim = output(0)->count(spatial_axis);
}
out_spatial_dim = output(0)->count(spatial_axis);
} else if (data_format == "NHWC") {
if (ReverseDimensions()) {
conv_out_spatial_dim = input(0).count(spatial_axis, (int)input(0).ndim() - 1);
} else {
conv_out_spatial_dim = output(0)->count(spatial_axis, (int)output(0)->ndim() - 1);
}
out_spatial_dim = output(0)->count(spatial_axis, (int)output(0)->ndim() - 1);
}
x_offset = input(0).count(channel_axis);
y_offset = output(0)->count(channel_axis);
// determine the misc
x_offset = input(0).count(1);
y_offset = output(0)->count(1);
kernel_dim = conv_in_channels / group * kernel_size[0] * kernel_size[1];
weight_offset = conv_out_channels * kernel_dim / group;
col_offset = kernel_dim * conv_out_spatial_dim;
output_offset = conv_out_channels * conv_out_spatial_dim / group;
// compute col buffer shape
col_buffer_shape.clear();
col_buffer_shape.push_back(kernel_dim * group);
for (int i = 0; i < num_spatial_axes; i++) {
if (ReverseDimensions()) {
col_buffer_shape.push_back(bottom_shape[channel_axis + i + 1]);
} else {
col_buffer_shape.push_back(output_shape[i]);
// determine the col shape
col_shape.clear();
if (data_format == "NCHW") {
col_shape.push_back(kernel_dim * group);
for (int i = 0; i < num_spatial_axes; i++) {
if (ReverseDimensions()) col_shape.push_back(bottom_shape[spatial_axis + i]);
else col_shape.push_back(output_shape[i]);
}
} else if (data_format == "NHWC") {
for (int i = 0; i < num_spatial_axes; i++) {
if (ReverseDimensions()) col_shape.push_back(bottom_shape[spatial_axis + i]);
else col_shape.push_back(output_shape[i]);
}
col_shape.push_back(kernel_dim * group);
}
}
template <class Context>
void ConvOpBase<Context>::GradientReshape() {
channels = input(0).dim(channel_axis);
channels = data_format == "NCHW" ? input(0).dim(1) : input(0).dim(-1);
if (ReverseDimensions()) {
conv_out_channels = channels;
conv_in_channels = num_output;
......@@ -147,50 +299,66 @@ void ConvOpBase<Context>::GradientReshape() {
conv_out_channels = num_output;
conv_in_channels = channels;
}
// determine the bottom and top shape
bottom_shape = input(0).dims();
ComputeOutputShape();
output(0)->Reshape(bottom_shape);
output(1)->ReshapeLike(input(1));
output(2)->Reshape(vector<TIndex>(1, num_output));
if (ReverseDimensions()) {
conv_out_spatial_dim = input(0).count(channel_axis + 1);
} else {
conv_out_spatial_dim = input(2).count(channel_axis + 1);
}
// compute input shape
// determine the input shape for im2col/col2im
input_shape.clear();
for (int i = 0; i < num_spatial_axes; i++) {
if (ReverseDimensions()) {
input_shape.push_back(input(2).dim(channel_axis + i + 1));
input_shape.push_back(input(-1).dim(spatial_axis + i));
} else {
input_shape.push_back(input(0).dim(channel_axis + i + 1));
input_shape.push_back(input(0).dim(spatial_axis + i));
}
}
kernel_dim = input(1).count(1); // in * kh * kw
out_spatial_dim = input(2).count(channel_axis + 1);
// determine the out spatial dim
if (data_format == "NCHW") {
if (ReverseDimensions()) {
conv_out_spatial_dim = input(0).count(spatial_axis);
} else {
conv_out_spatial_dim = input(-1).count(spatial_axis);
}
out_spatial_dim = input(-1).count(spatial_axis);
} else if (data_format == "NHWC") {
if (ReverseDimensions()) {
conv_out_spatial_dim = input(0).count(spatial_axis, (int)input(0).ndim() - 1);
} else {
conv_out_spatial_dim = input(-1).count(spatial_axis, (int)input(-1).ndim() - 1);
}
out_spatial_dim = input(-1).count(spatial_axis, (int)input(-1).ndim() - 1);
}
x_offset = input(0).count(channel_axis);
y_offset = input(2).count(channel_axis);
// determine the misc
x_offset = input(0).count(1);
y_offset = input(-1).count(1);
kernel_dim = conv_in_channels / group * kernel_size[0] * kernel_size[1];
weight_offset = conv_out_channels * kernel_dim / group;
col_offset = kernel_dim * conv_out_spatial_dim;
output_offset = conv_out_channels * conv_out_spatial_dim / group;
// compute col buffer shape
col_buffer_shape.clear();
col_buffer_shape.push_back(kernel_dim * group);
for (int i = 0; i < num_spatial_axes; i++) {
if (ReverseDimensions()) {
col_buffer_shape.push_back(bottom_shape[channel_axis + i + 1]);
} else {
col_buffer_shape.push_back(output_shape[i]);
// determine the col shape
col_shape.clear();
if (data_format == "NCHW") {
col_shape.push_back(kernel_dim * group);
for (int i = 0; i < num_spatial_axes; i++) {
if (ReverseDimensions()) col_shape.push_back(bottom_shape[spatial_axis + i]);
else col_shape.push_back(output_shape[i]);
}
} else if (data_format == "NHWC") {
for (int i = 0; i < num_spatial_axes; i++) {
if (ReverseDimensions()) col_shape.push_back(bottom_shape[spatial_axis + i]);
else col_shape.push_back(output_shape[i]);
}
col_shape.push_back(kernel_dim * group);
}
}
template class ConvOpBase<CPUContext>;
template class ConvOpBase<CPUContext>;;
template void ConvOpBase<CPUContext>::Wx(const float*, const float*, float*, bool);
template void ConvOpBase<CPUContext>::Pb(const float*, float*);
template void ConvOpBase<CPUContext>::Dx(const float*, const float*, float*);
......
......@@ -10,39 +10,58 @@ namespace dragon {
#define WORKSPACE_LIMIT_BYTES 64 * 1024 * 1024 // 64MB
template <class Context> template <typename T>
void CuDNNConvOp<Context>::RunWithType() {
void CuDNNConv2dOp<Context>::RunWithType() {
#if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnSetFilter4dDescriptor(filter_desc,
CUDNNType<T>::type,
CUDNN_TENSOR_NCHW,
this->num_output / this->group,
CUDNN_CHECK(cudnnSetFilter4dDescriptor(filter_desc,
CUDNNType<T>::type,
format,
this->num_output / this->group,
this->channels / this->group,
this->kernel_size[0], this->kernel_size[1]));
#else
CUDNN_CHECK(cudnnSetFilter4dDescriptor_v4(filter_desc,
CUDNNType<T>::type,
CUDNN_TENSOR_NCHW,
this->num_output / this->group,
CUDNN_CHECK(cudnnSetFilter4dDescriptor_v4(filter_desc,
CUDNNType<T>::type,
format,
this->num_output / this->group,
this->channels / this->group,
this->kernel_size[0], this->kernel_size[1]));
#endif
cudnnSetTensorDesc<T>(&input_desc,
vector<TIndex>({ input(0).dim(0),
input(0).dim(1) / this->group,
input(0).dim(2),
input(0).dim(3) }),
vector<TIndex>({ input(0).count(1),
input(0).count(2),
input(0).count(3),
1 }));
cudnnSetTensorDesc<T>(&output_desc,
vector<TIndex>({ output(0)->dim(0),
output(0)->dim(1) / this->group,
output(0)->dim(2),
output(0)->dim(3) }),
vector<TIndex>({ output(0)->count(1),
output(0)->count(2),
output(0)->count(3), 1 }));
Tensor fake_tensor;
vector<TIndex> fake_dims;
if (this->data_format == "NCHW") {
// determine the input shape
fake_tensor.ReshapeLike(input(0));
fake_dims = fake_tensor.dims();
fake_dims[1] /= this->group;
cudnnSetTensor4dDesc<T>(&input_desc, this->data_format, fake_dims);
// determine the output shape
fake_tensor.ReshapeLike(*output(0));
fake_dims = fake_tensor.dims();
fake_dims[1] /= this->group;
cudnnSetTensor4dDesc<T>(&output_desc, this->data_format, fake_dims);
// determine the bias shape if necessary
if (HasBias()) {
bias_offset = this->num_output / this->group;
cudnnSetTensor4dDesc<T>(&bias_desc, this->data_format, vector<TIndex>({ 1, bias_offset, 1, 1 }));
}
} else if (this->data_format == "NHWC") {
// determine the input shape
fake_tensor.ReshapeLike(input(0));
fake_dims = fake_tensor.dims();
fake_dims[fake_dims.size() - 1] /= this->group;
cudnnSetTensor4dDesc<T>(&input_desc, this->data_format, fake_dims);
// determine the output shape
fake_tensor.ReshapeLike(*output(0));
fake_dims = fake_tensor.dims();
fake_dims[fake_dims.size() - 1] /= this->group;
cudnnSetTensor4dDesc<T>(&output_desc, this->data_format, fake_dims);
// determine the bias shape if necessary
if (HasBias()) {
bias_offset = this->num_output / this->group;
cudnnSetTensor4dDesc<T>(&bias_desc, this->data_format, vector<TIndex>({ 1, 1, 1, bias_offset }));
}
}
CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm(handle[0],
input_desc,
......@@ -64,45 +83,40 @@ void CuDNNConvOp<Context>::RunWithType() {
Tensor* buffer = ws()->GetBuffer();
if (workspace_fwd_data_size == 0) workspace_fwd_data_size += 1;
buffer->Reshape(vector<TIndex>(1, this->group * workspace_fwd_data_size));
if (InputSize() > 2) {
bias_offset = this->num_output / this->group;
cudnnSetTensorDesc<T>(&bias_desc, vector<TIndex>({ 1, bias_offset, 1, 1 }));
}
auto* Xdata = input(0).template data<T, Context>();
auto* Ydata = output(0)->template mutable_data<T, Context>();
TENSOR_FILL(input(1), this->weight_shape);
auto* Wdata = input(1).template data<T, Context>();
if (InputSize() > 2) TENSOR_FILL(input(2), this->bias_shape);
if (HasBias()) TENSOR_FILL(input(2), this->bias_shape);
for (int g = 0; g < this->group; g++) {
auto* workspace = buffer->template mutable_data<char, Context>();
CUDNN_CHECK(cudnnConvolutionForward(handle[g],
CUDNNType<T>::one, input_desc, Xdata + this->x_offset * g,
filter_desc, Wdata + this->weight_offset * g,
conv_desc,
fwd_algo,
workspace + g * workspace_fwd_data_size, workspace_fwd_data_size,
CUDNNType<T>::zero, output_desc, Ydata + this->y_offset * g));
if (InputSize() > 2) {
CUDNN_CHECK(cudnnConvolutionForward(handle[g],
CUDNNType<T>::one, input_desc, Xdata + this->x_offset * g,
filter_desc, Wdata + this->weight_offset * g,
conv_desc,
fwd_algo,
workspace + g * workspace_fwd_data_size, workspace_fwd_data_size,
CUDNNType<T>::zero, output_desc, Ydata + this->y_offset * g));
if (HasBias()) {
auto* bias = input(2).template data<T, Context>();
CUDNN_CHECK(cudnnAddTensor(handle[g],
CUDNNType<T>::one, bias_desc, bias + this->bias_offset * g,
CUDNNType<T>::one, output_desc, Ydata + this->y_offset * g));
CUDNN_CHECK(cudnnAddTensor(handle[g],
CUDNNType<T>::one, bias_desc, bias + this->bias_offset * g,
CUDNNType<T>::one, output_desc, Ydata + this->y_offset * g));
}
}
kernel::Empty<T, Context>();
ws()->ReleaseBuffer(buffer);
ws()->ReleaseBuffer(buffer);
}
template <class Context>
void CuDNNConvOp<Context>::RunOnDevice() {
void CuDNNConv2dOp<Context>::RunOnDevice() {
#if CUDNN_VERSION_MAX(6, 0, 0)
for (int i = 0; i < this->dilation.size(); i++)
if (this->dilation[i] != 1) return ConvOp<Context>::RunOnDevice();
if (this->dilation[i] != 1) return Conv2dOp<Context>::RunOnDevice();
#endif
ConvOp<Context>::Reshape();
Conv2dOp<Context>::Reshape();
this->x_offset /= this->group;
this->y_offset /= this->group;
......@@ -143,76 +157,94 @@ void CuDNNConvOp<Context>::RunOnDevice() {
} else { LOG(FATAL) << "Unsupported input types."; }
}
DEPLOY_CUDNN(Conv);
DEPLOY_CUDNN(Conv2d);
template <class Context> template <typename T>
void CuDNNConvGradientOp<Context>::RunWithType() {
void CuDNNConv2dGradientOp<Context>::RunWithType() {
#if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnSetFilter4dDescriptor(filter_desc,
CUDNNType<T>::type,
CUDNN_TENSOR_NCHW,
this->num_output / this->group,
CUDNN_CHECK(cudnnSetFilter4dDescriptor(filter_desc,
CUDNNType<T>::type,
format,
this->num_output / this->group,
this->channels / this->group,
this->kernel_size[0], this->kernel_size[1]));
#else
CUDNN_CHECK(cudnnSetFilter4dDescriptor_v4(filter_desc,
CUDNNType<T>::type,
CUDNN_TENSOR_NCHW,
this->num_output / this->group,
CUDNN_CHECK(cudnnSetFilter4dDescriptor_v4(filter_desc,
CUDNNType<T>::type,
format,
this->num_output / this->group,
this->channels / this->group,
this->kernel_size[0], this->kernel_size[1]));
#endif
cudnnSetTensorDesc<T>(&input_desc,
vector<TIndex>({ input(-1).dim(0),
input(-1).dim(1) / this->group,
input(-1).dim(2),
input(-1).dim(3) }),
vector<TIndex>({ input(-1).count(1),
input(-1).count(2),
input(-1).count(3),
1 }));
cudnnSetTensorDesc<T>(&output_desc,
vector<TIndex>({ input(0).dim(0),
input(0).dim(1) / this->group,
input(0).dim(2),
input(0).dim(3) }),
vector<TIndex>({ input(0).count(1),
input(0).count(2),
input(0).count(3),
1 }));
Tensor fake_tensor;
vector<TIndex> fake_dims;
if (this->data_format == "NCHW") {
// determine the input shape
fake_tensor.ReshapeLike(input(-1));
fake_dims = fake_tensor.dims();
fake_dims[1] /= this->group;
cudnnSetTensor4dDesc<T>(&input_desc, this->data_format, fake_dims);
// determine the output shape
fake_tensor.ReshapeLike(input(0));
fake_dims = fake_tensor.dims();
fake_dims[1] /= this->group;
cudnnSetTensor4dDesc<T>(&output_desc, this->data_format, fake_dims);
// determine the bias shape if necessary
if (HasBias()) {
bias_offset = this->num_output / this->group;
cudnnSetTensor4dDesc<T>(&bias_desc, this->data_format, vector<TIndex>({ 1, bias_offset, 1, 1 }));
}
} else if (this->data_format == "NHWC") {
// determine the input shape
fake_tensor.ReshapeLike(input(-1));
fake_dims = fake_tensor.dims();
fake_dims[fake_dims.size() - 1] /= this->group;
cudnnSetTensor4dDesc<T>(&input_desc, this->data_format, fake_dims);
// determine the output shape
fake_tensor.ReshapeLike(input(0));
fake_dims = fake_tensor.dims();
fake_dims[fake_dims.size() - 1] /= this->group;
cudnnSetTensor4dDesc<T>(&output_desc, this->data_format, fake_dims);
// determine the bias shape if necessary
if (HasBias()) {
bias_offset = this->num_output / this->group;
cudnnSetTensor4dDesc<T>(&bias_desc, this->data_format, vector<TIndex>({ 1, 1, 1, bias_offset }));
}
}
CUDNN_CHECK(cudnnGetConvolutionBackwardFilterAlgorithm(handle[0],
output_desc,
input_desc,
conv_desc,
output_desc,
input_desc,
conv_desc,
filter_desc,
CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
WORKSPACE_LIMIT_BYTES,
WORKSPACE_LIMIT_BYTES,
&bwd_filter_algo));
CUDNN_CHECK(cudnnGetConvolutionBackwardFilterWorkspaceSize(handle[0],
output_desc,
input_desc,
conv_desc,
output_desc,
input_desc,
conv_desc,
filter_desc,
bwd_filter_algo,
bwd_filter_algo,
&workspace_bwd_filter_size));
CUDNN_CHECK(cudnnGetConvolutionBackwardDataAlgorithm(handle[0],
filter_desc,
input_desc,
conv_desc,
filter_desc,
input_desc,
conv_desc,
output_desc,
CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
WORKSPACE_LIMIT_BYTES,
WORKSPACE_LIMIT_BYTES,
&bwd_data_algo));
CUDNN_CHECK(cudnnGetConvolutionBackwardDataWorkspaceSize(handle[0],
filter_desc,
input_desc,
conv_desc,
filter_desc,
input_desc,
conv_desc,
output_desc,
bwd_data_algo,
bwd_data_algo,
&workspace_bwd_data_size));
Tensor* buffer1 = ws()->GetBuffer();
......@@ -221,42 +253,38 @@ void CuDNNConvGradientOp<Context>::RunWithType() {
if (workspace_bwd_filter_size == 0) workspace_bwd_filter_size += 1;
buffer1->Reshape(vector<TIndex>(1, this->group * workspace_bwd_data_size));
buffer2->Reshape(vector<TIndex>(1, this->group * workspace_bwd_filter_size));
if (output(2)->name() != "ignore") {
bias_offset = this->num_output / this->group;
cudnnSetTensorDesc<T>(&bias_desc, vector<TIndex>({ 1, bias_offset, 1, 1 }));
}
const T* dYdata = input(2).template data<T, Context>();
for (int g = 0; g < this->group; g++) {
if (output(2)->name() != "ignore") {
T* dBdata = output(2)->template mutable_data<T, Context>();
CUDNN_CHECK(cudnnConvolutionBackwardBias(handle[g],
CUDNNType<T>::one, input_desc, dYdata + this->y_offset * g,
CUDNNType<T>::one, bias_desc, dBdata + bias_offset * g));
CUDNNType<T>::one, input_desc, dYdata + this->y_offset * g,
CUDNNType<T>::one, bias_desc, dBdata + bias_offset * g));
}
if (output(1)->name() != "ignore") {
auto* Xdata = input(0).template data<T, Context>();
auto* dWdata = output(1)->template mutable_data<T, Context>();
auto* workspace = buffer2->mutable_data<char, Context>();
CUDNN_CHECK(cudnnConvolutionBackwardFilter(handle[1 * this->group + g],
CUDNNType<T>::one, output_desc, Xdata + this->x_offset * g,
input_desc, dYdata + this->y_offset * g,
conv_desc,
bwd_filter_algo,
workspace + g * workspace_bwd_filter_size, workspace_bwd_filter_size,
CUDNNType<T>::one, filter_desc, dWdata + this->weight_offset * g));
CUDNNType<T>::one, output_desc, Xdata + this->x_offset * g,
input_desc, dYdata + this->y_offset * g,
conv_desc,
bwd_filter_algo,
workspace + g * workspace_bwd_filter_size, workspace_bwd_filter_size,
CUDNNType<T>::one, filter_desc, dWdata + this->weight_offset * g));
}
if (output(0)->name() != "ignore") {
auto* Wdata = input(1).template data<T, Context>();
auto* dXdata = output(0)->template mutable_data<T, Context>();
auto* workspace = buffer1->mutable_data<char, Context>();
CUDNN_CHECK(cudnnConvolutionBackwardData(handle[2 * this->group + g],
CUDNNType<T>::one, filter_desc, Wdata + this->weight_offset * g,
input_desc, dYdata + this->y_offset * g,
conv_desc,
bwd_data_algo,
workspace + g * workspace_bwd_data_size, workspace_bwd_data_size,
CUDNNType<T>::zero, output_desc, dXdata + this->x_offset * g));
CUDNNType<T>::one, filter_desc, Wdata + this->weight_offset * g,
input_desc, dYdata + this->y_offset * g,
conv_desc,
bwd_data_algo,
workspace + g * workspace_bwd_data_size, workspace_bwd_data_size,
CUDNNType<T>::zero, output_desc, dXdata + this->x_offset * g));
}
}
kernel::Empty<T, Context>();
......@@ -265,12 +293,12 @@ void CuDNNConvGradientOp<Context>::RunWithType() {
}
template <class Context>
void CuDNNConvGradientOp<Context>::RunOnDevice() {
void CuDNNConv2dGradientOp<Context>::RunOnDevice() {
#if CUDNN_VERSION_MAX(6, 0, 0)
for (int i = 0; i < this->dilation.size(); i++)
if (this->dilation[i] != 1) return ConvGradientOp<Context>::RunOnDevice();
if (this->dilation[i] != 1) return Conv2dGradientOp<Context>::RunOnDevice();
#endif
ConvGradientOp<Context>::GradientReshape();
Conv2dGradientOp<Context>::GradientReshape();
this->x_offset /= this->group;
this->y_offset /= this->group;
......@@ -310,7 +338,7 @@ void CuDNNConvGradientOp<Context>::RunOnDevice() {
} else { LOG(FATAL) << "Unsupported input types."; }
}
DEPLOY_CUDNN(ConvGradient);
DEPLOY_CUDNN(Conv2dGradient);
} // namespace dragon
......
#ifdef WITH_CUDNN
#include "operators/vision/deconv_op.h"
#include "operators/vision/conv_transpose_op.h"
#include "core/workspace.h"
#include "utils/filler.h"
#include "utils/op_kernel.h"
......@@ -10,40 +10,58 @@ namespace dragon {
#define WORKSPACE_LIMIT_BYTES 64 * 1024 * 1024 // 64MB
template <class Context> template <typename T>
void CuDNNDeConvOp<Context>::RunWithType() {
void CuDNNConv2dTransposeOp<Context>::RunWithType() {
#if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnSetFilter4dDescriptor(filter_desc,
CUDNNType<T>::type,
CUDNN_TENSOR_NCHW,
this->channels / this->group,
CUDNN_CHECK(cudnnSetFilter4dDescriptor(filter_desc,
CUDNNType<T>::type,
format,
this->num_output / this->group,
this->channels / this->group,
this->kernel_size[0], this->kernel_size[1]));
#else
CUDNN_CHECK(cudnnSetFilter4dDescriptor_v4(filter_desc,
CUDNNType<T>::type,
CUDNN_TENSOR_NCHW,
this->channels / this->group,
CUDNN_CHECK(cudnnSetFilter4dDescriptor_v4(filter_desc,
CUDNNType<T>::type,
format,
this->num_output / this->group,
this->channels / this->group,
this->kernel_size[0], this->kernel_size[1]));
#endif
cudnnSetTensorDesc<T>(&input_desc,
vector<TIndex>({ input(0).dim(0),
input(0).dim(1) / this->group,
input(0).dim(2),
input(0).dim(3) }),
vector<TIndex>({ input(0).count(1),
input(0).count(2),
input(0).count(3),
1 }));
cudnnSetTensorDesc<T>(&output_desc,
vector<TIndex>({ output(0)->dim(0),
output(0)->dim(1) / this->group,
output(0)->dim(2),
output(0)->dim(3) }),
vector<TIndex>({ output(0)->count(1),
output(0)->count(2),
output(0)->count(3),
1 }));
Tensor fake_tensor;
vector<TIndex> fake_dims;
if (this->data_format == "NCHW") {
// determine the input shape
fake_tensor.ReshapeLike(input(0));
fake_dims = fake_tensor.dims();
fake_dims[1] /= this->group;
cudnnSetTensor4dDesc<T>(&input_desc, this->data_format, fake_dims);
// determine the output shape
fake_tensor.ReshapeLike(*output(0));
fake_dims = fake_tensor.dims();
fake_dims[1] /= this->group;
cudnnSetTensor4dDesc<T>(&output_desc, this->data_format, fake_dims);
// determine the bias shape if necessary
if (HasBias()) {
bias_offset = this->num_output / this->group;
cudnnSetTensor4dDesc<T>(&bias_desc, this->data_format, vector<TIndex>({ 1, bias_offset, 1, 1 }));
}
} else if (this->data_format == "NHWC") {
// determine the input shape
fake_tensor.ReshapeLike(input(0));
fake_dims = fake_tensor.dims();
fake_dims[fake_dims.size() - 1] /= this->group;
cudnnSetTensor4dDesc<T>(&input_desc, this->data_format, fake_dims);
// determine the output shape
fake_tensor.ReshapeLike(*output(0));
fake_dims = fake_tensor.dims();
fake_dims[fake_dims.size() - 1] /= this->group;
cudnnSetTensor4dDesc<T>(&output_desc, this->data_format, fake_dims);
// determine the bias shape if necessary
if (HasBias()) {
bias_offset = this->num_output / this->group;
cudnnSetTensor4dDesc<T>(&bias_desc, this->data_format, vector<TIndex>({ 1, 1, 1, bias_offset }));
}
}
CUDNN_CHECK(cudnnGetConvolutionBackwardDataAlgorithm(handle[0],
filter_desc,
......@@ -65,32 +83,28 @@ void CuDNNDeConvOp<Context>::RunWithType() {
Tensor* buffer = ws()->GetBuffer();
if (workspace_fwd_data_size == 0) workspace_fwd_data_size += 1;
buffer->Reshape(vector<TIndex>(1, this->group * workspace_fwd_data_size));
if (InputSize() > 2) {
bias_offset = this->num_output / this->group;
cudnnSetTensorDesc<T>(&bias_desc, vector<TIndex>({ 1, bias_offset, 1, 1 }));
}
auto* Xdata = input(0).template data<T, Context>();
auto* Ydata = output(0)->template mutable_data<T, Context>();
TENSOR_FILL(input(1), this->weight_shape);
auto* Wdata = input(1).template data<T, Context>();
if (InputSize() > 2) TENSOR_FILL(input(2), this->bias_shape);
if (HasBias()) TENSOR_FILL(input(2), this->bias_shape);
for (int g = 0; g < this->group; g++) {
auto* workspace = buffer->template mutable_data<char, Context>();
CUDNN_CHECK(cudnnConvolutionBackwardData(handle[g],
CUDNNType<T>::one, filter_desc, Wdata + this->weight_offset * g,
input_desc, Xdata + this->x_offset * g,
conv_desc,
fwd_algo,
workspace + g * workspace_fwd_data_size, workspace_fwd_data_size,
CUDNNType<T>::zero, output_desc, Ydata + this->y_offset * g));
CUDNN_CHECK(cudnnConvolutionBackwardData(handle[g],
CUDNNType<T>::one, filter_desc, Wdata + this->weight_offset * g,
input_desc, Xdata + this->x_offset * g,
conv_desc,
fwd_algo,
workspace + g * workspace_fwd_data_size, workspace_fwd_data_size,
CUDNNType<T>::zero, output_desc, Ydata + this->y_offset * g));
if (InputSize() > 2) {
if (HasBias()) {
auto* bias = input(2).template data<T, Context>();
CUDNN_CHECK(cudnnAddTensor(handle[g],
CUDNNType<T>::one, bias_desc, bias + this->bias_offset * g,
CUDNNType<T>::one, output_desc, Ydata + this->y_offset * g));
CUDNNType<T>::one, bias_desc, bias + this->bias_offset * g,
CUDNNType<T>::one, output_desc, Ydata + this->y_offset * g));
}
}
kernel::Empty<T, Context>();
......@@ -98,28 +112,28 @@ void CuDNNDeConvOp<Context>::RunWithType() {
}
template <class Context>
void CuDNNDeConvOp<Context>::RunOnDevice() {
void CuDNNConv2dTransposeOp<Context>::RunOnDevice() {
#if CUDNN_VERSION_MAX(6, 0, 0)
for (int i = 0; i < this->dilation.size(); i++)
if (this->dilation[i] != 1) return DeConvOp<Context>::RunOnDevice();
if (this->dilation[i] != 1) return Conv2dTransposeOp<Context>::RunOnDevice();
#endif
DeConvOp<Context>::Reshape();
Conv2dTransposeOp<Context>::Reshape();
this->x_offset /= this->group;
this->y_offset /= this->group;
if (input(0).template IsType<float>()) {
#if CUDNN_VERSION_MIN(6, 0, 0)
CUDNN_CHECK(cudnnSetConvolution2dDescriptor(conv_desc,
this->pad[0], this->pad[1],
this->pad[0], this->pad[1],
this->stride[0], this->stride[1],
this->dilation[0], this->dilation[1],
CUDNN_CROSS_CORRELATION,
CUDNN_CROSS_CORRELATION,
CUDNN_DATA_FLOAT));
#else
CUDNN_CHECK(cudnnSetConvolution2dDescriptor(conv_desc,
this->pad[0], this->pad[1],
this->pad[0], this->pad[1],
this->stride[0], this->stride[1],
1, 1,
1, 1,
CUDNN_CROSS_CORRELATION));
#endif
RunWithType<float>();
......@@ -127,16 +141,16 @@ void CuDNNDeConvOp<Context>::RunOnDevice() {
#ifdef WITH_CUDA_FP16
#if CUDNN_VERSION_MIN(6, 0, 0)
CUDNN_CHECK(cudnnSetConvolution2dDescriptor(conv_desc,
this->pad[0], this->pad[1],
this->pad[0], this->pad[1],
this->stride[0], this->stride[1],
this->dilation[0], this->dilation[1],
CUDNN_CROSS_CORRELATION,
this->dilation[0], this->dilation[1],
CUDNN_CROSS_CORRELATION,
CUDNN_DATA_FLOAT));
#else
CUDNN_CHECK(cudnnSetConvolution2dDescriptor(conv_desc,
this->pad[0], this->pad[1],
this->pad[0], this->pad[1],
this->stride[0], this->stride[1],
1, 1,
1, 1,
CUDNN_CROSS_CORRELATION));
#endif
RunWithType<float16>();
......@@ -144,76 +158,94 @@ void CuDNNDeConvOp<Context>::RunOnDevice() {
} else { LOG(FATAL) << "Unsupported input types."; }
}
DEPLOY_CUDNN(DeConv);
DEPLOY_CUDNN(Conv2dTranspose);
template <class Context> template <typename T>
void CuDNNDeConvGradientOp<Context>::RunWithType() {
void CuDNNConv2dTransposeGradientOp<Context>::RunWithType() {
#if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnSetFilter4dDescriptor(filter_desc,
CUDNNType<T>::type,
CUDNN_TENSOR_NCHW,
this->channels / this->group,
CUDNN_CHECK(cudnnSetFilter4dDescriptor(filter_desc,
CUDNNType<T>::type,
format,
this->num_output / this->group,
this->channels / this->group,
this->kernel_size[0], this->kernel_size[1]));
#else
CUDNN_CHECK(cudnnSetFilter4dDescriptor_v4(filter_desc,
CUDNNType<T>::type,
CUDNN_TENSOR_NCHW,
this->channels / this->group,
CUDNN_CHECK(cudnnSetFilter4dDescriptor_v4(filter_desc,
CUDNNType<T>::type,
format,
this->num_output / this->group,
this->channels / this->group,
this->kernel_size[0], this->kernel_size[1]));
#endif
cudnnSetTensorDesc<T>(&input_desc,
vector<TIndex>({ input(-1).dim(0),
input(-1).dim(1) / this->group,
input(-1).dim(2),
input(-1).dim(3) }),
vector<TIndex>({ input(-1).count(1),
input(-1).count(2),
input(-1).count(3),
1 }));
cudnnSetTensorDesc<T>(&output_desc,
vector<TIndex>({ input(0).dim(0),
input(0).dim(1) / this->group,
input(0).dim(2),
input(0).dim(3) }),
vector<TIndex>({ input(0).count(1),
input(0).count(2),
input(0).count(3),
1 }));
Tensor fake_tensor;
vector<TIndex> fake_dims;
if (this->data_format == "NCHW") {
// determine the input shape
fake_tensor.ReshapeLike(input(-1));
fake_dims = fake_tensor.dims();
fake_dims[1] /= this->group;
cudnnSetTensor4dDesc<T>(&input_desc, this->data_format, fake_dims);
// determine the output shape
fake_tensor.ReshapeLike(input(0));
fake_dims = fake_tensor.dims();
fake_dims[1] /= this->group;
cudnnSetTensor4dDesc<T>(&output_desc, this->data_format, fake_dims);
// determine the bias shape if necessary
if (HasBias()) {
bias_offset = this->num_output / this->group;
cudnnSetTensor4dDesc<T>(&bias_desc, this->data_format, vector<TIndex>({ 1, bias_offset, 1, 1 }));
}
} else if (this->data_format == "NHWC") {
// determine the input shape
fake_tensor.ReshapeLike(input(-1));
fake_dims = fake_tensor.dims();
fake_dims[fake_dims.size() - 1] /= this->group;
cudnnSetTensor4dDesc<T>(&input_desc, this->data_format, fake_dims);
// determine the output shape
fake_tensor.ReshapeLike(input(0));
fake_dims = fake_tensor.dims();
fake_dims[fake_dims.size() - 1] /= this->group;
cudnnSetTensor4dDesc<T>(&output_desc, this->data_format, fake_dims);
// determine the bias shape if necessary
if (HasBias()) {
bias_offset = this->num_output / this->group;
cudnnSetTensor4dDesc<T>(&bias_desc, this->data_format, vector<TIndex>({ 1, 1, 1, bias_offset }));
}
}
CUDNN_CHECK(cudnnGetConvolutionBackwardFilterAlgorithm(handle[0],
input_desc,
output_desc,
conv_desc,
input_desc,
output_desc,
conv_desc,
filter_desc,
CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
WORKSPACE_LIMIT_BYTES,
WORKSPACE_LIMIT_BYTES,
&bwd_filter_algo));
CUDNN_CHECK(cudnnGetConvolutionBackwardFilterWorkspaceSize(handle[0],
input_desc,
output_desc,
conv_desc,
input_desc,
output_desc,
conv_desc,
filter_desc,
bwd_filter_algo,
bwd_filter_algo,
&workspace_bwd_filter_size));
CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm(handle[0],
input_desc,
filter_desc,
conv_desc,
input_desc,
filter_desc,
conv_desc,
output_desc,
CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
WORKSPACE_LIMIT_BYTES,
WORKSPACE_LIMIT_BYTES,
&bwd_data_algo));
CUDNN_CHECK(cudnnGetConvolutionForwardWorkspaceSize(handle[0],
input_desc,
filter_desc,
conv_desc,
input_desc,
filter_desc,
conv_desc,
output_desc,
bwd_data_algo,
bwd_data_algo,
&workspace_bwd_data_size));
Tensor* buffer1 = ws()->GetBuffer();
......@@ -222,42 +254,38 @@ void CuDNNDeConvGradientOp<Context>::RunWithType() {
if (workspace_bwd_filter_size == 0) workspace_bwd_filter_size += 1;
buffer1->Reshape(vector<TIndex>(1, this->group * workspace_bwd_data_size));
buffer2->Reshape(vector<TIndex>(1, this->group * workspace_bwd_filter_size));
if (output(2)->name() != "ignore") {
bias_offset = this->num_output / this->group;
cudnnSetTensorDesc<T>(&bias_desc, vector<TIndex>({ 1, bias_offset, 1, 1 }));
}
const T* dYdata = input(2).template data<T, Context>();
for (int g = 0; g < this->group; g++) {
if (output(2)->name() != "ignore") {
T* dBdata = output(2)->template mutable_data<T, Context>();
CUDNN_CHECK(cudnnConvolutionBackwardBias(handle[g],
CUDNNType<T>::one, input_desc, dYdata + this->y_offset * g,
CUDNNType<T>::one, bias_desc, dBdata + bias_offset * g));
CUDNN_CHECK(cudnnConvolutionBackwardBias(handle[g],
CUDNNType<T>::one, input_desc, dYdata + this->y_offset * g,
CUDNNType<T>::one, bias_desc, dBdata + bias_offset * g));
}
if (output(1)->name() != "ignore") {
auto* Xdata = input(0).template data<T, Context>();
auto* dWdata = output(1)->template mutable_data<T, Context>();
auto* workspace = buffer2->mutable_data<char, Context>();
CUDNN_CHECK(cudnnConvolutionBackwardFilter(handle[1 * this->group + g],
CUDNNType<T>::one, input_desc, dYdata + this->y_offset * g,
output_desc, Xdata + this->x_offset * g,
conv_desc,
bwd_filter_algo,
workspace + g * workspace_bwd_filter_size, workspace_bwd_filter_size,
CUDNNType<T>::one, filter_desc, dWdata + this->weight_offset * g));
CUDNNType<T>::one, input_desc, dYdata + this->y_offset * g,
output_desc, Xdata + this->x_offset * g,
conv_desc,
bwd_filter_algo,
workspace + g * workspace_bwd_filter_size, workspace_bwd_filter_size,
CUDNNType<T>::one, filter_desc, dWdata + this->weight_offset * g));
}
if (output(0)->name() != "ignore") {
auto* Wdata = input(1).template data<T, Context>();
auto* dXdata = output(0)->template mutable_data<T, Context>();
auto* workspace = buffer1->mutable_data<char, Context>();
CUDNN_CHECK(cudnnConvolutionForward(handle[2 * this->group + g],
CUDNNType<T>::one, input_desc, dYdata + this->y_offset * g,
filter_desc, Wdata + this->weight_offset * g,
conv_desc,
bwd_data_algo,
workspace + g * workspace_bwd_data_size, workspace_bwd_data_size,
CUDNNType<T>::zero, output_desc, dXdata + this->x_offset * g));
CUDNNType<T>::one, input_desc, dYdata + this->y_offset * g,
filter_desc, Wdata + this->weight_offset * g,
conv_desc,
bwd_data_algo,
workspace + g * workspace_bwd_data_size, workspace_bwd_data_size,
CUDNNType<T>::zero, output_desc, dXdata + this->x_offset * g));
}
}
kernel::Empty<T, Context>();
......@@ -266,26 +294,26 @@ void CuDNNDeConvGradientOp<Context>::RunWithType() {
}
template <class Context>
void CuDNNDeConvGradientOp<Context>::RunOnDevice() {
void CuDNNConv2dTransposeGradientOp<Context>::RunOnDevice() {
#if CUDNN_VERSION_MAX(6, 0, 0)
for (int i = 0; i < this->dilation.size(); i++)
if (this->dilation[i] != 1) return DeConvGradientOp<Context>::RunOnDevice();
if (this->dilation[i] != 1) return Conv2dTransposeGradientOp<Context>::RunOnDevice();
#endif
DeConvGradientOp<Context>::GradientReshape();
Conv2dTransposeGradientOp<Context>::GradientReshape();
this->x_offset /= this->group;
this->y_offset /= this->group;
if (input(0).template IsType<float>()) {
#if CUDNN_VERSION_MIN(6, 0, 0)
CUDNN_CHECK(cudnnSetConvolution2dDescriptor(conv_desc,
this->pad[0], this->pad[1],
this->pad[0], this->pad[1],
this->stride[0], this->stride[1],
this->dilation[0], this->dilation[1],
CUDNN_CROSS_CORRELATION,
this->dilation[0], this->dilation[1],
CUDNN_CROSS_CORRELATION,
CUDNN_DATA_FLOAT));
#else
CUDNN_CHECK(cudnnSetConvolution2dDescriptor(conv_desc,
this->pad[0], this->pad[1],
this->pad[0], this->pad[1],
this->stride[0], this->stride[1],
1, 1,
CUDNN_CROSS_CORRELATION));
......@@ -295,16 +323,16 @@ void CuDNNDeConvGradientOp<Context>::RunOnDevice() {
#ifdef WITH_CUDA_FP16
#if CUDNN_VERSION_MIN(6, 0, 0)
CUDNN_CHECK(cudnnSetConvolution2dDescriptor(conv_desc,
this->pad[0], this->pad[1],
this->pad[0], this->pad[1],
this->stride[0], this->stride[1],
this->dilation[0], this->dilation[1],
CUDNN_CROSS_CORRELATION,
this->dilation[0], this->dilation[1],
CUDNN_CROSS_CORRELATION,
CUDNN_DATA_FLOAT));
#else
CUDNN_CHECK(cudnnSetConvolution2dDescriptor(conv_desc,
this->pad[0], this->pad[1],
this->pad[0], this->pad[1],
this->stride[0], this->stride[1],
1, 1,
1, 1,
CUDNN_CROSS_CORRELATION));
#endif
RunWithType<float16>();
......@@ -312,7 +340,7 @@ void CuDNNDeConvGradientOp<Context>::RunOnDevice() {
} else { LOG(FATAL) << "Unsupported input types."; }
}
DEPLOY_CUDNN(DeConvGradient);
DEPLOY_CUDNN(Conv2dTransposeGradient);
} // namespace dragon
......
......@@ -5,38 +5,36 @@
namespace dragon {
template <class Context> template <typename T>
void CuDNNPoolingOp<Context>::RunWithType() {
cudnnSetTensorDesc<T>(&input_desc, &input(0));
cudnnSetTensorDesc<T>(&output_desc, output(0));
if (this->global_pooling) {
void CuDNNPooling2dOp<Context>::RunWithType() {
cudnnSetTensor4dDesc<T>(&input_desc, this->data_format, &input(0));
cudnnSetTensor4dDesc<T>(&output_desc, this->data_format, output(0));
#if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnSetPooling2dDescriptor(pool_desc,
pool_mode,
CUDNN_PROPAGATE_NAN,
input(0).dim(2), input(0).dim(3),
0, 0,
1, 1));
CUDNN_CHECK(cudnnSetPooling2dDescriptor(pool_desc,
pool_mode,
CUDNN_PROPAGATE_NAN,
this->kernel_size[0], this->kernel_size[1],
this->pad[0], this->pad[1],
this->stride[0], this->stride[1]));
#else
CUDNN_CHECK(cudnnSetPooling2dDescriptor_v4(pool_desc,
pool_mode,
CUDNN_PROPAGATE_NAN,
input(0).dim(2), input(0).dim(3),
0, 0,
1, 1));
CUDNN_CHECK(cudnnSetPooling2dDescriptor_v4(pool_desc,
pool_mode,
CUDNN_PROPAGATE_NAN,
this->kernel_size[0], this->kernel_size[1],
this->pad[0], this->pad[1],
this->stride[0], this->stride[1]));
#endif
}
auto* Xdata = input(0).template data<T, Context>();
auto* Ydata = output(0)->template mutable_data<T, Context>();
CUDNN_CHECK(cudnnPoolingForward(cudnn_handle(),
CUDNN_CHECK(cudnnPoolingForward(cudnn_handle(),
pool_desc,
CUDNNType<T>::one, input_desc, Xdata,
CUDNNType<T>::zero, output_desc, Ydata));
}
template <class Context>
void CuDNNPoolingOp<Context>::RunOnDevice() {
PoolingOp<Context>::Reshape();
void CuDNNPooling2dOp<Context>::RunOnDevice() {
Pooling2dOp<Context>::Reshape();
if (input(0).template IsType<float>()) return RunWithType<float>();
#ifdef WITH_CUDA_FP16
......@@ -45,29 +43,27 @@ void CuDNNPoolingOp<Context>::RunOnDevice() {
else LOG(FATAL) << "Unsupported input types.";
}
DEPLOY_CUDNN(Pooling);
DEPLOY_CUDNN(Pooling2d);
template <class Context> template <typename T>
void CuDNNPoolingGradientOp<Context>::RunWithType() {
cudnnSetTensorDesc<T>(&input_desc, &input(-1));
cudnnSetTensorDesc<T>(&output_desc, output(0));
if (this->global_pooling) {
void CuDNNPooling2dGradientOp<Context>::RunWithType() {
cudnnSetTensor4dDesc<T>(&input_desc, this->data_format, &input(-1));
cudnnSetTensor4dDesc<T>(&output_desc, this->data_format, output(0));
#if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnSetPooling2dDescriptor(pool_desc,
pool_mode,
CUDNN_PROPAGATE_NAN,
input(0).dim(2), input(0).dim(3),
0, 0,
1, 1));
CUDNN_CHECK(cudnnSetPooling2dDescriptor(pool_desc,
pool_mode,
CUDNN_PROPAGATE_NAN,
this->kernel_size[0], this->kernel_size[1],
this->pad[0], this->pad[1],
this->stride[0], this->stride[1]));
#else
CUDNN_CHECK(cudnnSetPooling2dDescriptor_v4(pool_desc,
pool_mode,
CUDNN_PROPAGATE_NAN,
input(0).dim(2), input(0).dim(3),
0, 0,
1, 1));
CUDNN_CHECK(cudnnSetPooling2dDescriptor_v4(pool_desc,
pool_mode,
CUDNN_PROPAGATE_NAN,
this->kernel_size[0], this->kernel_size[1],
this->pad[0], this->pad[1],
this->stride[0], this->stride[1]));
#endif
}
auto* dYdata = input(-1).template data<T, Context>();
auto* Xdata = input(0).template data<T, Context>();
auto* Ydata = input(1).template data<T, Context>();
......@@ -82,8 +78,8 @@ void CuDNNPoolingGradientOp<Context>::RunWithType() {
}
template <class Context>
void CuDNNPoolingGradientOp<Context>::RunOnDevice() {
PoolingGradientOp<Context>::Reshape();
void CuDNNPooling2dGradientOp<Context>::RunOnDevice() {
Pooling2dGradientOp<Context>::Reshape();
if (input(0).template IsType<float>()) return RunWithType<float>();
#ifdef WITH_CUDA_FP16
......@@ -92,7 +88,7 @@ void CuDNNPoolingGradientOp<Context>::RunOnDevice() {
else LOG(FATAL) << "Unsupported input types.";
}
DEPLOY_CUDNN(PoolingGradient);
DEPLOY_CUDNN(Pooling2dGradient);
} // namespace dragon
......
......@@ -49,7 +49,7 @@ void LRNOp<Context>::PoolRunWithType() {
ks.set_name("kernel_size"); ks.add_ints(local_size);
s.set_name("stride"); s.add_ints(1);
p.set_name("pad"); p.add_ints((local_size - 1) / 2);
mode.set_name("mode"); mode.set_i(AVG_POOLING);
mode.set_name("mode"); mode.set_s("AVG");
OperatorDef pool_op_def = MakeOperatorDef("Pooling", "",
vector<string>({ sqr_out->name() }),
vector<string>({ pool_out->name() }),
......@@ -177,7 +177,7 @@ void LRNGradientOp<Context>::PoolRunWithType() {
ks.set_name("kernel_size"); ks.add_ints(local_size);
s.set_name("stride"); s.add_ints(1);
p.set_name("pad"); p.add_ints((local_size - 1) / 2);
mode.set_name("mode"); mode.set_i(AVG_POOLING);
mode.set_name("mode"); mode.set_s("AVG");
OperatorDef pool_op_def = MakeOperatorDef("PoolingGradient", "",
vector<string>({ sqr_out->name(),
pool_out->name(),
......
......@@ -7,27 +7,42 @@ namespace dragon {
template <class Context> template <typename T>
void NNResizeOp<Context>::RunWithType() {
if (data_format == "NCHW") {
n = input(0).dim(0);
c = input(0).dim(1);
h = input(0).dim(2);
w = input(0).dim(3);
out_h = output(0)->dim(2);
out_w = output(0)->dim(3);
} else if (data_format == "NHWC") {
n = input(0).dim(0);
h = input(0).dim(1);
w = input(0).dim(2);
c = input(0).dim(3);
out_h = output(0)->dim(1);
out_w = output(0)->dim(2);
}
auto* Xdata = input(0).template data<T, Context>();
auto* Ydata = output(0)->template mutable_data<T, Context>();
kernel::NNResize<T, Context>(output(0)->count(), dims[0], dims[1],
input(0).dim(2), input(0).dim(3),
dims[2], dims[3],
Xdata,
Ydata);
kernel::NNResize<T, Context>(output(0)->count(), n, c, h, w,
out_h, out_w,
data_format,
Xdata,
Ydata);
}
template <class Context>
void NNResizeOp<Context>::RunOnDevice() {
dims = input(0).dims();
vector<TIndex> dims = input(0).dims();
if (dynamic_dsize.size() > 0) {
CHECK_EQ(dynamic_dsize.size(), 2)
<< "\nThe dsize should be a scalar with 2 elements.";
for (int i = 0; i < 2; i++) {
Tensor* t = ws()->GetTensor(dynamic_dsize[i]);
if (t->IsType<int>()) {
dims[2 + i] = t->template data<int, CPUContext>()[0];
dims[spatial_axis + i] = t->template data<int, CPUContext>()[0];
} else if (t->IsType<float>()) {
dims[2 + i] = t->template data<float, CPUContext>()[0];
dims[spatial_axis + i] = t->template data<float, CPUContext>()[0];
} else {
LOG(FATAL) << "Unsupported types of dsize.";
}
......@@ -35,15 +50,15 @@ void NNResizeOp<Context>::RunOnDevice() {
} else if (static_dsize.size() > 0) {
CHECK_EQ(static_dsize.size(), 2)
<< "\nThe dsize should be a scalar with 2 elements.";
for (int i = 0; i < 2; i++) dims[2 + i] = static_dsize[i];
for (int i = 0; i < 2; i++) dims[spatial_axis + i] = static_dsize[i];
} else {
CHECK(fy != -1.0 && fx != -1.0)
<< "\nThe fx and fy should be set.";
dims[2] = int(dims[2] * fy);
dims[3] = int(dims[3] * fx);
dims[spatial_axis] = int(dims[spatial_axis] * fy);
dims[spatial_axis + 1] = int(dims[spatial_axis + 1] * fx);
}
output(0)->Reshape(dims);
if (input(0).template IsType<float>()) return RunWithType<float>();
else LOG(FATAL) << "Unsupported input types.";
}
......@@ -56,14 +71,28 @@ OPERATOR_SCHEMA(NNResize).NumInputs(1).NumOutputs(1);
template <class Context> template <typename T>
void NNResizeGradientOp<Context>::RunWithType() {
if (data_format == "NCHW") {
n = input(0).dim(0);
c = input(0).dim(1);
h = input(0).dim(2);
w = input(0).dim(3);
out_h = input(-1).dim(2);
out_w = input(-1).dim(3);
} else if (data_format == "NHWC") {
n = input(0).dim(0);
h = input(0).dim(1);
w = input(0).dim(2);
c = input(0).dim(3);
out_h = input(-1).dim(1);
out_w = input(-1).dim(2);
}
auto* dYdata = input(-1).template data<T, Context>();
auto* dXdata = output(0)->template mutable_data<T, Context>();
math::Set<T, Context>(output(0)->count(), 0, dXdata);
kernel::NNResizeGrad<T, Context>(input(-1).count(), input(0).dim(0), input(0).dim(1),
input(-1).dim(2), input(-1).dim(3),
output(0)->dim(2), output(0)->dim(3),
dYdata,
dXdata);
kernel::NNResizeGrad<T, Context>(input(-1).count(), n, c, h, w,
out_h, out_w,
data_format,
dYdata,
dXdata);
}
template <class Context>
......
#include "operators/vision/pooling_op.h"
#include "core/workspace.h"
#include "utils/math_functions.h"
#include "utils/op_kernel.h"
namespace dragon {
template <class Context> template <typename T>
void Pooling2dOp<Context>::MAXRunWithType() {
mask = ws()->CreateTensor("_t_" + anchor() + "_pool_mask");
mask->ReshapeLike(*output(0));
auto* Xdata = input(0).template data<T, Context>();
auto* Ydata = output(0)->template mutable_data<T, Context>();
auto* Mdata = mask->template mutable_data<int, Context>();
kernel::MAXPooling2d<T, Context>(output(0)->count(),
n, c, h, w,
pool_h, pool_w,
kernel_size[0], kernel_size[1],
stride[0], stride[1],
pad[0], pad[1],
data_format,
Xdata,
Mdata,
Ydata);
}
template <class Context> template <typename T>
void Pooling2dOp<Context>::AVGRunWithType() {
auto* Xdata = input(0).template data<T, Context>();
auto* Ydata = output(0)->template mutable_data<T, Context>();
kernel::AVGPooling2d<T, Context>(output(0)->count(),
n, c, h, w,
pool_h, pool_w,
kernel_size[0], kernel_size[1],
stride[0], stride[1],
pad[0], pad[1],
data_format,
Xdata,
Ydata);
}
template <class Context>
void Pooling2dOp<Context>::Reshape() {
if (data_format == "NCHW") {
n = input(0).dim(0);
c = input(0).dim(1);
h = input(0).dim(2);
w = input(0).dim(3);
if (global_pooling) {
for (int i = 0; i < 2; i++)
kernel_size[i] = input(0).dim(i + 2);
}
if (padding == "SAME") {
for (int i = 0; i < 2; i++) {
TIndex input_size = input(0).dim(i + 2);
TIndex output_size = (input_size + stride[i] - 1) / (float)stride[i];
TIndex padding_needed = std::max(TIndex(0), (output_size - 1) * stride[i] + kernel_size[i] - input_size);
TIndex pad_l = padding_needed / 2;
TIndex pad_r = padding_needed - pad_l;
pad[i] = pad_l;
}
}
} else if (data_format == "NHWC") {
n = input(0).dim(0);
h = input(0).dim(1);
w = input(0).dim(2);
c = input(0).dim(3);
if (global_pooling) {
for (int i = 0; i < 2; i++)
kernel_size[i] = input(0).dim(i + 1);
}
if (padding == "SAME") {
for (int i = 0; i < 2; i++) {
TIndex input_size = input(0).dim(i + 1);
TIndex output_size = (input_size + stride[i] - 1) / (float)stride[i];
TIndex padding_needed = std::max(TIndex(0), (output_size - 1) * stride[i] + kernel_size[i] - input_size);
TIndex pad_l = padding_needed / 2;
TIndex pad_r = padding_needed - pad_l;
pad[i] = pad_l;
}
}
} else LOG(FATAL) << "Unknown data format: " << data_format;
if (padding != "SAME") {
// case 1: infer output shape with symmetry pad size
pool_h = ceil((h + 2 * pad[0] - kernel_size[0]) / (float)stride[0]) + 1;
pool_w = ceil((w + 2 * pad[1] - kernel_size[1]) / (float)stride[1]) + 1;
if ((pool_h - 1) * stride[0] >= (h + pad[0])) pool_h--;
if ((pool_w - 1) * stride[1] >= (w + pad[1])) pool_w--;
} else {
// case 2: infer output shape with adaptive pad size
pool_h = (h + stride[0] - 1) / (float)stride[0];
pool_w = (w + stride[1] - 1) / (float)stride[1];
}
if (data_format == "NCHW") output(0)->Reshape(vector<TIndex>({ n, c, pool_h, pool_w }));
else if (data_format == "NHWC") output(0)->Reshape(vector<TIndex>({ n, pool_h, pool_w, c }));
}
template <class Context>
void Pooling2dOp<Context>::RunOnDevice() {
Reshape();
if (mode == "MAX") {
if (input(0).template IsType<float>()) MAXRunWithType<float>();
else LOG(FATAL) << "Unsupported input types.";
} else if (mode == "AVG") {
if (input(0).template IsType<float>()) AVGRunWithType<float>();
else LOG(FATAL) << "Unsupported input types.";
} else {
LOG(FATAL) << "Unsupported pooling mode: " << mode;
}
}
DEPLOY_CPU(Pooling2d);
#ifdef WITH_CUDA
DEPLOY_CUDA(Pooling2d);
#endif
OPERATOR_SCHEMA(Pooling2d).NumInputs(1).NumOutputs(1);
template <class Context> template <typename T>
void Pooling2dGradientOp<Context>::MAXRunWithType() {
mask = ws()->GetTensor("_t_" + anchor() + "_pool_mask");
auto* dYdata = input(-1).template data<T, Context>();
auto* dXdata = output(0)->template mutable_data<T, Context>();
auto* Mdata = mask->template data<int, Context>();
kernel::MAXPooling2dGrad<T, Context>(output(0)->count(),
n, c, h, w,
pool_h, pool_w,
kernel_size[0], kernel_size[1],
stride[0], stride[1],
pad[0], pad[1],
data_format,
dYdata,
Mdata,
dXdata);
}
template <class Context> template <typename T>
void Pooling2dGradientOp<Context>::AVGRunWithType() {
auto* dYdata = input(-1).template data<T, Context>();
auto* dXdata = output(0)->template mutable_data<T, Context>();
kernel::AVGPooling2dGrad<T, Context>(output(0)->count(),
n, c, h, w,
pool_h, pool_w,
kernel_size[0], kernel_size[1],
stride[0], stride[1],
pad[0], pad[1],
data_format,
dYdata,
dXdata);
}
template <class Context>
void Pooling2dGradientOp<Context>::Reshape() {
if (data_format == "NCHW") {
n = input(0).dim(0);
c = input(0).dim(1);
h = input(0).dim(2);
w = input(0).dim(3);
if (global_pooling) {
for (int i = 0; i < 2; i++)
kernel_size[i] = input(0).dim(i + 2);
}
if (padding == "SAME") {
for (int i = 0; i < 2; i++) {
TIndex input_size = input(0).dim(i + 2);
TIndex output_size = (input_size + stride[i] - 1) / (float)stride[i];
TIndex padding_needed = std::max(TIndex(0), (output_size - 1) * stride[i] + kernel_size[i] - input_size);
TIndex pad_l = padding_needed / 2;
TIndex pad_r = padding_needed - pad_l;
pad[i] = pad_l;
}
}
} else if (data_format == "NHWC") {
n = input(0).dim(0);
h = input(0).dim(1);
w = input(0).dim(2);
c = input(0).dim(3);
if (global_pooling) {
for (int i = 0; i < 2; i++)
kernel_size[i] = input(0).dim(i + 1);
}
if (padding == "SAME") {
for (int i = 0; i < 2; i++) {
TIndex input_size = input(0).dim(i + 1);
TIndex output_size = (input_size + stride[i] - 1) / (float)stride[i];
TIndex padding_needed = std::max(TIndex(0), (output_size - 1) * stride[i] + kernel_size[i] - input_size);
TIndex pad_l = padding_needed / 2;
TIndex pad_r = padding_needed - pad_l;
pad[i] = pad_l;
}
}
} else LOG(FATAL) << "Unknown data format: " << data_format;
if (padding != "SAME") {
// case 1: infer output shape with symmetry pad size
pool_h = ceil((h + 2 * pad[0] - kernel_size[0]) / (float)stride[0]) + 1;
pool_w = ceil((w + 2 * pad[1] - kernel_size[1]) / (float)stride[1]) + 1;
if ((pool_h - 1) * stride[0] >= (h + pad[0])) pool_h--;
if ((pool_w - 1) * stride[1] >= (w + pad[1])) pool_w--;
} else {
// case 2: infer output shape with adaptive pad size
pool_h = (h + stride[0] - 1) / (float)stride[0];
pool_w = (w + stride[1] - 1) / (float)stride[1];
}
output(0)->ReshapeLike(input(0));
}
template <class Context>
void Pooling2dGradientOp<Context>::RunOnDevice() {
Reshape();
if (mode == "MAX") {
if (input(0).template IsType<float>()) MAXRunWithType<float>();
else LOG(FATAL) << "Unsupported input types.";
} else if (mode == "AVG") {
if (input(0).template IsType<float>()) AVGRunWithType<float>();
else LOG(FATAL) << "Unsupported input types.";
} else {
LOG(FATAL) << "Unsupported pooling mode: " << mode;
}
}
DEPLOY_CPU(Pooling2dGradient);
#ifdef WITH_CUDA
DEPLOY_CUDA(Pooling2dGradient);
#endif
OPERATOR_SCHEMA(Pooling2dGradient).NumInputs(3).NumOutputs(1);
class GetPooling2dGradient final : public GradientMakerBase {
public:
GRADIENT_MAKER_CTOR(GetPooling2dGradient);
vector<OperatorDef> MakeDefs() override {
return SingleDef(def.type() + "Gradient", "",
vector<string> {I(0), O(0), GO(0)},
vector<string> {GI(0)});
}
};
REGISTER_GRADIENT(Pooling2d, GetPooling2dGradient);
} // namespace dragon
\ No newline at end of file
#include "operators/vision/pooling_op.h"
#include "core/workspace.h"
#include "utils/math_functions.h"
#include "utils/op_kernel.h"
namespace dragon {
template <class Context> template <typename T>
void PoolingOp<Context>::MaxRunWithType() {
mask = ws()->CreateTensor("_t_" + anchor() + "_pool_mask");
mask->ReshapeLike(*output(0));
auto* Xdata = input(0).template data<T, Context>();
auto* Ydata = output(0)->template mutable_data<T, Context>();
auto* Mdata = mask->template mutable_data<int, Context>();
kernel::MAXPooling<T, Context>(output(0)->count(),
num, channels, height, width,
pool_height, pool_width,
kernel_size[0], kernel_size[1],
stride[0], stride[1],
pad[0], pad[1],
Xdata,
Mdata,
Ydata);
}
template <class Context> template <typename T>
void PoolingOp<Context>::AvgRunWithType() {
auto* Xdata = input(0).template data<T, Context>();
auto* Ydata = output(0)->template mutable_data<T, Context>();
kernel::AVEPooling<T, Context>(output(0)->count(),
num, channels, height, width,
pool_height, pool_width,
kernel_size[0], kernel_size[1],
stride[0], stride[1],
pad[0], pad[1],
Xdata,
Ydata);
}
template <class Context>
void PoolingOp<Context>::Reshape() {
num = input(0).dim(0);
channels = input(0).dim(1);
height = input(0).dim(2);
width = input(0).dim(3);
if (global_pooling) {
for (int i = 0; i < 2; i++)
kernel_size[i] = input(0).dim(i + 2);
}
pool_height = ceil((height + 2 * pad[0] - kernel_size[0]) / (float)stride[0]) + 1;
pool_width = ceil((width + 2 * pad[1] - kernel_size[1]) / (float)stride[1]) + 1;
if ((pool_height - 1) * stride[0] >= (height + pad[0])) pool_height--;
if ((pool_width - 1) * stride[1] >= (width + pad[1])) pool_width--;
vector<TIndex> top_shape({ num, channels, pool_height, pool_width });
if (input(0).ndim() == 3) top_shape.erase(top_shape.begin());
output(0)->Reshape(top_shape);
}
template <class Context>
void PoolingOp<Context>::RunOnDevice() {
Reshape();
if (mode == MAX_POOLING) {
if (input(0).template IsType<float>()) MaxRunWithType<float>();
else LOG(FATAL) << "Unsupported input types.";
}
else if (mode == AVG_POOLING) {
if (input(0).template IsType<float>()) AvgRunWithType<float>();
else LOG(FATAL) << "Unsupported input types.";
}
else {
LOG(FATAL) << "Unsupported pooling mode.";
}
}
DEPLOY_CPU(Pooling);
#ifdef WITH_CUDA
DEPLOY_CUDA(Pooling);
#endif
OPERATOR_SCHEMA(Pooling).NumInputs(1).NumOutputs(1);
template <class Context> template <typename T>
void PoolingGradientOp<Context>::MaxRunWithType() {
mask = ws()->GetTensor("_t_" + anchor() + "_pool_mask");
auto* dYdata = input(-1).template data<T, Context>();
auto* dXdata = output(0)->template mutable_data<T, Context>();
auto* Mdata = mask->template data<int, Context>();
kernel::MAXPoolingGrad<T, Context>(output(0)->count(),
num, channels, height, width,
pool_height, pool_width,
kernel_size[0], kernel_size[1],
stride[0], stride[1],
pad[0], pad[1],
dYdata,
Mdata,
dXdata);
}
template <class Context> template <typename T>
void PoolingGradientOp<Context>::AvgRunWithType() {
auto* dYdata = input(-1).template data<T, Context>();
auto* dXdata = output(0)->template mutable_data<T, Context>();
kernel::AVEPoolingGrad<T, Context>(output(0)->count(),
num, channels, height, width,
pool_height, pool_width,
kernel_size[0], kernel_size[1],
stride[0], stride[1],
pad[0], pad[1],
dYdata,
dXdata);
}
template <class Context>
void PoolingGradientOp<Context>::Reshape() {
num = input(0).dim(0);
channels = input(0).dim(1);
height = input(0).dim(2);
width = input(0).dim(3);
if (global_pooling) {
for (int i = 0; i < 2; i++)
kernel_size[i] = input(0).dim(i + 2);
}
pool_height = ceil((height + 2 * pad[0] - kernel_size[0]) / (float)stride[0]) + 1;
pool_width = ceil((width + 2 * pad[1] - kernel_size[1]) / (float)stride[1]) + 1;
if ((pool_height - 1) * stride[0] >= (height + pad[0])) pool_height--;
if ((pool_width - 1)* stride[1] >= (width + pad[1])) pool_width--;
output(0)->ReshapeLike(input(0));
}
template <class Context>
void PoolingGradientOp<Context>::RunOnDevice() {
Reshape();
if (mode == MAX_POOLING) {
if (input(0).template IsType<float>()) MaxRunWithType<float>();
else LOG(FATAL) << "Unsupported input types.";
}
else if (mode == AVG_POOLING) {
if (input(0).template IsType<float>()) AvgRunWithType<float>();
else LOG(FATAL) << "Unsupported input types.";
}
else {
LOG(FATAL) << "Unsupported pooling mode.";
}
}
DEPLOY_CPU(PoolingGradient);
#ifdef WITH_CUDA
DEPLOY_CUDA(PoolingGradient);
#endif
OPERATOR_SCHEMA(PoolingGradient).NumInputs(3).NumOutputs(1);
class GetPoolingGradient final : public GradientMakerBase {
public:
GRADIENT_MAKER_CTOR(GetPoolingGradient);
vector<OperatorDef> MakeDefs() override {
return SingleDef(def.type() + "Gradient", "",
vector<string> {I(0), O(0), GO(0)},
vector<string> {GI(0)});
}
};
REGISTER_GRADIENT(Pooling, GetPoolingGradient);
} // namespace dragon
\ No newline at end of file
......@@ -48,8 +48,69 @@ void cudnnSetTensorDesc(cudnnTensorDescriptor_t* desc, const vector<TIndex>& dim
}
template <typename T>
void cudnnSetTensorDesc(cudnnTensorDescriptor_t* desc,
const vector<TIndex>& dims, const vector<TIndex>& strides) {
void cudnnSetTensor4dDesc(cudnnTensorDescriptor_t* desc,
const string& data_format,
const vector<TIndex>& dims) {
if (data_format == "NCHW") {
CUDNN_CHECK(cudnnSetTensor4dDescriptor(*desc, CUDNN_TENSOR_NCHW,
CUDNNType<T>::type,
dims[0],
dims[1],
dims[2],
dims[3]));
} else if (data_format == "NHWC") {
CUDNN_CHECK(cudnnSetTensor4dDescriptor(*desc, CUDNN_TENSOR_NHWC,
CUDNNType<T>::type,
dims[0],
dims[3],
dims[1],
dims[2]));
} else LOG(FATAL) << "Unknown data format: " << data_format;
}
template <typename T>
void cudnnSetTensor5dDesc(cudnnTensorDescriptor_t* desc,
const string& data_format,
const vector<TIndex>& dims) {
if (data_format == "NCHW") {
cudnnSetTensorDesc<T>(desc, dims);
} else if (data_format == "NHWC") {
const int N = (int)dims[0];
const int C = (int)dims[4];
const int H = (int)dims[1];
const int W = (int)dims[2];
const int D = (int)dims[3];
vector<int> fake_dims = { N, C, H, W, D };
vector<int> fake_strides = { H * W * D * C, 1, W * D * C, D * C, C };
CUDNN_CHECK(cudnnSetTensorNdDescriptor(*desc,
CUDNNType<T>::type,
5,
fake_dims.data(),
fake_strides.data()));
} else LOG(FATAL) << "Unknown data format: " << data_format;
}
template <typename T>
void cudnnSetTensor3dDesc(cudnnTensorDescriptor_t* desc,
const string& data_format,
const vector<TIndex>& dims) {
vector<TIndex> fake_dims = dims;
if (data_format == "NCHW") {
// NCH -> NCHXX
fake_dims.push_back(1);
fake_dims.push_back(1);
} else if (data_format == "NHWC") {
// NHC -> NHXXC
fake_dims.insert(fake_dims.begin() + 2, 1);
fake_dims.insert(fake_dims.begin() + 2, 1);
} else LOG(FATAL) << "Unknown data format: " << data_format;
cudnnSetTensor5dDesc<T>(desc, data_format, fake_dims);
}
template <typename T>
void cudnnSetTensorDesc(cudnnTensorDescriptor_t* desc,
const vector<TIndex>& dims,
const vector<TIndex>& strides) {
CHECK_EQ(dims.size(), strides.size());
CHECK(dims.size() >= 3 && dims.size() <= 8);
int ndim = (int)dims.size();
......@@ -76,22 +137,64 @@ void cudnnSetTensorDesc(cudnnTensorDescriptor_t* desc, Tensor* tensor) {
cudnnSetTensorDesc<T>(desc, fake_dims);
}
template <typename T>
void cudnnSetTensor4dDesc(cudnnTensorDescriptor_t* desc, const string& data_format, Tensor* tensor) {
CHECK_EQ((int)tensor->ndim(), 4)
<< "\nThe num of dimensions of Tensor(" << tensor->name() << ") "
<< "should be 4, but got " << tensor->ndim() << ".";
cudnnSetTensor4dDesc<T>(desc, data_format, tensor->dims());
}
template <typename T>
void cudnnSetTensor5dDesc(cudnnTensorDescriptor_t* desc, const string& data_format, Tensor* tensor) {
CHECK_EQ((int)tensor->ndim(), 5)
<< "\nThe num of dimensions of Tensor(" << tensor->name() << ") "
<< "should be 5, but got " << tensor->ndim() << ".";
cudnnSetTensor5dDesc<T>(desc, data_format, tensor->dims());
}
template <typename T>
void cudnnSetTensor3dDesc(cudnnTensorDescriptor_t* desc, const string& data_format, Tensor* tensor) {
CHECK_EQ((int)tensor->ndim(), 3)
<< "\nThe num of dimensions of Tensor(" << tensor->name() << ") "
<< "should be 3, but got " << tensor->ndim() << ".";
cudnnSetTensor3dDesc<T>(desc, data_format, tensor->dims());
}
template void cudnnSetTensorDesc<float>(cudnnTensorDescriptor_t*, Tensor*);
template void cudnnSetTensor4dDesc<float>(cudnnTensorDescriptor_t*, const string&, Tensor*);
template void cudnnSetTensor5dDesc<float>(cudnnTensorDescriptor_t*, const string&, Tensor*);
template void cudnnSetTensor3dDesc<float>(cudnnTensorDescriptor_t*, const string&, Tensor*);
template void cudnnSetTensorDesc<float>(cudnnTensorDescriptor_t*, const vector<TIndex>&);
template void cudnnSetTensor4dDesc<float>(cudnnTensorDescriptor_t*, const string&, const vector<TIndex>&);
template void cudnnSetTensor5dDesc<float>(cudnnTensorDescriptor_t*, const string&, const vector<TIndex>&);
template void cudnnSetTensor3dDesc<float>(cudnnTensorDescriptor_t*, const string&, const vector<TIndex>&);
template void cudnnSetTensorDesc<float>(cudnnTensorDescriptor_t*, const vector<TIndex>&, const vector<TIndex>&);
template void cudnnSetTensorDesc<double>(cudnnTensorDescriptor_t*, Tensor*);
template void cudnnSetTensor4dDesc<double>(cudnnTensorDescriptor_t*, const string&, Tensor*);
template void cudnnSetTensor5dDesc<double>(cudnnTensorDescriptor_t*, const string&, Tensor*);
template void cudnnSetTensor3dDesc<double>(cudnnTensorDescriptor_t*, const string&, Tensor*);
template void cudnnSetTensorDesc<double>(cudnnTensorDescriptor_t*, const vector<TIndex>&);
template void cudnnSetTensor4dDesc<double>(cudnnTensorDescriptor_t*, const string&, const vector<TIndex>&);
template void cudnnSetTensor5dDesc<double>(cudnnTensorDescriptor_t*, const string&, const vector<TIndex>&);
template void cudnnSetTensor3dDesc<double>(cudnnTensorDescriptor_t*, const string&, const vector<TIndex>&);
template void cudnnSetTensorDesc<double>(cudnnTensorDescriptor_t*, const vector<TIndex>&, const vector<TIndex>&);
#ifdef WITH_CUDA_FP16
template void cudnnSetTensorDesc<float16>(cudnnTensorDescriptor_t*, Tensor*);
template void cudnnSetTensor4dDesc<float16>(cudnnTensorDescriptor_t*, const string&, Tensor*);
template void cudnnSetTensor5dDesc<float16>(cudnnTensorDescriptor_t*, const string&, Tensor*);
template void cudnnSetTensor3dDesc<float16>(cudnnTensorDescriptor_t*, const string&, Tensor*);
template void cudnnSetTensorDesc<float16>(cudnnTensorDescriptor_t*, const vector<TIndex>&);
template void cudnnSetTensor4dDesc<float16>(cudnnTensorDescriptor_t*, const string&, const vector<TIndex>&);
template void cudnnSetTensor5dDesc<float16>(cudnnTensorDescriptor_t*, const string&, const vector<TIndex>&);
template void cudnnSetTensor3dDesc<float16>(cudnnTensorDescriptor_t*, const string&, const vector<TIndex>&);
template void cudnnSetTensorDesc<float16>(cudnnTensorDescriptor_t*, const vector<TIndex>&, const vector<TIndex>&);
#endif
} // namespace dragon
#endif // WITH_CUDNN
#endif // WITH_CUDNN
\ No newline at end of file
......@@ -178,13 +178,13 @@ template<> void PReluWGrad<float, CPUContext>(const int rows,
1.0,
bcast_dw, multiplier,
1.0,
dw);
dw);
} else if (data_format == "NHWC") {
math::Gemv<float, CPUContext>(CblasTrans, dim, channels,
1.0,
bcast_dw, multiplier,
1.0,
dw);
dw);
} else LOG(FATAL) << "Unknown data format: " << data_format;
}
......@@ -285,18 +285,18 @@ template<> void Softmax<float, CPUContext>(const int count,
for (int k = 0; k < inner_dim; k++)
scale[k] = std::max(scale[k], x[i * dim + j * inner_dim + k]);
}
math::Gemm<float, CPUContext>(CblasNoTrans, CblasNoTrans,
classes, inner_dim, 1,
-1.0,
sum_multiplier, scale,
math::Gemm<float, CPUContext>(CblasNoTrans, CblasNoTrans,
classes, inner_dim, 1,
-1.0,
sum_multiplier, scale,
1.0,
y);
y);
math::Exp<float, CPUContext>(dim, y, y);
math::Gemv<float, CPUContext>(CblasTrans, classes, inner_dim,
1.0,
y, sum_multiplier,
0.0,
scale);
1.0,
y, sum_multiplier,
0.0,
scale);
for (int j = 0; j < classes; ++j) {
math::Div<float, CPUContext>(inner_dim, y, scale, y);
y += inner_dim;
......@@ -316,15 +316,15 @@ template<> void SoftmaxGrad<float, CPUContext>(const int count,
const int dim = count / outer_dim;
for (int i = 0; i < outer_dim; ++i) {
for (int k = 0; k < inner_dim; ++k)
scale[k] = math::StridedDot<float, CPUContext>(classes,
dx + i * dim + k, inner_dim,
y + i*dim + k, inner_dim);
math::Gemm<float, CPUContext>(CblasNoTrans, CblasNoTrans,
classes, inner_dim, 1,
-1.0,
sum_multiplier, scale,
1.0,
dx + i*dim);
scale[k] = math::StridedDot<float, CPUContext>(classes,
dx + i * dim + k, inner_dim,
y + i*dim + k, inner_dim);
math::Gemm<float, CPUContext>(CblasNoTrans, CblasNoTrans,
classes, inner_dim, 1,
-1.0,
sum_multiplier, scale,
1.0,
dx + i*dim);
}
math::Mul<float, CPUContext>(count, dx, y, dx);
}
......@@ -358,23 +358,28 @@ template<> void BiasAdd<float, CPUContext>(const int count,
const int outer_dim,
const int dim,
const int inner_dim,
const string& format,
const string& data_format,
const float* bias,
const float* bias_multiplier,
float* y) {
if (format == "NCHW") {
const int y_offset = dim * inner_dim;
for (int n = 0; n < outer_dim; ++n) {
const int y_offset = dim * inner_dim;
for (int n = 0; n < outer_dim; ++n) {
if (data_format == "NCHW") {
math::Gemm<float, CPUContext>(CblasNoTrans, CblasNoTrans,
dim, inner_dim, 1,
1.0,
bias, bias_multiplier,
1.0,
y);
y += y_offset;
}
} else {
NOT_IMPLEMENTED;
y);
} else if (data_format == "NHWC") {
math::Gemm<float, CPUContext>(CblasNoTrans, CblasNoTrans,
inner_dim, dim, 1,
1.0,
bias_multiplier, bias,
1.0,
y);
} else LOG(FATAL) << "Unknown data format: " << data_format;
y += y_offset;
}
}
......@@ -428,12 +433,12 @@ template<> void Scale<float, CPUContext>(const int axis,
int dim = scale_dim * inner_dim;
Ydata = y->mutable_data<float, CPUContext>();
for (int n = 0; n < outer_dim; ++n) {
math::Gemm<float, CPUContext>(CblasNoTrans, CblasNoTrans,
scale_dim, inner_dim, 1,
1.0,
Bdata, BMul_data,
1.0,
Ydata);
math::Gemm<float, CPUContext>(CblasNoTrans, CblasNoTrans,
scale_dim, inner_dim, 1,
1.0,
Bdata, BMul_data,
1.0,
Ydata);
Ydata += dim;
}
}
......@@ -520,7 +525,7 @@ template<> void SmoothL1<float, CPUContext>(const int count,
for (int i = 0; i < count; ++i) {
const float val = x[i];
const float abs_val = abs(val);
if (abs_val < 1.0 / sigma2) y[i] = 0.5 * val * val *sigma2;
if (abs_val < 1.0 / sigma2) y[i] = 0.5 * val * val * sigma2;
else y[i] = abs_val - 0.5 / sigma2;
}
}
......@@ -713,69 +718,116 @@ template<> void SparseSoftmaxFocalLossGrad<float, CPUContext>(const int count,
}
}
/******************** misc.memory_data ********************/
/******************** misc.image_data ********************/
template <typename Tx, typename Ty>
void _ImageData_NCHW(const int N, const int C,
const int H, const int W,
const float* mean_values,
const float* std_values,
const Tx* x,
Ty* y) {
for (int n = 0; n < N; ++n) {
for (int c = 0; c < C; ++c) {
for (int h = 0; h < H; ++h) {
const int NH = n * H + h;
for (int w = 0; w < W; ++w) {
Ty raw_value = x[(NH * W + w) * C + c];
if (mean_values != nullptr) raw_value -= mean_values[c];
if (std_values != nullptr) raw_value /= std_values[c];
*(y++) = raw_value;
}
}
}
}
}
template <typename Tx, typename Ty>
void _ImageData_NHWC(const int N, const int C,
const int H, const int W,
const float* mean_values,
const float* std_values,
const Tx* x,
Ty* y) {
for (int n = 0; n < N; ++n) {
for (int h = 0; h < H; ++h) {
for (int w = 0; w < W; ++w) {
for (int c = 0; c < C; ++c) {
Ty raw_value = *(x++);
if (mean_values != nullptr) raw_value -= mean_values[c];
if (std_values != nullptr) raw_value /= std_values[c];
*(y++) = raw_value;
}
}
}
}
}
template <> void MemoryData<float, float, CPUContext>(const int count,
const int num,
const int channels,
const int height,
const int width,
const float* x,
float* y) {
#ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count))
#endif
for (int i = 0; i < count; ++i) {
const int w = i % width;
const int h = (i / width) % height;
const int c = (i / width / height) % channels;
const int n = i / width / height / channels;
const int x_idx = ((n * height + h) * width + w) * channels + c;
if (c == 0) y[i] = x[x_idx] - 102.9801;
else if (c == 1) y[i] = x[x_idx] - 115.9465;
else y[i] = x[x_idx] - 122.7717;
}
}
template <> void MemoryData<uint8_t, float, CPUContext>(const int count,
const int num,
const int channels,
const int height,
const int width,
const uint8_t* x,
float* y) {
#ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count))
#endif
for (int i = 0; i < count; ++i) {
const int w = i % width;
const int h = (i / width) % height;
const int c = (i / width / height) % channels;
const int n = i / width / height / channels;
const int x_idx = ((n * height + h) * width + w) * channels + c;
if (c == 0) y[i] = x[x_idx] - 102.9801;
else if (c == 1) y[i] = x[x_idx] - 115.9465;
else y[i] = x[x_idx] - 122.7717;
}
}
template <> void MemoryData<float, float16, CPUContext>(const int count,
const int num,
const int channels,
const int height,
const int width,
const float* x,
float16* y) {
template <> void ImageData<float, float, CPUContext>(const int count,
const int N, const int C,
const int H, const int W,
const float* mean_values,
const float* std_values,
const string& data_format,
const float* x,
float* y) {
if (data_format == "NCHW") {
_ImageData_NCHW<float, float>(N, C, H, W,
mean_values,
std_values,
x,
y);
} else if (data_format == "NHWC") {
_ImageData_NHWC<float, float>(N, C, H, W,
mean_values,
std_values,
x,
y);
} else LOG(FATAL) << "Unknown data format: " << data_format;
}
template <> void ImageData<uint8_t, float, CPUContext>(const int count,
const int N, const int C,
const int H, const int W,
const float* mean_values,
const float* std_values,
const string& data_format,
const uint8_t* x,
float* y) {
if (data_format == "NCHW") {
_ImageData_NCHW<uint8_t, float>(N, C, H, W,
mean_values,
std_values,
x,
y);
} else if (data_format == "NHWC") {
_ImageData_NHWC<uint8_t, float>(N, C, H, W,
mean_values,
std_values,
x,
y);
} else LOG(FATAL) << "Unknown data format: " << data_format;
}
template <> void ImageData<float, float16, CPUContext>(const int count,
const int N, const int C,
const int H, const int W,
const float* mean_values,
const float* std_values,
const string& data_format,
const float* x,
float16* y) {
LOG(FATAL) << "float16 is unsupported for CPUContext.";
}
template <> void MemoryData<uint8_t, float16, CPUContext>(const int count,
const int num,
const int channels,
const int height,
const int width,
const uint8_t* x,
float16* y) {
template <> void ImageData<uint8_t, float16, CPUContext>(const int count,
const int N, const int C,
const int H, const int W,
const float* mean_values,
const float* std_values,
const string& data_format,
const uint8_t* x,
float16* y) {
LOG(FATAL) << "float16 is unsupported for CPUContext.";
}
......@@ -875,8 +927,8 @@ template <> void At<float, CPUContext>(const int count,
for (int n = 0; n < outer_dim; ++n) {
x_offset = (n * x_slice_dim + x_idx_offset) * inner_dim;
y_offset = (n * y_slice_dim + y_idx_offset) * inner_dim;
context->Copy<float, CPUContext, CPUContext>(inner_dim,
y + y_offset,
context->Copy<float, CPUContext, CPUContext>(inner_dim,
y + y_offset,
x + x_offset);
}
}
......@@ -898,9 +950,9 @@ template <> void AtGrad<float, CPUContext>(const int count,
for (int n = 0; n < outer_dim; ++n) {
x_offset = (n * x_slice_dim + x_idx_offset) * inner_dim;
y_offset = (n * y_slice_dim + y_idx_offset) * inner_dim;
math::Add<float, CPUContext>(inner_dim,
dy + y_offset,
dx + x_offset,
math::Add<float, CPUContext>(inner_dim,
dy + y_offset,
dx + x_offset,
dx + x_offset);
}
}
......@@ -921,20 +973,20 @@ template <> void Concat<float, CPUContext>(const int count,
for (int n = 0; n < outer_dim; ++n) {
x_offset = n * x_concat_dim * inner_dim;
y_offset = (n * y_concat_dim + concat_offset) * inner_dim;
context->Copy<float, CPUContext, CPUContext>(x_concat_dim * inner_dim,
y + y_offset,
context->Copy<float, CPUContext, CPUContext>(x_concat_dim * inner_dim,
y + y_offset,
x + x_offset);
}
}
template <> void Concat<float16, CPUContext>(const int count,
const int outer_dim,
template <> void Concat<float16, CPUContext>(const int count,
const int outer_dim,
const int inner_dim,
const int x_concat_dim,
const int y_concat_dim,
const int x_concat_dim,
const int y_concat_dim,
const int concat_offset,
const float16* x,
float16* y,
const float16* x,
float16* y,
CPUContext* context) {
TIndex x_offset, y_offset;
for (int n = 0; n < outer_dim; ++n) {
......@@ -959,8 +1011,8 @@ template <> void ConcatGrad<float, CPUContext>(const int count,
for (int n = 0; n < outer_dim; ++n) {
x_offset = n * x_concat_dim * inner_dim;
y_offset = (n * y_concat_dim + concat_offset) * inner_dim;
context->Copy<float, CPUContext, CPUContext>(x_concat_dim * inner_dim,
dx + x_offset,
context->Copy<float, CPUContext, CPUContext>(x_concat_dim * inner_dim,
dx + x_offset,
dy + y_offset);
}
}
......@@ -978,8 +1030,8 @@ template <> void ConcatGrad<float16, CPUContext>(const int count,
for (int n = 0; n < outer_dim; ++n) {
x_offset = n * x_concat_dim * inner_dim;
y_offset = (n * y_concat_dim + concat_offset) * inner_dim;
context->Copy<float16, CPUContext, CPUContext>(x_concat_dim * inner_dim,
dx + x_offset,
context->Copy<float16, CPUContext, CPUContext>(x_concat_dim * inner_dim,
dx + x_offset,
dy + y_offset);
}
}
......@@ -992,15 +1044,18 @@ template<> void Crop1D<float, CPUContext>(const int count,
const int inner_dim,
const int start,
const float* x,
float* y) {
float* y,
CPUContext* context) {
const int count_v2 = count / inner_dim;
#ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count))
#pragma omp parallel for num_threads(GET_OMP_THREADS(count_v2))
#endif
for (int idx = 0; idx < count; idx++) {
const int i = idx % inner_dim;
const int ex_d = (idx / inner_dim) % ex_dim;
const int o = idx / inner_dim / ex_dim;
y[idx] = x[(o * dim + ex_d + start) * inner_dim + i];
for (int idx = 0; idx < count_v2; ++idx) {
const int ex_d = idx % ex_dim;
const int o = idx / ex_dim;
const float* x_ptr = x + (o * dim + ex_d + start) * inner_dim;
float* y_ptr = y + (o * ex_dim + ex_d) * inner_dim;
context->Copy<float, CPUContext, CPUContext>(inner_dim, y_ptr, x_ptr);
}
}
......@@ -1011,18 +1066,23 @@ template<> void Crop1DGrad<float, CPUContext>(const int count,
const int start,
const int end,
const float* dy,
float* dx) {
float* dx,
CPUContext* context) {
const int count_v2 = count / inner_dim;
#ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count))
#pragma omp parallel for num_threads(GET_OMP_THREADS(count_v2))
#endif
for (int idx = 0; idx < count; idx++) {
const int i = idx % inner_dim;
const int d = (idx / inner_dim) % dim;
const int o = idx / inner_dim / dim;
if (d >= start && d < end)
dx[idx] = dy[(o * ex_dim + d - start) * inner_dim + i];
for (int idx = 0; idx < count_v2; ++idx) {
const int d = idx % dim;
const int o = idx / dim;
float* dx_ptr = dx + (o * dim + d) * inner_dim;
if (d < start || d >= end) {
for (int i = 0; i < inner_dim; ++i) dx_ptr[i] = 0;
} else {
const float* dy_ptr = dy + (o * ex_dim + d - start) * inner_dim;
context->Copy<float, CPUContext, CPUContext>(inner_dim, dx_ptr, dy_ptr);
}
}
}
/******************** ndarray.pad ********************/
......@@ -1034,37 +1094,52 @@ template <> void ConstPad1D<float, CPUContext>(const int count,
const int pad_l,
const float value,
const float* x,
float* y) {
float* y,
CPUContext* context) {
const int count_v2 = count / inner_dim;
#ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count))
#pragma omp parallel for num_threads(GET_OMP_THREADS(count_v2))
#endif
for (int idx = 0; idx < count; idx++) {
const int i = idx % inner_dim;
const int ex_d = (idx / inner_dim) % ex_dim;
const int o = idx / inner_dim / ex_dim;
for (int idx = 0; idx < count_v2; ++idx) {
const int ex_d = idx % ex_dim;
const int o = idx / ex_dim;
const int d = ex_d - pad_l;
y[idx] = (d < 0 || d >= dim) ? value : x[(o * dim + d) * inner_dim + i];
float* y_ptr = y + (o * ex_dim + ex_d) * inner_dim;
if (d < 0 || d >= dim) {
for (int i = 0; i < inner_dim; ++i) y_ptr[i] = value;
} else {
const float* x_ptr = x + (o * dim + d) * inner_dim;
context->Copy<float, CPUContext, CPUContext>(inner_dim, y_ptr, x_ptr);
}
}
}
template <> void ReflectPad1D<float, CPUContext>(const int count,
const int dim,
const int ex_dim,
const int inner_dim,
const int pad_l,
const float* x,
float* y) {
const int dim,
const int ex_dim,
const int inner_dim,
const int pad_l,
const float* x,
float* y,
CPUContext* context) {
const int count_v2 = count / inner_dim;
#ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count))
#pragma omp parallel for num_threads(GET_OMP_THREADS(count_v2))
#endif
for (int idx = 0; idx < count; idx++) {
const int i = idx % inner_dim;
const int ex_d = (idx / inner_dim) % ex_dim;
const int o = idx / inner_dim / ex_dim;
for (int idx = 0; idx < count_v2; ++idx) {
const int ex_d = idx % ex_dim;
const int o = idx / ex_dim;
int d = ex_d - pad_l;
d = std::max(d, -d);
d = std::min(d, 2 * dim - d - 2);
y[idx] = x[(o * dim + d) * inner_dim + i];
float* y_ptr = y + (o * ex_dim + ex_d) * inner_dim;
if (d < 0 || d >= dim) {
for (int i = 0; i < inner_dim; ++i)
y_ptr[i] = x[(o * dim + d) * inner_dim + i];
} else {
const float* x_ptr = x + (o * dim + d) * inner_dim;
context->Copy<float, CPUContext, CPUContext>(inner_dim, y_ptr, x_ptr);
}
}
}
......@@ -1074,16 +1149,24 @@ template <> void EdgePad1D<float, CPUContext>(const int count,
const int inner_dim,
const int pad_l,
const float* x,
float* y) {
float* y,
CPUContext* context) {
const int count_v2 = count / inner_dim;
#ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count))
#pragma omp parallel for num_threads(GET_OMP_THREADS(count_v2))
#endif
for (int idx = 0; idx < count; idx++) {
const int i = idx % inner_dim;
const int ex_d = (idx / inner_dim) % ex_dim;
const int o = idx / inner_dim / ex_dim;
for (int idx = 0; idx < count_v2; ++idx) {
const int ex_d = idx % ex_dim;
const int o = idx / ex_dim;
const int d = std::min(dim - 1, std::max(ex_d - pad_l, 0));
y[idx] = x[(o * dim + d) * inner_dim + i];
float* y_ptr = y + (o * ex_dim + ex_d) * inner_dim;
if (d < 0 || d >= dim) {
for (int i = 0; i < inner_dim; ++i)
y_ptr[i] = x[(o * dim + d) * inner_dim + i];
} else {
const float* x_ptr = x + (o * dim + d) * inner_dim;
context->Copy<float, CPUContext, CPUContext>(inner_dim, y_ptr, x_ptr);
}
}
}
......@@ -1093,15 +1176,19 @@ template <> void ConstPad1DGrad<float, CPUContext>(const int count,
const int inner_dim,
const int pad_l,
const float* dy,
float* dx) {
float* dx,
CPUContext* context) {
const int count_v2 = count / inner_dim;
#ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count))
#pragma omp parallel for num_threads(GET_OMP_THREADS(count_v2))
#endif
for (int idx = 0; idx < count; idx++) {
const int i = idx % inner_dim;
const int ex_d = (idx / inner_dim) % dim + pad_l;
const int o = idx / inner_dim / dim;
dx[idx] = dy[(o * ex_dim + ex_d) * inner_dim + i];
for (int idx = 0; idx < count_v2; ++idx) {
const int d = idx % dim;
const int o = idx / dim;
const int ex_d = d + pad_l;
const float* dy_ptr = dy + (o * ex_dim + ex_d) * inner_dim;
float* dx_ptr = dx + (o * dim + d) * inner_dim;
context->Copy<float, CPUContext, CPUContext>(inner_dim, dx_ptr, dy_ptr);
}
}
......@@ -1112,7 +1199,7 @@ template <> void ReflectPad1DGrad<float, CPUContext>(const int count,
const int pad_l,
const float* dy,
float* dx) {
for (int idx = 0; idx < count; idx++) {
for (int idx = 0; idx < count; ++idx) {
const int i = idx % inner_dim;
const int ex_d = (idx / inner_dim) % ex_dim;
const int o = idx / inner_dim / ex_dim;
......@@ -1129,13 +1216,21 @@ template <> void EdgePad1DGrad<float, CPUContext>(const int count,
const int inner_dim,
const int pad_l,
const float* dy,
float* dx) {
for (int idx = 0; idx < count; idx++) {
const int i = idx % inner_dim;
const int ex_d = (idx / inner_dim) % ex_dim;
const int o = idx / inner_dim / ex_dim;
float* dx,
CPUContext* context) {
const int count_v2 = count / inner_dim;
for (int idx = 0; idx < count_v2; ++idx) {
const int ex_d = idx % ex_dim;
const int o = idx / ex_dim;
const int d = std::min(dim - 1, std::max(ex_d - pad_l, 0));
dx[(o * dim + d) * inner_dim + i] += dy[idx];
const float* dy_ptr = dy + (o * ex_dim + ex_d) * inner_dim;
if (d == 0 || d == dim - 1) {
for (int i = 0; i < inner_dim; ++i)
dx[(o * dim + d) * inner_dim + i] += dy_ptr[i];
} else {
float* dx_ptr = dx + (o * dim + d) * inner_dim;
context->Copy<float, CPUContext, CPUContext>(inner_dim, dx_ptr, dy_ptr);
}
}
}
......@@ -1404,7 +1499,7 @@ template <> void LSTMUnit<float, CPUContext>(const int count,
x_act[f_offset + ch] = f;
x_act[o_offset + ch] = o;
x_act[g_offset + ch] = g;
} // end ch
}
c_1 += channels;
c += channels;
h += channels;
......@@ -1448,7 +1543,7 @@ template <> void LSTMUnitGrad<float, CPUContext>(const int count,
*p_df = dc_1_sum_term * c_1[ch] * f * (1 - f);
*p_do = dh[ch] * tanh_c_t * o * (1 - o);
*p_dg = dc_1_sum_term * i * (1 - g * g);
} // end ch
}
c_1 += channels;
c += channels;
x_act += x_offset;
......@@ -1520,367 +1615,944 @@ template <> void RMSPropUpdate<float, CPUContext>(const int count,
math::Axpby<float, CPUContext>(count, lr, Tdata, 0.0, x);
}
/******************** vision.nn_resize ********************/
/******************** vision.bilinear_resize ********************/
template <> void BilinearResize<float, CPUContext>(const int count,
const int num, const int channels,
const int h_in, const int w_in,
const int h_out, const int w_out,
const float* x,
float* y) {
const float h_scale = (float)h_in / h_out;
const float w_scale = (float)w_in / w_out;
#ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count))
#endif
for (int i = 0; i < count; ++i) {
const int w = i % w_out;
const int h = (i / w_out) % h_out;
const int c = (i / w_out / h_out) % channels;
const int n = i / w_out / h_out / channels;
template <typename T>
void _BilinearResize_NCHW(const int N, const int C,
const int H, const int W,
const int out_h, const int out_w,
const float scale_h, const float scale_w,
const T* x,
T* y) {
for (int n = 0; n < N; ++n) {
for (int c = 0; c < C; ++c) {
const int NC = n * C + c;
for (int h = 0; h < out_h; ++h) {
const float h_in = h * scale_h;
const int top_y_idx = floorf(h_in);
const int bottom_y_idx = (h_in < H - 1) ? ceilf(h_in) : H - 1;
const int NCHT = NC * H + top_y_idx;
const int NCHB = NC * H + bottom_y_idx;
const float y_lerp = h_in - top_y_idx;
for (int w = 0; w < out_w; ++w) {
const float w_in = w * scale_w;
const int left_x_idx = floorf(w_in);
const int right_x_idx = (w_in < W - 1) ? ceilf(w_in) : W - 1;
const float x_lerp = w_in - left_x_idx;
const float top_left(x[NCHT * W + left_x_idx]);
const float top_right(x[NCHT * W + right_x_idx]);
const float bottom_left(x[NCHB * W + left_x_idx]);
const float bottom_right(x[NCHB * W + right_x_idx]);
const float top = top_left + (top_right - top_left) * x_lerp;
const float bottom = bottom_left + (bottom_right - bottom_left) * x_lerp;
*(y++) = top + (bottom - top) * y_lerp;
}
}
}
}
}
const float in_h = h * h_scale;
const int top_y_idx = floorf(in_h);
const int bottom_y_idx = (in_h < h_in - 1) ? ceilf(in_h) : h_in - 1;
const float y_lerp = in_h - top_y_idx;
template <typename T>
void _BilinearResize_NHWC(const int N, const int C,
const int H, const int W,
const int out_h, const int out_w,
const float scale_h, const float scale_w,
const T* x,
T* y) {
for (int n = 0; n < N; ++n) {
for (int h = 0; h < out_h; ++h) {
const float h_in = h * scale_h;
const int top_y_idx = floorf(h_in);
const int bottom_y_idx = (h_in < H - 1) ? ceilf(h_in) : H - 1;
const int NHT = n * H + top_y_idx;
const int NHB = n * H + bottom_y_idx;
const float y_lerp = h_in - top_y_idx;
for (int w = 0; w < out_w; ++w) {
const float w_in = w * scale_w;
const int left_x_idx = floorf(w_in);
const int right_x_idx = (w_in < W - 1) ? ceilf(w_in) : W - 1;
const float x_lerp = w_in - left_x_idx;
for (int c = 0; c < C; ++c) {
const float top_left(x[(NHT * W + left_x_idx) * C + c]);
const float top_right(x[(NHT * W + right_x_idx) * C + c]);
const float bottom_left(x[(NHB * W + left_x_idx) * C + c]);
const float bottom_right(x[(NHB * W + right_x_idx) * C + c]);
const float top = top_left + (top_right - top_left) * x_lerp;
const float bottom = bottom_left + (bottom_right - bottom_left) * x_lerp;
*(y++) = top + (bottom - top) * y_lerp;
}
}
}
}
}
const float in_w = w * w_scale;
const int left_x_idx = floorf(in_w);
const int right_x_idx = (in_w < w_in - 1) ? ceilf(in_w) : w_in - 1;
const float x_lerp = in_w - left_x_idx;
template <> void BilinearResize<float, CPUContext>(const int count,
const int N, const int C,
const int H, const int W,
const int out_h, const int out_w,
const string& data_format,
const float* x,
float* y) {
const float scale_h = (float)H / out_h;
const float scale_w = (float)W / out_w;
if (data_format == "NCHW") {
_BilinearResize_NCHW<float>(N, C, H, W,
out_h, out_w,
scale_h, scale_w,
x,
y);
} else if (data_format == "NHWC"){
_BilinearResize_NHWC<float>(N, C, H, W,
out_h, out_w,
scale_h, scale_w,
x,
y);
} else LOG(FATAL) << "Unknown data format: " << data_format;
}
const float top_left(x[((n * channels + c) * h_in + top_y_idx) * w_in + left_x_idx]);
const float top_right(x[((n * channels + c) * h_in + top_y_idx) * w_in + right_x_idx]);
const float bottom_left(x[((n * channels + c) * h_in + bottom_y_idx) * w_in + left_x_idx]);
const float bottom_right(x[((n * channels + c) * h_in + bottom_y_idx) * w_in + right_x_idx]);
template <typename T>
void _BilinearResizeGrad_NCHW(const int N, const int C,
const int H, const int W,
const int out_h, const int out_w,
const float scale_h, const float scale_w,
const T* dy,
T* dx) {
for (int n = 0; n < N; ++n) {
for (int c = 0; c < C; ++c) {
const int NC = n * C + c;
for (int h = 0; h < out_h; ++h) {
const float h_in = h * scale_h;
const int top_y_idx = floorf(h_in);
const int bottom_y_idx = (h_in < H - 1) ? ceilf(h_in) : H - 1;
const int NCHT = NC * H + top_y_idx;
const int NCHB = NC * H + bottom_y_idx;
const float y_lerp = h_in - top_y_idx;
for (int w = 0; w < out_w; ++w) {
const float w_in = w * scale_w;
const int left_x_idx = floorf(w_in);
const int right_x_idx = (w_in < W - 1) ? ceilf(w_in) : W - 1;
const float x_lerp = w_in - left_x_idx;
const float dtop = (1 - y_lerp) * (*(dy));
const float dbottom = y_lerp * (*(dy++));
dx[NCHT * W + left_x_idx] += static_cast<T>((1 - x_lerp) * dtop);
dx[NCHT * W + right_x_idx] += static_cast<T>(x_lerp * dtop);
dx[NCHB * W + left_x_idx] += static_cast<T>((1 - x_lerp) * dbottom);
dx[NCHB * W + right_x_idx] += static_cast<T>(x_lerp * dbottom);
}
}
}
}
}
const float top = top_left + (top_right - top_left) * x_lerp;
const float bottom = bottom_left + (bottom_right - bottom_left) * x_lerp;
y[i] = top + (bottom - top) * y_lerp;
template <typename T>
void _BilinearResizeGrad_NHWC(const int N, const int C,
const int H, const int W,
const int out_h, const int out_w,
const float scale_h, const float scale_w,
const T* dy,
T* dx) {
for (int n = 0; n < N; ++n) {
for (int h = 0; h < out_h; ++h) {
const float h_in = h * scale_h;
const int top_y_idx = floorf(h_in);
const int bottom_y_idx = (h_in < H - 1) ? ceilf(h_in) : H - 1;
const int NHT = n * H + top_y_idx;
const int NHB = n * H + bottom_y_idx;
const float y_lerp = h_in - top_y_idx;
for (int w = 0; w < out_w; ++w) {
const float w_in = w * scale_w;
const int left_x_idx = floorf(w_in);
const int right_x_idx = (w_in < W - 1) ? ceilf(w_in) : W - 1;
const float x_lerp = w_in - left_x_idx;
const float dtop = (1 - y_lerp) * (*(dy));
const float dbottom = y_lerp * (*(dy++));
for (int c = 0; c < C; ++c) {
dx[(NHT * W + left_x_idx) * C + c] += static_cast<T>((1 - x_lerp) * dtop);
dx[(NHT * W + right_x_idx) * C + c] += static_cast<T>(x_lerp * dtop);
dx[(NHB * W + left_x_idx) * C + c] += static_cast<T>((1 - x_lerp) * dbottom);
dx[(NHB * W + right_x_idx) * C + c] += static_cast<T>(x_lerp * dbottom);
}
}
}
}
}
template <> void BilinearResizeGrad<float, CPUContext>(const int count,
const int num, const int channels,
const int h_in, const int w_in,
const int h_out, const int w_out,
const float* dy,
const int N, const int C,
const int H, const int W,
const int out_h, const int out_w,
const string& data_format,
const float* dy,
float* dx) {
const float h_scale = (float)h_out / h_in;
const float w_scale = (float)w_out / w_in;
for (int i = 0; i < count; i++) {
const int w = i % w_in;
const int h = (i / w_in) % h_in;
const int c = (i / w_in / h_in) % channels;
const int n = i / w_in / h_in / channels;
const float original_h = h * h_scale;
const int top_y_idx = floorf(original_h);
const int bottom_y_idx = (original_h < h_out - 1) ? ceilf(original_h) : h_out - 1;
const float y_lerp = original_h - top_y_idx;
const float original_w = w * w_scale;
const int left_x_idx = floorf(original_w);
const int right_x_idx = (original_w < w_out - 1) ? ceilf(original_w) : w_out - 1;
const float x_lerp = original_w - left_x_idx;
const float scale_h = (float)H / out_h;
const float scale_w = (float)W / out_w;
math::Set<float, CPUContext>(N * C * H * W, 0, dx);
if (data_format == "NCHW") {
_BilinearResizeGrad_NCHW<float>(N, C, H, W,
out_h, out_w,
scale_h, scale_w,
dy,
dx);
} else if (data_format == "NHWC"){
_BilinearResizeGrad_NHWC<float>(N, C, H, W,
out_h, out_w,
scale_h, scale_w,
dy,
dx);
} else LOG(FATAL) << "Unknown data format: " << data_format;
}
const float dtop = (1 - y_lerp) * dy[i];
*(dx + ((n * channels + c) * h_out + top_y_idx) * w_out + left_x_idx)
+= static_cast<float>((1 - x_lerp) * dtop);
*(dx + ((n * channels + c) * h_out + top_y_idx) * w_out + right_x_idx)
+= static_cast<float>(x_lerp * dtop);
/******************** vision.conv ********************/
const float dbottom = y_lerp * dy[i];
*(dx + ((n * channels + c) * h_out + bottom_y_idx) * w_out + left_x_idx)
+= static_cast<float>((1 - x_lerp) * dbottom);
*(dx + ((n * channels + c) * h_out + bottom_y_idx) * w_out + right_x_idx)
+= static_cast<float>(x_lerp * dbottom);
template<typename T>
void _Im2Col2d_NCHW(const int C, const int H, const int W,
const int col_h, const int col_w,
const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w,
const int pad_h, const int pad_w,
const int dilation_h, const int dilation_w,
const T* im,
T* col) {
const int im_offset = H * W;
for (int c = 0; c < C; ++c, im += im_offset) {
for (int kh = 0; kh < kernel_h; ++kh) {
for (int kw = 0; kw < kernel_w; ++kw) {
int h = -pad_h + kh * dilation_h;
for (int output_h = 0; output_h < col_h; ++output_h) {
if (!judge(h, H)) {
for (int output_w = 0; output_w < col_w; ++output_w) *(col++) = 0;
} else {
int w = -pad_w + kw * dilation_w;
for (int output_w = 0; output_w < col_w; ++output_w) {
if (!judge(w, W)) *(col++) = 0;
else *(col++) = im[h * W + w];
w += stride_w;
}
}
h += stride_h;
}
}
}
}
}
/******************** vision.conv ********************/
template <> void Im2Col<float, CPUContext>(const int channels,
const int height, const int width,
const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w,
const int pad_h, const int pad_w,
const int dilation_h, const int dilation_w,
const float* im,
float* col) {
const int col_h = (height + 2 * pad_h - (dilation_h*(kernel_h - 1) + 1)) / stride_h + 1;
const int col_w = (width + 2 * pad_w - (dilation_w*(kernel_w - 1) + 1)) / stride_w + 1;
const int input_spatial = height * width;
// for each element in kernel, create a row-col-map for a input feature map
for (int channel = 0; channel < channels; ++channel, im += input_spatial) {
for (int kh_off = 0; kh_off < kernel_h; ++kh_off) {
for (int kw_off = 0; kw_off < kernel_w; ++kw_off) {
int input_row = -pad_h + kh_off * dilation_h;
// scan all output pixels and find the corresponding input pixels
for (int output_row = 0; output_row < col_h; ++output_row ) {
// set '0' for all output pixels out of the input map
if (!judge(input_row, height)) {
for (int output_col = 0; output_col < col_w; ++output_col) *(col++) = 0;
} else { // find the corresponding input pixels
int input_col = -pad_w + kw_off * dilation_w;
for (int output_col = 0; output_col < col_w; ++output_col) {
if (!judge(input_col, width)) *(col++) = 0;
else *(col++) = im[input_row * width + input_col];
input_col += stride_w;
template<typename T>
void _Im2Col2d_NHWC(const int C, const int H, const int W,
const int col_h, const int col_w,
const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w,
const int pad_h, const int pad_w,
const int dilation_h, const int dilation_w,
const T* im,
T* col) {
for (int output_h = 0; output_h < col_h; ++output_h) {
const int base_h = -pad_h + stride_h * output_h;
for (int output_w = 0; output_w < col_w; ++output_w) {
const int base_w = -pad_w + stride_w * output_w;
for (int kh = 0; kh < kernel_h; ++kh) {
int h = base_h + kh * dilation_h;
if (!judge(h, H)) {
for (int kw = 0; kw < kernel_w; ++kw)
for (int c = 0; c < C; ++c) *(col++) = 0;
} else {
for (int kw = 0; kw < kernel_w; ++kw) {
int w = base_w + kw * dilation_w;
for (int c = 0; c < C; ++c) {
if (!judge(w, W)) *(col++) = 0;
else *(col++) = im[(h * W + w) * C + c];
}
}
input_row += stride_h;
} // end output_row
} // end kw_off
} // end kh_off
} // end channel
}
template<> void Col2Im<float, CPUContext>(const int channels,
const int height, const int width,
const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w,
const int pad_h, const int pad_w,
const int dilation_h, const int dilation_w,
const float* col,
float* im) {
// must memset before use '+='
math::Set<float, CPUContext>(channels * height * width, 0, im);
const int col_h = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
const int col_w = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
const int input_spatial = height * width;
// for each element in kernel, create a row-col-map for a input feature map
for (int channel = 0; channel < channels; ++channel, im += input_spatial) {
for (int kh_off = 0; kh_off < kernel_h; ++kh_off) {
for (int kw_off = 0; kw_off < kernel_w; ++kw_off) {
int input_row = -pad_h + kh_off * dilation_h;
// scan all output pixels and find the corresponding input pixels
for (int output_row = 0; output_row < col_h; ++output_row) {
// skip the num of col_w pixels
if (!judge(input_row, height)) {
}
}
}
}
}
template <> void Im2Col2d<float, CPUContext>(const int C, const int H, const int W,
const int col_h, const int col_w,
const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w,
const int pad_h, const int pad_w,
const int dilation_h, const int dilation_w,
const string& data_format,
const float* im,
float* col) {
if (data_format == "NCHW") {
const int count = (C * col_h * col_w);
_Im2Col2d_NCHW<float>(C, H, W, col_h, col_w,
kernel_h, kernel_w,
stride_h, stride_w,
pad_h, pad_w,
dilation_h, dilation_w,
im,
col);
} else if (data_format == "NHWC") {
const int count = (col_h * col_w * C);
_Im2Col2d_NHWC<float>(C, H, W, col_h, col_w,
kernel_h, kernel_w,
stride_h, stride_w,
pad_h, pad_w,
dilation_h, dilation_w,
im,
col);
} else LOG(FATAL) << "Unknown data format: " << data_format;
}
template<typename T>
void _Col2Im2d_NCHW(const int C, const int H, const int W,
const int col_h, const int col_w,
const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w,
const int pad_h, const int pad_w,
const int dilation_h, const int dilation_w,
const T* col,
T* im) {
math::Set<float, CPUContext>(C * H * W, 0, im);
const int im_offset = H * W;
for (int c = 0; c < C; ++c, im += im_offset) {
for (int kh = 0; kh < kernel_h; ++kh) {
for (int kw = 0; kw < kernel_w; ++kw) {
int h = -pad_h + kh * dilation_h;
for (int output_h = 0; output_h < col_h; ++output_h) {
if (!judge(h, H)) {
col += col_w;
} else { // find the corresponding input pixels
int input_col = -pad_w + kw_off * dilation_w;
for (int output_col = 0; output_col < col_w; output_col++) {
if (judge(input_col, width)) im[input_row * width + input_col] += *col;
} else {
int w = -pad_w + kw * dilation_w;
for (int output_w = 0; output_w < col_w; ++output_w) {
if (judge(w, W)) im[h * W + w] += *col;
++col;
w += stride_w;
}
}
h += stride_h;
}
}
}
}
}
template<typename T>
void _Col2Im2d_NHWC(const int C, const int H, const int W,
const int col_h, const int col_w,
const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w,
const int pad_h, const int pad_w,
const int dilation_h, const int dilation_w,
const T* col,
T* im) {
math::Set<float, CPUContext>(C * H * W, 0, im);
for (int output_h = 0; output_h < col_h; ++output_h) {
const int base_h = -pad_h + stride_h * output_h;
for (int output_w = 0; output_w < col_w; ++output_w) {
const int base_w = -pad_w + stride_w * output_w;
for (int kh = 0; kh < kernel_h; ++kh) {
int h = base_h + kh * dilation_h;
if (!judge(h, H)) {
col += (kernel_w * C);
} else {
for (int kw = 0; kw < kernel_w; ++kw) {
int w = base_w + kw * dilation_w;
for (int c = 0; c < C; ++c) {
if (judge(w, W)) im[(h * W + w) * C + c] += *(col);
++col;
input_col += stride_w;
}
}
input_row += stride_h;
} // end output_row
} // end kw_off
} // end kh_off
} // end channel
}
}
}
}
}
template<> void Col2Im2d<float, CPUContext>(const int C, const int H, const int W,
const int col_h, const int col_w,
const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w,
const int pad_h, const int pad_w,
const int dilation_h, const int dilation_w,
const string& data_format,
const float* col,
float* im) {
if (data_format == "NCHW") {
const int count = (C * H * W);
_Col2Im2d_NCHW<float>(C, H, W, col_h, col_w,
kernel_h, kernel_w,
stride_h, stride_w,
pad_h, pad_w,
dilation_h, dilation_w,
col,
im);
} else if (data_format == "NHWC") {
const int count = (H * W * C);
_Col2Im2d_NHWC<float>(C, H, W, col_h, col_w,
kernel_h, kernel_w,
stride_h, stride_w,
pad_h, pad_w,
dilation_h, dilation_w,
col,
im);
} else LOG(FATAL) << "Unknown data format: " << data_format;
}
/******************** vision.nn_resize ********************/
template <typename T>
void _NNResize_NCHW(const int N, const int C,
const int H, const int W,
const int out_h, const int out_w,
const float scale_h, const float scale_w,
const T* x,
T* y) {
for (int n = 0; n < N; ++n) {
for (int c = 0; c < C; ++c) {
const int NC = n * C + c;
for (int h = 0; h < out_h; ++h) {
const int h_in = std::min(int(floorf(h * scale_h)), H - 1);
const int NCH = NC * H + h_in;
for (int w = 0; w < out_w; ++w) {
const int w_in = std::min(int(floorf(w * scale_w)), W - 1);
*(y++) = x[NCH * W + w_in];
}
}
}
}
}
template <typename T>
void _NNResize_NHWC(const int N, const int C,
const int H, const int W,
const int out_h, const int out_w,
const float scale_h, const float scale_w,
const T* x,
T* y) {
for (int n = 0; n < N; ++n) {
for (int h = 0; h < out_h; ++h) {
const int h_in = std::min(int(floorf(h * scale_h)), H - 1);
const int NH = n * H + h_in;
for (int w = 0; w < out_w; ++w) {
const int w_in = std::min(int(floorf(w * scale_w)), W - 1);
const int NHW = NH * W + w_in;
for (int c = 0; c < C; ++c) *(y++) = x[NHW * C + c];
}
}
}
}
template <> void NNResize<float, CPUContext>(const int count,
const int num, const int channels,
const int h_in, const int w_in,
const int h_out, const int w_out,
const float* x,
const int N, const int C,
const int H, const int W,
const int out_h, const int out_w,
const string& data_format,
const float* x,
float* y) {
const float h_scale = (float)h_in / h_out;
const float w_scale = (float)w_in / w_out;
#ifdef WITH_OMP
#pragma omp parallel for num_threads(GET_OMP_THREADS(count))
#endif
for (int i = 0; i < count; ++i) {
const int w = i % w_out;
const int h = (i / w_out) % h_out;
const int in_h = std::min(int(floorf(h * h_scale)), h_in - 1);
const int in_w = std::min(int(floorf(w * w_scale)), w_in - 1);
const int c = (i / w_out / h_out) % channels;
const int n = i / w_out / h_out / channels;
const int x_idx = ((n * channels + c) * h_in + in_h) * w_in + in_w;
y[i] = x[x_idx];
}
}
template <> void NNResizeGrad<float, CPUContext>(const int count,
const int num, const int channels,
const int h_in, const int w_in,
const int h_out, const int w_out,
const float* dy,
float* dx) {
const float h_scale = (float)h_out / h_in;
const float w_scale = (float)w_out / w_in;
for (int n = 0; n < num; n++) {
for (int c = 0; c < channels; ++c) {
for (int h = 0; h < h_in; ++h) {
const int out_h = std::min(int(floorf(h * h_scale)), (h_out - 1));
for (int w = 0; w < w_in; ++w) {
const int out_w = std::min(int(floorf(w * w_scale)), (w_out - 1));
const int y_idx = ((n * channels + c) * h_in + h) * w_in + w;
const int x_idx = ((n * channels + c) * h_out + out_h) * w_out + out_w;
dx[x_idx] += dy[y_idx];
const float scale_h = (float)H / out_h;
const float scale_w = (float)W / out_w;
if (data_format == "NCHW") {
_NNResize_NCHW<float>(N, C, H, W, out_h, out_w,
scale_h, scale_w,
x,
y);
} else if (data_format == "NHWC"){
_NNResize_NHWC<float>(N, C, H, W, out_h, out_w,
scale_h, scale_w,
x,
y);
} else LOG(FATAL) << "Unknown data format: " << data_format;
}
template <typename T>
void _NNResizeGrad_NCHW(const int N, const int C,
const int H, const int W,
const int out_h, const int out_w,
const float scale_h, const float scale_w,
const T* dy,
T* dx) {
for (int n = 0; n < N; ++n) {
for (int c = 0; c < C; ++c) {
const int NC = n * C + c;
for (int h = 0; h < out_h; ++h) {
const int h_in = std::min(int(floorf(h * scale_h)), H - 1);
const int NCH = NC * H + h_in;
for (int w = 0; w < out_w; ++w) {
const int w_in = std::min(int(floorf(w * scale_w)), W - 1);
dx[NCH * W + w_in] += *(dy++);
}
}
}
}
}
template <typename T>
void _NNResizeGrad_NHWC(const int N, const int C,
const int H, const int W,
const int out_h, const int out_w,
const float scale_h, const float scale_w,
const T* dy,
T* dx) {
for (int n = 0; n < N; ++n) {
for (int h = 0; h < out_h; ++h) {
const int h_in = std::min(int(floorf(h * scale_h)), H - 1);
const int NH = n * H + h_in;
for (int w = 0; w < out_w; ++w) {
const int w_in = std::min(int(floorf(w * scale_w)), W - 1);
const int NHW = NH * W + w_in;
for (int c = 0; c < C; ++c) dx[NHW * C + c] += *(dy++);
}
}
}
}
template <> void NNResizeGrad<float, CPUContext>(const int count,
const int N, const int C,
const int H, const int W,
const int out_h, const int out_w,
const string& data_format,
const float* dy,
float* dx) {
const float scale_h = (float)H / out_h;
const float scale_w = (float)W / out_w;
math::Set<float, CPUContext>(N * C * H * W, 0, dx);
if (data_format == "NCHW") {
_NNResizeGrad_NCHW<float>(N, C, H, W, out_h, out_w,
scale_h, scale_w,
dy,
dx);
} else if (data_format == "NHWC"){
_NNResizeGrad_NHWC<float>(N, C, H, W, out_h, out_w,
scale_h, scale_w,
dy,
dx);
} else LOG(FATAL) << "Unknown data format: " << data_format;
}
/******************** vision.pooling ********************/
template<> void MAXPooling<float, CPUContext>(const int count,
const int num, const int channels,
const int height, const int width,
const int pool_height, const int pool_width,
const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w,
const int pad_h, const int pad_w,
const float* x,
int* mask,
float* y) {
int x_offset = height * width;
int y_offset = pool_height * pool_width;
for (int n = 0; n < num; ++n) {
for (int c = 0; c < channels; ++c) {
for (int ph = 0; ph < pool_height; ++ph) {
for (int pw = 0; pw < pool_width; ++pw) {
template <typename T>
void _MAXPooling2d_NCHW(const int N, const int C,
const int H, const int W,
const int pool_h, const int pool_w,
const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w,
const int pad_h, const int pad_w,
const float* x,
int* mask,
float* y) {
int x_offset = H * W;
int y_offset = pool_h * pool_w;
for (int n = 0; n < N; ++n) {
for (int c = 0; c < C; ++c) {
for (int ph = 0; ph < pool_h; ++ph) {
for (int pw = 0; pw < pool_w; ++pw) {
int start_h = ph * stride_h - pad_h;
int start_w = pw * stride_w - pad_w;
int end_h = std::min(start_h + kernel_h, height);
int end_w = std::min(start_w + kernel_w, width);
int end_h = std::min(start_h + kernel_h, H);
int end_w = std::min(start_w + kernel_w, W);
start_h = std::max(start_h, 0);
start_w = std::max(start_w, 0);
const int pool_idx = ph * pool_width + pw;
const int pool_idx = ph * pool_w + pw;
float max_val = -FLT_MAX;
int max_idx = -1;
for (int h = start_h; h < end_h; ++h) {
for (int w = start_w; w < end_w; ++w) {
const int idx = h * width + w;
if (x[idx]>max_val) {
const int idx = h * W + w;
if (x[idx] > max_val) {
max_val = x[idx];
max_idx = idx;
}
} // end w
} // end h
}
}
y[pool_idx] = max_val;
mask[pool_idx] = max_idx;
} // end pw
} // end ph
// offset a channel
}
}
x += x_offset;
y += y_offset;
mask += y_offset;
} // end c
} // end n
}
}
}
template<> void AVEPooling<float, CPUContext>(const int count,
const int num, const int channels,
const int height, const int width,
const int pool_height, const int pool_width,
const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w,
const int pad_h, const int pad_w,
const float* x,
float* y) {
int x_offset = height * width;
int y_offset = pool_height * pool_width;
math::Set<float, CPUContext>(count, 0, y);
for (int n = 0; n < num; ++n) {
for (int c = 0; c < channels; ++c) {
for (int ph = 0; ph < pool_height; ++ph) {
for (int pw = 0; pw < pool_width; ++pw) {
int start_h = ph * stride_h - pad_h;
int start_w = pw * stride_w - pad_w;
int end_h = std::min(start_h + kernel_h, height + pad_h);
int end_w = std::min(start_w + kernel_w, width + pad_w);
int pool_size = (end_h - start_h) * (end_w - start_w);
end_h = std::min(end_h, height);
end_w = std::min(end_w, width);
start_h = std::max(start_h, 0);
start_w = std::max(start_w, 0);
const int pool_idx = ph * pool_width + pw;
template <typename T>
void _MAXPooling2d_NHWC(const int N, const int C,
const int H, const int W,
const int pool_h, const int pool_w,
const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w,
const int pad_h, const int pad_w,
const float* x,
int* mask,
float* y) {
int x_offset = H * W * C;
int y_offset = pool_h * pool_w * C;
for (int n = 0; n < N; ++n) {
for (int ph = 0; ph < pool_h; ph++) {
for (int pw = 0; pw < pool_w; ++pw) {
int start_h = ph * stride_h - pad_h;
int start_w = pw * stride_w - pad_w;
int end_h = std::min(start_h + kernel_h, H);
int end_w = std::min(start_w + kernel_w, W);
start_h = std::max(start_h, 0);
start_w = std::max(start_w, 0);
const int base_pool_idx = ph * pool_w + pw;
for (int c = 0; c < C; ++c) {
const int pool_idx = base_pool_idx * C + c;
float max_val = -FLT_MAX;
int max_idx = -1;
for (int h = start_h; h < end_h; ++h) {
for (int w = start_w; w < end_w; ++w) {
const int idx = h * width + w;
y[pool_idx] += x[idx];
const int idx = (h * W + w) * C + c;
if (x[idx] > max_val) {
max_val = x[idx];
max_idx = idx;
}
}
}
y[pool_idx] /= pool_size;
} //end pw
} //end ph
y[pool_idx] = max_val;
mask[pool_idx] = max_idx;
}
}
}
x += x_offset;
y += y_offset;
mask += y_offset;
}
}
template<> void MAXPooling2d<float, CPUContext>(const int count,
const int N, const int C,
const int H, const int W,
const int pool_h, const int pool_w,
const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w,
const int pad_h, const int pad_w,
const string& data_format,
const float* x,
int* mask,
float* y) {
if (data_format == "NCHW") {
_MAXPooling2d_NCHW<float>(N, C, H, W, pool_h, pool_w,
kernel_h, kernel_w,
stride_h, stride_w,
pad_h, pad_w,
x,
mask,
y);
} else if (data_format == "NHWC") {
_MAXPooling2d_NHWC<float>(N, C, H, W, pool_h, pool_w,
kernel_h, kernel_w,
stride_h, stride_w,
pad_h, pad_w,
x,
mask,
y);
} else LOG(FATAL) << "Unknown data format: " << data_format;
}
template<typename T>
void _AVGPooling2d_NCHW(const int N, const int C,
const int H, const int W,
const int pool_h, const int pool_w,
const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w,
const int pad_h, const int pad_w,
const float* x,
float* y) {
int x_offset = H * W;
int y_offset = pool_h * pool_w;
for (int n = 0; n < N; ++n) {
for (int c = 0; c < C; ++c) {
for (int ph = 0; ph < pool_h; ++ph) {
for (int pw = 0; pw < pool_w; ++pw) {
int start_h = ph * stride_h - pad_h;
int start_w = pw * stride_w - pad_w;
int end_h = std::min(start_h + kernel_h, H + pad_h);
int end_w = std::min(start_w + kernel_w, W + pad_w);
int pool_area = (end_h - start_h) * (end_w - start_w);
end_h = std::min(end_h, H);
end_w = std::min(end_w, W);
start_h = std::max(start_h, 0);
start_w = std::max(start_w, 0);
const int pool_idx = ph * pool_w + pw;
T sum_val = 0;
for (int h = start_h; h < end_h; ++h)
for (int w = start_w; w < end_w; ++w)
sum_val += x[h * W + w];
y[pool_idx] = sum_val / pool_area;
}
}
x += x_offset;
y += y_offset;
} //end c
} //end n
}
}
}
template<> void MAXPoolingGrad<float, CPUContext>(const int count,
const int num, const int channels,
const int height, const int width,
const int pool_height, const int pool_width,
const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w,
const int pad_h, const int pad_w,
const float* dy,
const int* mask,
float* dx) {
int x_offset = height * width;
int y_offset = pool_height * pool_width;
math::Set<float, CPUContext>(count, 0, dx);
for (int n = 0; n < num; ++n) {
for (int c = 0; c < channels; ++c) {
for (int ph = 0; ph < pool_height; ++ph) {
for (int pw = 0; pw < pool_width; ++pw) {
const int pool_idx = ph * pool_width + pw;
const int idx = mask[pool_idx];
dx[idx] += dy[pool_idx];
} // end pw
} // end ph
template<typename T>
void _AVGPooling2d_NHWC(const int N, const int C,
const int H, const int W,
const int pool_h, const int pool_w,
const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w,
const int pad_h, const int pad_w,
const float* x,
float* y) {
int x_offset = H * W * C;
int y_offset = pool_h * pool_w * C;
for (int n = 0; n < N; ++n) {
for (int ph = 0; ph < pool_h; ph++) {
for (int pw = 0; pw < pool_w; ++pw) {
int start_h = ph * stride_h - pad_h;
int start_w = pw * stride_w - pad_w;
int end_h = std::min(start_h + kernel_h, H + pad_h);
int end_w = std::min(start_w + kernel_w, W + pad_w);
int pool_area = (end_h - start_h) * (end_w - start_w);
end_h = std::min(end_h, H);
end_w = std::min(end_w, W);
start_h = std::max(start_h, 0);
start_w = std::max(start_w, 0);
const int base_pool_idx = ph * pool_w + pw;
for (int c = 0; c < C; ++c) {
const int pool_idx = base_pool_idx * C + c;
T sum_val = 0;
for (int h = start_h; h < end_h; ++h)
for (int w = start_w; w < end_w; ++w)
sum_val += x[(h * W + w) * C + c];
y[pool_idx] = sum_val / pool_area;
}
}
}
x += x_offset;
y += y_offset;
}
}
template<> void AVGPooling2d<float, CPUContext>(const int count,
const int N, const int C,
const int H, const int W,
const int pool_h, const int pool_w,
const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w,
const int pad_h, const int pad_w,
const string& data_format,
const float* x,
float* y) {
if (data_format == "NCHW") {
_AVGPooling2d_NCHW<float>(N, C, H, W, pool_h, pool_w,
kernel_h, kernel_w,
stride_h, stride_w,
pad_h, pad_w,
x,
y);
} else if (data_format == "NHWC") {
_AVGPooling2d_NHWC<float>(N, C, H, W, pool_h, pool_w,
kernel_h, kernel_w,
stride_h, stride_w,
pad_h, pad_w,
x,
y);
} else LOG(FATAL) << "Unknown data format: " << data_format;
}
template <typename T>
void _MAXPooling2dGrad_NCHW(const int N, const int C,
const int H, const int W,
const int pool_h, const int pool_w,
const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w,
const int pad_h, const int pad_w,
const float* dy,
const int* mask,
float* dx) {
int x_offset = H * W;
int y_offset = pool_h * pool_w;
math::Set<float, CPUContext>(N * C * H * W, 0, dx);
for (int n = 0; n < N; ++n) {
for (int c = 0; c < C; ++c) {
for (int ph = 0; ph < pool_h; ++ph) {
for (int pw = 0; pw < pool_w; ++pw) {
const int pool_idx = ph * pool_w + pw;
const int idx = mask[pool_idx];
dx[idx] += dy[pool_idx];
}
}
dx += x_offset;
dy += y_offset;
mask += y_offset;
} // end c
} // end n
}
}
}
template<> void AVEPoolingGrad<float, CPUContext>(const int count,
const int num, const int channels,
const int height, const int width,
const int pool_height, const int pool_width,
const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w,
const int pad_h, const int pad_w,
const float* dy,
float* dx) {
int x_offset = height * width;
int y_offset = pool_height * pool_width;
math::Set<float, CPUContext>(count, 0, dx);
for (int n = 0; n < num; ++n) {
for (int c = 0; c < channels; ++c) {
for (int ph = 0; ph < pool_height; ++ph) {
for (int pw = 0; pw < pool_width; ++pw) {
template <typename T>
void _MAXPooling2dGrad_NHWC(const int N, const int C,
const int H, const int W,
const int pool_h, const int pool_w,
const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w,
const int pad_h, const int pad_w,
const float* dy,
const int* mask,
float* dx) {
int x_offset = H * W * C;
int y_offset = pool_h * pool_w * C;
math::Set<float, CPUContext>(N * H * W * C, 0, dx);
for (int n = 0; n < N; ++n) {
for (int ph = 0; ph < pool_h; ph++) {
for (int pw = 0; pw < pool_w; ++pw) {
const int base_pool_idx = ph * pool_w + pw;
for (int c = 0; c < C; ++c) {
const int pool_idx = base_pool_idx * C + c;
const int idx = mask[pool_idx];
dx[idx] += dy[pool_idx];
}
}
}
dx += x_offset;
dy += y_offset;
}
}
template<> void MAXPooling2dGrad<float, CPUContext>(const int count,
const int N, const int C,
const int H, const int W,
const int pool_h, const int pool_w,
const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w,
const int pad_h, const int pad_w,
const string& data_format,
const float* dy,
const int* mask,
float* dx) {
if (data_format == "NCHW") {
_MAXPooling2dGrad_NCHW<float>(N, C, H, W, pool_h, pool_w,
kernel_h, kernel_w,
stride_h, stride_w,
pad_h, pad_w,
dy,
mask,
dx);
} else if (data_format == "NHWC") {
_MAXPooling2dGrad_NHWC<float>(N, C, H, W, pool_h, pool_w,
kernel_h, kernel_w,
stride_h, stride_w,
pad_h, pad_w,
dy,
mask,
dx);
} else LOG(FATAL) << "Unknown data format: " << data_format;
}
template <typename T>
void _AVGPooling2dGrad_NCHW(const int N, const int C,
const int H, const int W,
const int pool_h, const int pool_w,
const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w,
const int pad_h, const int pad_w,
const float* dy,
float* dx) {
int x_offset = H * W;
int y_offset = pool_h * pool_w;
math::Set<float, CPUContext>(N * C * H * W, 0, dx);
for (int n = 0; n < N; ++n) {
for (int c = 0; c < C; ++c) {
for (int ph = 0; ph < pool_h; ++ph) {
for (int pw = 0; pw < pool_w; ++pw) {
int start_h = ph * stride_h - pad_h;
int start_w = pw * stride_w - pad_w;
int end_h = std::min(start_h + kernel_h, height + pad_h);
int end_w = std::min(start_w + kernel_w, width + pad_w);
int pool_size = (end_h - start_h)*(end_w - start_w);
end_h = std::min(end_h, height);
end_w = std::min(end_w, width);
int end_h = std::min(start_h + kernel_h, H + pad_h);
int end_w = std::min(start_w + kernel_w, W + pad_w);
int pool_area = (end_h - start_h) * (end_w - start_w);
end_h = std::min(end_h, H);
end_w = std::min(end_w, W);
start_h = std::max(start_h, 0);
start_w = std::max(start_w, 0);
const int pool_idx = ph * pool_width + pw;
const int pool_idx = ph * pool_w + pw;
for (int h = start_h; h < end_h; ++h) {
for (int w = start_w; w < end_w; ++w) {
const int idx = h * width + w;
dx[idx] += (dy[pool_idx] / pool_size);
const int idx = h * W + w;
dx[idx] += (dy[pool_idx] / pool_area);
}
}
} // end pw
} // end ph
}
}
dx += x_offset;
dy += y_offset;
} // end c
} // end n
}
}
}
template <typename T>
void _AVGPooling2dGrad_NHWC(const int N, const int C,
const int H, const int W,
const int pool_h, const int pool_w,
const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w,
const int pad_h, const int pad_w,
const float* dy,
float* dx) {
int x_offset = H * W * C;
int y_offset = pool_h * pool_w * C;
math::Set<float, CPUContext>(N * H * W * C, 0, dx);
for (int n = 0; n < N; ++n) {
for (int ph = 0; ph < pool_h; ph++) {
for (int pw = 0; pw < pool_w; ++pw) {
int start_h = ph * stride_h - pad_h;
int start_w = pw * stride_w - pad_w;
int end_h = std::min(start_h + kernel_h, H + pad_h);
int end_w = std::min(start_w + kernel_w, W + pad_w);
int pool_area = (end_h - start_h) * (end_w - start_w);
end_h = std::min(end_h, H);
end_w = std::min(end_w, W);
start_h = std::max(start_h, 0);
start_w = std::max(start_w, 0);
const int base_pool_idx = ph * pool_w + pw;
for (int c = 0; c < C; ++c) {
const int pool_idx = base_pool_idx * C + c;
for (int h = start_h; h < end_h; ++h)
for (int w = start_w; w < end_w; ++w)
dx[(h * W + w) * C + c] += (dy[pool_idx] / pool_area);
}
}
}
dx += x_offset;
dy += y_offset;
}
}
template<> void AVGPooling2dGrad<float, CPUContext>(const int count,
const int N, const int C,
const int H, const int W,
const int pool_h, const int pool_w,
const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w,
const int pad_h, const int pad_w,
const string& data_format,
const float* dy,
float* dx) {
if (data_format == "NCHW") {
_AVGPooling2dGrad_NCHW<float>(N, C, H, W, pool_h, pool_w,
kernel_h, kernel_w,
stride_h, stride_w,
pad_h, pad_w,
dy,
dx);
} else if (data_format == "NHWC") {
_AVGPooling2dGrad_NHWC<float>(N, C, H, W, pool_h, pool_w,
kernel_h, kernel_w,
stride_h, stride_w,
pad_h, pad_w,
dy,
dx);
} else LOG(FATAL) << "Unknown data format: " << data_format;
}
/******************** vision.roi_pooling ********************/
......
This diff could not be displayed because it is too large.
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!