Commit 58284aa4 by Ting PAN

Refactor Vision Module

1 parent 771e3d5a
Showing with 1225 additions and 795 deletions
...@@ -37,7 +37,7 @@ class CPUContext { ...@@ -37,7 +37,7 @@ class CPUContext {
inline static void* New(size_t nbytes) { inline static void* New(size_t nbytes) {
void* data; void* data;
#ifdef WITH_CUDA_HOST_MEN #ifdef WITH_CUDA_HOST_MEM
CUDA_CHECK(cudaMallocHost(&data, nbytes)); CUDA_CHECK(cudaMallocHost(&data, nbytes));
#else #else
data = malloc(nbytes); data = malloc(nbytes);
......
...@@ -18,14 +18,14 @@ namespace dragon { ...@@ -18,14 +18,14 @@ namespace dragon {
#define MAX_GPUS 8 #define MAX_GPUS 8
/************************************************************************** /**************************************************************************
* cuXXX libraries wrapper "Context" as "Handle" * cuXXX libraries wrapper "Context" as "Handle".
* it's well known that each "Context" binds to some "Devices" in OpenCL * It's well known that each "Context" binds to some "Devices" in OpenCL.
* so, we must create different handles to associate different devices * So, we must create different handles to associate different devices or
* or the computations will be dispatched to the same GPU the computations will be dispatched to the same GPU.
* read more: http://docs.nvidia.com/cuda/cublas/, section 2.1.2 * Read more: http://docs.nvidia.com/cuda/cublas/, Sec 2.1.2.
* also, "Handle" is thread safe * Also, "Handle" is thread safe,
* it seems not necessary to create handles for different threads it seems not necessary to create handles for different threads
*************************************************************************/ *************************************************************************/
class CUDAObject { class CUDAObject {
......
...@@ -128,7 +128,7 @@ class Operator : public OperatorBase { ...@@ -128,7 +128,7 @@ class Operator : public OperatorBase {
#ifndef WITH_MPI #ifndef WITH_MPI
return true; return true;
#else #else
vector<int> allow_ranks = Operator::GetRepeatedArg<int>("mpi_rank"); vector<int> allow_ranks = Operator::GetRepeatedArg<int>("mpi_ranks");
if (allow_ranks.empty()) return true; if (allow_ranks.empty()) return true;
int cur_rank; int cur_rank;
MPI_Comm_rank(MPI_COMM_WORLD, &cur_rank); MPI_Comm_rank(MPI_COMM_WORLD, &cur_rank);
......
...@@ -105,7 +105,7 @@ class Tensor { ...@@ -105,7 +105,7 @@ class Tensor {
MixedMemory* memory() const { return own_mem_ ? memory_.get() : ex_memory_; } MixedMemory* memory() const { return own_mem_ ? memory_.get() : ex_memory_; }
MixedMemory::State memory_state() const { MixedMemory::State memory_state() const {
MixedMemory* mem = memory(); MixedMemory* mem = memory();
CHECK(mem) << "Memory access before allowcating."; CHECK(mem) << "\nMemory access before allowcating.";
return memory()->state(); return memory()->state();
} }
......
...@@ -19,8 +19,7 @@ class BiasAddOp : public Operator<Context> { ...@@ -19,8 +19,7 @@ class BiasAddOp : public Operator<Context> {
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {} data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {}
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void NCHWRunWithType(); template <typename T> void RunWithType();
template <typename T> void NHWCRunWithType();
protected: protected:
TIndex outer_dim, dim, inner_dim; TIndex outer_dim, dim, inner_dim;
...@@ -36,8 +35,7 @@ class BiasAddGradientOp final : public Operator<Context> { ...@@ -36,8 +35,7 @@ class BiasAddGradientOp final : public Operator<Context> {
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {} data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {}
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void NCHWRunWithType(); template <typename T> void RunWithType();
template <typename T> void NHWCRunWithType();
protected: protected:
int outer_dim, dim, inner_dim; int outer_dim, dim, inner_dim;
......
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_MISC_IMAGE_DATA_OP_H_
#define DRAGON_OPERATORS_MISC_IMAGE_DATA_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class ImageDataOp final : public Operator<Context> {
public:
ImageDataOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws),
dtype(OperatorBase::GetSingleArg<string>("dtype", "FLOAT32")),
mean_values(OperatorBase::GetRepeatedArg<float>("mean_values")),
std_values(OperatorBase::GetRepeatedArg<float>("std_values")),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {
if (mean_values.size() > 0) {
CHECK_EQ((int)mean_values.size(), 3)
<< "The mean values should be a list with length 3.";
mean.Reshape(vector<TIndex>(1, 3));
for (int i = 0; i < 3; i++)
mean.mutable_data<float, CPUContext>()[i] = mean_values[i];
}
if (std_values.size() > 0) {
CHECK_EQ((int)std_values.size(), 3)
<< "The std values should be a list with length 3.";
std.Reshape(vector<TIndex>(1, 3));
for (int i = 0; i < 3; i++)
std.mutable_data<float, CPUContext>()[i] = std_values[i];
}
}
void RunOnDevice() override;
template <typename Tx, typename Ty> void RunWithType();
protected:
string dtype, data_format;
vector<float> mean_values, std_values;
TIndex n, c, h, w;
Tensor mean, std;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_MISC_IMAGE_DATA_OP_H_
\ No newline at end of file
// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------
#ifndef DRAGON_OPERATORS_MISC_MEMORY_DATA_OP_H_
#define DRAGON_OPERATORS_MISC_MEMORY_DATA_OP_H_
#include "core/operator.h"
namespace dragon {
template <class Context>
class MemoryDataOp final : public Operator<Context> {
public:
MemoryDataOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws) {
int DATA_TYPE = OperatorBase::GetSingleArg<int>("dtype", 1);
data_type = TensorProto_DataType(DATA_TYPE);
}
void RunOnDevice() override;
template <typename Tx, typename Ty> void RunWithType();
protected:
TensorProto_DataType data_type;
};
} // namespace dragon
#endif // DRAGON_OPERATORS_MISC_MEMORY_DATA_OP_H_
\ No newline at end of file
...@@ -19,8 +19,9 @@ class ModelMPIBase : public Operator<Context> { ...@@ -19,8 +19,9 @@ class ModelMPIBase : public Operator<Context> {
public: public:
ModelMPIBase(const OperatorDef& op_def, Workspace* ws) ModelMPIBase(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
comm((MPI_Comm)OperatorBase::GetSingleArg<int>("comm", 0)), comm((MPI_Comm)OperatorBase::GetSingleArg<int64_t>("comm", 0)),
group((MPI_Group)OperatorBase::GetSingleArg<int>("group", 0)) { group((MPI_Group)OperatorBase::GetSingleArg<int64_t>("group", 0)),
dtype(OperatorBase::GetSingleArg<string>("dtype", "FLOAT32")) {
if (comm == MPI_COMM_NULL) return; if (comm == MPI_COMM_NULL) return;
MPI_Comm_size(MPI_COMM_WORLD, &world_size); MPI_Comm_size(MPI_COMM_WORLD, &world_size);
...@@ -36,11 +37,18 @@ class ModelMPIBase : public Operator<Context> { ...@@ -36,11 +37,18 @@ class ModelMPIBase : public Operator<Context> {
CHECK(comm_root != MPI_UNDEFINED) << "MPI root is not included in layer group."; CHECK(comm_root != MPI_UNDEFINED) << "MPI root is not included in layer group.";
} }
MPI_Datatype mpi_dtype() {
if (dtype == "FLOAT32") return MPI_FLOAT;
else LOG(FATAL) << "Unsupported input type: " << dtype;
return MPI_DATATYPE_NULL;
}
protected: protected:
MPI_Comm comm; MPI_Comm comm;
MPI_Group group; MPI_Group group;
int comm_size, comm_rank, comm_root; int comm_size, comm_rank, comm_root;
int world_size, world_rank; int world_size, world_rank;
string dtype;
}; };
} // namespace dragon } // namespace dragon
......
...@@ -19,26 +19,37 @@ class BilinearResizeOp : public Operator<Context> { ...@@ -19,26 +19,37 @@ class BilinearResizeOp : public Operator<Context> {
static_dsize(OperatorBase::GetRepeatedArg<int>("static_dsize")), static_dsize(OperatorBase::GetRepeatedArg<int>("static_dsize")),
dynamic_dsize(OperatorBase::GetRepeatedArg<string>("dynamic_dsize")), dynamic_dsize(OperatorBase::GetRepeatedArg<string>("dynamic_dsize")),
fy(OperatorBase::GetSingleArg<float>("fy", -1.0)), fy(OperatorBase::GetSingleArg<float>("fy", -1.0)),
fx(OperatorBase::GetSingleArg<float>("fx", -1.0)) {} fx(OperatorBase::GetSingleArg<float>("fx", -1.0)),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {
if (data_format == "NCHW") spatial_axis = 2;
else if (data_format == "NHWC") spatial_axis = 1;
else LOG(FATAL) << "Unknown data format: " << data_format;
}
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected: protected:
vector<int> static_dsize; vector<int> static_dsize;
vector<string> dynamic_dsize; vector<string> dynamic_dsize;
float fy, fx;
string data_format;
TIndex n, c, h, w, out_h, out_w, spatial_axis;
vector<TIndex> dims; vector<TIndex> dims;
float h_scale, w_scale, fy, fx;
}; };
template <class Context> template <class Context>
class BilinearResizeGradientOp : public Operator<Context> { class BilinearResizeGradientOp : public Operator<Context> {
public: public:
BilinearResizeGradientOp(const OperatorDef& op_def, Workspace* ws) BilinearResizeGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws) {} : Operator<Context>(op_def, ws),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {}
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected:
string data_format;
TIndex n, c, h, w, out_h, out_w;
}; };
} // namespace dragon } // namespace dragon
......
...@@ -12,23 +12,28 @@ ...@@ -12,23 +12,28 @@
namespace dragon { namespace dragon {
template <class Context> template <class Context>
class ConvOp : public ConvOpBase<Context> { class Conv2dOp : public ConvOpBase<Context> {
public: public:
ConvOp(const OperatorDef& def, Workspace* ws) Conv2dOp(const OperatorDef& def, Workspace* ws)
: ConvOpBase<Context>(def, ws) {} : ConvOpBase<Context>(def, ws) {
this->num_spatial_axes = 2;
Setup();
}
void ComputeOutputShape() override;
bool ReverseDimensions() override { return false; } bool ReverseDimensions() override { return false; }
virtual bool HasBias() { return InputSize() > 2; }
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
}; };
template <class Context> template <class Context>
class ConvGradientOp : public ConvOp<Context> { class Conv2dGradientOp : public Conv2dOp<Context> {
public: public:
ConvGradientOp(const OperatorDef& def, Workspace* ws) Conv2dGradientOp(const OperatorDef& def, Workspace* ws)
: ConvOp<Context>(def, ws) {} : Conv2dOp<Context>(def, ws) {}
bool HasBias() override { return output(2)->name() != "ignore"; }
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -39,10 +44,10 @@ class ConvGradientOp : public ConvOp<Context> { ...@@ -39,10 +44,10 @@ class ConvGradientOp : public ConvOp<Context> {
#include "utils/cudnn_device.h" #include "utils/cudnn_device.h"
template <class Context> template <class Context>
class CuDNNConvOp : public ConvOp<Context> { class CuDNNConv2dOp : public Conv2dOp<Context> {
public: public:
CuDNNConvOp(const OperatorDef& def, Workspace* ws) CuDNNConv2dOp(const OperatorDef& def, Workspace* ws)
: ConvOp<Context>(def, ws) { : Conv2dOp<Context>(def, ws) {
handle = new cudnnHandle_t[this->group]; handle = new cudnnHandle_t[this->group];
stream = new cudaStream_t[this->group]; stream = new cudaStream_t[this->group];
ctx().SwitchToDevice(); ctx().SwitchToDevice();
...@@ -55,8 +60,10 @@ class CuDNNConvOp : public ConvOp<Context> { ...@@ -55,8 +60,10 @@ class CuDNNConvOp : public ConvOp<Context> {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
CUDNN_CHECK(cudnnCreateConvolutionDescriptor(&conv_desc)); CUDNN_CHECK(cudnnCreateConvolutionDescriptor(&conv_desc));
if (InputSize() > 2) if (HasBias()) CUDNN_CHECK(cudnnCreateTensorDescriptor(&bias_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&bias_desc)); if (this->data_format == "NCHW") format = CUDNN_TENSOR_NCHW;
else if (this->data_format == "NHWC") format = CUDNN_TENSOR_NHWC;
else LOG(FATAL) << "Unknown data format: " << this->data_format;
} }
void RunOnDevice() override; void RunOnDevice() override;
...@@ -65,19 +72,20 @@ class CuDNNConvOp : public ConvOp<Context> { ...@@ -65,19 +72,20 @@ class CuDNNConvOp : public ConvOp<Context> {
protected: protected:
cudnnHandle_t* handle; cudnnHandle_t* handle;
cudaStream_t* stream; cudaStream_t* stream;
cudnnTensorFormat_t format;
cudnnConvolutionFwdAlgo_t fwd_algo; cudnnConvolutionFwdAlgo_t fwd_algo;
cudnnTensorDescriptor_t input_desc, output_desc, bias_desc; cudnnTensorDescriptor_t input_desc, output_desc, bias_desc;
cudnnConvolutionDescriptor_t conv_desc; cudnnConvolutionDescriptor_t conv_desc;
cudnnFilterDescriptor_t filter_desc; cudnnFilterDescriptor_t filter_desc;
size_t workspace_fwd_data_size; size_t workspace_fwd_data_size;
int bias_offset; TIndex bias_offset;
}; };
template <class Context> template <class Context>
class CuDNNConvGradientOp : public ConvGradientOp<Context> { class CuDNNConv2dGradientOp : public Conv2dGradientOp<Context> {
public: public:
CuDNNConvGradientOp(const OperatorDef& def, Workspace* ws) CuDNNConv2dGradientOp(const OperatorDef& def, Workspace* ws)
: ConvGradientOp<Context>(def, ws) { : Conv2dGradientOp<Context>(def, ws) {
handle = new cudnnHandle_t[this->group * 3]; handle = new cudnnHandle_t[this->group * 3];
stream = new cudaStream_t[this->group * 3]; stream = new cudaStream_t[this->group * 3];
for (int g = 0; g < this->group * 3; g++) { for (int g = 0; g < this->group * 3; g++) {
...@@ -89,8 +97,10 @@ class CuDNNConvGradientOp : public ConvGradientOp<Context> { ...@@ -89,8 +97,10 @@ class CuDNNConvGradientOp : public ConvGradientOp<Context> {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
CUDNN_CHECK(cudnnCreateConvolutionDescriptor(&conv_desc)); CUDNN_CHECK(cudnnCreateConvolutionDescriptor(&conv_desc));
if (InputSize() > 2) if (HasBias()) CUDNN_CHECK(cudnnCreateTensorDescriptor(&bias_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&bias_desc)); if (this->data_format == "NCHW") format = CUDNN_TENSOR_NCHW;
else if (this->data_format == "NHWC") format = CUDNN_TENSOR_NHWC;
else LOG(FATAL) << "Unknown data format: " << this->data_format;
} }
void RunOnDevice() override; void RunOnDevice() override;
...@@ -99,6 +109,7 @@ class CuDNNConvGradientOp : public ConvGradientOp<Context> { ...@@ -99,6 +109,7 @@ class CuDNNConvGradientOp : public ConvGradientOp<Context> {
protected: protected:
cudnnHandle_t* handle; cudnnHandle_t* handle;
cudaStream_t* stream; cudaStream_t* stream;
cudnnTensorFormat_t format;
cudnnConvolutionBwdFilterAlgo_t bwd_filter_algo; cudnnConvolutionBwdFilterAlgo_t bwd_filter_algo;
cudnnConvolutionBwdDataAlgo_t bwd_data_algo; cudnnConvolutionBwdDataAlgo_t bwd_data_algo;
cudnnTensorDescriptor_t input_desc, output_desc, bias_desc; cudnnTensorDescriptor_t input_desc, output_desc, bias_desc;
......
...@@ -18,53 +18,38 @@ class ConvOpBase : public Operator<Context> { ...@@ -18,53 +18,38 @@ class ConvOpBase : public Operator<Context> {
public: public:
ConvOpBase(const OperatorDef& op_def, Workspace* ws) ConvOpBase(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
num_output(OperatorBase::GetSingleArg<int>("num_output", 1)), data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")),
group(OperatorBase::GetSingleArg<int>("group", 1)) { padding(OperatorBase::GetSingleArg<string>("padding", "VALID")),
num_output(OperatorBase::GetSingleArg<int>("num_output", 1)),
channel_axis = 1, num_spatial_axes = 2; // Conv2D support only Now group(OperatorBase::GetSingleArg<int>("group", 1)),
vector<TIndex> spatial_shape(1, num_spatial_axes); static_dsize(OperatorBase::GetRepeatedArg<int>("static_dsize")),
dynamic_dsize(OperatorBase::GetRepeatedArg<string>("dynamic_dsize")) {
vector<int> ks = OperatorBase::GetRepeatedArg<int>("kernel_size"); if (data_format == "NCHW") spatial_axis = 2;
for (int i = 0; i < num_spatial_axes; i++) else if (data_format == "NHWC") spatial_axis = 1;
kernel_size.push_back(i < ks.size() ? ks[i]: ks[0]); else LOG(FATAL) << "Unknown data format: " << data_format;
num_spatial_axes = -1; // unknown
vector<int> s = OperatorBase::GetRepeatedArg<int>("stride");
for (int i = 0; i < num_spatial_axes; i++)
stride.push_back(i < s.size() ? s[i] : s[0]);
vector<int> p = OperatorBase::GetRepeatedArg<int>("pad");
for (int i = 0; i < num_spatial_axes; i++)
pad.push_back(i < p.size() ? p[i] : p[0]);
vector<int> d = OperatorBase::GetRepeatedArg<int>("dilation");
for (int i = 0; i < num_spatial_axes; i++)
dilation.push_back(i < d.size() ? d[i] : d[0]);
is_1x1 = true;
for (int i = 0; i < num_spatial_axes; i++) {
is_1x1 &= (kernel_size[i] == 1 &&
stride[i] == 1 &&
pad[i] == 0);
if (!is_1x1) break;
}
} }
protected: protected:
vector<TIndex> kernel_size, stride, pad, dilation; vector<TIndex> kernel_size, stride, pad, dilation;
vector<TIndex> input_shape, output_shape, bottom_shape, col_buffer_shape; string data_format, padding;
vector<TIndex> input_shape, output_shape, bottom_shape, top_shape, col_shape;
vector<TIndex> weight_shape, bias_shape; vector<TIndex> weight_shape, bias_shape;
Tensor* col_buffer, *bias_multiplier; Tensor* col_buffer, *bias_multiplier;
TIndex num_output, group; TIndex num_output, group;
TIndex channel_axis, num_spatial_axes; TIndex spatial_axis, num_spatial_axes;
TIndex channels, out_spatial_dim; TIndex channels, out_spatial_dim;
TIndex conv_in_channels, conv_out_channels; TIndex conv_in_channels, conv_out_channels;
TIndex conv_out_spatial_dim, kernel_dim; TIndex conv_out_spatial_dim, kernel_dim;
TIndex col_offset, output_offset, weight_offset, x_offset, y_offset; TIndex col_offset, output_offset, weight_offset, x_offset, y_offset;
vector<int> static_dsize;
vector<string> dynamic_dsize;
bool is_1x1; bool is_1x1;
void Setup();
void Reshape(); void Reshape();
void GradientReshape(); void GradientReshape();
virtual void ComputeOutputShape() = 0; virtual void ComputeOutputShape();
virtual bool ReverseDimensions() = 0; virtual bool ReverseDimensions() = 0;
template <typename T> void Wx(const T* x, const T* weights, T* y, bool skip_im2col = false); template <typename T> void Wx(const T* x, const T* weights, T* y, bool skip_im2col = false);
...@@ -74,25 +59,33 @@ class ConvOpBase : public Operator<Context> { ...@@ -74,25 +59,33 @@ class ConvOpBase : public Operator<Context> {
template <typename T> void Db(const T* dy, T* db); template <typename T> void Db(const T* dy, T* db);
private: private:
template <typename T> void Im2Col(const T* im, T* col_buffer) { template <typename T> void Im2Col(const T* im, T* col) {
kernel::Im2Col<T, Context>(conv_in_channels, if (input(0).ndim() == 4) {
input_shape[0], input_shape[1], kernel::Im2Col2d<T, Context>(conv_in_channels,
kernel_size[0], kernel_size[1], input_shape[0], input_shape[1],
stride[0], stride[1], output_shape[0], output_shape[1],
pad[0], pad[1], kernel_size[0], kernel_size[1],
dilation[0], dilation[1], stride[0], stride[1],
im, pad[0], pad[1],
col_buffer); dilation[0], dilation[1],
data_format,
im,
col);
} else LOG(FATAL) << "ConvNd has not been implemented yet";
} }
template <typename T> void Col2Im(const T* col_buffer, T* im) { template <typename T> void Col2Im(const T* col, T* im) {
kernel::Col2Im<T, Context>(conv_in_channels, if (input(0).ndim() == 4) {
input_shape[0], input_shape[1], kernel::Col2Im2d<T, Context>(conv_in_channels,
kernel_size[0], kernel_size[1], input_shape[0], input_shape[1],
stride[0], stride[1], output_shape[0], output_shape[1],
pad[0], pad[1], kernel_size[0], kernel_size[1],
dilation[0], dilation[1], stride[0], stride[1],
col_buffer, pad[0], pad[1],
im); dilation[0], dilation[1],
data_format,
col,
im);
} else LOG(FATAL) << "ConvNd has not been implemented yet";
} }
}; };
......
...@@ -4,32 +4,40 @@ ...@@ -4,32 +4,40 @@
// Written by Ting Pan // Written by Ting Pan
// -------------------------------------------------------- // --------------------------------------------------------
#ifndef DRAGON_OPERATORS_VISION_DECONV_OP_H_ #ifndef DRAGON_OPERATORS_VISION_CONV_TRANSPOSE_OP_H_
#define DRAGON_OPERATORS_VISION_DECONV_OP_H_ #define DRAGON_OPERATORS_VISION_CONV_TRANSPOSE_OP_H_
#include "operators/vision/conv_op_base.h" #include "operators/vision/conv_op_base.h"
namespace dragon { namespace dragon {
template <class Context> template <class Context>
class DeConvOp: public ConvOpBase<Context> { class Conv2dTransposeOp: public ConvOpBase<Context> {
public: public:
DeConvOp(const OperatorDef& def, Workspace* ws) Conv2dTransposeOp(const OperatorDef& def, Workspace* ws)
: ConvOpBase<Context>(def, ws) {} : ConvOpBase<Context>(def, ws) {
this->num_spatial_axes = 2;
Setup();
}
void ComputeOutputShape() override;
bool ReverseDimensions() override { return true; } bool ReverseDimensions() override { return true; }
virtual bool HasBias() { return InputSize() > 2; }
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected:
vector<int> static_dsize;
vector<string> dynamic_dsize;
}; };
template <class Context> template <class Context>
class DeConvGradientOp : public DeConvOp<Context> { class Conv2dTransposeGradientOp : public Conv2dTransposeOp<Context> {
public: public:
DeConvGradientOp(const OperatorDef& def, Workspace* ws) : Conv2dTransposeGradientOp(const OperatorDef& def, Workspace* ws)
DeConvOp<Context>(def, ws) {} : Conv2dTransposeOp<Context>(def, ws) {}
bool HasBias() override { return output(2)->name() != "ignore"; }
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -40,10 +48,10 @@ class DeConvGradientOp : public DeConvOp<Context> { ...@@ -40,10 +48,10 @@ class DeConvGradientOp : public DeConvOp<Context> {
#include "utils/cudnn_device.h" #include "utils/cudnn_device.h"
template <class Context> template <class Context>
class CuDNNDeConvOp : public DeConvOp<Context> { class CuDNNConv2dTransposeOp : public Conv2dTransposeOp<Context> {
public: public:
CuDNNDeConvOp(const OperatorDef& def, Workspace* ws) CuDNNConv2dTransposeOp(const OperatorDef& def, Workspace* ws)
: DeConvOp<Context>(def, ws) { : Conv2dTransposeOp<Context>(def, ws) {
handle = new cudnnHandle_t[this->group]; handle = new cudnnHandle_t[this->group];
stream = new cudaStream_t[this->group]; stream = new cudaStream_t[this->group];
for (int g = 0; g < this->group; g++) { for (int g = 0; g < this->group; g++) {
...@@ -55,8 +63,10 @@ class CuDNNDeConvOp : public DeConvOp<Context> { ...@@ -55,8 +63,10 @@ class CuDNNDeConvOp : public DeConvOp<Context> {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
CUDNN_CHECK(cudnnCreateConvolutionDescriptor(&conv_desc)); CUDNN_CHECK(cudnnCreateConvolutionDescriptor(&conv_desc));
if (InputSize() > 2) if (HasBias()) CUDNN_CHECK(cudnnCreateTensorDescriptor(&bias_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&bias_desc)); if (this->data_format == "NCHW") format = CUDNN_TENSOR_NCHW;
else if (this->data_format == "NHWC") format = CUDNN_TENSOR_NHWC;
else LOG(FATAL) << "Unknown data format: " << this->data_format;
} }
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -64,6 +74,7 @@ class CuDNNDeConvOp : public DeConvOp<Context> { ...@@ -64,6 +74,7 @@ class CuDNNDeConvOp : public DeConvOp<Context> {
protected: protected:
cudnnHandle_t* handle; cudnnHandle_t* handle;
cudaStream_t* stream; cudaStream_t* stream;
cudnnTensorFormat_t format;
cudnnConvolutionBwdDataAlgo_t fwd_algo; cudnnConvolutionBwdDataAlgo_t fwd_algo;
cudnnTensorDescriptor_t input_desc, output_desc, bias_desc; cudnnTensorDescriptor_t input_desc, output_desc, bias_desc;
cudnnConvolutionDescriptor_t conv_desc; cudnnConvolutionDescriptor_t conv_desc;
...@@ -73,10 +84,10 @@ class CuDNNDeConvOp : public DeConvOp<Context> { ...@@ -73,10 +84,10 @@ class CuDNNDeConvOp : public DeConvOp<Context> {
}; };
template <class Context> template <class Context>
class CuDNNDeConvGradientOp : public DeConvGradientOp<Context> { class CuDNNConv2dTransposeGradientOp : public Conv2dTransposeGradientOp<Context> {
public: public:
CuDNNDeConvGradientOp(const OperatorDef& def, Workspace* ws) CuDNNConv2dTransposeGradientOp(const OperatorDef& def, Workspace* ws)
: DeConvGradientOp<Context>(def, ws) { : Conv2dTransposeGradientOp<Context>(def, ws) {
handle = new cudnnHandle_t[this->group * 3]; handle = new cudnnHandle_t[this->group * 3];
stream = new cudaStream_t[this->group * 3]; stream = new cudaStream_t[this->group * 3];
for (int g = 0; g < this->group * 3; g++) { for (int g = 0; g < this->group * 3; g++) {
...@@ -88,8 +99,10 @@ public: ...@@ -88,8 +99,10 @@ public:
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
CUDNN_CHECK(cudnnCreateConvolutionDescriptor(&conv_desc)); CUDNN_CHECK(cudnnCreateConvolutionDescriptor(&conv_desc));
if (InputSize() > 2) if (HasBias()) CUDNN_CHECK(cudnnCreateTensorDescriptor(&bias_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&bias_desc)); if (this->data_format == "NCHW") format = CUDNN_TENSOR_NCHW;
else if (this->data_format == "NHWC") format = CUDNN_TENSOR_NHWC;
else LOG(FATAL) << "Unknown data format: " << this->data_format;
} }
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -97,6 +110,7 @@ public: ...@@ -97,6 +110,7 @@ public:
protected: protected:
cudnnHandle_t* handle; cudnnHandle_t* handle;
cudaStream_t* stream; cudaStream_t* stream;
cudnnTensorFormat_t format;
cudnnConvolutionBwdFilterAlgo_t bwd_filter_algo; cudnnConvolutionBwdFilterAlgo_t bwd_filter_algo;
cudnnConvolutionFwdAlgo_t bwd_data_algo; cudnnConvolutionFwdAlgo_t bwd_data_algo;
cudnnTensorDescriptor_t input_desc, output_desc, bias_desc; cudnnTensorDescriptor_t input_desc, output_desc, bias_desc;
...@@ -110,4 +124,4 @@ public: ...@@ -110,4 +124,4 @@ public:
} // namespace dragon } // namespace dragon
#endif // DRAGON_OPERATORS_VISION_DECONV_OP_H_ #endif // DRAGON_OPERATORS_VISION_CONV_TRANSPOSE_OP_H_
\ No newline at end of file \ No newline at end of file
...@@ -19,7 +19,12 @@ class NNResizeOp : public Operator<Context> { ...@@ -19,7 +19,12 @@ class NNResizeOp : public Operator<Context> {
static_dsize(OperatorBase::GetRepeatedArg<int>("static_dsize")), static_dsize(OperatorBase::GetRepeatedArg<int>("static_dsize")),
dynamic_dsize(OperatorBase::GetRepeatedArg<string>("dynamic_dsize")), dynamic_dsize(OperatorBase::GetRepeatedArg<string>("dynamic_dsize")),
fy(OperatorBase::GetSingleArg<float>("fy", -1.0)), fy(OperatorBase::GetSingleArg<float>("fy", -1.0)),
fx(OperatorBase::GetSingleArg<float>("fx", -1.0)) {} fx(OperatorBase::GetSingleArg<float>("fx", -1.0)),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {
if (data_format == "NCHW") spatial_axis = 2;
else if (data_format == "NHWC") spatial_axis = 1;
else LOG(FATAL) << "Unknown data format: " << data_format;
}
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -27,18 +32,24 @@ class NNResizeOp : public Operator<Context> { ...@@ -27,18 +32,24 @@ class NNResizeOp : public Operator<Context> {
protected: protected:
vector<int> static_dsize; vector<int> static_dsize;
vector<string> dynamic_dsize; vector<string> dynamic_dsize;
vector<TIndex> dims; float fy, fx;
float h_scale, w_scale, fy, fx; string data_format;
TIndex n, c, h, w, out_h, out_w, spatial_axis;
}; };
template <class Context> template <class Context>
class NNResizeGradientOp : public Operator<Context> { class NNResizeGradientOp : public Operator<Context> {
public: public:
NNResizeGradientOp(const OperatorDef& op_def, Workspace* ws) NNResizeGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws) {} : Operator<Context>(op_def, ws),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {}
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
protected:
string data_format;
TIndex n, c, h, w, out_h, out_w;
}; };
} // namespace dragon } // namespace dragon
......
...@@ -11,14 +11,14 @@ ...@@ -11,14 +11,14 @@
namespace dragon { namespace dragon {
enum PoolingMode { MAX_POOLING, AVG_POOLING };
template <class Context> template <class Context>
class PoolingOp: public Operator <Context> { class Pooling2dOp: public Operator <Context> {
public: public:
PoolingOp(const OperatorDef& op_def, Workspace* ws) Pooling2dOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
mode(PoolingMode(OperatorBase::GetSingleArg<int>("mode", MAX_POOLING))), mode(OperatorBase::GetSingleArg<string>("mode", "MAX")),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")),
padding(OperatorBase::GetSingleArg<string>("padding", "VALID")),
global_pooling(OperatorBase::GetSingleArg<bool>("global_pooling", false)) { global_pooling(OperatorBase::GetSingleArg<bool>("global_pooling", false)) {
vector<int> ks = OperatorBase::GetRepeatedArg<int>("kernel_size"); vector<int> ks = OperatorBase::GetRepeatedArg<int>("kernel_size");
vector<int> s = OperatorBase::GetRepeatedArg<int>("stride"); vector<int> s = OperatorBase::GetRepeatedArg<int>("stride");
...@@ -38,24 +38,25 @@ class PoolingOp: public Operator <Context> { ...@@ -38,24 +38,25 @@ class PoolingOp: public Operator <Context> {
void Reshape(); void Reshape();
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void MaxRunWithType(); template <typename T> void MAXRunWithType();
template <typename T> void AvgRunWithType(); template <typename T> void AVGRunWithType();
protected: protected:
vector<TIndex> kernel_size, stride, pad; vector<TIndex> kernel_size, stride, pad;
Tensor* mask; Tensor* mask;
PoolingMode mode; string mode, data_format, padding;
TIndex num, channels, height, width; TIndex n, c, h, w, pool_h, pool_w;
TIndex pool_height, pool_width;
bool global_pooling; bool global_pooling;
}; };
template <class Context> template <class Context>
class PoolingGradientOp: public Operator<Context> { class Pooling2dGradientOp: public Operator<Context> {
public: public:
PoolingGradientOp(const OperatorDef& op_def, Workspace* ws) Pooling2dGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
mode(PoolingMode(OperatorBase::GetSingleArg<int>("mode", MAX_POOLING))), mode(OperatorBase::GetSingleArg<string>("mode", "MAX")),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")),
padding(OperatorBase::GetSingleArg<string>("padding", "VALID")),
global_pooling(OperatorBase::GetSingleArg<bool>("global_pooling", false)) { global_pooling(OperatorBase::GetSingleArg<bool>("global_pooling", false)) {
vector<int> ks = OperatorBase::GetRepeatedArg<int>("kernel_size"); vector<int> ks = OperatorBase::GetRepeatedArg<int>("kernel_size");
vector<int> s = OperatorBase::GetRepeatedArg<int>("stride"); vector<int> s = OperatorBase::GetRepeatedArg<int>("stride");
...@@ -75,46 +76,36 @@ class PoolingGradientOp: public Operator<Context> { ...@@ -75,46 +76,36 @@ class PoolingGradientOp: public Operator<Context> {
void Reshape(); void Reshape();
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void MaxRunWithType(); template <typename T> void MAXRunWithType();
template <typename T> void AvgRunWithType(); template <typename T> void AVGRunWithType();
protected: protected:
vector<TIndex> kernel_size, stride, pad; vector<TIndex> kernel_size, stride, pad;
Tensor* mask; Tensor* mask;
PoolingMode mode; string mode, data_format, padding;
TIndex num, channels, height, width; TIndex n, c, h, w, pool_h, pool_w;
TIndex pool_height, pool_width;
bool global_pooling; bool global_pooling;
}; };
#ifdef WITH_CUDNN #ifdef WITH_CUDNN
template <class Context> template <class Context>
class CuDNNPoolingOp final : public PoolingOp<Context> { class CuDNNPooling2dOp final : public Pooling2dOp<Context> {
public: public:
CuDNNPoolingOp(const OperatorDef& op_def, Workspace* ws) CuDNNPooling2dOp(const OperatorDef& op_def, Workspace* ws)
: PoolingOp<Context>(op_def, ws) { : Pooling2dOp<Context>(op_def, ws) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
CUDNN_CHECK(cudnnCreatePoolingDescriptor(&pool_desc)); CUDNN_CHECK(cudnnCreatePoolingDescriptor(&pool_desc));
pool_mode = this->mode == MAX_POOLING ? if (this->mode == "MAX") {
CUDNN_POOLING_MAX : #if CUDNN_VERSION_MIN(6,0,0)
CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; pool_mode = CUDNN_POOLING_MAX_DETERMINISTIC;
#if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnSetPooling2dDescriptor(pool_desc,
pool_mode,
CUDNN_PROPAGATE_NAN,
this->kernel_size[0], this->kernel_size[1],
this->pad[0], this->pad[1],
this->stride[0], this->stride[1]));
#else #else
CUDNN_CHECK(cudnnSetPooling2dDescriptor_v4(pool_desc, pool_mode = CUDNN_POOLING_MAX;
pool_mode,
CUDNN_PROPAGATE_NAN,
this->kernel_size[0], this->kernel_size[1],
this->pad[0], this->pad[1],
this->stride[0], this->stride[1]));
#endif #endif
} else if (this->mode == "AVG") {
pool_mode = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
} else LOG(FATAL) << "Unsupported pooling mode: " << this->mode;
} }
void RunOnDevice() override; void RunOnDevice() override;
...@@ -122,34 +113,40 @@ class CuDNNPoolingOp final : public PoolingOp<Context> { ...@@ -122,34 +113,40 @@ class CuDNNPoolingOp final : public PoolingOp<Context> {
protected: protected:
cudnnTensorDescriptor_t input_desc, output_desc; cudnnTensorDescriptor_t input_desc, output_desc;
cudnnPoolingDescriptor_t pool_desc; cudnnPoolingDescriptor_t pool_desc;
cudnnPoolingMode_t pool_mode; cudnnPoolingMode_t pool_mode;
}; };
template <class Context> template <class Context>
class CuDNNPoolingGradientOp final : public PoolingGradientOp<Context> { class CuDNNPooling2dGradientOp final : public Pooling2dGradientOp<Context> {
public: public:
CuDNNPoolingGradientOp(const OperatorDef& op_def, Workspace* ws) CuDNNPooling2dGradientOp(const OperatorDef& op_def, Workspace* ws)
: PoolingGradientOp<Context>(op_def, ws) { : Pooling2dGradientOp<Context>(op_def, ws) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
CUDNN_CHECK(cudnnCreatePoolingDescriptor(&pool_desc)); CUDNN_CHECK(cudnnCreatePoolingDescriptor(&pool_desc));
pool_mode = this->mode == MAX_POOLING ? if (this->mode == "MAX") {
CUDNN_POOLING_MAX : #if CUDNN_VERSION_MIN(6,0,0)
CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; pool_mode = CUDNN_POOLING_MAX_DETERMINISTIC;
#else
pool_mode = CUDNN_POOLING_MAX;
#endif
} else if (this->mode == "AVG") {
pool_mode = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
} else LOG(FATAL) << "Unsupported pooling mode: " << this->mode;
#if CUDNN_VERSION_MIN(5, 0, 0) #if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnSetPooling2dDescriptor(pool_desc, CUDNN_CHECK(cudnnSetPooling2dDescriptor(pool_desc,
pool_mode, pool_mode,
CUDNN_PROPAGATE_NAN, CUDNN_PROPAGATE_NAN,
this->kernel_size[0], this->kernel_size[1], this->kernel_size[0], this->kernel_size[1],
this->pad[0], this->pad[1], this->pad[0], this->pad[1],
this->stride[0], this->stride[1])); this->stride[0], this->stride[1]));
#else #else
CUDNN_CHECK(cudnnSetPooling2dDescriptor_v4(pool_desc, CUDNN_CHECK(cudnnSetPooling2dDescriptor_v4(pool_desc,
pool_mode, pool_mode,
CUDNN_PROPAGATE_NAN, CUDNN_PROPAGATE_NAN,
this->kernel_size[0], this->kernel_size[1], this->kernel_size[0], this->kernel_size[1],
this->pad[0], this->pad[1], this->pad[0], this->pad[1],
this->stride[0], this->stride[1])); this->stride[0], this->stride[1]));
#endif #endif
} }
...@@ -159,8 +156,8 @@ class CuDNNPoolingGradientOp final : public PoolingGradientOp<Context> { ...@@ -159,8 +156,8 @@ class CuDNNPoolingGradientOp final : public PoolingGradientOp<Context> {
protected: protected:
cudnnTensorDescriptor_t input_desc, output_desc; cudnnTensorDescriptor_t input_desc, output_desc;
cudnnPoolingDescriptor_t pool_desc; cudnnPoolingDescriptor_t pool_desc;
cudnnPoolingMode_t pool_mode; cudnnPoolingMode_t pool_mode;
}; };
#endif // WITH_CUDNN #endif // WITH_CUDNN
......
...@@ -61,11 +61,29 @@ template <typename T> ...@@ -61,11 +61,29 @@ template <typename T>
void cudnnSetTensorDesc(cudnnTensorDescriptor_t* desc, Tensor* tensor); void cudnnSetTensorDesc(cudnnTensorDescriptor_t* desc, Tensor* tensor);
template <typename T> template <typename T>
void cudnnSetTensor4dDesc(cudnnTensorDescriptor_t* desc, const string& data_format, Tensor* tensor);
template <typename T>
void cudnnSetTensor5dDesc(cudnnTensorDescriptor_t* desc, const string& data_format, Tensor* tensor);
template <typename T>
void cudnnSetTensor3dDesc(cudnnTensorDescriptor_t* desc, const string& data_format, Tensor* tensor);
template <typename T>
void cudnnSetTensorDesc(cudnnTensorDescriptor_t* desc, const std::vector<int64_t>& dims); void cudnnSetTensorDesc(cudnnTensorDescriptor_t* desc, const std::vector<int64_t>& dims);
template <typename T> template <typename T>
void cudnnSetTensorDesc(cudnnTensorDescriptor_t* desc, void cudnnSetTensor4dDesc(cudnnTensorDescriptor_t* desc, const string& data_format, const std::vector<int64_t>& dims);
const std::vector<int64_t>& dims,
template <typename T>
void cudnnSetTensor5dDesc(cudnnTensorDescriptor_t* desc, const string& data_format, const std::vector<int64_t>& dims);
template <typename T>
void cudnnSetTensor3dDesc(cudnnTensorDescriptor_t* desc, const string& data_format, const std::vector<int64_t>& dims);
template <typename T>
void cudnnSetTensorDesc(cudnnTensorDescriptor_t* desc,
const std::vector<int64_t>& dims,
const std::vector<int64_t>& strides); const std::vector<int64_t>& strides);
} }
......
...@@ -54,7 +54,7 @@ class TruncatedNormalFiller final : public Filler < T, Context > { ...@@ -54,7 +54,7 @@ class TruncatedNormalFiller final : public Filler < T, Context > {
public: public:
TruncatedNormalFiller(const TensorFiller& filler): Filler<T, Context>(filler) {} TruncatedNormalFiller(const TensorFiller& filler): Filler<T, Context>(filler) {}
void Fill(Tensor* tensor) override { void Fill(Tensor* tensor) override {
// implement of gpu is diffcult // implement it on gpu is difficult
math::RandomTruncatedNormal<T, CPUContext>(tensor->count(), math::RandomTruncatedNormal<T, CPUContext>(tensor->count(),
filler().mean(), filler().mean(),
filler().std(), filler().std(),
......
...@@ -809,7 +809,7 @@ class Tensor(object): ...@@ -809,7 +809,7 @@ class Tensor(object):
if self.shape is not None: if self.shape is not None:
output.shape = input_shape[:] output.shape = input_shape[:]
output.shape.insert(axis, 1L) output.shape.insert(axis, np.long(1))
return output return output
......
...@@ -35,6 +35,7 @@ else: ...@@ -35,6 +35,7 @@ else:
argument.name = key argument.name = key
if type(value) is float: argument.f = value if type(value) is float: argument.f = value
elif type(value) is int: argument.i = value elif type(value) is int: argument.i = value
elif type(value) is long: argument.i = value
elif type(value) is np.int64: argument.i64 = int(value) elif type(value) is np.int64: argument.i64 = int(value)
elif type(value) is str: argument.s = value elif type(value) is str: argument.s = value
elif type(value) is unicode: argument.s = value elif type(value) is unicode: argument.s = value
...@@ -42,6 +43,7 @@ else: ...@@ -42,6 +43,7 @@ else:
elif isinstance(value, Message): argument.s = value.SerializeToString() elif isinstance(value, Message): argument.s = value.SerializeToString()
elif all(type(v) is float for v in value): argument.floats.extend(value) elif all(type(v) is float for v in value): argument.floats.extend(value)
elif all(type(v) is int for v in value): argument.ints.extend(value) elif all(type(v) is int for v in value): argument.ints.extend(value)
elif all(type(v) is long for v in value): argument.ints.extend(value)
elif all(type(v) is str for v in value): argument.strings.extend(value) elif all(type(v) is str for v in value): argument.strings.extend(value)
elif all(type(v) is unicode or type(v) is str for v in value): elif all(type(v) is unicode or type(v) is str for v in value):
argument.strings.extend(value) argument.strings.extend(value)
......
...@@ -269,7 +269,6 @@ def FeedTensor(tensor, ndarray, force_cpu=False, dtype=None): ...@@ -269,7 +269,6 @@ def FeedTensor(tensor, ndarray, force_cpu=False, dtype=None):
format(preset_dtype, dtype)) format(preset_dtype, dtype))
auto_dtype = preset_dtype auto_dtype = preset_dtype
ndarray = np.array(ndarray, dtype=auto_dtype) ndarray = np.array(ndarray, dtype=auto_dtype)
if hasattr(tensor, 'shape'): tensor.shape = list(ndarray.shape)
FeedTensorCC(name, ndarray, _stringify_proto(dev)) FeedTensorCC(name, ndarray, _stringify_proto(dev))
......
...@@ -11,7 +11,7 @@ Data ...@@ -11,7 +11,7 @@ Data
List Brief List Brief
============== ======================================================================== ============== ========================================================================
`LMDBData`_ Prefetch Image data with LMDB database. `LMDBData`_ Prefetch Image data with LMDB database.
`MemoryData`_ Perform ``NHWC <-> NCHW``, ``Mean Subtraction`` and ``Type Converting``. `ImageData`_ Process the images from 4D raw data.
============== ======================================================================== ============== ========================================================================
Initializer Initializer
...@@ -185,7 +185,7 @@ List Brief ...@@ -185,7 +185,7 @@ List Brief
.. _LMDBData: operators/data.html#dragon.operators.data.LMDBData .. _LMDBData: operators/data.html#dragon.operators.data.LMDBData
.. _MemoryData: operators/data.html#dragon.operators.data.MemoryData .. _ImageData: operators/data.html#dragon.operators.data.ImageData
.. _Fill: operators/initializer.html#dragon.operators.initializer.Fill .. _Fill: operators/initializer.html#dragon.operators.initializer.Fill
.. _RandomUniform: operators/initializer.html#dragon.operators.initializer.RandomUniform .. _RandomUniform: operators/initializer.html#dragon.operators.initializer.RandomUniform
......
...@@ -74,25 +74,39 @@ def LMDBData(**kwargs): ...@@ -74,25 +74,39 @@ def LMDBData(**kwargs):
return Run([], param_str=str(kwargs), nout=2, **arguments) return Run([], param_str=str(kwargs), nout=2, **arguments)
def MemoryData(inputs, dtype=np.float32, **kwargs): def ImageData(inputs, mean_values=None, std_values=None,
"""Perform ``NHWC <-> NCHW``, ``Mean Subtraction`` and ``Type Converting``. dtype='FLOAT32', data_format='NCHW', **kwargs):
"""Process the images from 4D raw data.
Note that we assume the data format of raw data is **NHWC**.
Parameters Parameters
---------- ----------
inputs : Tensor inputs : Tensor
The input tensor, with type of uint8 or float32. The input tensor, with type of **uint8** or **float32**.
dtype : np.float32 or np.float16 mean_values : list of float or None
The dtype of output tensor. The optional mean values to subtract.
std_values : list of float or None
The optional std values to divide.
dtype : str
The type of output. ``FLOAT32`` or ``FLOAT16``.
data_format : str
The data format of output. ``NCHW`` or ``NHWC``.
Returns Returns
------- -------
Tensor Tensor
The post-processing Tensor. The output tensor.
""" """
arguments = ParseArguments(locals()) arguments = ParseArguments(locals())
if dtype is np.float32: arguments['dtype'] = 1 if mean_values is not None:
elif dtype is np.float16: arguments['dtype'] = 12 if len(mean_values) != 3:
else: raise TypeError('Unsupported data type.') raise ValueError('The length of mean values should be 3.')
arguments['mean_values'] = [float(v) for v in mean_values]
return Tensor.CreateOperator(nout=1, op_type='MemoryData', **arguments) if std_values is not None:
\ No newline at end of file if len(std_values) != 3:
raise ValueError('The length of std values should be 3.')
arguments['std_values'] = [float(v) for v in std_values]
return Tensor.CreateOperator(nout=1, op_type='ImageData', **arguments)
\ No newline at end of file
...@@ -18,7 +18,7 @@ def SparseSoftmaxCrossEntropy(inputs, axis=1, normalization='VALID', ignore_labe ...@@ -18,7 +18,7 @@ def SparseSoftmaxCrossEntropy(inputs, axis=1, normalization='VALID', ignore_labe
axis : int axis : int
The axis of softmax function. The axis of softmax function.
normalization : str normalization : str
The normalization, ``UINT``, ``FULL``, ``VALID``, ``BATCH_SIZE`` or ``NONE``. The normalization, ``UNIT``, ``FULL``, ``VALID``, ``BATCH_SIZE`` or ``NONE``.
ignore_label : tuple or list ignore_label : tuple or list
The label id to ignore. Default is ``empty``. The label id to ignore. Default is ``empty``.
...@@ -29,7 +29,7 @@ def SparseSoftmaxCrossEntropy(inputs, axis=1, normalization='VALID', ignore_labe ...@@ -29,7 +29,7 @@ def SparseSoftmaxCrossEntropy(inputs, axis=1, normalization='VALID', ignore_labe
Notes Notes
----- -----
Set the normalization to ``UINT`` will return unreduced losses. Set the normalization to ``UNIT`` will return unreduced losses.
""" """
CheckInputs(inputs, 2) CheckInputs(inputs, 2)
...@@ -56,7 +56,7 @@ def SigmoidCrossEntropy(inputs, normalization='FULL', **kwargs): ...@@ -56,7 +56,7 @@ def SigmoidCrossEntropy(inputs, normalization='FULL', **kwargs):
inputs : list of Tensor inputs : list of Tensor
The inputs, represent [input, labels]. The inputs, represent [input, labels].
normalization : str normalization : str
The normalization, ``UINT``, ``FULL``, ``BATCH_SIZE`` or ``NONE``. The normalization, ``UNIT``, ``FULL``, ``BATCH_SIZE`` or ``NONE``.
Returns Returns
------- -------
...@@ -65,7 +65,7 @@ def SigmoidCrossEntropy(inputs, normalization='FULL', **kwargs): ...@@ -65,7 +65,7 @@ def SigmoidCrossEntropy(inputs, normalization='FULL', **kwargs):
Notes Notes
----- -----
Set the normalization to ``UINT`` will return unreduced losses. Set the normalization to ``UNIT`` will return unreduced losses.
""" """
CheckInputs(inputs, 2) CheckInputs(inputs, 2)
...@@ -90,7 +90,7 @@ def SoftmaxCrossEntropy(inputs, axis=1, normalization='FULL', **kwargs): ...@@ -90,7 +90,7 @@ def SoftmaxCrossEntropy(inputs, axis=1, normalization='FULL', **kwargs):
axis : int axis : int
The axis of softmax function. The axis of softmax function.
normalization : str normalization : str
The normalization, ``UINT``, ``FULL``, ``BATCH_SIZE`` or ``NONE``. The normalization, ``UNIT``, ``FULL``, ``BATCH_SIZE`` or ``NONE``.
Returns Returns
------- -------
...@@ -99,7 +99,7 @@ def SoftmaxCrossEntropy(inputs, axis=1, normalization='FULL', **kwargs): ...@@ -99,7 +99,7 @@ def SoftmaxCrossEntropy(inputs, axis=1, normalization='FULL', **kwargs):
Notes Notes
----- -----
Set the normalization to ``UINT`` will return unreduced losses. Set the normalization to ``UNIT`` will return unreduced losses.
""" """
CheckInputs(inputs, 2) CheckInputs(inputs, 2)
...@@ -213,13 +213,13 @@ def SparseSoftmaxFocalLoss(inputs, axis=1, normalization='VALID', ignore_labels= ...@@ -213,13 +213,13 @@ def SparseSoftmaxFocalLoss(inputs, axis=1, normalization='VALID', ignore_labels=
axis : int axis : int
The axis of softmax function. The axis of softmax function.
normalization : str normalization : str
The normalization, ``UINT``, ``FULL``, ``VALID``, ``BATCH_SIZE`` or ``NONE``. The normalization, ``UNIT``, ``FULL``, ``VALID``, ``BATCH_SIZE`` or ``NONE``.
ignore_label : tuple or list ignore_label : tuple or list
The label id to ignore. Default is ``empty``. The label id to ignore. Default is ``empty``.
alpha : float alpha : float
The scale factor on the rare class. Default is ``0.5``. The scale factor on the rare class. Default is ``0.5``.
gamma : float gamma : float
The exponetial decay factor on the easy examples. Default is ``2.0``. The exponential decay factor on the easy examples. Default is ``2.0``.
eps : float eps : float
The eps. The eps.
neg_id : int neg_id : int
...@@ -232,7 +232,7 @@ def SparseSoftmaxFocalLoss(inputs, axis=1, normalization='VALID', ignore_labels= ...@@ -232,7 +232,7 @@ def SparseSoftmaxFocalLoss(inputs, axis=1, normalization='VALID', ignore_labels=
Notes Notes
----- -----
Set the normalization to ``UINT`` will return unreduced losses. Set the normalization to ``UNIT`` will return unreduced losses.
""" """
CheckInputs(inputs, 2) CheckInputs(inputs, 2)
......
...@@ -80,7 +80,7 @@ def MPIGather(inputs, root, mpi_ranks=None, **kwargs): ...@@ -80,7 +80,7 @@ def MPIGather(inputs, root, mpi_ranks=None, **kwargs):
if mpi_ranks is None: if mpi_ranks is None:
num_nodes = mpi.Size() num_nodes = mpi.Size()
mpi_rank = [i for i in xrange(0, num_nodes)] mpi_ranks = [i for i in xrange(0, num_nodes)]
if not isinstance(mpi_ranks, list): mpi_ranks = [mpi_ranks] if not isinstance(mpi_ranks, list): mpi_ranks = [mpi_ranks]
comm, group = mpi.CreateGroup(root, incl=mpi_ranks) comm, group = mpi.CreateGroup(root, incl=mpi_ranks)
......
...@@ -9,8 +9,9 @@ from six.moves import range as xrange ...@@ -9,8 +9,9 @@ from six.moves import range as xrange
from . import * from . import *
def Conv2D(inputs, num_output, kernel_size, def Conv2d(inputs, num_output, kernel_size,
stride=1, pad=0, dilation=1, group=1, **kwargs): stride=1, pad=0, dilation=1, group=1,
padding='VALID', data_format='NCHW', **kwargs):
"""2D Convolution. """2D Convolution.
The number of inputs vary from ``2`` to ``3`` (Without or With ``bias``). The number of inputs vary from ``2`` to ``3`` (Without or With ``bias``).
...@@ -19,6 +20,8 @@ def Conv2D(inputs, num_output, kernel_size, ...@@ -19,6 +20,8 @@ def Conv2D(inputs, num_output, kernel_size,
|conv_output_dim| |conv_output_dim|
Set ``padding`` to **VALID** will use the value of ``pad``.
Parameters Parameters
---------- ----------
inputs : list of Tensor inputs : list of Tensor
...@@ -35,21 +38,25 @@ def Conv2D(inputs, num_output, kernel_size, ...@@ -35,21 +38,25 @@ def Conv2D(inputs, num_output, kernel_size,
The dilation multiple(s) of convolution. Default is ``1``. The dilation multiple(s) of convolution. Default is ``1``.
group : int group : int
The group size of convolution. Default is ``1``. The group size of convolution. Default is ``1``.
padding : str
The padding algorithm. ``VALID`` or ``SAME``.
data_format : str
The data format. ``NCHW`` or ``NHWC``.
Returns Returns
------- -------
Tensor Tensor
The tensor of 2d convolution. The output tensor.
Examples Examples
-------- --------
>>> input = Tensor().Variable() >>> input = Tensor().Variable()
>>> weights = Tensor().Normal(std=0.001) >>> weights = Tensor().Normal(std=0.001)
>>> biases = Tensor().Constant(value=0) >>> biases = Tensor().Constant(value=0)
>>> conv1 = Conv2D([input, weights, biases], num_output=64, kernel_size=3) >>> conv1 = Conv2d([input, weights, biases], num_output=64, kernel_size=3)
>>> weights = Tensor().Gaussian(std=0.001) >>> weights = Tensor().Gaussian(std=0.001)
>>> conv2 = Conv2D([conv1, weights], num_output=128, kernel_size=3, stride=1) >>> conv2 = Conv2d([conv1, weights], num_output=128, kernel_size=3, stride=1)
""" """
CheckInputs(inputs, 2, 3) CheckInputs(inputs, 2, 3)
...@@ -63,7 +70,7 @@ def Conv2D(inputs, num_output, kernel_size, ...@@ -63,7 +70,7 @@ def Conv2D(inputs, num_output, kernel_size,
if not isinstance(arguments['dilation'], list): if not isinstance(arguments['dilation'], list):
arguments['dilation'] = [arguments['dilation']] arguments['dilation'] = [arguments['dilation']]
output = Tensor.CreateOperator(nout=1, op_type='Conv', **arguments) output = Tensor.CreateOperator(nout=1, op_type='Conv2d', **arguments)
if inputs[0].shape is not None: if inputs[0].shape is not None:
output.shape = inputs[0].shape[:] output.shape = inputs[0].shape[:]
...@@ -83,8 +90,9 @@ def Conv2D(inputs, num_output, kernel_size, ...@@ -83,8 +90,9 @@ def Conv2D(inputs, num_output, kernel_size,
return output return output
def Deconv2D(inputs, num_output, kernel_size, def Conv2dTranspose(inputs, num_output, kernel_size,
stride=1, pad=0, dilation=1, group=1, **kwargs): stride=1, pad=0, dilation=1, group=1, output_shape=None,
padding='VALID', data_format='NCHW', **kwargs):
"""2D Deconvolution. """2D Deconvolution.
The number of inputs vary from ``2`` to ``3`` (Without or With ``bias``). The number of inputs vary from ``2`` to ``3`` (Without or With ``bias``).
...@@ -93,6 +101,10 @@ def Deconv2D(inputs, num_output, kernel_size, ...@@ -93,6 +101,10 @@ def Deconv2D(inputs, num_output, kernel_size,
|deconv_output_dim| |deconv_output_dim|
Set ``padding`` to **VALID** will use the value of ``pad``.
Provide ``output_shape`` if using **SAME** padding.
Parameters Parameters
---------- ----------
inputs : list of Tensor inputs : list of Tensor
...@@ -109,26 +121,46 @@ def Deconv2D(inputs, num_output, kernel_size, ...@@ -109,26 +121,46 @@ def Deconv2D(inputs, num_output, kernel_size,
The dilation multiple(s) of deconvolution. Default is ``1``. The dilation multiple(s) of deconvolution. Default is ``1``.
group : int group : int
The group size of deconvolution. Default is ``1``. The group size of deconvolution. Default is ``1``.
output_shape : list of int or None
The deterministic output shape for **SAME** padding.
padding : str
The padding algorithm. ``VALID`` or ``SAME``.
data_format : str
The data format. ``NCHW`` or ``NHWC``.
Returns Returns
------- -------
Tensor Tensor
The tensor of 2d deconvolution. The output tensor.
Examples Examples
-------- --------
>>> input = Tensor().Variable() >>> input = Tensor().Variable()
>>> weights = Tensor().Normal(std=0.001) >>> weights = Tensor().Normal(std=0.001)
>>> biases = Tensor().Constant(value=0) >>> biases = Tensor().Constant(value=0)
>>> deconv1 = Deconv2D([input, weights, biases], num_output=64, kernel_size=3) >>> deconv1 = Conv2dTranspose([input, weights, biases], num_output=64, kernel_size=3)
>>> weights = Tensor().Gaussian(std=0.001) >>> weights = Tensor().Gaussian(std=0.001)
>>> deconv2 = Deconv2D([conv1, weights], num_output=128, kernel_size=3, stride=1) >>> deconv2 = Conv2dTranspose([deconv1, weights], num_output=128, kernel_size=3, stride=1)
""" """
CheckInputs(inputs, 2, 3) CheckInputs(inputs, 2, 3)
arguments = ParseArguments(locals()) arguments = ParseArguments(locals())
arguments['output_shape'] = None
if output_shape is not None:
if not isinstance(output_shape, list):
raise TypeError('The output shape should be a list.')
if isinstance(output_shape[0], Tensor):
arguments['dynamic_dsize'] = []
arguments['extra_inputs'] = list(output_shape)
for dim in output_shape:
arguments['dynamic_dsize'].append(dim)
else:
arguments['static_dsize'] = []
for dim in output_shape:
arguments['static_dsize'].append(int(dim))
if not isinstance(arguments['kernel_size'], list): if not isinstance(arguments['kernel_size'], list):
arguments['kernel_size'] = [arguments['kernel_size']] arguments['kernel_size'] = [arguments['kernel_size']]
...@@ -141,44 +173,48 @@ def Deconv2D(inputs, num_output, kernel_size, ...@@ -141,44 +173,48 @@ def Deconv2D(inputs, num_output, kernel_size,
if not isinstance(arguments['dilation'], list): if not isinstance(arguments['dilation'], list):
arguments['dilation'] = [arguments['dilation']] arguments['dilation'] = [arguments['dilation']]
return Tensor.CreateOperator(nout=1, op_type='DeConv', **arguments) return Tensor.CreateOperator(nout=1, op_type='Conv2dTranspose', **arguments)
def Pool2D(inputs, kernel_size, stride, pad=0, def Pool2d(inputs, kernel_size, stride, pad=0, padding='VALID',
mode='MAX_POOLING', global_pooling=False, **kwargs): mode='MAX', data_format='NCHW', global_pooling=False, **kwargs):
"""2D Pooling, MAX or AVG. """2D Pooling, MAX or AVG.
The spatial output dimension of pooling can be computed as follows: The spatial output dimension of pooling can be computed as follows:
|pooling_output_dim| |pooling_output_dim|
Set ``padding`` to **VALID** will use the value of ``pad``.
If use ``global_pooling``, the stride and pad will be set to ``1`` and ``0`` automatically. If use ``global_pooling``, the stride and pad will be set to ``1`` and ``0`` automatically.
Parameters Parameters
---------- ----------
inputs : Tensor inputs : Tensor
The tensor to down-sample. The input tensor.
kernel_size : int or list kernel_size : int or list
The kernel size(s) of pooling. The kernel size(s) of pooling.
stride : int or list stride : int or list
The stride(s) of of pooling, The stride(s) of of pooling,
pad : int or list pad : int or list
The zero padding size(s) of pooling. Default is ``0``. The zero padding size(s) of pooling. Default is ``0``.
padding : str
The padding algorithm. ``VALID`` or ``SAME``.
mode : str mode : str
The mode, ``MAX_POOLING`` or ``AVG_POOLING``. The mode, ``MAX`` or ``AVG``.
data_format : str
The data format, ``NCHW`` or ``NHWC``.
global_pooling : boolean global_pooling : boolean
Whether to use global pooling. Whether to use global pooling.
Returns Returns
------- -------
Tensor Tensor
The down-sampled tensor. The output tensor.
""" """
CheckInputs(inputs, 1) CheckInputs(inputs, 1)
arguments = ParseArguments(locals()) arguments = ParseArguments(locals())
SUPPORT_MODES = {'MAX_POOLING': 0, 'AVG_POOLING': 1}
arguments['mode'] = SUPPORT_MODES[mode]
if not isinstance(arguments['kernel_size'], list): if not isinstance(arguments['kernel_size'], list):
arguments['kernel_size'] = [arguments['kernel_size']] arguments['kernel_size'] = [arguments['kernel_size']]
if not isinstance(arguments['stride'], list): if not isinstance(arguments['stride'], list):
...@@ -186,10 +222,11 @@ def Pool2D(inputs, kernel_size, stride, pad=0, ...@@ -186,10 +222,11 @@ def Pool2D(inputs, kernel_size, stride, pad=0,
if not isinstance(arguments['pad'], list): if not isinstance(arguments['pad'], list):
arguments['pad'] = [arguments['pad']] arguments['pad'] = [arguments['pad']]
output = Tensor.CreateOperator(nout=1, op_type='Pooling', **arguments) output = Tensor.CreateOperator(nout=1, op_type='Pooling2d', **arguments)
if inputs.shape is not None: if inputs.shape is not None:
output.shape = inputs.shape[:] output.shape = inputs.shape[:]
axis = 2 if data_format == 'NCHW' else 1
for i in xrange(2): for i in xrange(2):
k = arguments['kernel_size'][i] if i < len(arguments['kernel_size']) \ k = arguments['kernel_size'][i] if i < len(arguments['kernel_size']) \
else arguments['kernel_size'][-1] else arguments['kernel_size'][-1]
...@@ -197,10 +234,17 @@ def Pool2D(inputs, kernel_size, stride, pad=0, ...@@ -197,10 +234,17 @@ def Pool2D(inputs, kernel_size, stride, pad=0,
else arguments['stride'][-1] else arguments['stride'][-1]
p = arguments['pad'][i] if i < len(arguments['pad']) \ p = arguments['pad'][i] if i < len(arguments['pad']) \
else arguments['pad'][-1] else arguments['pad'][-1]
if padding == 'SAME':
input_size = output.shape[i + axis]
output_size = (input_size + s - 1) / float(s)
padding_needed = max(0, (output_size - 1) * s + k - input_size)
p_l = padding_needed / 2
p_r = padding_needed - p_l
p = min(p_l, p_r)
if not global_pooling: if not global_pooling:
output.shape[i + 2] = int(math.ceil(float(output.shape[i + 2] + 2 * p - k) / s) + 1) output.shape[i + axis] = int(math.ceil(float(output.shape[i + axis] + 2 * p - k) / s) + 1)
else: else:
output.shape[i + 2] = 1 output.shape[i + axis] = 1
return output return output
...@@ -296,7 +340,7 @@ def LRN(inputs, local_size=5, alpha=0.0001, beta=0.75, k=2.0, mode='ACROSS_CHANN ...@@ -296,7 +340,7 @@ def LRN(inputs, local_size=5, alpha=0.0001, beta=0.75, k=2.0, mode='ACROSS_CHANN
return output return output
def NNResize(inputs, dsize, fy=-1.0, fx=-1.0, **kwargs): def NNResize(inputs, dsize, fy=-1.0, fx=-1.0, data_format='NCHW', **kwargs):
"""Resize the image with Nearest-Neighbor method. """Resize the image with Nearest-Neighbor method.
Set ``dsize`` to None if you want to use ``fy`` and ``fx``. Set ``dsize`` to None if you want to use ``fy`` and ``fx``.
...@@ -306,16 +350,18 @@ def NNResize(inputs, dsize, fy=-1.0, fx=-1.0, **kwargs): ...@@ -306,16 +350,18 @@ def NNResize(inputs, dsize, fy=-1.0, fx=-1.0, **kwargs):
inputs : Tensor inputs : Tensor
The input tenosr. The input tenosr.
dsize : tuple, list, Tensor or None dsize : tuple, list, Tensor or None
The dest output size. The output size.
fy : float fy : float
The scale factor based on src height. Default is ``-1.0`` (Discarded). The scale factor based on src height. Default is ``-1.0`` (Discarded).
fx : float fx : float
The scale factor based on src width. Default is ``-1.0`` (Discarded). The scale factor based on src width. Default is ``-1.0`` (Discarded).
data_format : str
The data_format. ``NCHW`` or ``NHWC``.
Returns Returns
------- -------
Tensor Tensor
The resized tensor. The output tensor.
""" """
CheckInputs(inputs, 1) CheckInputs(inputs, 1)
...@@ -337,7 +383,7 @@ def NNResize(inputs, dsize, fy=-1.0, fx=-1.0, **kwargs): ...@@ -337,7 +383,7 @@ def NNResize(inputs, dsize, fy=-1.0, fx=-1.0, **kwargs):
return output return output
def BilinearResize(inputs, dsize, fy=-1.0, fx=-1.0, **kwargs): def BilinearResize(inputs, dsize, fy=-1.0, fx=-1.0, data_format='NCHW', **kwargs):
"""Resize the image with Bi-linear method. """Resize the image with Bi-linear method.
Set ``dsize`` to None if you want to use ``fy`` and ``fx``. Set ``dsize`` to None if you want to use ``fy`` and ``fx``.
...@@ -352,11 +398,13 @@ def BilinearResize(inputs, dsize, fy=-1.0, fx=-1.0, **kwargs): ...@@ -352,11 +398,13 @@ def BilinearResize(inputs, dsize, fy=-1.0, fx=-1.0, **kwargs):
The scale factor based on src height. Default is ``-1.0`` (Discarded). The scale factor based on src height. Default is ``-1.0`` (Discarded).
fx : float fx : float
The scale factor based on src width. Default is ``-1.0`` (Discarded). The scale factor based on src width. Default is ``-1.0`` (Discarded).
data_format : str
The data_format. ``NCHW`` or ``NHWC``.
Returns Returns
------- -------
Tensor Tensor
The resized tensor. The output tensor.
""" """
CheckInputs(inputs, 1) CheckInputs(inputs, 1)
...@@ -383,7 +431,7 @@ def BiasAdd(inputs, data_format='NCHW', **kwargs): ...@@ -383,7 +431,7 @@ def BiasAdd(inputs, data_format='NCHW', **kwargs):
Parameters Parameters
---------- ----------
inputs : Tensor inputs : list of Tensor
The inputs, represent [input, bias]. The inputs, represent [input, bias].
data_format : str data_format : str
The data format, ``NCHW`` or ``NHWC``. The data format, ``NCHW`` or ``NHWC``.
...@@ -394,7 +442,7 @@ def BiasAdd(inputs, data_format='NCHW', **kwargs): ...@@ -394,7 +442,7 @@ def BiasAdd(inputs, data_format='NCHW', **kwargs):
The bias-added tensor. The bias-added tensor.
""" """
CheckInputs(inputs, 1) CheckInputs(inputs, 2)
arguments = ParseArguments(locals()) arguments = ParseArguments(locals())
output = Tensor.CreateOperator(nout=1, op_type='BiasAdd', **arguments) output = Tensor.CreateOperator(nout=1, op_type='BiasAdd', **arguments)
......
...@@ -20,7 +20,7 @@ from .operators import recurrent ...@@ -20,7 +20,7 @@ from .operators import recurrent
# data # data
LMDBData = data.LMDBData LMDBData = data.LMDBData
MemoryData = data.MemoryData ImageData = data.ImageData
# init # init
Fill = init.Fill Fill = init.Fill
...@@ -31,9 +31,10 @@ GlorotUniform = init.GlorotUniform ...@@ -31,9 +31,10 @@ GlorotUniform = init.GlorotUniform
GlorotNormal = init.GlorotNormal GlorotNormal = init.GlorotNormal
# vision # vision
Conv2D = vision.Conv2D Conv2d = vision.Conv2d
Deconv2D = vision.Deconv2D Conv2dTranspose = vision.Conv2dTranspose
Pool2D = vision.Pool2D Deconv2d = vision.Conv2dTranspose
Pool2d = vision.Pool2d
ROIPooling = vision.ROIPooling ROIPooling = vision.ROIPooling
ROIAlign = vision.ROIAlign ROIAlign = vision.ROIAlign
LRN = vision.LRN LRN = vision.LRN
......
...@@ -514,7 +514,7 @@ class NormalizeLayer(Layer): ...@@ -514,7 +514,7 @@ class NormalizeLayer(Layer):
scale = Tensor(LayerParameter.name + '@param0') scale = Tensor(LayerParameter.name + '@param0')
if param.HasField('scale_filler'): if param.HasField('scale_filler'):
self.Fill(scale, param, 'scale_filler') self.Fill(scale, param, 'scale_filler')
else: scale.Contant(value=1.0) else: scale.Constant(value=1.0)
self.scale_blobs = [{'data': scale, 'diff': Tensor(scale.name + '_grad')}] self.scale_blobs = [{'data': scale, 'diff': Tensor(scale.name + '_grad')}]
self._blobs.extend(self.scale_blobs) self._blobs.extend(self.scale_blobs)
......
...@@ -48,22 +48,22 @@ class DataLayer(Layer): ...@@ -48,22 +48,22 @@ class DataLayer(Layer):
super(DataLayer, self).__init__(LayerParameter) super(DataLayer, self).__init__(LayerParameter)
param = LayerParameter.data_param param = LayerParameter.data_param
transformer_param = LayerParameter.transform_param transform_param = LayerParameter.transform_param
parallel_param = LayerParameter.parallel_param parallel_param = LayerParameter.parallel_param
self._param = {'source': param.source, self._param = {'source': param.source,
'prefetch': param.prefetch, 'prefetch': param.prefetch,
'batch_size': param.batch_size, 'batch_size': param.batch_size,
'phase': {0: 'TRAIN', 1: 'TEST'}[int(LayerParameter.phase)], 'phase': {0: 'TRAIN', 1: 'TEST'}[int(LayerParameter.phase)],
'scale': transformer_param.scale, 'scale': transform_param.scale,
'mirror': transformer_param.mirror, 'mirror': transform_param.mirror,
'crop_size': transformer_param.crop_size, 'crop_size': transform_param.crop_size,
'mean_values': [float(element) for element in transformer_param.mean_value], 'mean_values': [float(element) for element in transform_param.mean_value],
'force_color': transformer_param.force_color, 'force_color': transform_param.force_color,
'color_augmentation': transformer_param.color_augmentation, 'color_augmentation': transform_param.color_augmentation,
'padding': transformer_param.padding, 'padding': transform_param.padding,
'min_random_scale': transformer_param.min_random_scale, 'min_random_scale': transform_param.min_random_scale,
'max_random_scale': transformer_param.max_random_scale, 'max_random_scale': transform_param.max_random_scale,
'shuffle': parallel_param.shuffle, 'shuffle': parallel_param.shuffle,
'node_step': parallel_param.node_step, 'node_step': parallel_param.node_step,
'partition': parallel_param.partition} 'partition': parallel_param.partition}
...@@ -76,20 +76,25 @@ class DataLayer(Layer): ...@@ -76,20 +76,25 @@ class DataLayer(Layer):
class MemoryDataLayer(Layer): class MemoryDataLayer(Layer):
"""The implementation of ``MemoryDataLayer``. """The implementation of ``MemoryDataLayer``.
We extend it with ``FP16`` and ``NHWC <=> NCHW``. We extend it with ``FP16`` and ``NHWC => NCHW``.
Parameters Parameters
---------- ----------
dtype : caffe_pb2.MemoryDataParameter.DataType dtype : caffe_pb2.MemoryDataParameter.DataType
The dest data type. ``FLOAT32`` or ``FLOAT16``. The dest data type. ``FLOAT32`` or ``FLOAT16``.
mean_value : list of float
The mean of each channel. Refer `TransformationParameter.mean_value`_.
""" """
def __init__(self, LayerParameter): def __init__(self, LayerParameter):
super(MemoryDataLayer, self).__init__(LayerParameter) super(MemoryDataLayer, self).__init__(LayerParameter)
param = LayerParameter.memory_data_param param = LayerParameter.memory_data_param
import numpy as np transform_param = LayerParameter.transform_param
self._param = {'dtype': {0: np.float32, 1: np.float16}[param.dtype]} self._param = {'dtype': {0: 'FLOAT32', 1: 'FLOAT16'}[param.dtype]}
if len(transform_param.mean_value) > 0:
self._param['mean_values'] = \
[float(element) for element in transform_param.mean_value]
def Setup(self, bottom): def Setup(self, bottom):
super(MemoryDataLayer, self).Setup(bottom) super(MemoryDataLayer, self).Setup(bottom)
return ops.MemoryData(bottom[0], **self._param) return ops.ImageData(bottom[0], **self._param)
\ No newline at end of file \ No newline at end of file
...@@ -42,7 +42,9 @@ class ConvolutionLayer(Layer): ...@@ -42,7 +42,9 @@ class ConvolutionLayer(Layer):
'stride': [int(element) for element in param.stride] if len(param.stride) > 0 else [1], 'stride': [int(element) for element in param.stride] if len(param.stride) > 0 else [1],
'pad': [int(element) for element in param.pad] if len(param.pad) > 0 else [0], 'pad': [int(element) for element in param.pad] if len(param.pad) > 0 else [0],
'dilation': [int(element) for element in param.dilation] if len(param.dilation) > 0 else [1], 'dilation': [int(element) for element in param.dilation] if len(param.dilation) > 0 else [1],
'group': int(param.group)} 'group': int(param.group),
'padding': 'VALID',
'data_format': 'NCHW'}
if param.HasField('kernel_h'): if param.HasField('kernel_h'):
assert param.HasField('kernel_w') assert param.HasField('kernel_w')
self._param['kernel_size'] = [param.kernel_h, param.kernel_w] self._param['kernel_size'] = [param.kernel_h, param.kernel_w]
...@@ -69,7 +71,7 @@ class ConvolutionLayer(Layer): ...@@ -69,7 +71,7 @@ class ConvolutionLayer(Layer):
def Setup(self, bottom): def Setup(self, bottom):
super(ConvolutionLayer, self).Setup(bottom) super(ConvolutionLayer, self).Setup(bottom)
return ops.Conv2D(bottom + [blob['data'] for blob in self._blobs], **self._param) return ops.Conv2d(bottom + [blob['data'] for blob in self._blobs], **self._param)
class DeconvolutionLayer(ConvolutionLayer): class DeconvolutionLayer(ConvolutionLayer):
...@@ -102,7 +104,7 @@ class DeconvolutionLayer(ConvolutionLayer): ...@@ -102,7 +104,7 @@ class DeconvolutionLayer(ConvolutionLayer):
def Setup(self, bottom): def Setup(self, bottom):
super(DeconvolutionLayer, self).Setup(bottom) super(DeconvolutionLayer, self).Setup(bottom)
return ops.Deconv2D(bottom + [blob['data'] for blob in self._blobs], **self._param) return ops.Deconv2d(bottom + [blob['data'] for blob in self._blobs], **self._param)
class PoolingLayer(Layer): class PoolingLayer(Layer):
...@@ -135,7 +137,8 @@ class PoolingLayer(Layer): ...@@ -135,7 +137,8 @@ class PoolingLayer(Layer):
def __init__(self, LayerParameter): def __init__(self, LayerParameter):
super(PoolingLayer, self).__init__(LayerParameter) super(PoolingLayer, self).__init__(LayerParameter)
param = LayerParameter.pooling_param param = LayerParameter.pooling_param
self._param = {'mode': {0: 'MAX_POOLING', 1: 'AVG_POOLING'}[param.pool], self._param = {'mode': {0: 'MAX', 1: 'AVG'}[param.pool],
'data_format': 'NCHW',
'global_pooling': param.global_pooling} 'global_pooling': param.global_pooling}
if not param.HasField('kernel_h'): self._param['kernel_size'] = [param.kernel_size] if not param.HasField('kernel_h'): self._param['kernel_size'] = [param.kernel_size]
...@@ -150,7 +153,7 @@ class PoolingLayer(Layer): ...@@ -150,7 +153,7 @@ class PoolingLayer(Layer):
def Setup(self, bottom): def Setup(self, bottom):
input = bottom[0] if isinstance(bottom, list) else bottom input = bottom[0] if isinstance(bottom, list) else bottom
super(PoolingLayer, self).Setup(bottom) super(PoolingLayer, self).Setup(bottom)
return ops.Pool2D(input, **self._param) return ops.Pool2d(input, **self._param)
class ROIPoolingLayer(Layer): class ROIPoolingLayer(Layer):
...@@ -253,7 +256,8 @@ class NNResizeLayer(Layer): ...@@ -253,7 +256,8 @@ class NNResizeLayer(Layer):
if param.HasField('shape') else [] if param.HasField('shape') else []
self._param = {'dsize': dsize, self._param = {'dsize': dsize,
'fx': float(param.fx), 'fx': float(param.fx),
'fy': float(param.fy)} 'fy': float(param.fy),
'data_format': 'NCHW'}
def Setup(self, bottom): def Setup(self, bottom):
super(NNResizeLayer, self).Setup(bottom) super(NNResizeLayer, self).Setup(bottom)
...@@ -284,7 +288,8 @@ class BilinearResizeLayer(Layer): ...@@ -284,7 +288,8 @@ class BilinearResizeLayer(Layer):
if param.HasField('shape') else [] if param.HasField('shape') else []
self._param = {'dsize': dsize, self._param = {'dsize': dsize,
'fx': float(param.fx), 'fx': float(param.fx),
'fy': float(param.fy)} 'fy': float(param.fy),
'data_format': 'NCHW'}
def Setup(self, bottom): def Setup(self, bottom):
super(BilinearResizeLayer, self).Setup(bottom) super(BilinearResizeLayer, self).Setup(bottom)
...@@ -292,4 +297,4 @@ class BilinearResizeLayer(Layer): ...@@ -292,4 +297,4 @@ class BilinearResizeLayer(Layer):
if isinstance(bottom, list) and len(bottom) > 1: if isinstance(bottom, list) and len(bottom) > 1:
dshape = ops.Shape(bottom[1]) dshape = ops.Shape(bottom[1])
self._param['dsize'] = (dshape[2], dshape[3]) self._param['dsize'] = (dshape[2], dshape[3])
return ops.BilinearResize(input, **self._param) return ops.BilinearResize(input, **self._param)
\ No newline at end of file
...@@ -354,6 +354,25 @@ class Net(object): ...@@ -354,6 +354,25 @@ class Net(object):
return lambda net = self, net_outputs = self.outputs \ return lambda net = self, net_outputs = self.outputs \
: GetOutputs(net, net_outputs) : GetOutputs(net, net_outputs)
def forward_v2(self, **kwargs):
"""Forward pass while silencing all net outputs.
Parameters
----------
inputs : dict or None
The blobs to feed before.
Returns
-------
None
"""
if kwargs:
for name, blob in kwargs.items():
ws.FeedTensor(self._inputs_to_tensors[name], blob)
self.function()(return_outputs=False, stage='forward')
return None
def backward(self, **kwargs): def backward(self, **kwargs):
"""Backward pass. [**PyCaffe Style**] """Backward pass. [**PyCaffe Style**]
......
...@@ -9,5 +9,4 @@ from .compile import ( ...@@ -9,5 +9,4 @@ from .compile import (
scan, scan,
shared) shared)
from .configdefaults import config from .configdefaults import config
import gradient \ No newline at end of file
\ No newline at end of file
...@@ -17,6 +17,7 @@ from dragon.core.gradient_maker import GraphGradientMaker ...@@ -17,6 +17,7 @@ from dragon.core.gradient_maker import GraphGradientMaker
from dragon.core.scope import GetOperatorName, GetTensorName from dragon.core.scope import GetOperatorName, GetTensorName
from dragon.core.tensor import Tensor from dragon.core.tensor import Tensor
def GraphDef_Grad(meta_graph, targets): def GraphDef_Grad(meta_graph, targets):
"""Inject the gradient targets into GraphDef. """Inject the gradient targets into GraphDef.
...@@ -67,7 +68,8 @@ def GraphDef_Phase(meta_graph, targets): ...@@ -67,7 +68,8 @@ def GraphDef_Phase(meta_graph, targets):
""" """
phase = 'TEST' phase = 'TEST'
from dragon.core.scope import _PHASE_SCOPE from dragon.core.scope import _PHASE_SCOPE
if _PHASE_SCOPE != '': phase = _PHASE_SCOPE.upper() if _PHASE_SCOPE != '':
phase = _PHASE_SCOPE.upper()
else: else:
for target in targets: for target in targets:
if len(target.grad_wrts) > 0: if len(target.grad_wrts) > 0:
...@@ -101,7 +103,7 @@ def GraphDef_Update(meta_graph, updater): ...@@ -101,7 +103,7 @@ def GraphDef_Update(meta_graph, updater):
parallel_arguments = {} parallel_arguments = {}
# wrap hyper-parameters as Tensor for CC # wrap hyper-parameters as Tensor for CC
for k,v in updater._hyper_params.items(): for k, v in updater._hyper_params.items():
ws.FeedTensor(updater._prefix + k, np.array([v], dtype=np.float32)) ws.FeedTensor(updater._prefix + k, np.array([v], dtype=np.float32))
# check data parallel if necessary # check data parallel if necessary
...@@ -116,7 +118,8 @@ def GraphDef_Update(meta_graph, updater): ...@@ -116,7 +118,8 @@ def GraphDef_Update(meta_graph, updater):
meta_graph.arg.add().CopyFrom(MakeArgument(k, v)) meta_graph.arg.add().CopyFrom(MakeArgument(k, v))
for tuple in updater._tuples: for tuple in updater._tuples:
tensors = tuple[0]; arguments = tuple[1] tensors = tuple[0];
arguments = tuple[1]
kwargs = dict(arguments, **extra_arguments) kwargs = dict(arguments, **extra_arguments)
u_target = pb.UpdateTarget() u_target = pb.UpdateTarget()
u_target.type = updater._type u_target.type = updater._type
...@@ -226,16 +229,21 @@ def function(inputs=None, outputs=None, givens=None, updater=None): ...@@ -226,16 +229,21 @@ def function(inputs=None, outputs=None, givens=None, updater=None):
""" """
if not isinstance(inputs, list): if not isinstance(inputs, list):
if inputs is None: inputs = [] if inputs is None:
else: inputs = [inputs] inputs = []
else:
inputs = [inputs]
if not isinstance(outputs, list): if not isinstance(outputs, list):
if outputs is None: outputs = [] if outputs is None:
else: outputs = [outputs] outputs = []
else:
outputs = [outputs]
if len(outputs) > 0 and updater is not None: if len(outputs) > 0 and updater is not None:
raise RuntimeError('You can specific either outputs or updater, not both.') raise RuntimeError('You can specific either outputs or updater, not both.')
all_exprs = {}; all_extra_targets = set() all_exprs = {};
all_extra_targets = set()
if not isinstance(outputs, list): outputs = [outputs] if not isinstance(outputs, list): outputs = [outputs]
meta_graph = pb.GraphDef() meta_graph = pb.GraphDef()
...@@ -256,8 +264,8 @@ def function(inputs=None, outputs=None, givens=None, updater=None): ...@@ -256,8 +264,8 @@ def function(inputs=None, outputs=None, givens=None, updater=None):
for extra_target in all_extra_targets: meta_graph.target.extend([extra_target]) for extra_target in all_extra_targets: meta_graph.target.extend([extra_target])
# we should sort out the topology of these operators before using # we should sort out the topology of these operators before using
all_exprs = sorted(all_exprs.items(), key=lambda d:d[0]) all_exprs = sorted(all_exprs.items(), key=lambda d: d[0])
forward_ops = copy.deepcopy([v for k,v in all_exprs]) forward_ops = copy.deepcopy([v for k, v in all_exprs])
# handle givens # handle givens
if givens is not None: if givens is not None:
...@@ -271,12 +279,13 @@ def function(inputs=None, outputs=None, givens=None, updater=None): ...@@ -271,12 +279,13 @@ def function(inputs=None, outputs=None, givens=None, updater=None):
external_input_exprs = OrderedDict(external_input_exprs, **new_tensor.expressions) external_input_exprs = OrderedDict(external_input_exprs, **new_tensor.expressions)
else: else:
external_input_exprs = dict(external_input_exprs, **new_tensor.expressions) external_input_exprs = dict(external_input_exprs, **new_tensor.expressions)
elif isinstance(new_tensor, np.ndarray): ws.FeedTensor(new_tensor, GetTensorName()) elif isinstance(new_tensor, np.ndarray):
external_input_ops = [v for k,v in external_input_exprs.items()] ws.FeedTensor(new_tensor, GetTensorName())
external_input_ops = [v for k, v in external_input_exprs.items()]
for op in forward_ops: for op in forward_ops:
op.input.extend([name_dict[input] if input in name_dict op.input.extend([name_dict[input] if input in name_dict
else input for input in op.input]) else input for input in op.input])
del op.input[:int(len(op.input)/2)] del op.input[:int(len(op.input) / 2)]
forward_ops = external_input_ops + forward_ops forward_ops = external_input_ops + forward_ops
...@@ -285,7 +294,8 @@ def function(inputs=None, outputs=None, givens=None, updater=None): ...@@ -285,7 +294,8 @@ def function(inputs=None, outputs=None, givens=None, updater=None):
targets = [output.name for output in outputs] targets = [output.name for output in outputs]
targets.extend(all_extra_targets) targets.extend(all_extra_targets)
forward_ops, grad_ops = GraphGradientMaker.Make(forward_ops, targets) forward_ops, grad_ops = GraphGradientMaker.Make(forward_ops, targets)
else: grad_ops = [] else:
grad_ops = []
meta_graph.op.extend(forward_ops + grad_ops) meta_graph.op.extend(forward_ops + grad_ops)
if len(outputs) > 0: if len(outputs) > 0:
...@@ -304,4 +314,36 @@ def function(inputs=None, outputs=None, givens=None, updater=None): ...@@ -304,4 +314,36 @@ def function(inputs=None, outputs=None, givens=None, updater=None):
# return a lambda point to run this graph # return a lambda point to run this graph
return lambda *args, **kwargs: \ return lambda *args, **kwargs: \
ws.RunGraph(meta_graph.name, (inputs, args), outputs, **kwargs) ws.RunGraph(meta_graph.name, (inputs, args), outputs, **kwargs)
\ No newline at end of file
def eval(self, feed_dict=None):
if not hasattr(self, '_eval_func'):
if feed_dict is not None:
self._eval_func = function(inputs=feed_dict.keys(), outputs=self)
else:
self._eval_func = function(outputs=self)
# cond.1: run by feeding
if feed_dict is not None:
# checking
for key, value in feed_dict.items():
if not isinstance(key, Tensor):
raise TypeError('The key of feed_dict key should be a Tensor.')
if key.shape is not None:
if len(key.shape) != len(value.shape):
raise RuntimeError('The Tensor({}) was limited to {} dimensions, \
while feed a value with {} dimensions.'.
format(key.name, len(key.shape), len(value.shape)))
for i in xrange(len(key.shape)):
if key.shape[i] is None: continue
if key.shape[i] != value.shape[i]:
raise RuntimeError('The shape of Tensor({}) was limited as ('.format(key.name) +
','.join([str(dim) for dim in key.shape]) + '), ' +
'while feed a value with (' + ','.join([str(dim) for dim in value.shape]) + ').')
return self._eval_func(*feed_dict.values())
else:
# cond.2: run without feeding
return self._eval_func()
Tensor.eval = eval
...@@ -37,7 +37,7 @@ def grad(cost, wrt, **kwargs): ...@@ -37,7 +37,7 @@ def grad(cost, wrt, **kwargs):
if not isinstance(wrt, list): wrt = [wrt] if not isinstance(wrt, list): wrt = [wrt]
for w in wrt: for w in wrt:
cost.grad_wrts.append(w.name) cost.grad_wrts.append(w.name)
w.grad_objs.append(cost.name) w.grad_objs.append(cost)
w_grad = Tensor(w.name + '_grad') w_grad = Tensor(w.name + '_grad')
w_grad.extra_targets.add(cost.name) w_grad.extra_targets.add(cost.name)
w_grad.expressions = cost.expressions w_grad.expressions = cost.expressions
......
...@@ -34,7 +34,7 @@ void PReluOp<Context>::RunOnDevice() { ...@@ -34,7 +34,7 @@ void PReluOp<Context>::RunOnDevice() {
dim = input(0).count(2); dim = input(0).count(2);
} else { } else {
channels = input(0).dim(-1); channels = input(0).dim(-1);
dim = input(0).count() / channels; dim = input(0).count(1) / channels;
} }
output(0)->ReshapeLike(input(0)); output(0)->ReshapeLike(input(0));
...@@ -95,7 +95,7 @@ void PReluGradientOp<Context>::RunOnDevice() { ...@@ -95,7 +95,7 @@ void PReluGradientOp<Context>::RunOnDevice() {
dim = input(0).count(2); dim = input(0).count(2);
} else { } else {
channels = input(0).dim(-1); channels = input(0).dim(-1);
dim = input(0).count() / channels; dim = input(0).count(1) / channels;
} }
output(0)->ReshapeLike(input(0)); output(0)->ReshapeLike(input(0));
......
...@@ -6,41 +6,35 @@ ...@@ -6,41 +6,35 @@
namespace dragon { namespace dragon {
template <class Context> template <typename T> template <class Context> template <typename T>
void BiasAddOp<Context>::NCHWRunWithType() { void BiasAddOp<Context>::RunWithType() {
outer_dim = input(0).dim(0);
dim = input(0).dim(1);
inner_dim = input(0).count(2);
TENSOR_FILL(input(1), vector<TIndex>(1, dim)); TENSOR_FILL(input(1), vector<TIndex>(1, dim));
INIT_MULTIPLIER(bias_multiplier, inner_dim); INIT_MULTIPLIER(bias_multiplier, inner_dim);
auto* Bdata = input(1).template data<T, Context>(); auto* Bdata = input(1).template data<T, Context>();
auto* BMul_data = bias_multiplier->template data<T, Context>(); auto* BMul_data = bias_multiplier->template data<T, Context>();
auto* Ydata = output(0)->template mutable_data<T, Context>(); auto* Ydata = output(0)->template mutable_data<T, Context>();
kernel::BiasAdd<T, Context>(output(0)->count(), outer_dim, input(1).count(), kernel::BiasAdd<T, Context>(output(0)->count(), outer_dim, dim, inner_dim,
inner_dim, data_format, Bdata, BMul_data, Ydata); data_format,
} Bdata,
BMul_data,
template <class Context> template <typename T> Ydata);
void BiasAddOp<Context>::NHWCRunWithType() {
NOT_IMPLEMENTED;
} }
template <class Context> template <class Context>
void BiasAddOp<Context>::RunOnDevice() { void BiasAddOp<Context>::RunOnDevice() {
if (data_format == "NCHW") {
outer_dim = input(0).dim(0);
dim = input(0).dim(1);
inner_dim = input(0).count(2);
} else if (data_format == "NHWC") {
outer_dim = input(0).dim(0);
dim = input(0).dim(-1);
inner_dim = input(0).count(1) / dim;
} else LOG(FATAL) << "Unknown data format: " << data_format;
output(0)->ReshapeLike(input(0)); output(0)->ReshapeLike(input(0));
output(0)->Share(input(0)); output(0)->Share(input(0));
if (data_format == "NCHW") { if (input(0).template IsType<float>()) RunWithType<float>();
if (input(0).template IsType<float>()) NCHWRunWithType<float>(); else LOG(FATAL) << "Unsupported input types.";
else LOG(FATAL) << "Unsupported input types.";
}
else if (data_format == "NHWC") {
if (input(0).template IsType<float>()) NHWCRunWithType<float>();
else LOG(FATAL) << "Unsupported input types.";
}
else {
LOG(FATAL) << "Unknown data format: " << data_format;
}
} }
DEPLOY_CPU(BiasAdd); DEPLOY_CPU(BiasAdd);
...@@ -50,49 +44,52 @@ DEPLOY_CUDA(BiasAdd); ...@@ -50,49 +44,52 @@ DEPLOY_CUDA(BiasAdd);
OPERATOR_SCHEMA(BiasAdd).NumInputs(2).NumOutputs(1); OPERATOR_SCHEMA(BiasAdd).NumInputs(2).NumOutputs(1);
template <class Context> template <typename T> template <class Context> template <typename T>
void BiasAddGradientOp<Context>::NCHWRunWithType() { void BiasAddGradientOp<Context>::RunWithType() {
if (output(1)->name() != "ignore") { if (output(1)->name() != "ignore") {
outer_dim = input(0).dim(0);
dim = input(0).dim(1);
inner_dim = input(0).count(2);
output(1)->ReshapeLike(input(1));
INIT_MULTIPLIER(bias_multiplier, inner_dim); INIT_MULTIPLIER(bias_multiplier, inner_dim);
auto* BMul_data = this->bias_multiplier->template data<T, Context>(); auto* BMul_data = this->bias_multiplier->template data<T, Context>();
auto* dYdata = input(-1).template data<T, Context>(); auto* dYdata = input(-1).template data<T, Context>();
auto* dBias = output(1)->template mutable_data<T, Context>(); auto* dBias = output(1)->template mutable_data<T, Context>();
const int y_offset = dim * inner_dim; const int y_offset = dim * inner_dim;
for (int n = 0; n < outer_dim; n++) { for (int n = 0; n < outer_dim; n++) {
math::Gemv<T, Context>(CblasNoTrans, dim, inner_dim, if (data_format == "NCHW") {
1.0, dYdata, BMul_data, 1.0, dBias); math::Gemv<T, Context>(CblasNoTrans, dim, inner_dim,
1.0,
dYdata, BMul_data,
1.0,
dBias);
} else if (data_format == "NHWC") {
math::Gemv<T, Context>(CblasTrans, inner_dim, dim,
1.0,
dYdata, BMul_data,
1.0,
dBias);
}
dYdata += y_offset; dYdata += y_offset;
} }
} }
}
template <class Context> template <typename T> if (output(0)->name() != "ignore") {
void BiasAddGradientOp<Context>::NHWCRunWithType() { output(0)->ReshapeLike(input(-1));
NOT_IMPLEMENTED; output(0)->Share(input(-1));
}
} }
template <class Context> template <class Context>
void BiasAddGradientOp<Context>::RunOnDevice() { void BiasAddGradientOp<Context>::RunOnDevice() {
if (data_format == "NCHW") { if (data_format == "NCHW") {
if (input(0).template IsType<float>()) NCHWRunWithType<float>(); outer_dim = input(0).dim(0);
else LOG(FATAL) << "Unsupported input types."; dim = input(0).dim(1);
} inner_dim = input(0).count(2);
else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
if (input(0).template IsType<float>()) NHWCRunWithType<float>(); outer_dim = input(0).dim(0);
else LOG(FATAL) << "Unsupported input types."; dim = input(0).dim(-1);
} inner_dim = input(0).count(1) / dim;
else { } else LOG(FATAL) << "Unknown data format: " << data_format;
LOG(FATAL) << "Unknown data format: " << data_format; output(1)->ReshapeLike(input(1));
}
if (output(0)->name() != "ignore") { if (input(0).template IsType<float>()) RunWithType<float>();
output(0)->ReshapeLike(input(-1)); else LOG(FATAL) << "Unsupported input types.";
output(0)->Share(input(-1));
}
} }
DEPLOY_CPU(BiasAddGradient); DEPLOY_CPU(BiasAddGradient);
......
#include "operators/cast/float2half_op.h"
#include "core/workspace.h"
#include "utils/op_kernel.h"
namespace dragon {
#ifdef WITH_CUDA_FP16
template <class Context>
void FloatToHalfOp<Context>::RunOnDevice() {
CHECK(input(0).template IsType<float>())
<< "The type of tensor should be float32.";
output(0)->ReshapeLike(input(0));
// cast
auto* Xdata = input(0).template data<float, Context>();
auto* Ydata = output(0)->template mutable_data<float16, Context>();
kernel::Float2Half<float, Context>(output(0)->count(), Xdata, Ydata);
// release & share
input(0).Reset();
input(0).ReshapeLike(*output(0));
input(0).Share(*output(0));
}
#ifdef WITH_CUDA
DEPLOY_CUDA(FloatToHalf);
#endif
OPERATOR_SCHEMA(FloatToHalf).NumInputs(1).NumOutputs(1);
NO_GRADIENT(FloatToHalf);
#endif
} // namespace dragon
\ No newline at end of file
...@@ -17,11 +17,11 @@ void SoftmaxCrossEntropyOp<Context>::RunWithType() { ...@@ -17,11 +17,11 @@ void SoftmaxCrossEntropyOp<Context>::RunWithType() {
if (normalization == "UNIT") { if (normalization == "UNIT") {
output(0)->Reshape(vector<TIndex>(1, outer_dim * inner_dim)); output(0)->Reshape(vector<TIndex>(1, outer_dim * inner_dim));
auto* Ydata = output(0)->template mutable_data<T, Context>(); auto* Ydata = output(0)->template mutable_data<T, Context>();
kernel::Sum<T, Context>(losses.count(), kernel::Sum<T, Context>(outer_dim * inner_dim,
input(0).dim(axis), input(0).dim(axis),
inner_dim, inner_dim,
Ldata, Ldata,
Ydata); Ydata);
return; return;
} }
...@@ -65,12 +65,12 @@ void SoftmaxCrossEntropyGradientOp<Context>::RunWithType() { ...@@ -65,12 +65,12 @@ void SoftmaxCrossEntropyGradientOp<Context>::RunWithType() {
if (normalization == "UNIT") { if (normalization == "UNIT") {
auto* dYdata = input(-1).template data<T, Context>(); auto* dYdata = input(-1).template data<T, Context>();
kernel::SumGrad<T, Context>(input(0).count() / input(0).dim(axis), kernel::SumGrad<T, Context>(outer_dim * inner_dim,
input(0).dim(axis), input(0).dim(axis),
inner_dim, inner_dim,
1.0, 1.0,
dYdata, dYdata,
Pdata); Pdata);
math::Mul<T, Context>(output(0)->count(), Pdata, dXdata, dXdata); math::Mul<T, Context>(output(0)->count(), Pdata, dXdata, dXdata);
return; return;
} }
......
#include "operators/misc/image_data_op.h"
#include "utils/op_kernel.h"
namespace dragon {
template <class Context> template <typename Tx, typename Ty>
void ImageDataOp<Context>::RunWithType() {
auto* Xdata = input(0).template data<Tx, Context>();
auto* Mdata = mean.count() > 0 ? mean.template data<float, Context>() : nullptr;
auto* Sdata = std.count() > 0 ? std.template data<float, Context>() : nullptr;
auto* Ydata = output(0)->template mutable_data<Ty, Context>();
kernel::ImageData<Tx, Ty, Context>(output(0)->count(),
n, c, h, w,
Mdata, Sdata,
data_format,
Xdata,
Ydata);
}
template <class Context>
void ImageDataOp<Context>::RunOnDevice() {
n = input(0).dim(0);
c = input(0).dim(3);
h = input(0).dim(1);
w = input(0).dim(2);
if (data_format == "NCHW") {
output(0)->Reshape(vector<TIndex>({ n, c, h, w }));
} else if (data_format == "NHWC") {
output(0)->ReshapeLike(input(0));
} else LOG(FATAL) << "Unknown data format: " << data_format;
if (input(0).template IsType<float>()) {
if (dtype == "FLOAT32") RunWithType<float, float>();
#ifdef WITH_CUDA_FP16
else if (dtype == "FLOAT16") RunWithType<float, float16>();
#endif
else LOG(FATAL) << "Unsupported output type: " << dtype;
} else if (input(0).template IsType<uint8_t>()) {
if (dtype == "FLOAT32") RunWithType<uint8_t, float>();
#ifdef WITH_CUDA_FP16
else if (dtype == "FLOAT16") RunWithType<uint8_t, float16>();
#endif
else LOG(FATAL) << "Unsupported output type: " << dtype;
}
else { LOG(FATAL) << "Unsupported input types."; }
}
DEPLOY_CPU(ImageData);
#ifdef WITH_CUDA
DEPLOY_CUDA(ImageData);
#endif
OPERATOR_SCHEMA(ImageData).NumInputs(1).NumOutputs(1);
NO_GRADIENT(ImageData);
} // namespace dragon
\ No newline at end of file
#include "operators/misc/memory_data_op.h"
#include "utils/op_kernel.h"
namespace dragon {
template <class Context> template <typename Tx, typename Ty>
void MemoryDataOp<Context>::RunWithType() {
auto* Xdata = input(0).template data<Tx, Context>();
auto* Ydata = output(0)->template mutable_data<Ty, Context>();
kernel::MemoryData<Tx, Ty, Context>(output(0)->count(),
output(0)->dim(0),
output(0)->dim(1),
output(0)->dim(2),
output(0)->dim(3),
Xdata, Ydata);
}
template <class Context>
void MemoryDataOp<Context>::RunOnDevice() {
vector<TIndex> dims({ input(0).dim(0), input(0).dim(3),
input(0).dim(1), input(0).dim(2) });
output(0)->Reshape(dims);
if (input(0).template IsType<float>()) {
if (data_type == TensorProto_DataType_FLOAT) RunWithType<float, float>();
#ifdef WITH_CUDA_FP16
else if (data_type == TensorProto_DataType_FLOAT16) RunWithType<float, float16>();
#endif
else LOG(FATAL) << "Unsupported input types.";
}
else if (input(0).template IsType<uint8_t>()) {
if (data_type == TensorProto_DataType_FLOAT) RunWithType<uint8_t, float>();
#ifdef WITH_CUDA_FP16
if (data_type == TensorProto_DataType_FLOAT16) RunWithType<uint8_t, float16>();
#endif
}
else { LOG(FATAL) << "Unsupported input types."; }
}
DEPLOY_CPU(MemoryData);
#ifdef WITH_CUDA
DEPLOY_CUDA(MemoryData);
#endif
OPERATOR_SCHEMA(MemoryData).NumInputs(1).NumOutputs(1);
NO_GRADIENT(MemoryData);
} // namespace dragon
\ No newline at end of file
...@@ -13,7 +13,7 @@ void MPIBroadcastOp<Context>::RunWithType() { ...@@ -13,7 +13,7 @@ void MPIBroadcastOp<Context>::RunWithType() {
#else #else
auto* Xdata = input(0).template mutable_data<T, CPUContext>(); auto* Xdata = input(0).template mutable_data<T, CPUContext>();
#endif #endif
MPI_Bcast(Xdata, input(0).count(), MPI_FLOAT, this->comm_root, this->comm); MPI_Bcast(Xdata, input(0).count(), mpi_dtype(), this->comm_root, this->comm);
output(0)->Share(input(0)); output(0)->Share(input(0));
} else { } else {
#ifdef WITH_MPI_CUDA #ifdef WITH_MPI_CUDA
...@@ -21,7 +21,7 @@ void MPIBroadcastOp<Context>::RunWithType() { ...@@ -21,7 +21,7 @@ void MPIBroadcastOp<Context>::RunWithType() {
#else #else
auto* Ydata = output(0)->template mutable_data<T, CPUContext>(); auto* Ydata = output(0)->template mutable_data<T, CPUContext>();
#endif #endif
MPI_Bcast(Ydata, output(0)->count(), MPI_FLOAT, this->comm_root, this->comm); MPI_Bcast(Ydata, output(0)->count(), mpi_dtype(), this->comm_root, this->comm);
} }
} }
...@@ -41,13 +41,13 @@ void MPIBroadcastOp<Context>::RunOnDevice() { ...@@ -41,13 +41,13 @@ void MPIBroadcastOp<Context>::RunOnDevice() {
} }
MPI_Bcast(ndim, 1, MPI_UNSIGNED_LONG_LONG, this->comm_root, this->comm); MPI_Bcast(ndim, 1, MPI_UNSIGNED_LONG_LONG, this->comm_root, this->comm);
if (dims == nullptr) dims = new TIndex[ndim[0]]; if (dims == nullptr) dims = new TIndex[ndim[0]];
MPI_Bcast(dims, 4, MPI_LONG_LONG, this->comm_root, this->comm); MPI_Bcast(dims, (int)ndim[0], MPI_LONG_LONG, this->comm_root, this->comm);
vector<TIndex> _dims; vector<TIndex> _dims;
for (int i = 0; i < ndim[0]; i++) _dims.push_back(dims[i]); for (int i = 0; i < (int)ndim[0]; i++) _dims.push_back(dims[i]);
output(0)->Reshape(_dims); output(0)->Reshape(_dims);
if (input(0).template IsType<float>()) RunWithType<float>(); if (this->dtype == "FLOAT32") RunWithType<float>();
else LOG(FATAL) << "Unsupported input types."; else LOG(FATAL) << "Unsupported input type: " << this->dtype;
} }
DEPLOY_CPU(MPIBroadcast); DEPLOY_CPU(MPIBroadcast);
...@@ -71,7 +71,7 @@ void MPIBroadcastGradientOp<Context>::RunWithType() { ...@@ -71,7 +71,7 @@ void MPIBroadcastGradientOp<Context>::RunWithType() {
#endif #endif
for (int i = 0; i < this->comm_size; i++) { for (int i = 0; i < this->comm_size; i++) {
if (i == this->comm_root) continue; if (i == this->comm_root) continue;
MPI_Recv(dYdata, output(0)->count(), MPI_FLOAT, i, 0, this->comm, MPI_STATUS_IGNORE); MPI_Recv(dYdata, output(0)->count(), mpi_dtype(), i, 0, this->comm, MPI_STATUS_IGNORE);
#ifdef WITH_MPI_CUDA #ifdef WITH_MPI_CUDA
math::Add<T, Context>(output(0)->count(), dYdata, dXdata, dXdata); math::Add<T, Context>(output(0)->count(), dYdata, dXdata, dXdata);
#else #else
...@@ -85,7 +85,7 @@ void MPIBroadcastGradientOp<Context>::RunWithType() { ...@@ -85,7 +85,7 @@ void MPIBroadcastGradientOp<Context>::RunWithType() {
#else #else
auto* dYdata = input(-1).template data<T, CPUContext>(); auto* dYdata = input(-1).template data<T, CPUContext>();
#endif #endif
MPI_Send(dYdata, input(-1).count(), MPI_FLOAT, this->comm_root, 0, this->comm); MPI_Send(dYdata, input(-1).count(), mpi_dtype(), this->comm_root, 0, this->comm);
} }
} }
...@@ -93,10 +93,10 @@ template <class Context> ...@@ -93,10 +93,10 @@ template <class Context>
void MPIBroadcastGradientOp<Context>::RunOnDevice() { void MPIBroadcastGradientOp<Context>::RunOnDevice() {
output(0)->ReshapeLike(input(-1)); output(0)->ReshapeLike(input(-1));
if (input(0).template IsType<float>()) RunWithType<float>(); if (this->dtype == "FLOAT32") RunWithType<float>();
else LOG(FATAL) << "Unsupported input types."; else LOG(FATAL) << "Unsupported input type: " << this->dtype;
} }
DEPLOY_CPU(MPIBroadcastGradient); DEPLOY_CPU(MPIBroadcastGradient);
#ifdef WITH_CUDA #ifdef WITH_CUDA
DEPLOY_CUDA(MPIBroadcastGradient); DEPLOY_CUDA(MPIBroadcastGradient);
......
...@@ -16,29 +16,50 @@ void MPIGatherOp<Context>::RunWithType() { ...@@ -16,29 +16,50 @@ void MPIGatherOp<Context>::RunWithType() {
#else #else
auto* Ydata = output(i)->template mutable_data<T, CPUContext>(); auto* Ydata = output(i)->template mutable_data<T, CPUContext>();
#endif #endif
MPI_Recv(Ydata, output(i)->count(), MPI_FLOAT, i, 0, this->comm, MPI_STATUS_IGNORE); MPI_Recv(Ydata, output(i)->count(), mpi_dtype(), i, 0, this->comm, MPI_STATUS_IGNORE);
} }
} }
else{ else {
#ifdef WITH_MPI_CUDA #ifdef WITH_MPI_CUDA
auto* Xdata = input(0).template data<T, Context>(); auto* Xdata = input(0).template data<T, Context>();
#else #else
auto* Xdata = input(0).template data<T, CPUContext>(); auto* Xdata = input(0).template data<T, CPUContext>();
#endif #endif
MPI_Send(Xdata, input(0).count(), MPI_FLOAT, this->comm_root, 0, this->comm); MPI_Send(Xdata, input(0).count(), mpi_dtype(), this->comm_root, 0, this->comm);
} }
} }
template <class Context> template <class Context>
void MPIGatherOp<Context>::RunOnDevice() { void MPIGatherOp<Context>::RunOnDevice() {
if (this->comm_rank == this->comm_root) { CHECK_EQ(this->comm_size, OutputSize());
CHECK_EQ(this->comm_size, OutputSize()); // reshape from root
for (int i = 0; i < OutputSize(); i++) if (this->comm_rank == this->comm_root)
output(i)->ReshapeLike(input(0)); output(0)->ReshapeLike(input(0));
// reshape from others
size_t* all_ndim = new size_t[this->comm_size];
size_t ndim[1];
if (this->comm_rank != this->comm_root) {
ndim[0] = input(0).ndim();
MPI_Send(ndim, 1, MPI_UNSIGNED_LONG_LONG, this->comm_root, 0, this->comm);
} else {
for (int i = 1; i < this->comm_size; i++)
MPI_Recv(all_ndim + i, 1, MPI_UNSIGNED_LONG_LONG, i, 0, this->comm, MPI_STATUS_IGNORE);
}
if (this->comm_rank != this->comm_root) {
MPI_Send(input(0).dims().data(), (int)ndim[0], MPI_LONG_LONG, this->comm_root, 0, this->comm);
} else {
for (int i = 1; i < this->comm_size; i++) {
TIndex* dims = new TIndex[all_ndim[i]];
MPI_Recv(dims, (int)all_ndim[i], MPI_LONG_LONG, i, 0, this->comm, MPI_STATUS_IGNORE);
vector<TIndex> dims_;
for (int j = 0; j < (int)all_ndim[i]; j++) dims_.push_back(dims[j]);
output(i)->Reshape(dims_);
}
} }
if (input(0).template IsType<float>()) RunWithType<float>(); if (this->dtype == "FLOAT32") RunWithType<float>();
else LOG(FATAL) << "Unsupported input types."; else LOG(FATAL) << "Unsupported input type: " << this->dtype;
} }
DEPLOY_CPU(MPIGather); DEPLOY_CPU(MPIGather);
...@@ -58,7 +79,7 @@ void MPIGatherGradientOp<Context>::RunWithType() { ...@@ -58,7 +79,7 @@ void MPIGatherGradientOp<Context>::RunWithType() {
#else #else
auto* dYdata = input(this->comm_rank + 1).template data<T, CPUContext>(); auto* dYdata = input(this->comm_rank + 1).template data<T, CPUContext>();
#endif #endif
MPI_Send(dYdata, input(this->comm_rank + 1).count(), MPI_FLOAT, i, 0, this->comm); MPI_Send(dYdata, input(this->comm_rank + 1).count(), mpi_dtype(), i, 0, this->comm);
} }
} }
else{ else{
...@@ -67,7 +88,7 @@ void MPIGatherGradientOp<Context>::RunWithType() { ...@@ -67,7 +88,7 @@ void MPIGatherGradientOp<Context>::RunWithType() {
#else #else
auto* dXdata = output(0)->template mutable_data<T, CPUContext>(); auto* dXdata = output(0)->template mutable_data<T, CPUContext>();
#endif #endif
MPI_Recv(dXdata, output(0)->count(), MPI_FLOAT, this->comm_root, 0, this->comm, MPI_STATUS_IGNORE); MPI_Recv(dXdata, output(0)->count(), mpi_dtype(), this->comm_root, 0, this->comm, MPI_STATUS_IGNORE);
} }
} }
...@@ -75,8 +96,8 @@ template <class Context> ...@@ -75,8 +96,8 @@ template <class Context>
void MPIGatherGradientOp<Context>::RunOnDevice() { void MPIGatherGradientOp<Context>::RunOnDevice() {
output(0)->ReshapeLike(input(0)); output(0)->ReshapeLike(input(0));
if (input(0).template IsType<float>()) RunWithType<float>(); if (this->dtype == "FLOAT32") RunWithType<float>();
else LOG(FATAL) << "Unsupported input types."; else LOG(FATAL) << "Unsupported input type: " << this->dtype;
} }
DEPLOY_CPU(MPIGatherGradient); DEPLOY_CPU(MPIGatherGradient);
......
...@@ -15,7 +15,8 @@ void CropOp<Context>::RunWithType() { ...@@ -15,7 +15,8 @@ void CropOp<Context>::RunWithType() {
inner_dim, inner_dim,
starts[axis], starts[axis],
Xdata, Xdata,
Ydata); Ydata,
&ctx());
} }
template <class Context> template <class Context>
...@@ -219,7 +220,6 @@ template <class Context> template <typename T> ...@@ -219,7 +220,6 @@ template <class Context> template <typename T>
void CropGradientOp<Context>::RunWithType() { void CropGradientOp<Context>::RunWithType() {
auto* dYdata = source->template data<T, Context>(); auto* dYdata = source->template data<T, Context>();
auto* dXdata = dest->template mutable_data<T, Context>(); auto* dXdata = dest->template mutable_data<T, Context>();
math::Set<T, Context>(dest->count(), 0, dXdata);
kernel::Crop1DGrad<T, Context>(dest->count(), kernel::Crop1DGrad<T, Context>(dest->count(),
input(0).dim(axis), input(0).dim(axis),
dim, dim,
...@@ -227,7 +227,8 @@ void CropGradientOp<Context>::RunWithType() { ...@@ -227,7 +227,8 @@ void CropGradientOp<Context>::RunWithType() {
starts[axis], starts[axis],
ends[axis], ends[axis],
dYdata, dYdata,
dXdata); dXdata,
&ctx());
} }
template <class Context> template <class Context>
......
...@@ -16,7 +16,8 @@ void PadOp<Context>::ConstRunWithType() { ...@@ -16,7 +16,8 @@ void PadOp<Context>::ConstRunWithType() {
pad_l[axis], pad_l[axis],
value, value,
Xdata, Xdata,
Ydata); Ydata,
&ctx());
} }
template <class Context> template <typename T> template <class Context> template <typename T>
...@@ -29,7 +30,8 @@ void PadOp<Context>::ReflectRunWithType() { ...@@ -29,7 +30,8 @@ void PadOp<Context>::ReflectRunWithType() {
inner_dim, inner_dim,
pad_l[axis], pad_l[axis],
Xdata, Xdata,
Ydata); Ydata,
&ctx());
} }
template <class Context> template <typename T> template <class Context> template <typename T>
...@@ -42,7 +44,8 @@ void PadOp<Context>::EdgeRunWithType() { ...@@ -42,7 +44,8 @@ void PadOp<Context>::EdgeRunWithType() {
inner_dim, inner_dim,
pad_l[axis], pad_l[axis],
Xdata, Xdata,
Ydata); Ydata,
&ctx());
} }
template <class Context> template <class Context>
...@@ -109,14 +112,14 @@ template <class Context> template <typename T> ...@@ -109,14 +112,14 @@ template <class Context> template <typename T>
void PadGradientOp<Context>::ConstRunWithType() { void PadGradientOp<Context>::ConstRunWithType() {
auto* dYdata = source->template data<T, Context>(); auto* dYdata = source->template data<T, Context>();
auto* dXdata = dest->template mutable_data<T, Context>(); auto* dXdata = dest->template mutable_data<T, Context>();
math::Set<T, Context>(dest->count(), 0, dXdata);
kernel::ConstPad1DGrad<T, Context>(dest->count(), kernel::ConstPad1DGrad<T, Context>(dest->count(),
dim - pad_l[axis] - pad_r[axis], dim - pad_l[axis] - pad_r[axis],
dim, dim,
inner_dim, inner_dim,
pad_l[axis], pad_l[axis],
dYdata, dYdata,
dXdata); dXdata,
&ctx());
} }
template <class Context> template <typename T> template <class Context> template <typename T>
...@@ -144,7 +147,8 @@ void PadGradientOp<Context>::EdgeRunWithType() { ...@@ -144,7 +147,8 @@ void PadGradientOp<Context>::EdgeRunWithType() {
inner_dim, inner_dim,
pad_l[axis], pad_l[axis],
dYdata, dYdata,
dXdata); dXdata,
&ctx());
} }
template <class Context> template <class Context>
......
...@@ -7,13 +7,28 @@ namespace dragon { ...@@ -7,13 +7,28 @@ namespace dragon {
template <class Context> template <typename T> template <class Context> template <typename T>
void BilinearResizeOp<Context>::RunWithType() { void BilinearResizeOp<Context>::RunWithType() {
if (data_format == "NCHW") {
n = dims[0];
c = dims[1];
h = input(0).dim(2);
w = input(0).dim(3);
out_h = dims[2];
out_w = dims[3];
} else if (data_format == "NHWC") {
n = dims[0];
h = input(0).dim(1);
w = input(0).dim(2);
out_h = dims[1];
out_w = dims[2];
c = dims[3];
}
auto* Xdata = input(0).template data<T, Context>(); auto* Xdata = input(0).template data<T, Context>();
auto* Ydata = output(0)->template mutable_data<T, Context>(); auto* Ydata = output(0)->template mutable_data<T, Context>();
kernel::BilinearResize<T, Context>(output(0)->count(), dims[0], dims[1], kernel::BilinearResize<T, Context>(output(0)->count(), n, c, h, w,
input(0).dim(2), input(0).dim(3), out_h, out_w,
dims[2], dims[3], data_format,
Xdata, Xdata,
Ydata); Ydata);
} }
template <class Context> template <class Context>
...@@ -25,9 +40,9 @@ void BilinearResizeOp<Context>::RunOnDevice() { ...@@ -25,9 +40,9 @@ void BilinearResizeOp<Context>::RunOnDevice() {
for (int i = 0; i < 2; i++) { for (int i = 0; i < 2; i++) {
Tensor* t = ws()->GetTensor(dynamic_dsize[i]); Tensor* t = ws()->GetTensor(dynamic_dsize[i]);
if (t->IsType<int>()) { if (t->IsType<int>()) {
dims[2 + i] = t->template data<int, CPUContext>()[0]; dims[spatial_axis + i] = t->template data<int, CPUContext>()[0];
} else if (t->IsType<float>()) { } else if (t->IsType<float>()) {
dims[2 + i] = t->template data<float, CPUContext>()[0]; dims[spatial_axis + i] = t->template data<float, CPUContext>()[0];
} else { } else {
LOG(FATAL) << "Unsupported types of dsize."; LOG(FATAL) << "Unsupported types of dsize.";
} }
...@@ -35,12 +50,12 @@ void BilinearResizeOp<Context>::RunOnDevice() { ...@@ -35,12 +50,12 @@ void BilinearResizeOp<Context>::RunOnDevice() {
} else if (static_dsize.size() > 0) { } else if (static_dsize.size() > 0) {
CHECK_EQ(static_dsize.size(), 2) CHECK_EQ(static_dsize.size(), 2)
<< "\nThe dsize should be a scalar with 2 elements."; << "\nThe dsize should be a scalar with 2 elements.";
for (int i = 0; i < 2; i++) dims[2 + i] = static_dsize[i]; for (int i = 0; i < 2; i++) dims[spatial_axis + i] = static_dsize[i];
} else { } else {
CHECK(fy != -1.0 && fx != -1.0) CHECK(fy != -1.0 && fx != -1.0)
<< "\nThe fx and fy should be set."; << "\nThe fx and fy should be set.";
dims[2] = int(dims[2] * fy); dims[spatial_axis] = int(dims[spatial_axis] * fy);
dims[3] = int(dims[3] * fx); dims[spatial_axis + 1] = int(dims[spatial_axis + 1] * fx);
} }
output(0)->Reshape(dims); output(0)->Reshape(dims);
...@@ -56,14 +71,28 @@ OPERATOR_SCHEMA(BilinearResize).NumInputs(1).NumOutputs(1); ...@@ -56,14 +71,28 @@ OPERATOR_SCHEMA(BilinearResize).NumInputs(1).NumOutputs(1);
template <class Context> template <typename T> template <class Context> template <typename T>
void BilinearResizeGradientOp<Context>::RunWithType() { void BilinearResizeGradientOp<Context>::RunWithType() {
if (data_format == "NCHW") {
n = input(0).dim(0);
c = input(0).dim(1);
h = input(0).dim(2);
w = input(0).dim(3);
out_h = input(-1).dim(2);
out_w = input(-1).dim(3);
} else if (data_format == "NHWC") {
n = input(0).dim(0);
h = input(0).dim(1);
w = input(0).dim(2);
c = input(0).dim(3);
out_h = input(-1).dim(1);
out_w = input(-1).dim(2);
}
auto* dYdata = input(-1).template data<T, Context>(); auto* dYdata = input(-1).template data<T, Context>();
auto* dXdata = output(0)->template mutable_data<T, Context>(); auto* dXdata = output(0)->template mutable_data<T, Context>();
math::Set<T, Context>(output(0)->count(), 0, dXdata); kernel::BilinearResizeGrad<T, Context>(input(-1).count(), n, c, h, w,
kernel::BilinearResizeGrad<T, Context>(input(-1).count(), input(0).dim(0), input(0).dim(1), out_h, out_w,
input(-1).dim(2), input(-1).dim(3), data_format,
output(0)->dim(2), output(0)->dim(3), dYdata,
dYdata, dXdata);
dXdata);
} }
template <class Context> template <class Context>
......
...@@ -4,35 +4,24 @@ ...@@ -4,35 +4,24 @@
namespace dragon { namespace dragon {
template <class Context>
void ConvOp<Context>::ComputeOutputShape() {
this->output_shape.clear();
for (int i = 0; i < this->num_spatial_axes; i++) {
const int input_dim = this->bottom_shape[this->channel_axis + i + 1];
const int dilated_kernel = this->dilation[i] * (this->kernel_size[i] - 1) + 1;
const int output_dim = (input_dim + 2 * this->pad[i] - dilated_kernel) / this->stride[i] + 1;
this->output_shape.push_back(output_dim);
}
}
template <class Context> template <typename T> template <class Context> template <typename T>
void ConvOp<Context>::RunWithType() { void Conv2dOp<Context>::RunWithType() {
// get buffer // get buffer
this->col_buffer = ws()->GetBuffer(); this->col_buffer = ws()->GetBuffer();
this->col_buffer->Reshape(this->col_buffer_shape); this->col_buffer->Reshape(this->col_shape);
auto* Xdata = input(0).template data<T, Context>(); auto* Xdata = input(0).template data<T, Context>();
auto* Ydata = output(0)->template mutable_data<T, Context>(); auto* Ydata = output(0)->template mutable_data<T, Context>();
TENSOR_FILL(input(1), this->weight_shape); TENSOR_FILL(input(1), this->weight_shape);
auto* Wdata = input(1).template data<T, Context>(); auto* Wdata = input(1).template data<T, Context>();
if (InputSize() > 2) { if (HasBias()) {
TENSOR_FILL(input(2), this->bias_shape); TENSOR_FILL(input(2), this->bias_shape);
INIT_MULTIPLIER(this->bias_multiplier, this->out_spatial_dim); INIT_MULTIPLIER(this->bias_multiplier, this->out_spatial_dim);
} }
for (int n = 0; n < input(0).dim(0); n++) { for (int n = 0; n < input(0).dim(0); n++) {
Wx(Xdata + n * this->x_offset, Wdata, Ydata + n * this->y_offset); Wx(Xdata + n * this->x_offset, Wdata, Ydata + n * this->y_offset);
if (InputSize() > 2) { if (HasBias()) {
auto* Bdata = input(2).template data<T, Context>(); auto* Bdata = input(2).template data<T, Context>();
Pb(Bdata, Ydata + n * this->y_offset); Pb(Bdata, Ydata + n * this->y_offset);
} }
...@@ -43,28 +32,28 @@ void ConvOp<Context>::RunWithType() { ...@@ -43,28 +32,28 @@ void ConvOp<Context>::RunWithType() {
} }
template <class Context> template <class Context>
void ConvOp<Context>::RunOnDevice() { void Conv2dOp<Context>::RunOnDevice() {
Reshape(); Reshape();
if (input(0).template IsType<float>()) RunWithType<float>(); if (input(0).template IsType<float>()) RunWithType<float>();
else LOG(FATAL) << "Unsupported input types."; else LOG(FATAL) << "Unsupported input types.";
} }
DEPLOY_CPU(Conv); DEPLOY_CPU(Conv2d);
#ifdef WITH_CUDA #ifdef WITH_CUDA
DEPLOY_CUDA(Conv); DEPLOY_CUDA(Conv2d);
#endif #endif
OPERATOR_SCHEMA(Conv).NumInputs(2, 3).NumOutputs(1); OPERATOR_SCHEMA(Conv2d).NumInputs(2, 3).NumOutputs(1);
template <class Context> template <typename T> template <class Context> template <typename T>
void ConvGradientOp<Context>::RunWithType() { void Conv2dGradientOp<Context>::RunWithType() {
// get buffer // get buffer
this->col_buffer = ws()->GetBuffer(); this->col_buffer = ws()->GetBuffer();
this->col_buffer->Reshape(this->col_buffer_shape); this->col_buffer->Reshape(this->col_shape);
auto* dYdata = input(-1).template data<T, Context>(); auto* dYdata = input(-1).template data<T, Context>();
if (output(2)->name() != "ignore") { if (HasBias()) {
INIT_MULTIPLIER(this->bias_multiplier, this->out_spatial_dim); INIT_MULTIPLIER(this->bias_multiplier, this->out_spatial_dim);
T* dBdata = output(2)->template mutable_data<T, Context>(); T* dBdata = output(2)->template mutable_data<T, Context>();
for (int n = 0; n < input(2).dim(0); n++) for (int n = 0; n < input(2).dim(0); n++)
...@@ -89,28 +78,28 @@ void ConvGradientOp<Context>::RunWithType() { ...@@ -89,28 +78,28 @@ void ConvGradientOp<Context>::RunWithType() {
} }
template <class Context> template <class Context>
void ConvGradientOp<Context>::RunOnDevice() { void Conv2dGradientOp<Context>::RunOnDevice() {
GradientReshape(); GradientReshape();
if (input(0).template IsType<float>()) RunWithType<float>(); if (input(0).template IsType<float>()) RunWithType<float>();
else LOG(FATAL) << "Unsupported input types."; else LOG(FATAL) << "Unsupported input types.";
} }
DEPLOY_CPU(ConvGradient); DEPLOY_CPU(Conv2dGradient);
#ifdef WITH_CUDA #ifdef WITH_CUDA
DEPLOY_CUDA(ConvGradient); DEPLOY_CUDA(Conv2dGradient);
#endif #endif
OPERATOR_SCHEMA(ConvGradient).NumInputs(3).NumOutputs(3); OPERATOR_SCHEMA(Conv2dGradient).NumInputs(3).NumOutputs(3);
class GetConvGradient final : public GradientMakerBase { class GetConv2dGradient final : public GradientMakerBase {
public: public:
GRADIENT_MAKER_CTOR(GetConvGradient); GRADIENT_MAKER_CTOR(GetConv2dGradient);
vector<OperatorDef> MakeDefs() override { vector<OperatorDef> MakeDefs() override {
return SingleDef(def.type() + "Gradient", "", return SingleDef(def.type() + "Gradient", "",
vector<string> {I(0), I(1), GO(0)}, vector<string> {I(0), I(1), GO(0)},
vector<string> {GI(0), GI(1), GI(2)}); vector<string> {GI(0), GI(1), GI(2)});
} }
}; };
REGISTER_GRADIENT(Conv, GetConvGradient); REGISTER_GRADIENT(Conv2d, GetConv2dGradient);
} // namespace dragon } // namespace dragon
\ No newline at end of file
#include "operators/vision/deconv_op.h" #include "operators/vision/conv_transpose_op.h"
#include "core/workspace.h" #include "core/workspace.h"
#include "utils/filler.h" #include "utils/filler.h"
namespace dragon { namespace dragon {
template <class Context>
void DeConvOp<Context>::ComputeOutputShape() {
this->output_shape.clear();
for (int i = 0; i < this->num_spatial_axes; i++) {
const int input_dim = this->bottom_shape[this->channel_axis + i + 1];
const int dilated_kernel = this->dilation[i] * (this->kernel_size[i] - 1) + 1;
const int output_dim = this->stride[i] * (input_dim - 1) + dilated_kernel - 2 * this->pad[i];
this->output_shape.push_back(output_dim);
}
}
template <class Context> template <typename T> template <class Context> template <typename T>
void DeConvOp<Context>::RunWithType() { void Conv2dTransposeOp<Context>::RunWithType() {
// get buffer // get buffer
this->col_buffer = ws()->GetBuffer(); this->col_buffer = ws()->GetBuffer();
this->col_buffer->Reshape(this->col_buffer_shape); this->col_buffer->Reshape(this->col_shape);
auto* Xdata = input(0).template data<T, Context>(); auto* Xdata = input(0).template data<T, Context>();
auto* Ydata = output(0)->template mutable_data<T, Context>(); auto* Ydata = output(0)->template mutable_data<T, Context>();
...@@ -43,24 +32,27 @@ void DeConvOp<Context>::RunWithType() { ...@@ -43,24 +32,27 @@ void DeConvOp<Context>::RunWithType() {
} }
template <class Context> template <class Context>
void DeConvOp<Context>::RunOnDevice() { void Conv2dTransposeOp<Context>::RunOnDevice() {
Reshape(); Reshape();
// fix the output shape for im2col/col2im
for (int i = 0; i < this->num_spatial_axes; i++)
this->output_shape[i] = input(0).dim(this->spatial_axis + i);
if (input(0).template IsType<float>()) RunWithType<float>(); if (input(0).template IsType<float>()) RunWithType<float>();
else LOG(FATAL) << "Unsupported input types."; else LOG(FATAL) << "Unsupported input types.";
} }
DEPLOY_CPU(DeConv); DEPLOY_CPU(Conv2dTranspose);
#ifdef WITH_CUDA #ifdef WITH_CUDA
DEPLOY_CUDA(DeConv); DEPLOY_CUDA(Conv2dTranspose);
#endif #endif
OPERATOR_SCHEMA(DeConv).NumInputs(2, 3).NumOutputs(1); OPERATOR_SCHEMA(Conv2dTranspose).NumInputs(2, 3).NumOutputs(1);
template <class Context> template <typename T> template <class Context> template <typename T>
void DeConvGradientOp<Context>::RunWithType() { void Conv2dTransposeGradientOp<Context>::RunWithType() {
// get buffer // get buffer
this->col_buffer = ws()->GetBuffer(); this->col_buffer = ws()->GetBuffer();
this->col_buffer->Reshape(this->col_buffer_shape); this->col_buffer->Reshape(this->col_shape);
auto* dYdata = input(-1).template data<T, Context>(); auto* dYdata = input(-1).template data<T, Context>();
...@@ -90,28 +82,31 @@ void DeConvGradientOp<Context>::RunWithType() { ...@@ -90,28 +82,31 @@ void DeConvGradientOp<Context>::RunWithType() {
} }
template <class Context> template <class Context>
void DeConvGradientOp<Context>::RunOnDevice() { void Conv2dTransposeGradientOp<Context>::RunOnDevice() {
GradientReshape(); GradientReshape();
// fix the output shape for im2col/col2im
for (int i = 0; i < this->num_spatial_axes; i++)
this->output_shape[i] = input(0).dim(this->spatial_axis + i);
if (input(0).template IsType<float>()) RunWithType<float>(); if (input(0).template IsType<float>()) RunWithType<float>();
else LOG(FATAL) << "Unsupported input types."; else LOG(FATAL) << "Unsupported input types.";
} }
DEPLOY_CPU(DeConvGradient); DEPLOY_CPU(Conv2dTransposeGradient);
#ifdef WITH_CUDA #ifdef WITH_CUDA
DEPLOY_CUDA(DeConvGradient); DEPLOY_CUDA(Conv2dTransposeGradient);
#endif #endif
OPERATOR_SCHEMA(DeConvGradient).NumInputs(3).NumOutputs(3); OPERATOR_SCHEMA(Conv2dTransposeGradient).NumInputs(3).NumOutputs(3);
class GetDeConvGradient final : public GradientMakerBase { class GetConv2dTransposeGradient final : public GradientMakerBase {
public: public:
GRADIENT_MAKER_CTOR(GetDeConvGradient); GRADIENT_MAKER_CTOR(GetConv2dTransposeGradient);
vector<OperatorDef> MakeDefs() override { vector<OperatorDef> MakeDefs() override {
return SingleDef(def.type() + "Gradient", "", return SingleDef(def.type() + "Gradient", "",
vector<string> {I(0), I(1), GO(0)}, vector<string> {I(0), I(1), GO(0)},
vector<string> {GI(0), GI(1), GI(2)}); vector<string> {GI(0), GI(1), GI(2)});
} }
}; };
REGISTER_GRADIENT(DeConv, GetDeConvGradient); REGISTER_GRADIENT(Conv2dTranspose, GetConv2dTransposeGradient);
} // namespace dragon } // namespace dragon
\ No newline at end of file
...@@ -5,38 +5,36 @@ ...@@ -5,38 +5,36 @@
namespace dragon { namespace dragon {
template <class Context> template <typename T> template <class Context> template <typename T>
void CuDNNPoolingOp<Context>::RunWithType() { void CuDNNPooling2dOp<Context>::RunWithType() {
cudnnSetTensorDesc<T>(&input_desc, &input(0)); cudnnSetTensor4dDesc<T>(&input_desc, this->data_format, &input(0));
cudnnSetTensorDesc<T>(&output_desc, output(0)); cudnnSetTensor4dDesc<T>(&output_desc, this->data_format, output(0));
if (this->global_pooling) {
#if CUDNN_VERSION_MIN(5, 0, 0) #if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnSetPooling2dDescriptor(pool_desc, CUDNN_CHECK(cudnnSetPooling2dDescriptor(pool_desc,
pool_mode, pool_mode,
CUDNN_PROPAGATE_NAN, CUDNN_PROPAGATE_NAN,
input(0).dim(2), input(0).dim(3), this->kernel_size[0], this->kernel_size[1],
0, 0, this->pad[0], this->pad[1],
1, 1)); this->stride[0], this->stride[1]));
#else #else
CUDNN_CHECK(cudnnSetPooling2dDescriptor_v4(pool_desc, CUDNN_CHECK(cudnnSetPooling2dDescriptor_v4(pool_desc,
pool_mode, pool_mode,
CUDNN_PROPAGATE_NAN, CUDNN_PROPAGATE_NAN,
input(0).dim(2), input(0).dim(3), this->kernel_size[0], this->kernel_size[1],
0, 0, this->pad[0], this->pad[1],
1, 1)); this->stride[0], this->stride[1]));
#endif #endif
}
auto* Xdata = input(0).template data<T, Context>(); auto* Xdata = input(0).template data<T, Context>();
auto* Ydata = output(0)->template mutable_data<T, Context>(); auto* Ydata = output(0)->template mutable_data<T, Context>();
CUDNN_CHECK(cudnnPoolingForward(cudnn_handle(), CUDNN_CHECK(cudnnPoolingForward(cudnn_handle(),
pool_desc, pool_desc,
CUDNNType<T>::one, input_desc, Xdata, CUDNNType<T>::one, input_desc, Xdata,
CUDNNType<T>::zero, output_desc, Ydata)); CUDNNType<T>::zero, output_desc, Ydata));
} }
template <class Context> template <class Context>
void CuDNNPoolingOp<Context>::RunOnDevice() { void CuDNNPooling2dOp<Context>::RunOnDevice() {
PoolingOp<Context>::Reshape(); Pooling2dOp<Context>::Reshape();
if (input(0).template IsType<float>()) return RunWithType<float>(); if (input(0).template IsType<float>()) return RunWithType<float>();
#ifdef WITH_CUDA_FP16 #ifdef WITH_CUDA_FP16
...@@ -45,29 +43,27 @@ void CuDNNPoolingOp<Context>::RunOnDevice() { ...@@ -45,29 +43,27 @@ void CuDNNPoolingOp<Context>::RunOnDevice() {
else LOG(FATAL) << "Unsupported input types."; else LOG(FATAL) << "Unsupported input types.";
} }
DEPLOY_CUDNN(Pooling); DEPLOY_CUDNN(Pooling2d);
template <class Context> template <typename T> template <class Context> template <typename T>
void CuDNNPoolingGradientOp<Context>::RunWithType() { void CuDNNPooling2dGradientOp<Context>::RunWithType() {
cudnnSetTensorDesc<T>(&input_desc, &input(-1)); cudnnSetTensor4dDesc<T>(&input_desc, this->data_format, &input(-1));
cudnnSetTensorDesc<T>(&output_desc, output(0)); cudnnSetTensor4dDesc<T>(&output_desc, this->data_format, output(0));
if (this->global_pooling) {
#if CUDNN_VERSION_MIN(5, 0, 0) #if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnSetPooling2dDescriptor(pool_desc, CUDNN_CHECK(cudnnSetPooling2dDescriptor(pool_desc,
pool_mode, pool_mode,
CUDNN_PROPAGATE_NAN, CUDNN_PROPAGATE_NAN,
input(0).dim(2), input(0).dim(3), this->kernel_size[0], this->kernel_size[1],
0, 0, this->pad[0], this->pad[1],
1, 1)); this->stride[0], this->stride[1]));
#else #else
CUDNN_CHECK(cudnnSetPooling2dDescriptor_v4(pool_desc, CUDNN_CHECK(cudnnSetPooling2dDescriptor_v4(pool_desc,
pool_mode, pool_mode,
CUDNN_PROPAGATE_NAN, CUDNN_PROPAGATE_NAN,
input(0).dim(2), input(0).dim(3), this->kernel_size[0], this->kernel_size[1],
0, 0, this->pad[0], this->pad[1],
1, 1)); this->stride[0], this->stride[1]));
#endif #endif
}
auto* dYdata = input(-1).template data<T, Context>(); auto* dYdata = input(-1).template data<T, Context>();
auto* Xdata = input(0).template data<T, Context>(); auto* Xdata = input(0).template data<T, Context>();
auto* Ydata = input(1).template data<T, Context>(); auto* Ydata = input(1).template data<T, Context>();
...@@ -82,8 +78,8 @@ void CuDNNPoolingGradientOp<Context>::RunWithType() { ...@@ -82,8 +78,8 @@ void CuDNNPoolingGradientOp<Context>::RunWithType() {
} }
template <class Context> template <class Context>
void CuDNNPoolingGradientOp<Context>::RunOnDevice() { void CuDNNPooling2dGradientOp<Context>::RunOnDevice() {
PoolingGradientOp<Context>::Reshape(); Pooling2dGradientOp<Context>::Reshape();
if (input(0).template IsType<float>()) return RunWithType<float>(); if (input(0).template IsType<float>()) return RunWithType<float>();
#ifdef WITH_CUDA_FP16 #ifdef WITH_CUDA_FP16
...@@ -92,7 +88,7 @@ void CuDNNPoolingGradientOp<Context>::RunOnDevice() { ...@@ -92,7 +88,7 @@ void CuDNNPoolingGradientOp<Context>::RunOnDevice() {
else LOG(FATAL) << "Unsupported input types."; else LOG(FATAL) << "Unsupported input types.";
} }
DEPLOY_CUDNN(PoolingGradient); DEPLOY_CUDNN(Pooling2dGradient);
} // namespace dragon } // namespace dragon
......
...@@ -49,7 +49,7 @@ void LRNOp<Context>::PoolRunWithType() { ...@@ -49,7 +49,7 @@ void LRNOp<Context>::PoolRunWithType() {
ks.set_name("kernel_size"); ks.add_ints(local_size); ks.set_name("kernel_size"); ks.add_ints(local_size);
s.set_name("stride"); s.add_ints(1); s.set_name("stride"); s.add_ints(1);
p.set_name("pad"); p.add_ints((local_size - 1) / 2); p.set_name("pad"); p.add_ints((local_size - 1) / 2);
mode.set_name("mode"); mode.set_i(AVG_POOLING); mode.set_name("mode"); mode.set_s("AVG");
OperatorDef pool_op_def = MakeOperatorDef("Pooling", "", OperatorDef pool_op_def = MakeOperatorDef("Pooling", "",
vector<string>({ sqr_out->name() }), vector<string>({ sqr_out->name() }),
vector<string>({ pool_out->name() }), vector<string>({ pool_out->name() }),
...@@ -177,7 +177,7 @@ void LRNGradientOp<Context>::PoolRunWithType() { ...@@ -177,7 +177,7 @@ void LRNGradientOp<Context>::PoolRunWithType() {
ks.set_name("kernel_size"); ks.add_ints(local_size); ks.set_name("kernel_size"); ks.add_ints(local_size);
s.set_name("stride"); s.add_ints(1); s.set_name("stride"); s.add_ints(1);
p.set_name("pad"); p.add_ints((local_size - 1) / 2); p.set_name("pad"); p.add_ints((local_size - 1) / 2);
mode.set_name("mode"); mode.set_i(AVG_POOLING); mode.set_name("mode"); mode.set_s("AVG");
OperatorDef pool_op_def = MakeOperatorDef("PoolingGradient", "", OperatorDef pool_op_def = MakeOperatorDef("PoolingGradient", "",
vector<string>({ sqr_out->name(), vector<string>({ sqr_out->name(),
pool_out->name(), pool_out->name(),
......
...@@ -7,27 +7,42 @@ namespace dragon { ...@@ -7,27 +7,42 @@ namespace dragon {
template <class Context> template <typename T> template <class Context> template <typename T>
void NNResizeOp<Context>::RunWithType() { void NNResizeOp<Context>::RunWithType() {
if (data_format == "NCHW") {
n = input(0).dim(0);
c = input(0).dim(1);
h = input(0).dim(2);
w = input(0).dim(3);
out_h = output(0)->dim(2);
out_w = output(0)->dim(3);
} else if (data_format == "NHWC") {
n = input(0).dim(0);
h = input(0).dim(1);
w = input(0).dim(2);
c = input(0).dim(3);
out_h = output(0)->dim(1);
out_w = output(0)->dim(2);
}
auto* Xdata = input(0).template data<T, Context>(); auto* Xdata = input(0).template data<T, Context>();
auto* Ydata = output(0)->template mutable_data<T, Context>(); auto* Ydata = output(0)->template mutable_data<T, Context>();
kernel::NNResize<T, Context>(output(0)->count(), dims[0], dims[1], kernel::NNResize<T, Context>(output(0)->count(), n, c, h, w,
input(0).dim(2), input(0).dim(3), out_h, out_w,
dims[2], dims[3], data_format,
Xdata, Xdata,
Ydata); Ydata);
} }
template <class Context> template <class Context>
void NNResizeOp<Context>::RunOnDevice() { void NNResizeOp<Context>::RunOnDevice() {
dims = input(0).dims(); vector<TIndex> dims = input(0).dims();
if (dynamic_dsize.size() > 0) { if (dynamic_dsize.size() > 0) {
CHECK_EQ(dynamic_dsize.size(), 2) CHECK_EQ(dynamic_dsize.size(), 2)
<< "\nThe dsize should be a scalar with 2 elements."; << "\nThe dsize should be a scalar with 2 elements.";
for (int i = 0; i < 2; i++) { for (int i = 0; i < 2; i++) {
Tensor* t = ws()->GetTensor(dynamic_dsize[i]); Tensor* t = ws()->GetTensor(dynamic_dsize[i]);
if (t->IsType<int>()) { if (t->IsType<int>()) {
dims[2 + i] = t->template data<int, CPUContext>()[0]; dims[spatial_axis + i] = t->template data<int, CPUContext>()[0];
} else if (t->IsType<float>()) { } else if (t->IsType<float>()) {
dims[2 + i] = t->template data<float, CPUContext>()[0]; dims[spatial_axis + i] = t->template data<float, CPUContext>()[0];
} else { } else {
LOG(FATAL) << "Unsupported types of dsize."; LOG(FATAL) << "Unsupported types of dsize.";
} }
...@@ -35,15 +50,15 @@ void NNResizeOp<Context>::RunOnDevice() { ...@@ -35,15 +50,15 @@ void NNResizeOp<Context>::RunOnDevice() {
} else if (static_dsize.size() > 0) { } else if (static_dsize.size() > 0) {
CHECK_EQ(static_dsize.size(), 2) CHECK_EQ(static_dsize.size(), 2)
<< "\nThe dsize should be a scalar with 2 elements."; << "\nThe dsize should be a scalar with 2 elements.";
for (int i = 0; i < 2; i++) dims[2 + i] = static_dsize[i]; for (int i = 0; i < 2; i++) dims[spatial_axis + i] = static_dsize[i];
} else { } else {
CHECK(fy != -1.0 && fx != -1.0) CHECK(fy != -1.0 && fx != -1.0)
<< "\nThe fx and fy should be set."; << "\nThe fx and fy should be set.";
dims[2] = int(dims[2] * fy); dims[spatial_axis] = int(dims[spatial_axis] * fy);
dims[3] = int(dims[3] * fx); dims[spatial_axis + 1] = int(dims[spatial_axis + 1] * fx);
} }
output(0)->Reshape(dims); output(0)->Reshape(dims);
if (input(0).template IsType<float>()) return RunWithType<float>(); if (input(0).template IsType<float>()) return RunWithType<float>();
else LOG(FATAL) << "Unsupported input types."; else LOG(FATAL) << "Unsupported input types.";
} }
...@@ -56,14 +71,28 @@ OPERATOR_SCHEMA(NNResize).NumInputs(1).NumOutputs(1); ...@@ -56,14 +71,28 @@ OPERATOR_SCHEMA(NNResize).NumInputs(1).NumOutputs(1);
template <class Context> template <typename T> template <class Context> template <typename T>
void NNResizeGradientOp<Context>::RunWithType() { void NNResizeGradientOp<Context>::RunWithType() {
if (data_format == "NCHW") {
n = input(0).dim(0);
c = input(0).dim(1);
h = input(0).dim(2);
w = input(0).dim(3);
out_h = input(-1).dim(2);
out_w = input(-1).dim(3);
} else if (data_format == "NHWC") {
n = input(0).dim(0);
h = input(0).dim(1);
w = input(0).dim(2);
c = input(0).dim(3);
out_h = input(-1).dim(1);
out_w = input(-1).dim(2);
}
auto* dYdata = input(-1).template data<T, Context>(); auto* dYdata = input(-1).template data<T, Context>();
auto* dXdata = output(0)->template mutable_data<T, Context>(); auto* dXdata = output(0)->template mutable_data<T, Context>();
math::Set<T, Context>(output(0)->count(), 0, dXdata); kernel::NNResizeGrad<T, Context>(input(-1).count(), n, c, h, w,
kernel::NNResizeGrad<T, Context>(input(-1).count(), input(0).dim(0), input(0).dim(1), out_h, out_w,
input(-1).dim(2), input(-1).dim(3), data_format,
output(0)->dim(2), output(0)->dim(3), dYdata,
dYdata, dXdata);
dXdata);
} }
template <class Context> template <class Context>
......
#include "operators/vision/pooling_op.h"
#include "core/workspace.h"
#include "utils/math_functions.h"
#include "utils/op_kernel.h"
namespace dragon {
template <class Context> template <typename T>
void Pooling2dOp<Context>::MAXRunWithType() {
mask = ws()->CreateTensor("_t_" + anchor() + "_pool_mask");
mask->ReshapeLike(*output(0));
auto* Xdata = input(0).template data<T, Context>();
auto* Ydata = output(0)->template mutable_data<T, Context>();
auto* Mdata = mask->template mutable_data<int, Context>();
kernel::MAXPooling2d<T, Context>(output(0)->count(),
n, c, h, w,
pool_h, pool_w,
kernel_size[0], kernel_size[1],
stride[0], stride[1],
pad[0], pad[1],
data_format,
Xdata,
Mdata,
Ydata);
}
template <class Context> template <typename T>
void Pooling2dOp<Context>::AVGRunWithType() {
auto* Xdata = input(0).template data<T, Context>();
auto* Ydata = output(0)->template mutable_data<T, Context>();
kernel::AVGPooling2d<T, Context>(output(0)->count(),
n, c, h, w,
pool_h, pool_w,
kernel_size[0], kernel_size[1],
stride[0], stride[1],
pad[0], pad[1],
data_format,
Xdata,
Ydata);
}
template <class Context>
void Pooling2dOp<Context>::Reshape() {
if (data_format == "NCHW") {
n = input(0).dim(0);
c = input(0).dim(1);
h = input(0).dim(2);
w = input(0).dim(3);
if (global_pooling) {
for (int i = 0; i < 2; i++)
kernel_size[i] = input(0).dim(i + 2);
}
if (padding == "SAME") {
for (int i = 0; i < 2; i++) {
TIndex input_size = input(0).dim(i + 2);
TIndex output_size = (input_size + stride[i] - 1) / (float)stride[i];
TIndex padding_needed = std::max(TIndex(0), (output_size - 1) * stride[i] + kernel_size[i] - input_size);
TIndex pad_l = padding_needed / 2;
TIndex pad_r = padding_needed - pad_l;
pad[i] = pad_l;
}
}
} else if (data_format == "NHWC") {
n = input(0).dim(0);
h = input(0).dim(1);
w = input(0).dim(2);
c = input(0).dim(3);
if (global_pooling) {
for (int i = 0; i < 2; i++)
kernel_size[i] = input(0).dim(i + 1);
}
if (padding == "SAME") {
for (int i = 0; i < 2; i++) {
TIndex input_size = input(0).dim(i + 1);
TIndex output_size = (input_size + stride[i] - 1) / (float)stride[i];
TIndex padding_needed = std::max(TIndex(0), (output_size - 1) * stride[i] + kernel_size[i] - input_size);
TIndex pad_l = padding_needed / 2;
TIndex pad_r = padding_needed - pad_l;
pad[i] = pad_l;
}
}
} else LOG(FATAL) << "Unknown data format: " << data_format;
if (padding != "SAME") {
// case 1: infer output shape with symmetry pad size
pool_h = ceil((h + 2 * pad[0] - kernel_size[0]) / (float)stride[0]) + 1;
pool_w = ceil((w + 2 * pad[1] - kernel_size[1]) / (float)stride[1]) + 1;
if ((pool_h - 1) * stride[0] >= (h + pad[0])) pool_h--;
if ((pool_w - 1) * stride[1] >= (w + pad[1])) pool_w--;
} else {
// case 2: infer output shape with adaptive pad size
pool_h = (h + stride[0] - 1) / (float)stride[0];
pool_w = (w + stride[1] - 1) / (float)stride[1];
}
if (data_format == "NCHW") output(0)->Reshape(vector<TIndex>({ n, c, pool_h, pool_w }));
else if (data_format == "NHWC") output(0)->Reshape(vector<TIndex>({ n, pool_h, pool_w, c }));
}
template <class Context>
void Pooling2dOp<Context>::RunOnDevice() {
Reshape();
if (mode == "MAX") {
if (input(0).template IsType<float>()) MAXRunWithType<float>();
else LOG(FATAL) << "Unsupported input types.";
} else if (mode == "AVG") {
if (input(0).template IsType<float>()) AVGRunWithType<float>();
else LOG(FATAL) << "Unsupported input types.";
} else {
LOG(FATAL) << "Unsupported pooling mode: " << mode;
}
}
DEPLOY_CPU(Pooling2d);
#ifdef WITH_CUDA
DEPLOY_CUDA(Pooling2d);
#endif
OPERATOR_SCHEMA(Pooling2d).NumInputs(1).NumOutputs(1);
template <class Context> template <typename T>
void Pooling2dGradientOp<Context>::MAXRunWithType() {
mask = ws()->GetTensor("_t_" + anchor() + "_pool_mask");
auto* dYdata = input(-1).template data<T, Context>();
auto* dXdata = output(0)->template mutable_data<T, Context>();
auto* Mdata = mask->template data<int, Context>();
kernel::MAXPooling2dGrad<T, Context>(output(0)->count(),
n, c, h, w,
pool_h, pool_w,
kernel_size[0], kernel_size[1],
stride[0], stride[1],
pad[0], pad[1],
data_format,
dYdata,
Mdata,
dXdata);
}
template <class Context> template <typename T>
void Pooling2dGradientOp<Context>::AVGRunWithType() {
auto* dYdata = input(-1).template data<T, Context>();
auto* dXdata = output(0)->template mutable_data<T, Context>();
kernel::AVGPooling2dGrad<T, Context>(output(0)->count(),
n, c, h, w,
pool_h, pool_w,
kernel_size[0], kernel_size[1],
stride[0], stride[1],
pad[0], pad[1],
data_format,
dYdata,
dXdata);
}
template <class Context>
void Pooling2dGradientOp<Context>::Reshape() {
if (data_format == "NCHW") {
n = input(0).dim(0);
c = input(0).dim(1);
h = input(0).dim(2);
w = input(0).dim(3);
if (global_pooling) {
for (int i = 0; i < 2; i++)
kernel_size[i] = input(0).dim(i + 2);
}
if (padding == "SAME") {
for (int i = 0; i < 2; i++) {
TIndex input_size = input(0).dim(i + 2);
TIndex output_size = (input_size + stride[i] - 1) / (float)stride[i];
TIndex padding_needed = std::max(TIndex(0), (output_size - 1) * stride[i] + kernel_size[i] - input_size);
TIndex pad_l = padding_needed / 2;
TIndex pad_r = padding_needed - pad_l;
pad[i] = pad_l;
}
}
} else if (data_format == "NHWC") {
n = input(0).dim(0);
h = input(0).dim(1);
w = input(0).dim(2);
c = input(0).dim(3);
if (global_pooling) {
for (int i = 0; i < 2; i++)
kernel_size[i] = input(0).dim(i + 1);
}
if (padding == "SAME") {
for (int i = 0; i < 2; i++) {
TIndex input_size = input(0).dim(i + 1);
TIndex output_size = (input_size + stride[i] - 1) / (float)stride[i];
TIndex padding_needed = std::max(TIndex(0), (output_size - 1) * stride[i] + kernel_size[i] - input_size);
TIndex pad_l = padding_needed / 2;
TIndex pad_r = padding_needed - pad_l;
pad[i] = pad_l;
}
}
} else LOG(FATAL) << "Unknown data format: " << data_format;
if (padding != "SAME") {
// case 1: infer output shape with symmetry pad size
pool_h = ceil((h + 2 * pad[0] - kernel_size[0]) / (float)stride[0]) + 1;
pool_w = ceil((w + 2 * pad[1] - kernel_size[1]) / (float)stride[1]) + 1;
if ((pool_h - 1) * stride[0] >= (h + pad[0])) pool_h--;
if ((pool_w - 1) * stride[1] >= (w + pad[1])) pool_w--;
} else {
// case 2: infer output shape with adaptive pad size
pool_h = (h + stride[0] - 1) / (float)stride[0];
pool_w = (w + stride[1] - 1) / (float)stride[1];
}
output(0)->ReshapeLike(input(0));
}
template <class Context>
void Pooling2dGradientOp<Context>::RunOnDevice() {
Reshape();
if (mode == "MAX") {
if (input(0).template IsType<float>()) MAXRunWithType<float>();
else LOG(FATAL) << "Unsupported input types.";
} else if (mode == "AVG") {
if (input(0).template IsType<float>()) AVGRunWithType<float>();
else LOG(FATAL) << "Unsupported input types.";
} else {
LOG(FATAL) << "Unsupported pooling mode: " << mode;
}
}
DEPLOY_CPU(Pooling2dGradient);
#ifdef WITH_CUDA
DEPLOY_CUDA(Pooling2dGradient);
#endif
OPERATOR_SCHEMA(Pooling2dGradient).NumInputs(3).NumOutputs(1);
class GetPooling2dGradient final : public GradientMakerBase {
public:
GRADIENT_MAKER_CTOR(GetPooling2dGradient);
vector<OperatorDef> MakeDefs() override {
return SingleDef(def.type() + "Gradient", "",
vector<string> {I(0), O(0), GO(0)},
vector<string> {GI(0)});
}
};
REGISTER_GRADIENT(Pooling2d, GetPooling2dGradient);
} // namespace dragon
\ No newline at end of file
#include "operators/vision/pooling_op.h"
#include "core/workspace.h"
#include "utils/math_functions.h"
#include "utils/op_kernel.h"
namespace dragon {
template <class Context> template <typename T>
void PoolingOp<Context>::MaxRunWithType() {
mask = ws()->CreateTensor("_t_" + anchor() + "_pool_mask");
mask->ReshapeLike(*output(0));
auto* Xdata = input(0).template data<T, Context>();
auto* Ydata = output(0)->template mutable_data<T, Context>();
auto* Mdata = mask->template mutable_data<int, Context>();
kernel::MAXPooling<T, Context>(output(0)->count(),
num, channels, height, width,
pool_height, pool_width,
kernel_size[0], kernel_size[1],
stride[0], stride[1],
pad[0], pad[1],
Xdata,
Mdata,
Ydata);
}
template <class Context> template <typename T>
void PoolingOp<Context>::AvgRunWithType() {
auto* Xdata = input(0).template data<T, Context>();
auto* Ydata = output(0)->template mutable_data<T, Context>();
kernel::AVEPooling<T, Context>(output(0)->count(),
num, channels, height, width,
pool_height, pool_width,
kernel_size[0], kernel_size[1],
stride[0], stride[1],
pad[0], pad[1],
Xdata,
Ydata);
}
template <class Context>
void PoolingOp<Context>::Reshape() {
num = input(0).dim(0);
channels = input(0).dim(1);
height = input(0).dim(2);
width = input(0).dim(3);
if (global_pooling) {
for (int i = 0; i < 2; i++)
kernel_size[i] = input(0).dim(i + 2);
}
pool_height = ceil((height + 2 * pad[0] - kernel_size[0]) / (float)stride[0]) + 1;
pool_width = ceil((width + 2 * pad[1] - kernel_size[1]) / (float)stride[1]) + 1;
if ((pool_height - 1) * stride[0] >= (height + pad[0])) pool_height--;
if ((pool_width - 1) * stride[1] >= (width + pad[1])) pool_width--;
vector<TIndex> top_shape({ num, channels, pool_height, pool_width });
if (input(0).ndim() == 3) top_shape.erase(top_shape.begin());
output(0)->Reshape(top_shape);
}
template <class Context>
void PoolingOp<Context>::RunOnDevice() {
Reshape();
if (mode == MAX_POOLING) {
if (input(0).template IsType<float>()) MaxRunWithType<float>();
else LOG(FATAL) << "Unsupported input types.";
}
else if (mode == AVG_POOLING) {
if (input(0).template IsType<float>()) AvgRunWithType<float>();
else LOG(FATAL) << "Unsupported input types.";
}
else {
LOG(FATAL) << "Unsupported pooling mode.";
}
}
DEPLOY_CPU(Pooling);
#ifdef WITH_CUDA
DEPLOY_CUDA(Pooling);
#endif
OPERATOR_SCHEMA(Pooling).NumInputs(1).NumOutputs(1);
template <class Context> template <typename T>
void PoolingGradientOp<Context>::MaxRunWithType() {
mask = ws()->GetTensor("_t_" + anchor() + "_pool_mask");
auto* dYdata = input(-1).template data<T, Context>();
auto* dXdata = output(0)->template mutable_data<T, Context>();
auto* Mdata = mask->template data<int, Context>();
kernel::MAXPoolingGrad<T, Context>(output(0)->count(),
num, channels, height, width,
pool_height, pool_width,
kernel_size[0], kernel_size[1],
stride[0], stride[1],
pad[0], pad[1],
dYdata,
Mdata,
dXdata);
}
template <class Context> template <typename T>
void PoolingGradientOp<Context>::AvgRunWithType() {
auto* dYdata = input(-1).template data<T, Context>();
auto* dXdata = output(0)->template mutable_data<T, Context>();
kernel::AVEPoolingGrad<T, Context>(output(0)->count(),
num, channels, height, width,
pool_height, pool_width,
kernel_size[0], kernel_size[1],
stride[0], stride[1],
pad[0], pad[1],
dYdata,
dXdata);
}
template <class Context>
void PoolingGradientOp<Context>::Reshape() {
num = input(0).dim(0);
channels = input(0).dim(1);
height = input(0).dim(2);
width = input(0).dim(3);
if (global_pooling) {
for (int i = 0; i < 2; i++)
kernel_size[i] = input(0).dim(i + 2);
}
pool_height = ceil((height + 2 * pad[0] - kernel_size[0]) / (float)stride[0]) + 1;
pool_width = ceil((width + 2 * pad[1] - kernel_size[1]) / (float)stride[1]) + 1;
if ((pool_height - 1) * stride[0] >= (height + pad[0])) pool_height--;
if ((pool_width - 1)* stride[1] >= (width + pad[1])) pool_width--;
output(0)->ReshapeLike(input(0));
}
template <class Context>
void PoolingGradientOp<Context>::RunOnDevice() {
Reshape();
if (mode == MAX_POOLING) {
if (input(0).template IsType<float>()) MaxRunWithType<float>();
else LOG(FATAL) << "Unsupported input types.";
}
else if (mode == AVG_POOLING) {
if (input(0).template IsType<float>()) AvgRunWithType<float>();
else LOG(FATAL) << "Unsupported input types.";
}
else {
LOG(FATAL) << "Unsupported pooling mode.";
}
}
DEPLOY_CPU(PoolingGradient);
#ifdef WITH_CUDA
DEPLOY_CUDA(PoolingGradient);
#endif
OPERATOR_SCHEMA(PoolingGradient).NumInputs(3).NumOutputs(1);
class GetPoolingGradient final : public GradientMakerBase {
public:
GRADIENT_MAKER_CTOR(GetPoolingGradient);
vector<OperatorDef> MakeDefs() override {
return SingleDef(def.type() + "Gradient", "",
vector<string> {I(0), O(0), GO(0)},
vector<string> {GI(0)});
}
};
REGISTER_GRADIENT(Pooling, GetPoolingGradient);
} // namespace dragon
\ No newline at end of file
...@@ -48,8 +48,69 @@ void cudnnSetTensorDesc(cudnnTensorDescriptor_t* desc, const vector<TIndex>& dim ...@@ -48,8 +48,69 @@ void cudnnSetTensorDesc(cudnnTensorDescriptor_t* desc, const vector<TIndex>& dim
} }
template <typename T> template <typename T>
void cudnnSetTensorDesc(cudnnTensorDescriptor_t* desc, void cudnnSetTensor4dDesc(cudnnTensorDescriptor_t* desc,
const vector<TIndex>& dims, const vector<TIndex>& strides) { const string& data_format,
const vector<TIndex>& dims) {
if (data_format == "NCHW") {
CUDNN_CHECK(cudnnSetTensor4dDescriptor(*desc, CUDNN_TENSOR_NCHW,
CUDNNType<T>::type,
dims[0],
dims[1],
dims[2],
dims[3]));
} else if (data_format == "NHWC") {
CUDNN_CHECK(cudnnSetTensor4dDescriptor(*desc, CUDNN_TENSOR_NHWC,
CUDNNType<T>::type,
dims[0],
dims[3],
dims[1],
dims[2]));
} else LOG(FATAL) << "Unknown data format: " << data_format;
}
template <typename T>
void cudnnSetTensor5dDesc(cudnnTensorDescriptor_t* desc,
const string& data_format,
const vector<TIndex>& dims) {
if (data_format == "NCHW") {
cudnnSetTensorDesc<T>(desc, dims);
} else if (data_format == "NHWC") {
const int N = (int)dims[0];
const int C = (int)dims[4];
const int H = (int)dims[1];
const int W = (int)dims[2];
const int D = (int)dims[3];
vector<int> fake_dims = { N, C, H, W, D };
vector<int> fake_strides = { H * W * D * C, 1, W * D * C, D * C, C };
CUDNN_CHECK(cudnnSetTensorNdDescriptor(*desc,
CUDNNType<T>::type,
5,
fake_dims.data(),
fake_strides.data()));
} else LOG(FATAL) << "Unknown data format: " << data_format;
}
template <typename T>
void cudnnSetTensor3dDesc(cudnnTensorDescriptor_t* desc,
const string& data_format,
const vector<TIndex>& dims) {
vector<TIndex> fake_dims = dims;
if (data_format == "NCHW") {
// NCH -> NCHXX
fake_dims.push_back(1);
fake_dims.push_back(1);
} else if (data_format == "NHWC") {
// NHC -> NHXXC
fake_dims.insert(fake_dims.begin() + 2, 1);
fake_dims.insert(fake_dims.begin() + 2, 1);
} else LOG(FATAL) << "Unknown data format: " << data_format;
cudnnSetTensor5dDesc<T>(desc, data_format, fake_dims);
}
template <typename T>
void cudnnSetTensorDesc(cudnnTensorDescriptor_t* desc,
const vector<TIndex>& dims,
const vector<TIndex>& strides) {
CHECK_EQ(dims.size(), strides.size()); CHECK_EQ(dims.size(), strides.size());
CHECK(dims.size() >= 3 && dims.size() <= 8); CHECK(dims.size() >= 3 && dims.size() <= 8);
int ndim = (int)dims.size(); int ndim = (int)dims.size();
...@@ -76,22 +137,64 @@ void cudnnSetTensorDesc(cudnnTensorDescriptor_t* desc, Tensor* tensor) { ...@@ -76,22 +137,64 @@ void cudnnSetTensorDesc(cudnnTensorDescriptor_t* desc, Tensor* tensor) {
cudnnSetTensorDesc<T>(desc, fake_dims); cudnnSetTensorDesc<T>(desc, fake_dims);
} }
template <typename T>
void cudnnSetTensor4dDesc(cudnnTensorDescriptor_t* desc, const string& data_format, Tensor* tensor) {
CHECK_EQ((int)tensor->ndim(), 4)
<< "\nThe num of dimensions of Tensor(" << tensor->name() << ") "
<< "should be 4, but got " << tensor->ndim() << ".";
cudnnSetTensor4dDesc<T>(desc, data_format, tensor->dims());
}
template <typename T>
void cudnnSetTensor5dDesc(cudnnTensorDescriptor_t* desc, const string& data_format, Tensor* tensor) {
CHECK_EQ((int)tensor->ndim(), 5)
<< "\nThe num of dimensions of Tensor(" << tensor->name() << ") "
<< "should be 5, but got " << tensor->ndim() << ".";
cudnnSetTensor5dDesc<T>(desc, data_format, tensor->dims());
}
template <typename T>
void cudnnSetTensor3dDesc(cudnnTensorDescriptor_t* desc, const string& data_format, Tensor* tensor) {
CHECK_EQ((int)tensor->ndim(), 3)
<< "\nThe num of dimensions of Tensor(" << tensor->name() << ") "
<< "should be 3, but got " << tensor->ndim() << ".";
cudnnSetTensor3dDesc<T>(desc, data_format, tensor->dims());
}
template void cudnnSetTensorDesc<float>(cudnnTensorDescriptor_t*, Tensor*); template void cudnnSetTensorDesc<float>(cudnnTensorDescriptor_t*, Tensor*);
template void cudnnSetTensor4dDesc<float>(cudnnTensorDescriptor_t*, const string&, Tensor*);
template void cudnnSetTensor5dDesc<float>(cudnnTensorDescriptor_t*, const string&, Tensor*);
template void cudnnSetTensor3dDesc<float>(cudnnTensorDescriptor_t*, const string&, Tensor*);
template void cudnnSetTensorDesc<float>(cudnnTensorDescriptor_t*, const vector<TIndex>&); template void cudnnSetTensorDesc<float>(cudnnTensorDescriptor_t*, const vector<TIndex>&);
template void cudnnSetTensor4dDesc<float>(cudnnTensorDescriptor_t*, const string&, const vector<TIndex>&);
template void cudnnSetTensor5dDesc<float>(cudnnTensorDescriptor_t*, const string&, const vector<TIndex>&);
template void cudnnSetTensor3dDesc<float>(cudnnTensorDescriptor_t*, const string&, const vector<TIndex>&);
template void cudnnSetTensorDesc<float>(cudnnTensorDescriptor_t*, const vector<TIndex>&, const vector<TIndex>&); template void cudnnSetTensorDesc<float>(cudnnTensorDescriptor_t*, const vector<TIndex>&, const vector<TIndex>&);
template void cudnnSetTensorDesc<double>(cudnnTensorDescriptor_t*, Tensor*); template void cudnnSetTensorDesc<double>(cudnnTensorDescriptor_t*, Tensor*);
template void cudnnSetTensor4dDesc<double>(cudnnTensorDescriptor_t*, const string&, Tensor*);
template void cudnnSetTensor5dDesc<double>(cudnnTensorDescriptor_t*, const string&, Tensor*);
template void cudnnSetTensor3dDesc<double>(cudnnTensorDescriptor_t*, const string&, Tensor*);
template void cudnnSetTensorDesc<double>(cudnnTensorDescriptor_t*, const vector<TIndex>&); template void cudnnSetTensorDesc<double>(cudnnTensorDescriptor_t*, const vector<TIndex>&);
template void cudnnSetTensor4dDesc<double>(cudnnTensorDescriptor_t*, const string&, const vector<TIndex>&);
template void cudnnSetTensor5dDesc<double>(cudnnTensorDescriptor_t*, const string&, const vector<TIndex>&);
template void cudnnSetTensor3dDesc<double>(cudnnTensorDescriptor_t*, const string&, const vector<TIndex>&);
template void cudnnSetTensorDesc<double>(cudnnTensorDescriptor_t*, const vector<TIndex>&, const vector<TIndex>&); template void cudnnSetTensorDesc<double>(cudnnTensorDescriptor_t*, const vector<TIndex>&, const vector<TIndex>&);
#ifdef WITH_CUDA_FP16 #ifdef WITH_CUDA_FP16
template void cudnnSetTensorDesc<float16>(cudnnTensorDescriptor_t*, Tensor*); template void cudnnSetTensorDesc<float16>(cudnnTensorDescriptor_t*, Tensor*);
template void cudnnSetTensor4dDesc<float16>(cudnnTensorDescriptor_t*, const string&, Tensor*);
template void cudnnSetTensor5dDesc<float16>(cudnnTensorDescriptor_t*, const string&, Tensor*);
template void cudnnSetTensor3dDesc<float16>(cudnnTensorDescriptor_t*, const string&, Tensor*);
template void cudnnSetTensorDesc<float16>(cudnnTensorDescriptor_t*, const vector<TIndex>&); template void cudnnSetTensorDesc<float16>(cudnnTensorDescriptor_t*, const vector<TIndex>&);
template void cudnnSetTensor4dDesc<float16>(cudnnTensorDescriptor_t*, const string&, const vector<TIndex>&);
template void cudnnSetTensor5dDesc<float16>(cudnnTensorDescriptor_t*, const string&, const vector<TIndex>&);
template void cudnnSetTensor3dDesc<float16>(cudnnTensorDescriptor_t*, const string&, const vector<TIndex>&);
template void cudnnSetTensorDesc<float16>(cudnnTensorDescriptor_t*, const vector<TIndex>&, const vector<TIndex>&); template void cudnnSetTensorDesc<float16>(cudnnTensorDescriptor_t*, const vector<TIndex>&, const vector<TIndex>&);
#endif #endif
} // namespace dragon } // namespace dragon
#endif // WITH_CUDNN #endif // WITH_CUDNN
\ No newline at end of file
This diff could not be displayed because it is too large.
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!