Commit c9db9eee by Ting PAN

Fix/Refactor the GroupConvolution on cuDNN

1 parent 6f2751b1
...@@ -51,6 +51,9 @@ class CPUContext { ...@@ -51,6 +51,9 @@ class CPUContext {
inline static void Memcpy(size_t nbytes, void* dst, const void* src) { memcpy(dst, src, nbytes); } inline static void Memcpy(size_t nbytes, void* dst, const void* src) { memcpy(dst, src, nbytes); }
inline static void Delete(void* data) { free(data); } inline static void Delete(void* data) { free(data); }
template<class DstContext, class SrcContext>
inline static void MemcpyAsync(size_t nbytes, void* dst, const void* src) { NOT_IMPLEMENTED; }
template<typename T, class DstContext, class SrcContext> template<typename T, class DstContext, class SrcContext>
inline static void Copy(int n, T* dst, const T* src) { inline static void Copy(int n, T* dst, const T* src) {
if (dst == src) return; if (dst == src) return;
......
...@@ -108,11 +108,11 @@ class CUDAContext { ...@@ -108,11 +108,11 @@ class CUDAContext {
CUDA_CHECK(cudaMemcpy(dst, src, nbytes, cudaMemcpyDefault)); CUDA_CHECK(cudaMemcpy(dst, src, nbytes, cudaMemcpyDefault));
} }
template<class DstContext, class SrcContext>
inline static void MemcpyAsync(size_t nbytes, void* dst, const void* src) { inline static void MemcpyAsync(size_t nbytes, void* dst, const void* src) {
cudaStream_t stream; cudaStream_t stream;
CUDA_CHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); CUDA_CHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
CUDA_CHECK(cudaMemcpyAsync(dst, src, nbytes, cudaMemcpyDefault, stream)); CUDA_CHECK(cudaMemcpyAsync(dst, src, nbytes, cudaMemcpyDefault, stream));
CUDA_CHECK(cudaStreamSynchronize(stream));
CUDA_CHECK(cudaStreamDestroy(stream)); CUDA_CHECK(cudaStreamDestroy(stream));
} }
......
...@@ -215,7 +215,7 @@ class Tensor { ...@@ -215,7 +215,7 @@ class Tensor {
TIndex size_ = 0, capacity_ = 0; TIndex size_ = 0, capacity_ = 0;
TypeMeta meta_; TypeMeta meta_;
string name_; string name_;
shared_ptr<MixedMemory> memory_; shared_ptr<MixedMemory> memory_, host_memory_;
MixedMemory* ex_memory_ = nullptr; MixedMemory* ex_memory_ = nullptr;
bool is_corrupted_ = false, own_mem_ = true; bool is_corrupted_ = false, own_mem_ = true;
}; };
......
...@@ -18,7 +18,8 @@ class L2NormOp final : public Operator<Context> { ...@@ -18,7 +18,8 @@ class L2NormOp final : public Operator<Context> {
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
axis(OperatorBase::GetSingleArg<int>("axis", 0)), axis(OperatorBase::GetSingleArg<int>("axis", 0)),
num_axes(OperatorBase::GetSingleArg<int>("num_axes", -1)), num_axes(OperatorBase::GetSingleArg<int>("num_axes", -1)),
eps(OperatorBase::GetSingleArg<float>("eps", float(1e-5))) {} eps(OperatorBase::GetSingleArg<float>("eps", float(1e-5))),
mode(OperatorBase::GetSingleArg<string>("mode", "SUM")) {}
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -26,6 +27,7 @@ class L2NormOp final : public Operator<Context> { ...@@ -26,6 +27,7 @@ class L2NormOp final : public Operator<Context> {
protected: protected:
float eps; float eps;
TIndex axis, num_axes, end_axis; TIndex axis, num_axes, end_axis;
string mode;
bool across_inner; bool across_inner;
Tensor* norm, *buffer, *multiplier; Tensor* norm, *buffer, *multiplier;
TIndex outer_dim, dim, inner_dim, spatial_dim; TIndex outer_dim, dim, inner_dim, spatial_dim;
......
...@@ -48,10 +48,15 @@ class CuDNNConv2dOp : public Conv2dOp<Context> { ...@@ -48,10 +48,15 @@ class CuDNNConv2dOp : public Conv2dOp<Context> {
public: public:
CuDNNConv2dOp(const OperatorDef& def, Workspace* ws) CuDNNConv2dOp(const OperatorDef& def, Workspace* ws)
: Conv2dOp<Context>(def, ws) { : Conv2dOp<Context>(def, ws) {
handle = new cudnnHandle_t[this->group]; #if CUDNN_VERSION_MIN(7, 0, 0)
stream = new cudaStream_t[this->group]; cudnn_group = 1;
#else
cudnn_group = this->group;
#endif
handle = new cudnnHandle_t[cudnn_group];
stream = new cudaStream_t[cudnn_group];
ctx().SwitchToDevice(); ctx().SwitchToDevice();
for (int g = 0; g < this->group; g++) { for (int g = 0; g < cudnn_group; g++) {
CUDA_CHECK(cudaStreamCreate(&stream[g])); CUDA_CHECK(cudaStreamCreate(&stream[g]));
CUDNN_CHECK(cudnnCreate(&handle[g])); CUDNN_CHECK(cudnnCreate(&handle[g]));
CUDNN_CHECK(cudnnSetStream(handle[g], stream[g])); CUDNN_CHECK(cudnnSetStream(handle[g], stream[g]));
...@@ -78,7 +83,7 @@ class CuDNNConv2dOp : public Conv2dOp<Context> { ...@@ -78,7 +83,7 @@ class CuDNNConv2dOp : public Conv2dOp<Context> {
cudnnConvolutionDescriptor_t conv_desc; cudnnConvolutionDescriptor_t conv_desc;
cudnnFilterDescriptor_t filter_desc; cudnnFilterDescriptor_t filter_desc;
size_t workspace_fwd_data_size; size_t workspace_fwd_data_size;
TIndex bias_offset; TIndex bias_offset, cudnn_group;
}; };
template <class Context> template <class Context>
...@@ -86,9 +91,14 @@ class CuDNNConv2dGradientOp : public Conv2dGradientOp<Context> { ...@@ -86,9 +91,14 @@ class CuDNNConv2dGradientOp : public Conv2dGradientOp<Context> {
public: public:
CuDNNConv2dGradientOp(const OperatorDef& def, Workspace* ws) CuDNNConv2dGradientOp(const OperatorDef& def, Workspace* ws)
: Conv2dGradientOp<Context>(def, ws) { : Conv2dGradientOp<Context>(def, ws) {
handle = new cudnnHandle_t[this->group * 3]; #if CUDNN_VERSION_MIN(7, 0, 0)
stream = new cudaStream_t[this->group * 3]; cudnn_group = 1;
for (int g = 0; g < this->group * 3; g++) { #else
cudnn_group = this->group;
#endif
handle = new cudnnHandle_t[cudnn_group * 3];
stream = new cudaStream_t[cudnn_group * 3];
for (int g = 0; g < cudnn_group * 3; g++) {
CUDA_CHECK(cudaStreamCreate(&stream[g])); CUDA_CHECK(cudaStreamCreate(&stream[g]));
CUDNN_CHECK(cudnnCreate(&handle[g])); CUDNN_CHECK(cudnnCreate(&handle[g]));
CUDNN_CHECK(cudnnSetStream(handle[g], stream[g])); CUDNN_CHECK(cudnnSetStream(handle[g], stream[g]));
...@@ -116,7 +126,7 @@ class CuDNNConv2dGradientOp : public Conv2dGradientOp<Context> { ...@@ -116,7 +126,7 @@ class CuDNNConv2dGradientOp : public Conv2dGradientOp<Context> {
cudnnConvolutionDescriptor_t conv_desc; cudnnConvolutionDescriptor_t conv_desc;
cudnnFilterDescriptor_t filter_desc; cudnnFilterDescriptor_t filter_desc;
size_t workspace_bwd_filter_size, workspace_bwd_data_size; size_t workspace_bwd_filter_size, workspace_bwd_data_size;
int bias_offset; TIndex bias_offset, cudnn_group;
}; };
#endif // WITH_CUDNN #endif // WITH_CUDNN
......
...@@ -52,8 +52,13 @@ class CuDNNConv2dTransposeOp : public Conv2dTransposeOp<Context> { ...@@ -52,8 +52,13 @@ class CuDNNConv2dTransposeOp : public Conv2dTransposeOp<Context> {
public: public:
CuDNNConv2dTransposeOp(const OperatorDef& def, Workspace* ws) CuDNNConv2dTransposeOp(const OperatorDef& def, Workspace* ws)
: Conv2dTransposeOp<Context>(def, ws) { : Conv2dTransposeOp<Context>(def, ws) {
handle = new cudnnHandle_t[this->group]; #if CUDNN_VERSION_MIN(7, 0, 0)
stream = new cudaStream_t[this->group]; cudnn_group = 1;
#else
cudnn_group = this->group;
#endif
handle = new cudnnHandle_t[cudnn_group];
stream = new cudaStream_t[cudnn_group];
for (int g = 0; g < this->group; g++) { for (int g = 0; g < this->group; g++) {
CUDA_CHECK(cudaStreamCreate(&stream[g])); CUDA_CHECK(cudaStreamCreate(&stream[g]));
CUDNN_CHECK(cudnnCreate(&handle[g])); CUDNN_CHECK(cudnnCreate(&handle[g]));
...@@ -80,7 +85,7 @@ class CuDNNConv2dTransposeOp : public Conv2dTransposeOp<Context> { ...@@ -80,7 +85,7 @@ class CuDNNConv2dTransposeOp : public Conv2dTransposeOp<Context> {
cudnnConvolutionDescriptor_t conv_desc; cudnnConvolutionDescriptor_t conv_desc;
cudnnFilterDescriptor_t filter_desc; cudnnFilterDescriptor_t filter_desc;
size_t workspace_fwd_data_size; size_t workspace_fwd_data_size;
int bias_offset; TIndex bias_offset, cudnn_group;
}; };
template <class Context> template <class Context>
...@@ -88,9 +93,14 @@ class CuDNNConv2dTransposeGradientOp : public Conv2dTransposeGradientOp<Context> ...@@ -88,9 +93,14 @@ class CuDNNConv2dTransposeGradientOp : public Conv2dTransposeGradientOp<Context>
public: public:
CuDNNConv2dTransposeGradientOp(const OperatorDef& def, Workspace* ws) CuDNNConv2dTransposeGradientOp(const OperatorDef& def, Workspace* ws)
: Conv2dTransposeGradientOp<Context>(def, ws) { : Conv2dTransposeGradientOp<Context>(def, ws) {
handle = new cudnnHandle_t[this->group * 3]; #if CUDNN_VERSION_MIN(7, 0, 0)
stream = new cudaStream_t[this->group * 3]; cudnn_group = 1;
for (int g = 0; g < this->group * 3; g++) { #else
cudnn_group = this->group;
#endif
handle = new cudnnHandle_t[cudnn_group * 3];
stream = new cudaStream_t[cudnn_group * 3];
for (int g = 0; g < cudnn_group * 3; g++) {
CUDA_CHECK(cudaStreamCreate(&stream[g])); CUDA_CHECK(cudaStreamCreate(&stream[g]));
CUDNN_CHECK(cudnnCreate(&handle[g])); CUDNN_CHECK(cudnnCreate(&handle[g]));
CUDNN_CHECK(cudnnSetStream(handle[g], stream[g])); CUDNN_CHECK(cudnnSetStream(handle[g], stream[g]));
...@@ -117,7 +127,7 @@ public: ...@@ -117,7 +127,7 @@ public:
cudnnConvolutionDescriptor_t conv_desc; cudnnConvolutionDescriptor_t conv_desc;
cudnnFilterDescriptor_t filter_desc; cudnnFilterDescriptor_t filter_desc;
size_t workspace_bwd_filter_size, workspace_bwd_data_size; size_t workspace_bwd_filter_size, workspace_bwd_data_size;
int bias_offset; TIndex bias_offset, cudnn_group;
}; };
#endif // WITH_CUDNN #endif // WITH_CUDNN
......
...@@ -18,11 +18,12 @@ class LRNOp : public Operator<Context> { ...@@ -18,11 +18,12 @@ class LRNOp : public Operator<Context> {
public: public:
LRNOp(const OperatorDef& op_def, Workspace* ws) LRNOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
mode((LRNMode)OperatorBase::GetSingleArg<int>("mode", ACROSS_CHANNELS)),
local_size(OperatorBase::GetSingleArg<int>("local_size", 5)), local_size(OperatorBase::GetSingleArg<int>("local_size", 5)),
alpha(OperatorBase::GetSingleArg<float>("alpha", float(0.0001))), alpha(OperatorBase::GetSingleArg<float>("alpha", float(0.0001))),
beta(OperatorBase::GetSingleArg<float>("beta", float(0.75))), beta(OperatorBase::GetSingleArg<float>("beta", float(0.75))),
k(OperatorBase::GetSingleArg<float>("k", float(2.0))) {} k(OperatorBase::GetSingleArg<float>("k", float(2.0))),
mode(OperatorBase::GetSingleArg<string>("mode", "ACROSS_CHANNELS")),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {}
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -34,9 +35,9 @@ class LRNOp : public Operator<Context> { ...@@ -34,9 +35,9 @@ class LRNOp : public Operator<Context> {
template <typename T> void ProdRunWithType(); template <typename T> void ProdRunWithType();
protected: protected:
LRNMode mode;
int local_size; int local_size;
float alpha, beta, k; float alpha, beta, k;
string mode, data_format;
unique_ptr<OperatorBase> sqr_op, pool_op, pow_op, prod_op; unique_ptr<OperatorBase> sqr_op, pool_op, pow_op, prod_op;
Tensor* sqr_in, *prod_in, *sqr_out, *pool_out, *pow_out; Tensor* sqr_in, *prod_in, *sqr_out, *pool_out, *pow_out;
Tensor* scale; Tensor* scale;
...@@ -47,11 +48,12 @@ class LRNGradientOp : public Operator<Context> { ...@@ -47,11 +48,12 @@ class LRNGradientOp : public Operator<Context> {
public: public:
LRNGradientOp(const OperatorDef& op_def, Workspace* ws) LRNGradientOp(const OperatorDef& op_def, Workspace* ws)
: Operator<Context>(op_def, ws), : Operator<Context>(op_def, ws),
mode((LRNMode)OperatorBase::GetSingleArg<int>("mode", ACROSS_CHANNELS)),
local_size(OperatorBase::GetSingleArg<int>("local_size", 5)), local_size(OperatorBase::GetSingleArg<int>("local_size", 5)),
alpha(OperatorBase::GetSingleArg<float>("alpha", float(0.0001))), alpha(OperatorBase::GetSingleArg<float>("alpha", float(0.0001))),
beta(OperatorBase::GetSingleArg<float>("beta", float(0.75))), beta(OperatorBase::GetSingleArg<float>("beta", float(0.75))),
k(OperatorBase::GetSingleArg<float>("k", float(2.0))) {} k(OperatorBase::GetSingleArg<float>("k", float(2.0))),
mode(OperatorBase::GetSingleArg<string>("mode", "ACROSS_CHANNELS")),
data_format(OperatorBase::GetSingleArg<string>("data_format", "NCHW")) {}
void RunOnDevice() override; void RunOnDevice() override;
template <typename T> void RunWithType(); template <typename T> void RunWithType();
...@@ -63,9 +65,9 @@ class LRNGradientOp : public Operator<Context> { ...@@ -63,9 +65,9 @@ class LRNGradientOp : public Operator<Context> {
template <typename T> void ProdRunWithType(); template <typename T> void ProdRunWithType();
protected: protected:
LRNMode mode;
int local_size; int local_size;
float alpha, beta, k; float alpha, beta, k;
string mode, data_format;
unique_ptr<OperatorBase> sqr_op, pool_op, pow_op, prod_op; unique_ptr<OperatorBase> sqr_op, pool_op, pow_op, prod_op;
Tensor* sqr_in, *prod_in, *sqr_out, *pool_out, *pow_out; Tensor* sqr_in, *prod_in, *sqr_out, *pool_out, *pow_out;
Tensor* scale; Tensor* scale;
......
...@@ -76,6 +76,9 @@ template <typename T> ...@@ -76,6 +76,9 @@ template <typename T>
void cudnnSetTensor4dDesc(cudnnTensorDescriptor_t* desc, const string& data_format, const std::vector<int64_t>& dims); void cudnnSetTensor4dDesc(cudnnTensorDescriptor_t* desc, const string& data_format, const std::vector<int64_t>& dims);
template <typename T> template <typename T>
void cudnnSetTensor4dDescWithGroup(cudnnTensorDescriptor_t* desc, const string& data_format, const std::vector<int64_t>& dims, const int64_t group);
template <typename T>
void cudnnSetTensor5dDesc(cudnnTensorDescriptor_t* desc, const string& data_format, const std::vector<int64_t>& dims); void cudnnSetTensor5dDesc(cudnnTensorDescriptor_t* desc, const string& data_format, const std::vector<int64_t>& dims);
template <typename T> template <typename T>
......
...@@ -156,29 +156,39 @@ class Tensor(object): ...@@ -156,29 +156,39 @@ class Tensor(object):
""" """
return self.Normal(mu=mean, sigma=std) return self.Normal(mu=mean, sigma=std)
def Xavier(self): def Xavier(self, scale=3.0):
""" """
Register as a variable with xavier initializer. Register as a variable with xavier initializer.
""" """
return self._no_parameter_filler('xavier') filler = pb.TensorFiller()
filler.tensor = self.name
filler.type = 'xavier'
filler.scale = scale
ws.CreateFiller(filler)
return self
def MSRA(self): def MSRA(self, scale=2.0):
""" """
Register as a variable with msra initializer. Register as a variable with msra initializer.
""" """
return self._no_parameter_filler('msra') filler = pb.TensorFiller()
filler.tensor = self.name
filler.type = 'msra'
filler.scale = scale
ws.CreateFiller(filler)
return self
def GlorotUniform(self): def GlorotUniform(self, scale=3.0):
""" """
Register as a variable with glorot uniform initializer. Register as a variable with glorot uniform initializer.
""" """
return self.Xavier() return self.Xavier(scale)
def GlorotNormal(self): def GlorotNormal(self, scale=2.0):
""" """
Register as a variable with glorot normal initializer. Register as a variable with glorot normal initializer.
""" """
return self.MSRA() return self.MSRA(scale)
############################################## ##############################################
# # # #
......
...@@ -19,10 +19,18 @@ Installation - Linux (Normal, CPU) ...@@ -19,10 +19,18 @@ Installation - Linux (Normal, CPU)
**Step 1:** Install C++ Dependencies **Step 1:** Install C++ Dependencies
**$** Setup Python Development Environment
.. code-block:: shell .. code-block:: shell
sudo apt-get install libpython-dev sudo apt-get install libpython-dev
**Note:** You can also use `Anaconda`_, A powerful toolkit for Data Science.
**$** Setup C++ Development Environment
sudo apt-get install libprotobuf-dev sudo apt-get install libprotobuf-dev
sudo apt-get install protobuf-compiler
sudo apt-get install libopenblas-dev sudo apt-get install libopenblas-dev
**Step 2:** Install Python Requirements **Step 2:** Install Python Requirements
...@@ -83,10 +91,18 @@ Installation - Linux (Normal, GPU) ...@@ -83,10 +91,18 @@ Installation - Linux (Normal, GPU)
**Step 2:** Install C++ Dependencies **Step 2:** Install C++ Dependencies
**$** Setup Python Development Environment
.. code-block:: shell .. code-block:: shell
sudo apt-get install libpython-dev sudo apt-get install libpython-dev
**Note:** You can also use `Anaconda`_, A powerful toolkit for Data Science.
**$** Setup C++ Development Environment
sudo apt-get install libprotobuf-dev sudo apt-get install libprotobuf-dev
sudo apt-get install protobuf-compiler
sudo apt-get install libopenblas-dev sudo apt-get install libopenblas-dev
**Step 3:** Install Python Requirements **Step 3:** Install Python Requirements
...@@ -149,10 +165,18 @@ Installation - Linux (Distributed, CPU) ...@@ -149,10 +165,18 @@ Installation - Linux (Distributed, CPU)
**Step 2:** Install C++ Dependencies **Step 2:** Install C++ Dependencies
**$** Setup Python Development Environment
.. code-block:: shell .. code-block:: shell
sudo apt-get install libpython-dev sudo apt-get install libpython-dev
**Note:** You can also use `Anaconda`_, A powerful toolkit for Data Science.
**$** Setup C++ Development Environment
sudo apt-get install libprotobuf-dev sudo apt-get install libprotobuf-dev
sudo apt-get install protobuf-compiler
sudo apt-get install libopenblas-dev sudo apt-get install libopenblas-dev
**Step 3:** Install Python Requirements **Step 3:** Install Python Requirements
...@@ -229,10 +253,18 @@ Installation - Linux (Distributed, GPU) ...@@ -229,10 +253,18 @@ Installation - Linux (Distributed, GPU)
**Step 3:** Install C++ Dependencies **Step 3:** Install C++ Dependencies
**$** Setup Python Development Environment
.. code-block:: shell .. code-block:: shell
sudo apt-get install libpython-dev sudo apt-get install libpython-dev
**Note:** You can also use `Anaconda`_, A powerful toolkit for Data Science.
**$** Setup C++ Development Environment
sudo apt-get install libprotobuf-dev sudo apt-get install libprotobuf-dev
sudo apt-get install protobuf-compiler
sudo apt-get install libopenblas-dev sudo apt-get install libopenblas-dev
**Step 4:** Install Python Requirements **Step 4:** Install Python Requirements
...@@ -564,6 +596,7 @@ Add ``REPO_ROOT/3rdparty/bin`` to system environment variables ...@@ -564,6 +596,7 @@ Add ``REPO_ROOT/3rdparty/bin`` to system environment variables
python setup.py install --user python setup.py install --user
.. _Anaconda: https://www.anaconda.com/download
.. _CUDA: https://developer.nvidia.com/cuda-toolkit .. _CUDA: https://developer.nvidia.com/cuda-toolkit
.. _CUDNN: https://developer.nvidia.com/cudnn .. _CUDNN: https://developer.nvidia.com/cudnn
.. _NCCL: https://developer.nvidia.com/nccl .. _NCCL: https://developer.nvidia.com/nccl
......
...@@ -673,6 +673,7 @@ def Reshape(inputs, shape, **kwargs): ...@@ -673,6 +673,7 @@ def Reshape(inputs, shape, **kwargs):
output.shape = [1] * len(shape) output.shape = [1] * len(shape)
for i, s in enumerate(shape): for i, s in enumerate(shape):
if s == -1: output.shape[i] = 1 if s == -1: output.shape[i] = 1
elif s == 0: output.shape[i] = inputs.shape[i]
else: output.shape[i] = s else: output.shape[i] = s
return output return output
......
...@@ -189,7 +189,7 @@ def InstanceNorm(inputs, axis=-1, eps=1e-3, **kwargs): ...@@ -189,7 +189,7 @@ def InstanceNorm(inputs, axis=-1, eps=1e-3, **kwargs):
return output return output
def L2Norm(inputs, axis=0, num_axes=-1, eps=1e-5, **kwargs): def L2Norm(inputs, axis=0, num_axes=-1, eps=1e-5, mode='SUM', **kwargs):
"""L2 Normalization, introduced by `[Liu et.al, 2015] <https://arxiv.org/abs/1506.04579>`_. """L2 Normalization, introduced by `[Liu et.al, 2015] <https://arxiv.org/abs/1506.04579>`_.
Parameters Parameters
...@@ -202,6 +202,8 @@ def L2Norm(inputs, axis=0, num_axes=-1, eps=1e-5, **kwargs): ...@@ -202,6 +202,8 @@ def L2Norm(inputs, axis=0, num_axes=-1, eps=1e-5, **kwargs):
The number of axes of stats region. Default is ``-1`` (Till End). The number of axes of stats region. Default is ``-1`` (Till End).
eps : float eps : float
The eps. The eps.
mode : str
The mode on computing normalizer. ``SUM`` or ``MEAN``.
Returns Returns
------- -------
......
...@@ -61,6 +61,12 @@ def Conv2d(inputs, num_output, kernel_size, ...@@ -61,6 +61,12 @@ def Conv2d(inputs, num_output, kernel_size,
""" """
CheckInputs(inputs, 2, 3) CheckInputs(inputs, 2, 3)
arguments = ParseArguments(locals()) arguments = ParseArguments(locals())
if padding not in ('VALID', 'SAME'):
raise ValueError('Unsupported padding algorithm: {}'.format(padding))
if data_format not in ('NCHW', 'NHWC'):
raise ValueError('Unsupported data format: {}'.format(data_format))
if not isinstance(arguments['kernel_size'], list): if not isinstance(arguments['kernel_size'], list):
arguments['kernel_size'] = [arguments['kernel_size']] arguments['kernel_size'] = [arguments['kernel_size']]
if not isinstance(arguments['stride'], list): if not isinstance(arguments['stride'], list):
...@@ -154,6 +160,11 @@ def Conv2dTranspose(inputs, num_output, kernel_size, ...@@ -154,6 +160,11 @@ def Conv2dTranspose(inputs, num_output, kernel_size,
CheckInputs(inputs, 2, 3) CheckInputs(inputs, 2, 3)
arguments = ParseArguments(locals()) arguments = ParseArguments(locals())
if padding not in ('VALID', 'SAME'):
raise ValueError('Unsupported padding algorithm: {}'.format(padding))
if data_format not in ('NCHW', 'NHWC'):
raise ValueError('Unsupported data format: {}'.format(data_format))
arguments['output_shape'] = None arguments['output_shape'] = None
if output_shape is not None: if output_shape is not None:
if not isinstance(output_shape, list): if not isinstance(output_shape, list):
...@@ -170,17 +181,43 @@ def Conv2dTranspose(inputs, num_output, kernel_size, ...@@ -170,17 +181,43 @@ def Conv2dTranspose(inputs, num_output, kernel_size,
if not isinstance(arguments['kernel_size'], list): if not isinstance(arguments['kernel_size'], list):
arguments['kernel_size'] = [arguments['kernel_size']] arguments['kernel_size'] = [arguments['kernel_size']]
if not isinstance(arguments['stride'], list): if not isinstance(arguments['stride'], list):
arguments['stride'] = [arguments['stride']] arguments['stride'] = [arguments['stride']]
if not isinstance(arguments['pad'], list): if not isinstance(arguments['pad'], list):
arguments['pad'] = [arguments['pad']] arguments['pad'] = [arguments['pad']]
if not isinstance(arguments['dilation'], list): if not isinstance(arguments['dilation'], list):
arguments['dilation'] = [arguments['dilation']] arguments['dilation'] = [arguments['dilation']]
return Tensor.CreateOperator(nout=1, op_type='Conv2dTranspose', **arguments) output = Tensor.CreateOperator(nout=1, op_type='Conv2dTranspose', **arguments)
if inputs[0].shape is not None:
output.shape = inputs[0].shape[:]
channel_axis = 1 if data_format == 'NCHW' else -1
spatial_axis = 2 if data_format == 'NCHW' else 1
output.shape[channel_axis] = num_output
for i in xrange(2):
k = arguments['kernel_size'][i] if i < len(arguments['kernel_size']) \
else arguments['kernel_size'][-1]
s = arguments['stride'][i] if i < len(arguments['stride']) \
else arguments['stride'][-1]
p = arguments['pad'][i] if i < len(arguments['pad']) \
else arguments['pad'][-1]
d = arguments['dilation'][i] if i < len(arguments['dilation']) \
else arguments['dilation'][-1]
dk = d * (k - 1) + 1
dp = 2 * p
input_size = output.shape[i + spatial_axis]
if padding != 'SAME':
output.shape[i + spatial_axis] = s * (input_size - 1) + dk - dp
else:
if output_shape is None:
raise ValueError('The output shape must be specified if using SAME padding algorithm.')
if 'dynamic_dsize' in arguments:
output.shape = None
return output
output.shape[i + spatial_axis] = output_shape[i + spatial_axis]
return output
def Pool2d(inputs, kernel_size, stride, pad=0, padding='VALID', def Pool2d(inputs, kernel_size, stride, pad=0, padding='VALID',
...@@ -222,6 +259,14 @@ def Pool2d(inputs, kernel_size, stride, pad=0, padding='VALID', ...@@ -222,6 +259,14 @@ def Pool2d(inputs, kernel_size, stride, pad=0, padding='VALID',
""" """
CheckInputs(inputs, 1) CheckInputs(inputs, 1)
arguments = ParseArguments(locals()) arguments = ParseArguments(locals())
if mode not in ('MAX', 'AVG'):
raise ValueError('Unsupported lrn mode: {}'.format(mode))
if padding not in ('VALID', 'SAME'):
raise ValueError('Unsupported padding algorithm: {}'.format(padding))
if data_format not in ('NCHW', 'NHWC'):
raise ValueError('Unsupported data format: {}'.format(data_format))
if not isinstance(arguments['kernel_size'], list): if not isinstance(arguments['kernel_size'], list):
arguments['kernel_size'] = [arguments['kernel_size']] arguments['kernel_size'] = [arguments['kernel_size']]
if not isinstance(arguments['stride'], list): if not isinstance(arguments['stride'], list):
...@@ -311,7 +356,8 @@ def ROIAlign(inputs, pool_h=0, pool_w=0, spatial_scale=1.0, **kwargs): ...@@ -311,7 +356,8 @@ def ROIAlign(inputs, pool_h=0, pool_w=0, spatial_scale=1.0, **kwargs):
return Tensor.CreateOperator(nout=1, op_type='ROIAlign', **arguments) return Tensor.CreateOperator(nout=1, op_type='ROIAlign', **arguments)
def LRN(inputs, local_size=5, alpha=0.0001, beta=0.75, k=2.0, mode='ACROSS_CHANNELS', **kwargs): def LRN(inputs, local_size=5, alpha=0.0001, beta=0.75, k=2.0,
mode='ACROSS_CHANNELS', data_format='NCHW', **kwargs):
"""Local Response Normalization, introduced by `[Krizhevsky et.al, 2012] <http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks>`_. """Local Response Normalization, introduced by `[Krizhevsky et.al, 2012] <http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks>`_.
Parameters Parameters
...@@ -328,17 +374,22 @@ def LRN(inputs, local_size=5, alpha=0.0001, beta=0.75, k=2.0, mode='ACROSS_CHANN ...@@ -328,17 +374,22 @@ def LRN(inputs, local_size=5, alpha=0.0001, beta=0.75, k=2.0, mode='ACROSS_CHANN
The k of LRN. The k of LRN.
mode : str mode : str
The mode, ``ACROSS_CHANNELS`` or ``WITHIN_CHANNEL``. The mode, ``ACROSS_CHANNELS`` or ``WITHIN_CHANNEL``.
data_format : str
The data format. ``NCHW`` or ``NHWC``.
Returns Returns
------- -------
Tensor Tensor
The normalized tensor. The output tensor.
""" """
CheckInputs(inputs, 1) CheckInputs(inputs, 1)
arguments = ParseArguments(locals()) arguments = ParseArguments(locals())
SUPPORT_MODES = {'ACROSS_CHANNELS': 0, 'WITHIN_CHANNEL': 1}
arguments['mode'] = SUPPORT_MODES[mode] if mode not in ('ACROSS_CHANNELS', 'WITHIN_CHANNEL'):
raise ValueError('Unsupported lrn mode: {}'.format(mode))
if data_format not in ('NCHW', 'NHWC'):
raise ValueError('Unsupported data format: {}'.format(data_format))
output = Tensor.CreateOperator(nout=1, op_type='LRN', **arguments) output = Tensor.CreateOperator(nout=1, op_type='LRN', **arguments)
...@@ -356,9 +407,9 @@ def NNResize(inputs, dsize, fy=-1.0, fx=-1.0, data_format='NCHW', **kwargs): ...@@ -356,9 +407,9 @@ def NNResize(inputs, dsize, fy=-1.0, fx=-1.0, data_format='NCHW', **kwargs):
Parameters Parameters
---------- ----------
inputs : Tensor inputs : Tensor
The input tenosr. The input tensor.
dsize : tuple, list, Tensor or None dsize : tuple, list, Tensor or None
The output size. The output size, formats as (h, w).
fy : float fy : float
The scale factor based on src height. Default is ``-1.0`` (Discarded). The scale factor based on src height. Default is ``-1.0`` (Discarded).
fx : float fx : float
...@@ -374,6 +425,10 @@ def NNResize(inputs, dsize, fy=-1.0, fx=-1.0, data_format='NCHW', **kwargs): ...@@ -374,6 +425,10 @@ def NNResize(inputs, dsize, fy=-1.0, fx=-1.0, data_format='NCHW', **kwargs):
""" """
CheckInputs(inputs, 1) CheckInputs(inputs, 1)
arguments = ParseArguments(locals()) arguments = ParseArguments(locals())
if data_format not in ('NCHW', 'NHWC'):
raise ValueError('Unsupported data format: {}'.format(data_format))
if arguments['dsize'] is not None: if arguments['dsize'] is not None:
if isinstance(arguments['dsize'][0], Tensor): if isinstance(arguments['dsize'][0], Tensor):
arguments['dynamic_dsize'] = [arguments['dsize'][0].name, arguments['dynamic_dsize'] = [arguments['dsize'][0].name,
...@@ -388,6 +443,20 @@ def NNResize(inputs, dsize, fy=-1.0, fx=-1.0, data_format='NCHW', **kwargs): ...@@ -388,6 +443,20 @@ def NNResize(inputs, dsize, fy=-1.0, fx=-1.0, data_format='NCHW', **kwargs):
output = Tensor.CreateOperator(nout=1, op_type='NNResize', **arguments) output = Tensor.CreateOperator(nout=1, op_type='NNResize', **arguments)
if inputs.shape is not None:
if len(inputs.shape) != 4:
raise ValueError('The inputs should be a 4d Tensor.')
if 'dynamic_dsize' not in arguments:
output.shape = inputs.shape[:]
spatial_axis = 2 if data_format == 'NCHW' else 1
for i in xrange(2):
output_dim = output.shape[spatial_axis + i]
if 'static_size' in arguments:
output_dim = dsize[i]
else:
output_dim = int(float(output_dim) * ([fy, fx])[i])
output.shape[spatial_axis + i] = output_dim
return output return output
...@@ -399,9 +468,9 @@ def BilinearResize(inputs, dsize, fy=-1.0, fx=-1.0, data_format='NCHW', **kwargs ...@@ -399,9 +468,9 @@ def BilinearResize(inputs, dsize, fy=-1.0, fx=-1.0, data_format='NCHW', **kwargs
Parameters Parameters
---------- ----------
inputs : Tensor inputs : Tensor
The input tenosr. The input tensor.
dsize : tuple, list, Tensor or None dsize : tuple, list, Tensor or None
The dest output size. The output size, formats as (h, w).
fy : float fy : float
The scale factor based on src height. Default is ``-1.0`` (Discarded). The scale factor based on src height. Default is ``-1.0`` (Discarded).
fx : float fx : float
...@@ -417,6 +486,10 @@ def BilinearResize(inputs, dsize, fy=-1.0, fx=-1.0, data_format='NCHW', **kwargs ...@@ -417,6 +486,10 @@ def BilinearResize(inputs, dsize, fy=-1.0, fx=-1.0, data_format='NCHW', **kwargs
""" """
CheckInputs(inputs, 1) CheckInputs(inputs, 1)
arguments = ParseArguments(locals()) arguments = ParseArguments(locals())
if data_format not in ('NCHW', 'NHWC'):
raise ValueError('Unsupported data format: {}'.format(data_format))
if arguments['dsize'] is not None: if arguments['dsize'] is not None:
if isinstance(arguments['dsize'][0], Tensor): if isinstance(arguments['dsize'][0], Tensor):
arguments['dynamic_dsize'] = [arguments['dsize'][0].name, arguments['dynamic_dsize'] = [arguments['dsize'][0].name,
...@@ -431,6 +504,20 @@ def BilinearResize(inputs, dsize, fy=-1.0, fx=-1.0, data_format='NCHW', **kwargs ...@@ -431,6 +504,20 @@ def BilinearResize(inputs, dsize, fy=-1.0, fx=-1.0, data_format='NCHW', **kwargs
output = Tensor.CreateOperator(nout=1, op_type='BilinearResize', **arguments) output = Tensor.CreateOperator(nout=1, op_type='BilinearResize', **arguments)
if inputs.shape is not None:
if len(inputs.shape) != 4:
raise ValueError('The inputs should be a 4d Tensor.')
if 'dynamic_dsize' not in arguments:
output.shape = inputs.shape[:]
spatial_axis = 2 if data_format == 'NCHW' else 1
for i in xrange(2):
output_dim = output.shape[spatial_axis + i]
if 'static_size' in arguments:
output_dim = dsize[i]
else:
output_dim = int(float(output_dim) * ([fy, fx])[i])
output.shape[spatial_axis + i] = output_dim
return output return output
...@@ -453,6 +540,9 @@ def BiasAdd(inputs, data_format='NCHW', **kwargs): ...@@ -453,6 +540,9 @@ def BiasAdd(inputs, data_format='NCHW', **kwargs):
CheckInputs(inputs, 2) CheckInputs(inputs, 2)
arguments = ParseArguments(locals()) arguments = ParseArguments(locals())
if data_format not in ('NCHW', 'NHWC'):
raise ValueError('Unsupported data format: {}'.format(data_format))
output = Tensor.CreateOperator(nout=1, op_type='BiasAdd', **arguments) output = Tensor.CreateOperator(nout=1, op_type='BiasAdd', **arguments)
if inputs[0].shape is not None: if inputs[0].shape is not None:
......
...@@ -229,7 +229,9 @@ class LRNLayer(Layer): ...@@ -229,7 +229,9 @@ class LRNLayer(Layer):
self._param = {'local_size': param.local_size, self._param = {'local_size': param.local_size,
'alpha': param.alpha, 'alpha': param.alpha,
'beta': param.beta, 'beta': param.beta,
'mode': {0: 'ACROSS_CHANNELS', 1: 'WITHIN_CHANNEL'}[param.norm_region]} 'mode': {0: 'ACROSS_CHANNELS', 1: 'WITHIN_CHANNEL'}[param.norm_region],
'data_format': 'NCHW'}
def Setup(self, bottom): def Setup(self, bottom):
super(LRNLayer, self).Setup(bottom) super(LRNLayer, self).Setup(bottom)
input = bottom[0] if isinstance(bottom, list) else bottom input = bottom[0] if isinstance(bottom, list) else bottom
......
...@@ -261,7 +261,7 @@ GraphDef Graph::MakeUpdate(const GraphDef& meta_graph) { ...@@ -261,7 +261,7 @@ GraphDef Graph::MakeUpdate(const GraphDef& meta_graph) {
} else if (this->args_["parallel_mode"].s() == "MIXED") { } else if (this->args_["parallel_mode"].s() == "MIXED") {
/* /*
See: Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour See: Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour
Links: http://arxiv.org/abs/1706.02677 Link: http://arxiv.org/abs/1706.02677
*/ */
NOT_IMPLEMENTED; NOT_IMPLEMENTED;
} }
......
...@@ -150,6 +150,8 @@ void Operator<Context>::CleanResource() { ...@@ -150,6 +150,8 @@ void Operator<Context>::CleanResource() {
if (output(i)->memory() != buffer->memory()) buffer->Move(output(i)->memory()); if (output(i)->memory() != buffer->memory()) buffer->Move(output(i)->memory());
} }
} }
// post-process for sharing grads
if (allow_share_grads_) { if (allow_share_grads_) {
// TODO(PhyscalX): we preset input(-1)->output(0) to share // TODO(PhyscalX): we preset input(-1)->output(0) to share
Tensor* dY = &input(-1); Tensor* dY = &input(-1);
......
...@@ -30,6 +30,7 @@ void L2NormOp<Context>::RunWithType() { ...@@ -30,6 +30,7 @@ void L2NormOp<Context>::RunWithType() {
if (across_inner) { if (across_inner) {
auto* Ndata_ = norm->template mutable_data<float, CPUContext>(); auto* Ndata_ = norm->template mutable_data<float, CPUContext>();
float sum_of_sqr = math::Dot<T, Context>(buffer->count(), Xdata, Xdata); float sum_of_sqr = math::Dot<T, Context>(buffer->count(), Xdata, Xdata);
if (mode == "MEAN") sum_of_sqr = sum_of_sqr / dim;
Ndata_[n] = pow(sum_of_sqr + eps, 0.5); Ndata_[n] = pow(sum_of_sqr + eps, 0.5);
math::Scale<T, Context>(buffer->count(), 1.0 / Ndata_[n], Xdata, Ydata); math::Scale<T, Context>(buffer->count(), 1.0 / Ndata_[n], Xdata, Ydata);
} else { } else {
...@@ -37,7 +38,7 @@ void L2NormOp<Context>::RunWithType() { ...@@ -37,7 +38,7 @@ void L2NormOp<Context>::RunWithType() {
math::Square<T, Context>(buffer->count(), Xdata, Bdata); math::Square<T, Context>(buffer->count(), Xdata, Bdata);
// compute T1 = \sum_{i} x_{i,j}^{2} // compute T1 = \sum_{i} x_{i,j}^{2}
math::Gemv<T, Context>(CblasTrans, dim, inner_dim, math::Gemv<T, Context>(CblasTrans, dim, inner_dim,
1.0, mode == "MEAN" ? 1.0 / dim : 1.0,
Bdata, DMuldata, Bdata, DMuldata,
1.0, 1.0,
Ndata); Ndata);
......
...@@ -15,51 +15,33 @@ void CuDNNConv2dOp<Context>::RunWithType() { ...@@ -15,51 +15,33 @@ void CuDNNConv2dOp<Context>::RunWithType() {
CUDNN_CHECK(cudnnSetFilter4dDescriptor(filter_desc, CUDNN_CHECK(cudnnSetFilter4dDescriptor(filter_desc,
CUDNNType<T>::type, CUDNNType<T>::type,
format, format,
this->num_output / this->group, this->num_output / cudnn_group,
this->channels / this->group, this->channels / this->group,
this->kernel_size[0], this->kernel_size[1])); this->kernel_size[0], this->kernel_size[1]));
#else #else
CUDNN_CHECK(cudnnSetFilter4dDescriptor_v4(filter_desc, CUDNN_CHECK(cudnnSetFilter4dDescriptor_v4(filter_desc,
CUDNNType<T>::type, CUDNNType<T>::type,
format, format,
this->num_output / this->group, this->num_output / cudnn_group,
this->channels / this->group, this->channels / this->group,
this->kernel_size[0], this->kernel_size[1])); this->kernel_size[0], this->kernel_size[1]));
#endif #endif
Tensor fake_tensor; // determine the input & output shape
vector<TIndex> fake_dims; cudnnSetTensor4dDescWithGroup<T>(&input_desc, this->data_format, input(0).dims(), cudnn_group);
if (this->data_format == "NCHW") { cudnnSetTensor4dDescWithGroup<T>(&output_desc, this->data_format, output(0)->dims(), cudnn_group);
// determine the input shape
fake_tensor.ReshapeLike(input(0)); // determine the bias shape and misc
fake_dims = fake_tensor.dims();
fake_dims[1] /= this->group;
cudnnSetTensor4dDesc<T>(&input_desc, this->data_format, fake_dims);
// determine the output shape
fake_tensor.ReshapeLike(*output(0));
fake_dims = fake_tensor.dims();
fake_dims[1] /= this->group;
cudnnSetTensor4dDesc<T>(&output_desc, this->data_format, fake_dims);
// determine the bias shape if necessary
if (HasBias()) { if (HasBias()) {
bias_offset = this->num_output / this->group; bias_offset = this->num_output / cudnn_group;
if (this->data_format == "NCHW") {
cudnnSetTensor4dDesc<T>(&bias_desc, this->data_format, vector<TIndex>({ 1, bias_offset, 1, 1 })); cudnnSetTensor4dDesc<T>(&bias_desc, this->data_format, vector<TIndex>({ 1, bias_offset, 1, 1 }));
} this->x_offset = input(0).count(1) / cudnn_group;
this->y_offset = output(0)->count(1) / cudnn_group;
} else if (this->data_format == "NHWC") { } else if (this->data_format == "NHWC") {
// determine the input shape
fake_tensor.ReshapeLike(input(0));
fake_dims = fake_tensor.dims();
fake_dims[fake_dims.size() - 1] /= this->group;
cudnnSetTensor4dDesc<T>(&input_desc, this->data_format, fake_dims);
// determine the output shape
fake_tensor.ReshapeLike(*output(0));
fake_dims = fake_tensor.dims();
fake_dims[fake_dims.size() - 1] /= this->group;
cudnnSetTensor4dDesc<T>(&output_desc, this->data_format, fake_dims);
// determine the bias shape if necessary
if (HasBias()) {
bias_offset = this->num_output / this->group;
cudnnSetTensor4dDesc<T>(&bias_desc, this->data_format, vector<TIndex>({ 1, 1, 1, bias_offset })); cudnnSetTensor4dDesc<T>(&bias_desc, this->data_format, vector<TIndex>({ 1, 1, 1, bias_offset }));
this->x_offset = input(0).dim(-1) / cudnn_group;
this->y_offset = output(0)->dim(-1) / cudnn_group;
} }
} }
...@@ -82,7 +64,7 @@ void CuDNNConv2dOp<Context>::RunWithType() { ...@@ -82,7 +64,7 @@ void CuDNNConv2dOp<Context>::RunWithType() {
Tensor* buffer = ws()->GetBuffer(); Tensor* buffer = ws()->GetBuffer();
if (workspace_fwd_data_size == 0) workspace_fwd_data_size += 1; if (workspace_fwd_data_size == 0) workspace_fwd_data_size += 1;
buffer->Reshape(vector<TIndex>(1, this->group * workspace_fwd_data_size)); buffer->Reshape(vector<TIndex>(1, cudnn_group * workspace_fwd_data_size));
auto* Xdata = input(0).template data<T, Context>(); auto* Xdata = input(0).template data<T, Context>();
auto* Ydata = output(0)->template mutable_data<T, Context>(); auto* Ydata = output(0)->template mutable_data<T, Context>();
...@@ -90,7 +72,7 @@ void CuDNNConv2dOp<Context>::RunWithType() { ...@@ -90,7 +72,7 @@ void CuDNNConv2dOp<Context>::RunWithType() {
auto* Wdata = input(1).template data<T, Context>(); auto* Wdata = input(1).template data<T, Context>();
if (HasBias()) TENSOR_FILL(input(2), this->bias_shape); if (HasBias()) TENSOR_FILL(input(2), this->bias_shape);
for (int g = 0; g < this->group; g++) { for (int g = 0; g < cudnn_group; g++) {
auto* workspace = buffer->template mutable_data<char, Context>(); auto* workspace = buffer->template mutable_data<char, Context>();
CUDNN_CHECK(cudnnConvolutionForward(handle[g], CUDNN_CHECK(cudnnConvolutionForward(handle[g],
CUDNNType<T>::one, input_desc, Xdata + this->x_offset * g, CUDNNType<T>::one, input_desc, Xdata + this->x_offset * g,
...@@ -117,8 +99,6 @@ void CuDNNConv2dOp<Context>::RunOnDevice() { ...@@ -117,8 +99,6 @@ void CuDNNConv2dOp<Context>::RunOnDevice() {
if (this->dilation[i] != 1) return Conv2dOp<Context>::RunOnDevice(); if (this->dilation[i] != 1) return Conv2dOp<Context>::RunOnDevice();
#endif #endif
Conv2dOp<Context>::Reshape(); Conv2dOp<Context>::Reshape();
this->x_offset /= this->group;
this->y_offset /= this->group;
if (input(0).template IsType<float>()) { if (input(0).template IsType<float>()) {
#if CUDNN_VERSION_MIN(6, 0, 0) #if CUDNN_VERSION_MIN(6, 0, 0)
...@@ -135,6 +115,9 @@ void CuDNNConv2dOp<Context>::RunOnDevice() { ...@@ -135,6 +115,9 @@ void CuDNNConv2dOp<Context>::RunOnDevice() {
1, 1, 1, 1,
CUDNN_CROSS_CORRELATION)); CUDNN_CROSS_CORRELATION));
#endif #endif
#if CUDNN_VERSION_MIN(7, 0, 0)
CUDNN_CHECK(cudnnSetConvolutionGroupCount(conv_desc, this->group));
#endif
RunWithType<float>(); RunWithType<float>();
} else if (input(0).template IsType<float16>()) { } else if (input(0).template IsType<float16>()) {
#ifdef WITH_CUDA_FP16 #ifdef WITH_CUDA_FP16
...@@ -152,6 +135,9 @@ void CuDNNConv2dOp<Context>::RunOnDevice() { ...@@ -152,6 +135,9 @@ void CuDNNConv2dOp<Context>::RunOnDevice() {
1, 1, 1, 1,
CUDNN_CROSS_CORRELATION)); CUDNN_CROSS_CORRELATION));
#endif #endif
#if CUDNN_VERSION_MIN(7, 0, 0)
CUDNN_CHECK(cudnnSetConvolutionGroupCount(conv_desc, this->group));
#endif
RunWithType<float16>(); RunWithType<float16>();
#endif // WITH_CUDA_FP16 #endif // WITH_CUDA_FP16
} else { LOG(FATAL) << "Unsupported input types."; } } else { LOG(FATAL) << "Unsupported input types."; }
...@@ -165,51 +151,33 @@ void CuDNNConv2dGradientOp<Context>::RunWithType() { ...@@ -165,51 +151,33 @@ void CuDNNConv2dGradientOp<Context>::RunWithType() {
CUDNN_CHECK(cudnnSetFilter4dDescriptor(filter_desc, CUDNN_CHECK(cudnnSetFilter4dDescriptor(filter_desc,
CUDNNType<T>::type, CUDNNType<T>::type,
format, format,
this->num_output / this->group, this->num_output / cudnn_group,
this->channels / this->group, this->channels / this->group,
this->kernel_size[0], this->kernel_size[1])); this->kernel_size[0], this->kernel_size[1]));
#else #else
CUDNN_CHECK(cudnnSetFilter4dDescriptor_v4(filter_desc, CUDNN_CHECK(cudnnSetFilter4dDescriptor_v4(filter_desc,
CUDNNType<T>::type, CUDNNType<T>::type,
format, format,
this->num_output / this->group, this->num_output / cudnn_group,
this->channels / this->group, this->channels / this->group,
this->kernel_size[0], this->kernel_size[1])); this->kernel_size[0], this->kernel_size[1]));
#endif #endif
Tensor fake_tensor; // determine the input & output shape
vector<TIndex> fake_dims; cudnnSetTensor4dDescWithGroup<T>(&input_desc, this->data_format, input(-1).dims(), cudnn_group);
if (this->data_format == "NCHW") { cudnnSetTensor4dDescWithGroup<T>(&output_desc, this->data_format, input(0).dims(), cudnn_group);
// determine the input shape
fake_tensor.ReshapeLike(input(-1)); // determine the bias shape and misc
fake_dims = fake_tensor.dims();
fake_dims[1] /= this->group;
cudnnSetTensor4dDesc<T>(&input_desc, this->data_format, fake_dims);
// determine the output shape
fake_tensor.ReshapeLike(input(0));
fake_dims = fake_tensor.dims();
fake_dims[1] /= this->group;
cudnnSetTensor4dDesc<T>(&output_desc, this->data_format, fake_dims);
// determine the bias shape if necessary
if (HasBias()) { if (HasBias()) {
bias_offset = this->num_output / this->group; bias_offset = this->num_output / cudnn_group;
if (this->data_format == "NCHW") {
cudnnSetTensor4dDesc<T>(&bias_desc, this->data_format, vector<TIndex>({ 1, bias_offset, 1, 1 })); cudnnSetTensor4dDesc<T>(&bias_desc, this->data_format, vector<TIndex>({ 1, bias_offset, 1, 1 }));
} this->x_offset = input(0).count(1) / cudnn_group;
this->y_offset = input(-1).count(1) / cudnn_group;
} else if (this->data_format == "NHWC") { } else if (this->data_format == "NHWC") {
// determine the input shape
fake_tensor.ReshapeLike(input(-1));
fake_dims = fake_tensor.dims();
fake_dims[fake_dims.size() - 1] /= this->group;
cudnnSetTensor4dDesc<T>(&input_desc, this->data_format, fake_dims);
// determine the output shape
fake_tensor.ReshapeLike(input(0));
fake_dims = fake_tensor.dims();
fake_dims[fake_dims.size() - 1] /= this->group;
cudnnSetTensor4dDesc<T>(&output_desc, this->data_format, fake_dims);
// determine the bias shape if necessary
if (HasBias()) {
bias_offset = this->num_output / this->group;
cudnnSetTensor4dDesc<T>(&bias_desc, this->data_format, vector<TIndex>({ 1, 1, 1, bias_offset })); cudnnSetTensor4dDesc<T>(&bias_desc, this->data_format, vector<TIndex>({ 1, 1, 1, bias_offset }));
this->x_offset = input(0).dim(-1) / cudnn_group;
this->y_offset = input(-1).dim(-1) / cudnn_group;
} }
} }
...@@ -251,11 +219,11 @@ void CuDNNConv2dGradientOp<Context>::RunWithType() { ...@@ -251,11 +219,11 @@ void CuDNNConv2dGradientOp<Context>::RunWithType() {
Tensor* buffer2 = ws()->GetBuffer(); Tensor* buffer2 = ws()->GetBuffer();
if (workspace_bwd_data_size == 0) workspace_bwd_data_size += 1; if (workspace_bwd_data_size == 0) workspace_bwd_data_size += 1;
if (workspace_bwd_filter_size == 0) workspace_bwd_filter_size += 1; if (workspace_bwd_filter_size == 0) workspace_bwd_filter_size += 1;
buffer1->Reshape(vector<TIndex>(1, this->group * workspace_bwd_data_size)); buffer1->Reshape(vector<TIndex>(1, cudnn_group * workspace_bwd_data_size));
buffer2->Reshape(vector<TIndex>(1, this->group * workspace_bwd_filter_size)); buffer2->Reshape(vector<TIndex>(1, cudnn_group * workspace_bwd_filter_size));
const T* dYdata = input(2).template data<T, Context>(); const T* dYdata = input(2).template data<T, Context>();
for (int g = 0; g < this->group; g++) { for (int g = 0; g < cudnn_group; g++) {
if (output(2)->name() != "ignore") { if (output(2)->name() != "ignore") {
T* dBdata = output(2)->template mutable_data<T, Context>(); T* dBdata = output(2)->template mutable_data<T, Context>();
CUDNN_CHECK(cudnnConvolutionBackwardBias(handle[g], CUDNN_CHECK(cudnnConvolutionBackwardBias(handle[g],
...@@ -266,7 +234,7 @@ void CuDNNConv2dGradientOp<Context>::RunWithType() { ...@@ -266,7 +234,7 @@ void CuDNNConv2dGradientOp<Context>::RunWithType() {
auto* Xdata = input(0).template data<T, Context>(); auto* Xdata = input(0).template data<T, Context>();
auto* dWdata = output(1)->template mutable_data<T, Context>(); auto* dWdata = output(1)->template mutable_data<T, Context>();
auto* workspace = buffer2->mutable_data<char, Context>(); auto* workspace = buffer2->mutable_data<char, Context>();
CUDNN_CHECK(cudnnConvolutionBackwardFilter(handle[1 * this->group + g], CUDNN_CHECK(cudnnConvolutionBackwardFilter(handle[1 * cudnn_group + g],
CUDNNType<T>::one, output_desc, Xdata + this->x_offset * g, CUDNNType<T>::one, output_desc, Xdata + this->x_offset * g,
input_desc, dYdata + this->y_offset * g, input_desc, dYdata + this->y_offset * g,
conv_desc, conv_desc,
...@@ -278,7 +246,7 @@ void CuDNNConv2dGradientOp<Context>::RunWithType() { ...@@ -278,7 +246,7 @@ void CuDNNConv2dGradientOp<Context>::RunWithType() {
auto* Wdata = input(1).template data<T, Context>(); auto* Wdata = input(1).template data<T, Context>();
auto* dXdata = output(0)->template mutable_data<T, Context>(); auto* dXdata = output(0)->template mutable_data<T, Context>();
auto* workspace = buffer1->mutable_data<char, Context>(); auto* workspace = buffer1->mutable_data<char, Context>();
CUDNN_CHECK(cudnnConvolutionBackwardData(handle[2 * this->group + g], CUDNN_CHECK(cudnnConvolutionBackwardData(handle[2 * cudnn_group + g],
CUDNNType<T>::one, filter_desc, Wdata + this->weight_offset * g, CUDNNType<T>::one, filter_desc, Wdata + this->weight_offset * g,
input_desc, dYdata + this->y_offset * g, input_desc, dYdata + this->y_offset * g,
conv_desc, conv_desc,
...@@ -299,8 +267,6 @@ void CuDNNConv2dGradientOp<Context>::RunOnDevice() { ...@@ -299,8 +267,6 @@ void CuDNNConv2dGradientOp<Context>::RunOnDevice() {
if (this->dilation[i] != 1) return Conv2dGradientOp<Context>::RunOnDevice(); if (this->dilation[i] != 1) return Conv2dGradientOp<Context>::RunOnDevice();
#endif #endif
Conv2dGradientOp<Context>::GradientReshape(); Conv2dGradientOp<Context>::GradientReshape();
this->x_offset /= this->group;
this->y_offset /= this->group;
if (input(0).template IsType<float>()) { if (input(0).template IsType<float>()) {
#if CUDNN_VERSION_MIN(6, 0, 0) #if CUDNN_VERSION_MIN(6, 0, 0)
...@@ -317,6 +283,9 @@ void CuDNNConv2dGradientOp<Context>::RunOnDevice() { ...@@ -317,6 +283,9 @@ void CuDNNConv2dGradientOp<Context>::RunOnDevice() {
1, 1, 1, 1,
CUDNN_CROSS_CORRELATION)); CUDNN_CROSS_CORRELATION));
#endif #endif
#if CUDNN_VERSION_MIN(7, 0, 0)
CUDNN_CHECK(cudnnSetConvolutionGroupCount(conv_desc, this->group));
#endif
RunWithType<float>(); RunWithType<float>();
} else if (input(0).template IsType<float16>()) { } else if (input(0).template IsType<float16>()) {
#ifdef WITH_CUDA_FP16 #ifdef WITH_CUDA_FP16
...@@ -333,6 +302,9 @@ void CuDNNConv2dGradientOp<Context>::RunOnDevice() { ...@@ -333,6 +302,9 @@ void CuDNNConv2dGradientOp<Context>::RunOnDevice() {
this->stride[0], this->stride[1], this->stride[0], this->stride[1],
1, 1, CUDNN_CROSS_CORRELATION)); 1, 1, CUDNN_CROSS_CORRELATION));
#endif #endif
#if CUDNN_VERSION_MIN(7, 0, 0)
CUDNN_CHECK(cudnnSetConvolutionGroupCount(conv_desc, this->group));
#endif
RunWithType<float16>(); RunWithType<float16>();
#endif // WITH_CUDA_FP16 #endif // WITH_CUDA_FP16
} else { LOG(FATAL) << "Unsupported input types."; } } else { LOG(FATAL) << "Unsupported input types."; }
......
...@@ -15,51 +15,34 @@ void CuDNNConv2dTransposeOp<Context>::RunWithType() { ...@@ -15,51 +15,34 @@ void CuDNNConv2dTransposeOp<Context>::RunWithType() {
CUDNN_CHECK(cudnnSetFilter4dDescriptor(filter_desc, CUDNN_CHECK(cudnnSetFilter4dDescriptor(filter_desc,
CUDNNType<T>::type, CUDNNType<T>::type,
format, format,
this->num_output / this->group, this->num_output / cudnn_group,
this->channels / this->group, this->channels / this->group,
this->kernel_size[0], this->kernel_size[1])); this->kernel_size[0], this->kernel_size[1]));
#else #else
CUDNN_CHECK(cudnnSetFilter4dDescriptor_v4(filter_desc, CUDNN_CHECK(cudnnSetFilter4dDescriptor_v4(filter_desc,
CUDNNType<T>::type, CUDNNType<T>::type,
format, format,
this->num_output / this->group, this->num_output / cudnn_group,
this->channels / this->group, this->channels / this->group,
this->kernel_size[0], this->kernel_size[1])); this->kernel_size[0], this->kernel_size[1]));
#endif #endif
Tensor fake_tensor; // determine the input & output shape
vector<TIndex> fake_dims; cudnnSetTensor4dDescWithGroup<T>(&input_desc, this->data_format, input(0).dims(), cudnn_group);
if (this->data_format == "NCHW") { cudnnSetTensor4dDescWithGroup<T>(&output_desc, this->data_format, output(0)->dims(), cudnn_group);
// determine the input shape
fake_tensor.ReshapeLike(input(0)); // determine the bias shape and misc
fake_dims = fake_tensor.dims();
fake_dims[1] /= this->group;
cudnnSetTensor4dDesc<T>(&input_desc, this->data_format, fake_dims);
// determine the output shape
fake_tensor.ReshapeLike(*output(0));
fake_dims = fake_tensor.dims();
fake_dims[1] /= this->group;
cudnnSetTensor4dDesc<T>(&output_desc, this->data_format, fake_dims);
// determine the bias shape if necessary
if (HasBias()) { if (HasBias()) {
bias_offset = this->num_output / this->group; bias_offset = this->num_output / cudnn_group;
if (this->data_format == "NCHW") {
cudnnSetTensor4dDesc<T>(&bias_desc, this->data_format, vector<TIndex>({ 1, bias_offset, 1, 1 })); cudnnSetTensor4dDesc<T>(&bias_desc, this->data_format, vector<TIndex>({ 1, bias_offset, 1, 1 }));
this->x_offset = input(0).count(1) / cudnn_group;
this->y_offset = output(0)->count(1) / cudnn_group;
} }
} else if (this->data_format == "NHWC") { else if (this->data_format == "NHWC") {
// determine the input shape
fake_tensor.ReshapeLike(input(0));
fake_dims = fake_tensor.dims();
fake_dims[fake_dims.size() - 1] /= this->group;
cudnnSetTensor4dDesc<T>(&input_desc, this->data_format, fake_dims);
// determine the output shape
fake_tensor.ReshapeLike(*output(0));
fake_dims = fake_tensor.dims();
fake_dims[fake_dims.size() - 1] /= this->group;
cudnnSetTensor4dDesc<T>(&output_desc, this->data_format, fake_dims);
// determine the bias shape if necessary
if (HasBias()) {
bias_offset = this->num_output / this->group;
cudnnSetTensor4dDesc<T>(&bias_desc, this->data_format, vector<TIndex>({ 1, 1, 1, bias_offset })); cudnnSetTensor4dDesc<T>(&bias_desc, this->data_format, vector<TIndex>({ 1, 1, 1, bias_offset }));
this->x_offset = input(0).dim(-1) / cudnn_group;
this->y_offset = output(0)->dim(-1) / cudnn_group;
} }
} }
...@@ -82,7 +65,7 @@ void CuDNNConv2dTransposeOp<Context>::RunWithType() { ...@@ -82,7 +65,7 @@ void CuDNNConv2dTransposeOp<Context>::RunWithType() {
Tensor* buffer = ws()->GetBuffer(); Tensor* buffer = ws()->GetBuffer();
if (workspace_fwd_data_size == 0) workspace_fwd_data_size += 1; if (workspace_fwd_data_size == 0) workspace_fwd_data_size += 1;
buffer->Reshape(vector<TIndex>(1, this->group * workspace_fwd_data_size)); buffer->Reshape(vector<TIndex>(1, cudnn_group * workspace_fwd_data_size));
auto* Xdata = input(0).template data<T, Context>(); auto* Xdata = input(0).template data<T, Context>();
auto* Ydata = output(0)->template mutable_data<T, Context>(); auto* Ydata = output(0)->template mutable_data<T, Context>();
...@@ -90,7 +73,7 @@ void CuDNNConv2dTransposeOp<Context>::RunWithType() { ...@@ -90,7 +73,7 @@ void CuDNNConv2dTransposeOp<Context>::RunWithType() {
auto* Wdata = input(1).template data<T, Context>(); auto* Wdata = input(1).template data<T, Context>();
if (HasBias()) TENSOR_FILL(input(2), this->bias_shape); if (HasBias()) TENSOR_FILL(input(2), this->bias_shape);
for (int g = 0; g < this->group; g++) { for (int g = 0; g < cudnn_group; g++) {
auto* workspace = buffer->template mutable_data<char, Context>(); auto* workspace = buffer->template mutable_data<char, Context>();
CUDNN_CHECK(cudnnConvolutionBackwardData(handle[g], CUDNN_CHECK(cudnnConvolutionBackwardData(handle[g],
CUDNNType<T>::one, filter_desc, Wdata + this->weight_offset * g, CUDNNType<T>::one, filter_desc, Wdata + this->weight_offset * g,
...@@ -118,8 +101,6 @@ void CuDNNConv2dTransposeOp<Context>::RunOnDevice() { ...@@ -118,8 +101,6 @@ void CuDNNConv2dTransposeOp<Context>::RunOnDevice() {
if (this->dilation[i] != 1) return Conv2dTransposeOp<Context>::RunOnDevice(); if (this->dilation[i] != 1) return Conv2dTransposeOp<Context>::RunOnDevice();
#endif #endif
Conv2dTransposeOp<Context>::Reshape(); Conv2dTransposeOp<Context>::Reshape();
this->x_offset /= this->group;
this->y_offset /= this->group;
if (input(0).template IsType<float>()) { if (input(0).template IsType<float>()) {
#if CUDNN_VERSION_MIN(6, 0, 0) #if CUDNN_VERSION_MIN(6, 0, 0)
...@@ -136,6 +117,9 @@ void CuDNNConv2dTransposeOp<Context>::RunOnDevice() { ...@@ -136,6 +117,9 @@ void CuDNNConv2dTransposeOp<Context>::RunOnDevice() {
1, 1, 1, 1,
CUDNN_CROSS_CORRELATION)); CUDNN_CROSS_CORRELATION));
#endif #endif
#if CUDNN_VERSION_MIN(7, 0, 0)
CUDNN_CHECK(cudnnSetConvolutionGroupCount(conv_desc, this->group));
#endif
RunWithType<float>(); RunWithType<float>();
} else if (input(0).template IsType<float16>()) { } else if (input(0).template IsType<float16>()) {
#ifdef WITH_CUDA_FP16 #ifdef WITH_CUDA_FP16
...@@ -153,6 +137,9 @@ void CuDNNConv2dTransposeOp<Context>::RunOnDevice() { ...@@ -153,6 +137,9 @@ void CuDNNConv2dTransposeOp<Context>::RunOnDevice() {
1, 1, 1, 1,
CUDNN_CROSS_CORRELATION)); CUDNN_CROSS_CORRELATION));
#endif #endif
#if CUDNN_VERSION_MIN(7, 0, 0)
CUDNN_CHECK(cudnnSetConvolutionGroupCount(conv_desc, this->group));
#endif
RunWithType<float16>(); RunWithType<float16>();
#endif // WITH_CUDA_FP16 #endif // WITH_CUDA_FP16
} else { LOG(FATAL) << "Unsupported input types."; } } else { LOG(FATAL) << "Unsupported input types."; }
...@@ -166,51 +153,34 @@ void CuDNNConv2dTransposeGradientOp<Context>::RunWithType() { ...@@ -166,51 +153,34 @@ void CuDNNConv2dTransposeGradientOp<Context>::RunWithType() {
CUDNN_CHECK(cudnnSetFilter4dDescriptor(filter_desc, CUDNN_CHECK(cudnnSetFilter4dDescriptor(filter_desc,
CUDNNType<T>::type, CUDNNType<T>::type,
format, format,
this->num_output / this->group, this->num_output / cudnn_group,
this->channels / this->group, this->channels / this->group,
this->kernel_size[0], this->kernel_size[1])); this->kernel_size[0], this->kernel_size[1]));
#else #else
CUDNN_CHECK(cudnnSetFilter4dDescriptor_v4(filter_desc, CUDNN_CHECK(cudnnSetFilter4dDescriptor_v4(filter_desc,
CUDNNType<T>::type, CUDNNType<T>::type,
format, format,
this->num_output / this->group, this->num_output / cudnn_group,
this->channels / this->group, this->channels / this->group,
this->kernel_size[0], this->kernel_size[1])); this->kernel_size[0], this->kernel_size[1]));
#endif #endif
Tensor fake_tensor; // determine the input & output shape
vector<TIndex> fake_dims; cudnnSetTensor4dDescWithGroup<T>(&input_desc, this->data_format, input(-1).dims(), cudnn_group);
if (this->data_format == "NCHW") { cudnnSetTensor4dDescWithGroup<T>(&output_desc, this->data_format, input(0).dims(), cudnn_group);
// determine the input shape
fake_tensor.ReshapeLike(input(-1)); // determine the bias shape and misc
fake_dims = fake_tensor.dims();
fake_dims[1] /= this->group;
cudnnSetTensor4dDesc<T>(&input_desc, this->data_format, fake_dims);
// determine the output shape
fake_tensor.ReshapeLike(input(0));
fake_dims = fake_tensor.dims();
fake_dims[1] /= this->group;
cudnnSetTensor4dDesc<T>(&output_desc, this->data_format, fake_dims);
// determine the bias shape if necessary
if (HasBias()) { if (HasBias()) {
bias_offset = this->num_output / this->group; bias_offset = this->num_output / cudnn_group;
if (this->data_format == "NCHW") {
cudnnSetTensor4dDesc<T>(&bias_desc, this->data_format, vector<TIndex>({ 1, bias_offset, 1, 1 })); cudnnSetTensor4dDesc<T>(&bias_desc, this->data_format, vector<TIndex>({ 1, bias_offset, 1, 1 }));
this->x_offset = input(0).count(1) / cudnn_group;
this->y_offset = input(-1).count(1) / cudnn_group;
} }
} else if (this->data_format == "NHWC") { else if (this->data_format == "NHWC") {
// determine the input shape
fake_tensor.ReshapeLike(input(-1));
fake_dims = fake_tensor.dims();
fake_dims[fake_dims.size() - 1] /= this->group;
cudnnSetTensor4dDesc<T>(&input_desc, this->data_format, fake_dims);
// determine the output shape
fake_tensor.ReshapeLike(input(0));
fake_dims = fake_tensor.dims();
fake_dims[fake_dims.size() - 1] /= this->group;
cudnnSetTensor4dDesc<T>(&output_desc, this->data_format, fake_dims);
// determine the bias shape if necessary
if (HasBias()) {
bias_offset = this->num_output / this->group;
cudnnSetTensor4dDesc<T>(&bias_desc, this->data_format, vector<TIndex>({ 1, 1, 1, bias_offset })); cudnnSetTensor4dDesc<T>(&bias_desc, this->data_format, vector<TIndex>({ 1, 1, 1, bias_offset }));
this->x_offset = input(0).dim(-1) / cudnn_group;
this->y_offset = input(-1).dim(-1) / cudnn_group;
} }
} }
...@@ -252,11 +222,11 @@ void CuDNNConv2dTransposeGradientOp<Context>::RunWithType() { ...@@ -252,11 +222,11 @@ void CuDNNConv2dTransposeGradientOp<Context>::RunWithType() {
Tensor* buffer2 = ws()->GetBuffer(); Tensor* buffer2 = ws()->GetBuffer();
if (workspace_bwd_data_size == 0) workspace_bwd_data_size += 1; if (workspace_bwd_data_size == 0) workspace_bwd_data_size += 1;
if (workspace_bwd_filter_size == 0) workspace_bwd_filter_size += 1; if (workspace_bwd_filter_size == 0) workspace_bwd_filter_size += 1;
buffer1->Reshape(vector<TIndex>(1, this->group * workspace_bwd_data_size)); buffer1->Reshape(vector<TIndex>(1, cudnn_group * workspace_bwd_data_size));
buffer2->Reshape(vector<TIndex>(1, this->group * workspace_bwd_filter_size)); buffer2->Reshape(vector<TIndex>(1, cudnn_group * workspace_bwd_filter_size));
const T* dYdata = input(2).template data<T, Context>(); const T* dYdata = input(2).template data<T, Context>();
for (int g = 0; g < this->group; g++) { for (int g = 0; g < cudnn_group; g++) {
if (output(2)->name() != "ignore") { if (output(2)->name() != "ignore") {
T* dBdata = output(2)->template mutable_data<T, Context>(); T* dBdata = output(2)->template mutable_data<T, Context>();
CUDNN_CHECK(cudnnConvolutionBackwardBias(handle[g], CUDNN_CHECK(cudnnConvolutionBackwardBias(handle[g],
...@@ -267,7 +237,7 @@ void CuDNNConv2dTransposeGradientOp<Context>::RunWithType() { ...@@ -267,7 +237,7 @@ void CuDNNConv2dTransposeGradientOp<Context>::RunWithType() {
auto* Xdata = input(0).template data<T, Context>(); auto* Xdata = input(0).template data<T, Context>();
auto* dWdata = output(1)->template mutable_data<T, Context>(); auto* dWdata = output(1)->template mutable_data<T, Context>();
auto* workspace = buffer2->mutable_data<char, Context>(); auto* workspace = buffer2->mutable_data<char, Context>();
CUDNN_CHECK(cudnnConvolutionBackwardFilter(handle[1 * this->group + g], CUDNN_CHECK(cudnnConvolutionBackwardFilter(handle[1 * cudnn_group + g],
CUDNNType<T>::one, input_desc, dYdata + this->y_offset * g, CUDNNType<T>::one, input_desc, dYdata + this->y_offset * g,
output_desc, Xdata + this->x_offset * g, output_desc, Xdata + this->x_offset * g,
conv_desc, conv_desc,
...@@ -279,7 +249,7 @@ void CuDNNConv2dTransposeGradientOp<Context>::RunWithType() { ...@@ -279,7 +249,7 @@ void CuDNNConv2dTransposeGradientOp<Context>::RunWithType() {
auto* Wdata = input(1).template data<T, Context>(); auto* Wdata = input(1).template data<T, Context>();
auto* dXdata = output(0)->template mutable_data<T, Context>(); auto* dXdata = output(0)->template mutable_data<T, Context>();
auto* workspace = buffer1->mutable_data<char, Context>(); auto* workspace = buffer1->mutable_data<char, Context>();
CUDNN_CHECK(cudnnConvolutionForward(handle[2 * this->group + g], CUDNN_CHECK(cudnnConvolutionForward(handle[2 * cudnn_group + g],
CUDNNType<T>::one, input_desc, dYdata + this->y_offset * g, CUDNNType<T>::one, input_desc, dYdata + this->y_offset * g,
filter_desc, Wdata + this->weight_offset * g, filter_desc, Wdata + this->weight_offset * g,
conv_desc, conv_desc,
...@@ -300,8 +270,6 @@ void CuDNNConv2dTransposeGradientOp<Context>::RunOnDevice() { ...@@ -300,8 +270,6 @@ void CuDNNConv2dTransposeGradientOp<Context>::RunOnDevice() {
if (this->dilation[i] != 1) return Conv2dTransposeGradientOp<Context>::RunOnDevice(); if (this->dilation[i] != 1) return Conv2dTransposeGradientOp<Context>::RunOnDevice();
#endif #endif
Conv2dTransposeGradientOp<Context>::GradientReshape(); Conv2dTransposeGradientOp<Context>::GradientReshape();
this->x_offset /= this->group;
this->y_offset /= this->group;
if (input(0).template IsType<float>()) { if (input(0).template IsType<float>()) {
#if CUDNN_VERSION_MIN(6, 0, 0) #if CUDNN_VERSION_MIN(6, 0, 0)
...@@ -318,6 +286,9 @@ void CuDNNConv2dTransposeGradientOp<Context>::RunOnDevice() { ...@@ -318,6 +286,9 @@ void CuDNNConv2dTransposeGradientOp<Context>::RunOnDevice() {
1, 1, 1, 1,
CUDNN_CROSS_CORRELATION)); CUDNN_CROSS_CORRELATION));
#endif #endif
#if CUDNN_VERSION_MIN(7, 0, 0)
CUDNN_CHECK(cudnnSetConvolutionGroupCount(conv_desc, this->group));
#endif
RunWithType<float>(); RunWithType<float>();
} else if (input(0).template IsType<float16>()) { } else if (input(0).template IsType<float16>()) {
#ifdef WITH_CUDA_FP16 #ifdef WITH_CUDA_FP16
...@@ -335,6 +306,9 @@ void CuDNNConv2dTransposeGradientOp<Context>::RunOnDevice() { ...@@ -335,6 +306,9 @@ void CuDNNConv2dTransposeGradientOp<Context>::RunOnDevice() {
1, 1, 1, 1,
CUDNN_CROSS_CORRELATION)); CUDNN_CROSS_CORRELATION));
#endif #endif
#if CUDNN_VERSION_MIN(7, 0, 0)
CUDNN_CHECK(cudnnSetConvolutionGroupCount(conv_desc, this->group));
#endif
RunWithType<float16>(); RunWithType<float16>();
#endif // WITH_CUDA_FP16 #endif // WITH_CUDA_FP16
} else { LOG(FATAL) << "Unsupported input types."; } } else { LOG(FATAL) << "Unsupported input types."; }
......
...@@ -6,9 +6,9 @@ namespace dragon { ...@@ -6,9 +6,9 @@ namespace dragon {
template <class Context> template <typename T> template <class Context> template <typename T>
void CuDNNLRNOp<Context>::RunWithType() { void CuDNNLRNOp<Context>::RunWithType() {
if (this->data_format == "NCHW") {
cudnnSetTensorDesc<T>(&input_desc, &input(0)); cudnnSetTensorDesc<T>(&input_desc, &input(0));
cudnnSetTensorDesc<T>(&output_desc, output(0)); cudnnSetTensorDesc<T>(&output_desc, output(0));
auto* Xdata = input(0).template data<T, Context>(); auto* Xdata = input(0).template data<T, Context>();
auto* Ydata = output(0)->template mutable_data<T, Context>(); auto* Ydata = output(0)->template mutable_data<T, Context>();
CUDNN_CHECK(cudnnLRNCrossChannelForward(cudnn_handle(), CUDNN_CHECK(cudnnLRNCrossChannelForward(cudnn_handle(),
...@@ -16,20 +16,23 @@ void CuDNNLRNOp<Context>::RunWithType() { ...@@ -16,20 +16,23 @@ void CuDNNLRNOp<Context>::RunWithType() {
CUDNN_LRN_CROSS_CHANNEL_DIM1, CUDNN_LRN_CROSS_CHANNEL_DIM1,
CUDNNType<T>::one, input_desc, Xdata, CUDNNType<T>::one, input_desc, Xdata,
CUDNNType<T>::zero, output_desc, Ydata)); CUDNNType<T>::zero, output_desc, Ydata));
} else LOG(FATAL) << "Unknown data format: " << this->data_format;
} }
template <class Context> template <class Context>
void CuDNNLRNOp<Context>::RunOnDevice() { void CuDNNLRNOp<Context>::RunOnDevice() {
output(0)->ReshapeLike(input(0)); output(0)->ReshapeLike(input(0));
if (this->mode == ACROSS_CHANNELS) { if (this->mode == "ACROSS_CHANNELS") {
if (input(0).template IsType<float>()) RunWithType<float>(); if (input(0).template IsType<float>()) RunWithType<float>();
#ifdef WITH_CUDA_FP16 #ifdef WITH_CUDA_FP16
else if (input(0).template IsType<float16>()) RunWithType<float16>(); else if (input(0).template IsType<float16>()) RunWithType<float16>();
#endif #endif
else LOG(FATAL) << "Unsupported input types."; else LOG(FATAL) << "Unsupported input types.";
} else { } else if (this->mode == "WITHIN_CHANNEL") {
LRNOp<Context>::RunOnDevice(); LRNOp<Context>::RunOnDevice();
} else {
LOG(FATAL) << "Unsupported lrn mode: " << this->mode;
} }
} }
...@@ -37,6 +40,7 @@ DEPLOY_CUDNN(LRN); ...@@ -37,6 +40,7 @@ DEPLOY_CUDNN(LRN);
template <class Context> template <typename T> template <class Context> template <typename T>
void CuDNNLRNGradientOp<Context>::RunWithType() { void CuDNNLRNGradientOp<Context>::RunWithType() {
if (this->data_format == "NCHW") {
cudnnSetTensorDesc<T>(&input_desc, &input(-1)); cudnnSetTensorDesc<T>(&input_desc, &input(-1));
cudnnSetTensorDesc<T>(&output_desc, output(0)); cudnnSetTensorDesc<T>(&output_desc, output(0));
...@@ -51,20 +55,23 @@ void CuDNNLRNGradientOp<Context>::RunWithType() { ...@@ -51,20 +55,23 @@ void CuDNNLRNGradientOp<Context>::RunWithType() {
input_desc, dYdata, input_desc, dYdata,
output_desc, Xdata, output_desc, Xdata,
CUDNNType<T>::zero, output_desc, dXdata)); CUDNNType<T>::zero, output_desc, dXdata));
} else LOG(FATAL) << "Unknown data format: " << this->data_format;
} }
template <class Context> template <class Context>
void CuDNNLRNGradientOp<Context>::RunOnDevice() { void CuDNNLRNGradientOp<Context>::RunOnDevice() {
output(0)->ReshapeLike(input(0)); output(0)->ReshapeLike(input(0));
if (this->mode == ACROSS_CHANNELS) { if (this->mode == "ACROSS_CHANNELS") {
if (input(0).template IsType<float>()) RunWithType<float>(); if (input(0).template IsType<float>()) RunWithType<float>();
#ifdef WITH_CUDA_FP16 #ifdef WITH_CUDA_FP16
else if (input(0).template IsType<float16>()) RunWithType<float16>(); else if (input(0).template IsType<float16>()) RunWithType<float16>();
#endif #endif
else LOG(FATAL) << "Unsupported input types."; else LOG(FATAL) << "Unsupported input types.";
} else { } else if (this->mode == "WITHIN_CHANNEL") {
LRNGradientOp<Context>::RunOnDevice(); LRNGradientOp<Context>::RunOnDevice();
} else {
LOG(FATAL) << "Unsupported lrn mode: " << this->mode;
} }
} }
......
...@@ -45,15 +45,16 @@ template <class Context> template <typename T> ...@@ -45,15 +45,16 @@ template <class Context> template <typename T>
void LRNOp<Context>::PoolRunWithType() { void LRNOp<Context>::PoolRunWithType() {
pool_out = ws()->CreateTensor("/mnt/" + anchor() + "/pool_out"); pool_out = ws()->CreateTensor("/mnt/" + anchor() + "/pool_out");
if (!pool_op) { if (!pool_op) {
Argument ks, s, p, mode; Argument ks, s, p, m, df;
ks.set_name("kernel_size"); ks.add_ints(local_size); ks.set_name("kernel_size"); ks.add_ints(local_size);
s.set_name("stride"); s.add_ints(1); s.set_name("stride"); s.add_ints(1);
p.set_name("pad"); p.add_ints((local_size - 1) / 2); p.set_name("pad"); p.add_ints((local_size - 1) / 2);
mode.set_name("mode"); mode.set_s("AVG"); m.set_name("mode"); m.set_s("AVG");
OperatorDef pool_op_def = MakeOperatorDef("Pooling", "", df.set_name("data_format"); df.set_s(data_format);
OperatorDef pool_op_def = MakeOperatorDef("Pooling2d", "",
vector<string>({ sqr_out->name() }), vector<string>({ sqr_out->name() }),
vector<string>({ pool_out->name() }), vector<string>({ pool_out->name() }),
vector<Argument>({ ks, s, p, mode })); vector<Argument>({ ks, s, p, m, df }));
if (this->op_def().has_device_option()) if (this->op_def().has_device_option())
pool_op_def.mutable_device_option()->CopyFrom(this->op_def().device_option()); pool_op_def.mutable_device_option()->CopyFrom(this->op_def().device_option());
pool_op.reset(CreateOperator(pool_op_def, ws())); pool_op.reset(CreateOperator(pool_op_def, ws()));
...@@ -99,12 +100,11 @@ void LRNOp<Context>::ProdRunWithType() { ...@@ -99,12 +100,11 @@ void LRNOp<Context>::ProdRunWithType() {
template <class Context> template <class Context>
void LRNOp<Context>::RunOnDevice() { void LRNOp<Context>::RunOnDevice() {
if (mode == ACROSS_CHANNELS) { if (mode == "ACROSS_CHANNELS") {
if (input(0).template IsType<float>()) { if (input(0).template IsType<float>()) {
AcrossRunWithType<float>(); AcrossRunWithType<float>();
} else { LOG(FATAL) << "Unsupported input types."; } } else { LOG(FATAL) << "Unsupported input types."; }
} } else if (mode == "WITHIN_CHANNEL") {
else {
if (input(0).template IsType<float>()) { if (input(0).template IsType<float>()) {
SplitRunWithType<float>(); SplitRunWithType<float>();
SquareRunWithType<float>(); SquareRunWithType<float>();
...@@ -112,6 +112,8 @@ void LRNOp<Context>::RunOnDevice() { ...@@ -112,6 +112,8 @@ void LRNOp<Context>::RunOnDevice() {
PowRunWithType<float>(); PowRunWithType<float>();
ProdRunWithType<float>(); ProdRunWithType<float>();
} else { LOG(FATAL) << "Unsupported input types."; } } else { LOG(FATAL) << "Unsupported input types."; }
} else {
LOG(FATAL) << "Unsupported lrn mode: " << mode;
} }
} }
...@@ -173,17 +175,18 @@ template <class Context> template <typename T> ...@@ -173,17 +175,18 @@ template <class Context> template <typename T>
void LRNGradientOp<Context>::PoolRunWithType() { void LRNGradientOp<Context>::PoolRunWithType() {
sqr_out = ws()->GetTensor("/mnt/" + anchor() + "/sqr_out"); sqr_out = ws()->GetTensor("/mnt/" + anchor() + "/sqr_out");
if (!pool_op) { if (!pool_op) {
Argument ks, s, p, mode; Argument ks, s, p, m, df;
ks.set_name("kernel_size"); ks.add_ints(local_size); ks.set_name("kernel_size"); ks.add_ints(local_size);
s.set_name("stride"); s.add_ints(1); s.set_name("stride"); s.add_ints(1);
p.set_name("pad"); p.add_ints((local_size - 1) / 2); p.set_name("pad"); p.add_ints((local_size - 1) / 2);
mode.set_name("mode"); mode.set_s("AVG"); m.set_name("mode"); m.set_s("AVG");
OperatorDef pool_op_def = MakeOperatorDef("PoolingGradient", "", df.set_name("data_format"); df.set_s(data_format);
OperatorDef pool_op_def = MakeOperatorDef("Pooling2dGradient", "",
vector<string>({ sqr_out->name(), vector<string>({ sqr_out->name(),
pool_out->name(), pool_out->name(),
pool_out->name() + "_grad" }), pool_out->name() + "_grad" }),
vector<string>({ sqr_out->name() + "_grad" }), vector<string>({ sqr_out->name() + "_grad" }),
vector<Argument>({ ks, s, p, mode })); vector<Argument>({ ks, s, p, m, df }));
if (this->op_def().has_device_option()) if (this->op_def().has_device_option())
pool_op_def.mutable_device_option()->CopyFrom(this->op_def().device_option()); pool_op_def.mutable_device_option()->CopyFrom(this->op_def().device_option());
pool_op.reset(CreateOperator(pool_op_def, ws())); pool_op.reset(CreateOperator(pool_op_def, ws()));
...@@ -224,12 +227,11 @@ void LRNGradientOp<Context>::SplitRunWithType() { ...@@ -224,12 +227,11 @@ void LRNGradientOp<Context>::SplitRunWithType() {
template <class Context> template <class Context>
void LRNGradientOp<Context>::RunOnDevice() { void LRNGradientOp<Context>::RunOnDevice() {
if (mode == ACROSS_CHANNELS) { if (mode == "ACROSS_CHANNELS") {
if (input(0).template IsType<float>()) { if (input(0).template IsType<float>()) {
AcrossRunWithType<float>(); AcrossRunWithType<float>();
} else { LOG(FATAL) << "Unsupported input types."; } } else { LOG(FATAL) << "Unsupported input types."; }
} } else if (mode == "WITHIN_CHANNEL") {
else {
if (input(0).template IsType<float>()) { if (input(0).template IsType<float>()) {
ProdRunWithType<float>(); ProdRunWithType<float>();
PowRunWithType<float>(); PowRunWithType<float>();
...@@ -237,6 +239,8 @@ void LRNGradientOp<Context>::RunOnDevice() { ...@@ -237,6 +239,8 @@ void LRNGradientOp<Context>::RunOnDevice() {
SquareRunWithType<float>(); SquareRunWithType<float>();
SplitRunWithType<float>(); SplitRunWithType<float>();
} else { LOG(FATAL) << "Unsupported input types."; } } else { LOG(FATAL) << "Unsupported input types."; }
} else {
LOG(FATAL) << "Unsupported lrn mode: " << mode;
} }
} }
......
...@@ -69,6 +69,34 @@ void cudnnSetTensor4dDesc(cudnnTensorDescriptor_t* desc, ...@@ -69,6 +69,34 @@ void cudnnSetTensor4dDesc(cudnnTensorDescriptor_t* desc,
} }
template <typename T> template <typename T>
void cudnnSetTensor4dDescWithGroup(cudnnTensorDescriptor_t* desc,
const string& data_format,
const vector<TIndex>& dims,
const TIndex group) {
if (data_format == "NCHW") {
CUDNN_CHECK(cudnnSetTensor4dDescriptorEx(*desc, CUDNNType<T>::type,
dims[0],
dims[1] / group,
dims[2],
dims[3],
dims[1] * dims[2] * dims[3],
dims[2] * dims[3],
dims[3],
1));
} else if (data_format == "NHWC") {
CUDNN_CHECK(cudnnSetTensor4dDescriptorEx(*desc, CUDNNType<T>::type,
dims[0],
dims[3] / group,
dims[1],
dims[2],
dims[1] * dims[2] * dims[3],
1,
dims[2] * dims[3],
dims[3]));
} else LOG(FATAL) << "Unknown data format: " << data_format;
}
template <typename T>
void cudnnSetTensor5dDesc(cudnnTensorDescriptor_t* desc, void cudnnSetTensor5dDesc(cudnnTensorDescriptor_t* desc,
const string& data_format, const string& data_format,
const vector<TIndex>& dims) { const vector<TIndex>& dims) {
...@@ -169,6 +197,7 @@ template void cudnnSetTensorDesc<float>(cudnnTensorDescriptor_t*, const vector<T ...@@ -169,6 +197,7 @@ template void cudnnSetTensorDesc<float>(cudnnTensorDescriptor_t*, const vector<T
template void cudnnSetTensor4dDesc<float>(cudnnTensorDescriptor_t*, const string&, const vector<TIndex>&); template void cudnnSetTensor4dDesc<float>(cudnnTensorDescriptor_t*, const string&, const vector<TIndex>&);
template void cudnnSetTensor5dDesc<float>(cudnnTensorDescriptor_t*, const string&, const vector<TIndex>&); template void cudnnSetTensor5dDesc<float>(cudnnTensorDescriptor_t*, const string&, const vector<TIndex>&);
template void cudnnSetTensor3dDesc<float>(cudnnTensorDescriptor_t*, const string&, const vector<TIndex>&); template void cudnnSetTensor3dDesc<float>(cudnnTensorDescriptor_t*, const string&, const vector<TIndex>&);
template void cudnnSetTensor4dDescWithGroup<float>(cudnnTensorDescriptor_t*, const string&, const vector<TIndex>&, const TIndex);
template void cudnnSetTensorDesc<float>(cudnnTensorDescriptor_t*, const vector<TIndex>&, const vector<TIndex>&); template void cudnnSetTensorDesc<float>(cudnnTensorDescriptor_t*, const vector<TIndex>&, const vector<TIndex>&);
...@@ -180,6 +209,7 @@ template void cudnnSetTensorDesc<double>(cudnnTensorDescriptor_t*, const vector< ...@@ -180,6 +209,7 @@ template void cudnnSetTensorDesc<double>(cudnnTensorDescriptor_t*, const vector<
template void cudnnSetTensor4dDesc<double>(cudnnTensorDescriptor_t*, const string&, const vector<TIndex>&); template void cudnnSetTensor4dDesc<double>(cudnnTensorDescriptor_t*, const string&, const vector<TIndex>&);
template void cudnnSetTensor5dDesc<double>(cudnnTensorDescriptor_t*, const string&, const vector<TIndex>&); template void cudnnSetTensor5dDesc<double>(cudnnTensorDescriptor_t*, const string&, const vector<TIndex>&);
template void cudnnSetTensor3dDesc<double>(cudnnTensorDescriptor_t*, const string&, const vector<TIndex>&); template void cudnnSetTensor3dDesc<double>(cudnnTensorDescriptor_t*, const string&, const vector<TIndex>&);
template void cudnnSetTensor4dDescWithGroup<double>(cudnnTensorDescriptor_t*, const string&, const vector<TIndex>&, const TIndex);
template void cudnnSetTensorDesc<double>(cudnnTensorDescriptor_t*, const vector<TIndex>&, const vector<TIndex>&); template void cudnnSetTensorDesc<double>(cudnnTensorDescriptor_t*, const vector<TIndex>&, const vector<TIndex>&);
...@@ -192,6 +222,7 @@ template void cudnnSetTensorDesc<float16>(cudnnTensorDescriptor_t*, const vector ...@@ -192,6 +222,7 @@ template void cudnnSetTensorDesc<float16>(cudnnTensorDescriptor_t*, const vector
template void cudnnSetTensor4dDesc<float16>(cudnnTensorDescriptor_t*, const string&, const vector<TIndex>&); template void cudnnSetTensor4dDesc<float16>(cudnnTensorDescriptor_t*, const string&, const vector<TIndex>&);
template void cudnnSetTensor5dDesc<float16>(cudnnTensorDescriptor_t*, const string&, const vector<TIndex>&); template void cudnnSetTensor5dDesc<float16>(cudnnTensorDescriptor_t*, const string&, const vector<TIndex>&);
template void cudnnSetTensor3dDesc<float16>(cudnnTensorDescriptor_t*, const string&, const vector<TIndex>&); template void cudnnSetTensor3dDesc<float16>(cudnnTensorDescriptor_t*, const string&, const vector<TIndex>&);
template void cudnnSetTensor4dDescWithGroup<float16>(cudnnTensorDescriptor_t*, const string&, const vector<TIndex>&, const TIndex);
template void cudnnSetTensorDesc<float16>(cudnnTensorDescriptor_t*, const vector<TIndex>&, const vector<TIndex>&); template void cudnnSetTensorDesc<float16>(cudnnTensorDescriptor_t*, const vector<TIndex>&, const vector<TIndex>&);
#endif #endif
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!