Commit 73ed1b96 by Ting PAN

Remove support for CUDNN v6

Summary:
For the purpose of consistency on getting CUDNN convolution algorithms,
CUDNN v6 (mainly relied by CUDA 8.0) is now dropped.
1 parent bbfecf22
...@@ -58,20 +58,20 @@ endif() ...@@ -58,20 +58,20 @@ endif()
# ---[ Library directories # ---[ Library directories
list(APPEND THIRD_PARTY_LIBRARY_DIRS ${PROTOBUF_SDK_ROOT_DIR}/lib) list(APPEND THIRD_PARTY_LIBRARY_DIRS ${PROTOBUF_SDK_ROOT_DIR}/lib)
if (USE_CUDA) if (USE_MPI)
list(APPEND THIRD_PARTY_LIBRARY_DIRS ${CUDA_TOOLKIT_ROOT_DIR}/lib64) list(APPEND THIRD_PARTY_LIBRARY_DIRS ${THIRD_PARTY_DIR}/mpi/lib)
list(APPEND THIRD_PARTY_LIBRARY_DIRS ${CUDA_TOOLKIT_ROOT_DIR}/lib/x64)
endif() endif()
if (USE_CUDNN) if (USE_CUDNN)
list(APPEND THIRD_PARTY_LIBRARY_DIRS ${THIRD_PARTY_DIR}/cudnn/lib64) list(APPEND THIRD_PARTY_LIBRARY_DIRS ${THIRD_PARTY_DIR}/cudnn/lib64)
list(APPEND THIRD_PARTY_LIBRARY_DIRS ${THIRD_PARTY_DIR}/cudnn/lib/x64) list(APPEND THIRD_PARTY_LIBRARY_DIRS ${THIRD_PARTY_DIR}/cudnn/lib/x64)
endif() endif()
if (USE_MPI)
list(APPEND THIRD_PARTY_LIBRARY_DIRS ${THIRD_PARTY_DIR}/mpi/lib)
endif()
if (USE_TENSORRT) if (USE_TENSORRT)
list(APPEND THIRD_PARTY_LIBRARY_DIRS ${TENSORRT_SDK_ROOT_DIR}/lib) list(APPEND THIRD_PARTY_LIBRARY_DIRS ${TENSORRT_SDK_ROOT_DIR}/lib)
endif() endif()
if (USE_CUDA)
list(APPEND THIRD_PARTY_LIBRARY_DIRS ${CUDA_TOOLKIT_ROOT_DIR}/lib64)
list(APPEND THIRD_PARTY_LIBRARY_DIRS ${CUDA_TOOLKIT_ROOT_DIR}/lib/x64)
endif()
# ---[ Defines # ---[ Defines
if (BUILD_PYTHON) if (BUILD_PYTHON)
......
...@@ -35,6 +35,12 @@ else() ...@@ -35,6 +35,12 @@ else()
endif() endif()
foreach(_plain_name ${ARGN}) foreach(_plain_name ${ARGN})
# Filter the linker option
string(FIND "${_plain_name}" "-Wl" _is_linker_option)
if (${_is_linker_option} GREATER -1)
list(APPEND _libraries ${_plain_name})
continue()
endif()
# Firstly, search in the third party # Firstly, search in the third party
set(LIB_VAR "${_target}/lib/${_plain_name}") set(LIB_VAR "${_target}/lib/${_plain_name}")
find_library( find_library(
......
...@@ -55,10 +55,6 @@ Share ...@@ -55,10 +55,6 @@ Share
##### #####
.. doxygenfunction:: dragon::Tensor::Share .. doxygenfunction:: dragon::Tensor::Share
SwitchToDevice
##############
.. doxygenfunction:: dragon::Tensor::SwitchToDevice
axis axis
#### ####
.. doxygenfunction:: dragon::Tensor::axis .. doxygenfunction:: dragon::Tensor::axis
......
...@@ -19,10 +19,6 @@ State ...@@ -19,10 +19,6 @@ State
Public Functions Public Functions
---------------- ----------------
SwitchToDevice
##############
.. doxygenfunction:: dragon::UnifiedMemory::SwitchToDevice
SwitchToCUDADevice SwitchToCUDADevice
################## ##################
.. doxygenfunction:: dragon::UnifiedMemory::SwitchToCUDADevice .. doxygenfunction:: dragon::UnifiedMemory::SwitchToCUDADevice
......
...@@ -55,6 +55,7 @@ const void* UnifiedMemory::cpu_data(size_t size) { ...@@ -55,6 +55,7 @@ const void* UnifiedMemory::cpu_data(size_t size) {
} }
const void* UnifiedMemory::cuda_data(size_t size) { const void* UnifiedMemory::cuda_data(size_t size) {
SwitchToCUDADevice(CUDAContext::current_device());
ToCUDA(size); ToCUDA(size);
return (const void*)cuda_ptr_; return (const void*)cuda_ptr_;
} }
...@@ -70,6 +71,7 @@ void* UnifiedMemory::mutable_cpu_data(size_t size) { ...@@ -70,6 +71,7 @@ void* UnifiedMemory::mutable_cpu_data(size_t size) {
} }
void* UnifiedMemory::mutable_cuda_data(size_t size) { void* UnifiedMemory::mutable_cuda_data(size_t size) {
SwitchToCUDADevice(CUDAContext::current_device());
ToCUDA(size); ToCUDA(size);
state_ = STATE_AT_CUDA; state_ = STATE_AT_CUDA;
return cuda_ptr_; return cuda_ptr_;
...@@ -116,10 +118,6 @@ UnifiedMemory::~UnifiedMemory() { ...@@ -116,10 +118,6 @@ UnifiedMemory::~UnifiedMemory() {
} }
} }
void UnifiedMemory::SwitchToDevice(int device_id) {
if (cuda_ptr_) SwitchToCUDADevice(device_id);
}
void UnifiedMemory::SwitchToCUDADevice(int device_id) { void UnifiedMemory::SwitchToCUDADevice(int device_id) {
#ifdef USE_CUDA #ifdef USE_CUDA
if (cuda_ptr_) { if (cuda_ptr_) {
...@@ -129,7 +127,9 @@ void UnifiedMemory::SwitchToCUDADevice(int device_id) { ...@@ -129,7 +127,9 @@ void UnifiedMemory::SwitchToCUDADevice(int device_id) {
new_ptr_ = CUDAContext::New(size_); new_ptr_ = CUDAContext::New(size_);
CUDAContext::Memcpy<CUDAContext, CUDAContext>( CUDAContext::Memcpy<CUDAContext, CUDAContext>(
size_, new_ptr_, cuda_ptr_, device_id_); size_, new_ptr_, cuda_ptr_, device_id_);
if (own_cuda_ptr_) CUDAContext::Delete(cuda_ptr_); if (own_cuda_ptr_) {
CUDAContext::Delete(cuda_ptr_);
}
cuda_ptr_ = new_ptr_; cuda_ptr_ = new_ptr_;
device_id_ = device_id; device_id_ = device_id;
} }
......
...@@ -49,9 +49,6 @@ class DRAGON_API UnifiedMemory { ...@@ -49,9 +49,6 @@ class DRAGON_API UnifiedMemory {
/*! \brief Destructor */ /*! \brief Destructor */
~UnifiedMemory(); ~UnifiedMemory();
/*! \brief Switch to the given device */
void SwitchToDevice(int device);
/*! \brief Switch to the given cuda device */ /*! \brief Switch to the given cuda device */
void SwitchToCUDADevice(int device); void SwitchToCUDADevice(int device);
......
...@@ -136,20 +136,6 @@ void Operator<Context>::Release() { ...@@ -136,20 +136,6 @@ void Operator<Context>::Release() {
} }
} }
template <class Context>
void Operator<Context>::SwitchToDevice() {
for (auto* tensor : inputs_) {
if (tensor->has_name()) {
tensor->SwitchToDevice(ctx()->device());
}
}
for (auto* tensor : outputs_) {
if (tensor->has_name()) {
tensor->SwitchToDevice(ctx()->device());
}
}
}
OperatorBase* OperatorBase*
TryCreateOperator(const string& key, const OperatorDef& def, Workspace* ws) { TryCreateOperator(const string& key, const OperatorDef& def, Workspace* ws) {
switch (def.device_option().device_type()) { switch (def.device_option().device_type()) {
......
...@@ -207,9 +207,6 @@ class DRAGON_API Operator : public OperatorBase { ...@@ -207,9 +207,6 @@ class DRAGON_API Operator : public OperatorBase {
/*! \brief Release the ownership of inputs */ /*! \brief Release the ownership of inputs */
virtual void Release(); virtual void Release();
/*! \brief Coordinate the context of inputs and outputs */
virtual void SwitchToDevice();
/*! \brief The detailed execution on device */ /*! \brief The detailed execution on device */
virtual void RunOnDevice() = 0; virtual void RunOnDevice() = 0;
...@@ -217,7 +214,6 @@ class DRAGON_API Operator : public OperatorBase { ...@@ -217,7 +214,6 @@ class DRAGON_API Operator : public OperatorBase {
void Run(int stream = 0) final { void Run(int stream = 0) final {
Prepare(); Prepare();
ctx()->SwitchToDevice(stream); ctx()->SwitchToDevice(stream);
SwitchToDevice();
RunOnDevice(); RunOnDevice();
if (do_sync_) { if (do_sync_) {
ctx()->FinishDeviceComputation(); ctx()->FinishDeviceComputation();
......
...@@ -109,12 +109,6 @@ class DRAGON_API Tensor { ...@@ -109,12 +109,6 @@ class DRAGON_API Tensor {
return Reshape(other.dims_); return Reshape(other.dims_);
} }
/*! \brief Switch memory to the specific device */
void SwitchToDevice(int device_id) {
UnifiedMemory* mem = memory();
if (mem) mem->SwitchToDevice(device_id);
}
/*! \brief Copy memory from a tensor with context */ /*! \brief Copy memory from a tensor with context */
template <class Context> template <class Context>
Tensor* CopyFrom(const Tensor& other, Context* ctx) { Tensor* CopyFrom(const Tensor& other, Context* ctx) {
......
...@@ -71,7 +71,7 @@ if (USE_CUDNN) ...@@ -71,7 +71,7 @@ if (USE_CUDNN)
if (USE_SHARED_LIBS) if (USE_SHARED_LIBS)
target_link_libraries_v2(dragon cudnn) target_link_libraries_v2(dragon cudnn)
else() else()
target_link_libraries_v2(dragon cudnn_static) target_link_libraries_v2(dragon_python -Wl,--whole-archive cudnn_static -Wl,--no-whole-archive)
endif() endif()
endif() endif()
if (USE_NCCL) if (USE_NCCL)
......
...@@ -79,7 +79,7 @@ if (USE_CUDNN) ...@@ -79,7 +79,7 @@ if (USE_CUDNN)
if (USE_SHARED_LIBS) if (USE_SHARED_LIBS)
target_link_libraries_v2(dragonrt cudnn) target_link_libraries_v2(dragonrt cudnn)
else() else()
target_link_libraries_v2(dragonrt cudnn_static) target_link_libraries_v2(dragonrt -Wl,--whole-archive cudnn_static -Wl,--no-whole-archive)
endif() endif()
endif() endif()
if (USE_CUDA AND (NOT USE_SHARED_LIBS)) if (USE_CUDA AND (NOT USE_SHARED_LIBS))
......
...@@ -58,8 +58,6 @@ DEFINE_OP_SINGLE_ARG_WITH_DESC(float, DropoutGradientOp, ratio); ...@@ -58,8 +58,6 @@ DEFINE_OP_SINGLE_ARG_WITH_DESC(float, DropoutGradientOp, ratio);
#ifdef USE_CUDNN #ifdef USE_CUDNN
#if CUDNN_VERSION_MIN(7, 0, 0)
template <class Context> template <class Context>
class CuDNNDropoutOp final : public DropoutOp<Context> { class CuDNNDropoutOp final : public DropoutOp<Context> {
public: public:
...@@ -118,8 +116,6 @@ class CuDNNDropoutGradientOp final : public DropoutGradientOp<Context> { ...@@ -118,8 +116,6 @@ class CuDNNDropoutGradientOp final : public DropoutGradientOp<Context> {
unsigned long long rng_seed_; unsigned long long rng_seed_;
}; };
#endif // CUDNN_VERSION_MIN(7, 0, 0)
#endif // USE_CUDNN #endif // USE_CUDNN
} // namespace dragon } // namespace dragon
......
...@@ -3,8 +3,6 @@ ...@@ -3,8 +3,6 @@
#include "dragon/core/workspace.h" #include "dragon/core/workspace.h"
#include "dragon/operators/activation/dropout_op.h" #include "dragon/operators/activation/dropout_op.h"
#if CUDNN_VERSION_MIN(7, 0, 0)
namespace dragon { namespace dragon {
template <class Context> template <class Context>
...@@ -124,6 +122,4 @@ DEPLOY_CUDNN_OPERATOR(DropoutGradient); ...@@ -124,6 +122,4 @@ DEPLOY_CUDNN_OPERATOR(DropoutGradient);
} // namespace dragon } // namespace dragon
#endif // CUDNN_VERSION_MIN(7, 0, 0)
#endif // USE_CUDNN #endif // USE_CUDNN
...@@ -23,8 +23,6 @@ class ShapeOp final : public Operator<Context> { ...@@ -23,8 +23,6 @@ class ShapeOp final : public Operator<Context> {
SIMPLE_CTOR_DTOR(ShapeOp); SIMPLE_CTOR_DTOR(ShapeOp);
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
void SwitchToDevice() override {}
void RunOnDevice() override; void RunOnDevice() override;
}; };
......
...@@ -44,8 +44,6 @@ class CTCLossGradientOp final : public Operator<Context> { ...@@ -44,8 +44,6 @@ class CTCLossGradientOp final : public Operator<Context> {
#ifdef USE_CUDNN #ifdef USE_CUDNN
#if CUDNN_VERSION_MIN(7, 0, 0)
template <class Context> template <class Context>
class CuDNNCTCLossOp final : public Operator<Context> { class CuDNNCTCLossOp final : public Operator<Context> {
public: public:
...@@ -81,8 +79,6 @@ class CuDNNCTCLossOp final : public Operator<Context> { ...@@ -81,8 +79,6 @@ class CuDNNCTCLossOp final : public Operator<Context> {
vec32_t packed_labels_, label_lengths_, input_lengths_; vec32_t packed_labels_, label_lengths_, input_lengths_;
}; };
#endif // CUDNN_VERSION_MIN(7, 0, 0)
#endif // USE_CUDNN #endif // USE_CUDNN
} // namespace dragon } // namespace dragon
......
...@@ -3,8 +3,6 @@ ...@@ -3,8 +3,6 @@
#include "dragon/core/workspace.h" #include "dragon/core/workspace.h"
#include "dragon/operators/loss/ctc_loss_op.h" #include "dragon/operators/loss/ctc_loss_op.h"
#if CUDNN_VERSION_MIN(7, 0, 0)
#define CUDNN_LABEL_LENGTH_LIMIT 256 #define CUDNN_LABEL_LENGTH_LIMIT 256
namespace dragon { namespace dragon {
...@@ -95,6 +93,4 @@ DEPLOY_CUDNN_OPERATOR(CTCLoss); ...@@ -95,6 +93,4 @@ DEPLOY_CUDNN_OPERATOR(CTCLoss);
} // namespace dragon } // namespace dragon
#endif // CUDNN_VERSION_MIN(7, 0, 0)
#endif // USE_CUDNN #endif // USE_CUDNN
...@@ -19,11 +19,10 @@ void CuDNNRecurrentOpBase<Context>::ResetDesc() { ...@@ -19,11 +19,10 @@ void CuDNNRecurrentOpBase<Context>::ResetDesc() {
// Setup Dropout // Setup Dropout
if (dropout_ratio_ < 1.f) { if (dropout_ratio_ < 1.f) {
#if CUDNN_VERSION_MIN(7, 0, 0)
if (!states_initialized_) { if (!states_initialized_) {
states_initialized_ = 1; states_initialized_ = 1;
CUDNN_CHECK( auto cudnn_handle = ctx()->cudnn_handle();
cudnnDropoutGetStatesSize(ctx()->cudnn_handle(), &states_size_)); CUDNN_CHECK(cudnnDropoutGetStatesSize(cudnn_handle, &states_size_));
std::lock_guard<std::mutex> lk(CUDAContext::mutex()); std::lock_guard<std::mutex> lk(CUDAContext::mutex());
auto* states_tensor = workspace()->CreateTensor( auto* states_tensor = workspace()->CreateTensor(
"/share/cudnn/dropout:" + str::to(rng_seed_) + "/states"); "/share/cudnn/dropout:" + str::to(rng_seed_) + "/states");
...@@ -31,7 +30,7 @@ void CuDNNRecurrentOpBase<Context>::ResetDesc() { ...@@ -31,7 +30,7 @@ void CuDNNRecurrentOpBase<Context>::ResetDesc() {
auto* states = states_tensor->template mutable_data<uint8_t, Context>(); auto* states = states_tensor->template mutable_data<uint8_t, Context>();
CUDNN_CHECK(cudnnRestoreDropoutDescriptor( CUDNN_CHECK(cudnnRestoreDropoutDescriptor(
dropout_desc_, dropout_desc_,
ctx()->cudnn_handle(), cudnn_handle,
dropout_ratio_, dropout_ratio_,
states, states,
states_size_, states_size_,
...@@ -41,16 +40,13 @@ void CuDNNRecurrentOpBase<Context>::ResetDesc() { ...@@ -41,16 +40,13 @@ void CuDNNRecurrentOpBase<Context>::ResetDesc() {
->template mutable_data<uint8_t, Context>(); ->template mutable_data<uint8_t, Context>();
CUDNN_CHECK(cudnnSetDropoutDescriptor( CUDNN_CHECK(cudnnSetDropoutDescriptor(
dropout_desc_, dropout_desc_,
ctx()->cudnn_handle(), cudnn_handle,
dropout_ratio_, dropout_ratio_,
states, states,
states_size_, states_size_,
rng_seed_)); rng_seed_));
} }
} }
#else
LOG(FATAL) << "Dropout has been supported since CuDNN 7.0";
#endif
} }
// Setup RNN // Setup RNN
...@@ -61,7 +57,6 @@ void CuDNNRecurrentOpBase<Context>::ResetDesc() { ...@@ -61,7 +57,6 @@ void CuDNNRecurrentOpBase<Context>::ResetDesc() {
} else if (input_type == TypeMeta::Id<double>()) { } else if (input_type == TypeMeta::Id<double>()) {
compute_type_ = CUDNN_DATA_DOUBLE; compute_type_ = CUDNN_DATA_DOUBLE;
} }
#if CUDNN_VERSION_MIN(7, 0, 0)
CUDNN_CHECK(cudnnSetRNNDescriptor_v6( CUDNN_CHECK(cudnnSetRNNDescriptor_v6(
ctx()->cudnn_handle(), ctx()->cudnn_handle(),
rnn_desc_, rnn_desc_,
...@@ -73,35 +68,22 @@ void CuDNNRecurrentOpBase<Context>::ResetDesc() { ...@@ -73,35 +68,22 @@ void CuDNNRecurrentOpBase<Context>::ResetDesc() {
rnn_mode_, rnn_mode_,
CUDNN_RNN_ALGO_STANDARD, CUDNN_RNN_ALGO_STANDARD,
compute_type_)); compute_type_));
#else
CUDNN_CHECK(cudnnSetRNNDescriptor(
rnn_desc_,
hidden_size_,
num_layers_,
dropout_desc_,
rnn_input_mode_,
rnn_direction_,
rnn_mode_,
compute_type_));
#endif
// Setup TensorCore // Setup TensorCore
#if CUDNN_VERSION_MIN(7, 0, 0) if (TENSOR_CORE_AVAILABLE()) {
if (enable_tensor_core_ > 0) {
cudnnMathType_t math_type; cudnnMathType_t math_type;
if (input_type == TypeMeta::Id<float16>()) { if (input_type == TypeMeta::Id<float16>()) {
math_type = CUDNN_TENSOR_OP_MATH; math_type = CUDNN_TENSOR_OP_MATH;
} else { } else {
math_type = CUDNN_DEFAULT_MATH; math_type = CUDNN_DEFAULT_MATH;
#if CUDNN_VERSION_MIN(8, 0, 0)
if (!CUDAContext::objects().cudnn_allow_tf32_) { if (!CUDAContext::objects().cudnn_allow_tf32_) {
#if CUDNN_VERSION_MIN(8, 0, 0)
math_type = CUDNN_FMA_MATH; math_type = CUDNN_FMA_MATH;
}
#endif #endif
} }
}
CUDNN_CHECK(cudnnSetRNNMatrixMathType(rnn_desc_, math_type)); CUDNN_CHECK(cudnnSetRNNMatrixMathType(rnn_desc_, math_type));
} }
#endif
// Setup X and Y // Setup X and Y
output_dims_ = {seq_length_, batch_size, y_dim}; output_dims_ = {seq_length_, batch_size, y_dim};
......
...@@ -100,9 +100,10 @@ class CuDNNConvOp final : public CuDNNConvOpBase<Context> { ...@@ -100,9 +100,10 @@ class CuDNNConvOp final : public CuDNNConvOpBase<Context> {
template <typename T> template <typename T>
void ResetDesc(); void ResetDesc();
size_t cudnn_ws_nbytes_; size_t cudnn_ws_size_;
vec64_t input_dims_, filter_dims_; vec64_t input_dims_, filter_dims_;
bool exhaustive_search_ = false; bool exhaustive_search_ = false;
bool algo_deterministic_ = false;
cudnnConvolutionFwdAlgo_t fwd_algo_; cudnnConvolutionFwdAlgo_t fwd_algo_;
cudnnTensorDescriptor_t input_desc_, output_desc_; cudnnTensorDescriptor_t input_desc_, output_desc_;
cudnnTensorDescriptor_t bias_desc_, output_desc_for_bias_; cudnnTensorDescriptor_t bias_desc_, output_desc_for_bias_;
...@@ -148,8 +149,10 @@ class CuDNNConvGradientOp final : public CuDNNConvOpBase<Context> { ...@@ -148,8 +149,10 @@ class CuDNNConvGradientOp final : public CuDNNConvOpBase<Context> {
template <typename T> template <typename T>
void ResetDesc(); void ResetDesc();
size_t cudnn_ws_nbytes_; size_t cudnn_ws_size_;
vec64_t input_dims_, filter_dims_; vec64_t input_dims_, filter_dims_;
bool data_algo_deterministic_ = false;
bool filter_algo_deterministic_ = false;
bool exhaustive_search_data_ = false; bool exhaustive_search_data_ = false;
bool exhaustive_search_filter_ = false; bool exhaustive_search_filter_ = false;
cudnnConvolutionBwdFilterAlgo_t bwd_filter_algo_; cudnnConvolutionBwdFilterAlgo_t bwd_filter_algo_;
......
...@@ -136,11 +136,6 @@ class CuDNNConvOpBase : public ConvOpBase<Context> { ...@@ -136,11 +136,6 @@ class CuDNNConvOpBase : public ConvOpBase<Context> {
CuDNNConvOpBase(const OperatorDef& def, Workspace* ws) CuDNNConvOpBase(const OperatorDef& def, Workspace* ws)
: ConvOpBase<Context>(def, ws) { : ConvOpBase<Context>(def, ws) {
GetBaseArguments(); GetBaseArguments();
#if CUDNN_VERSION_MIN(7, 0, 0)
group_v2_ = 1;
#else
group_v2_ = group_;
#endif
if (data_format() == "NCHW") { if (data_format() == "NCHW") {
tensor_format_ = CUDNN_TENSOR_NCHW; tensor_format_ = CUDNN_TENSOR_NCHW;
} else if (data_format() == "NHWC") { } else if (data_format() == "NHWC") {
...@@ -184,7 +179,6 @@ class CuDNNConvOpBase : public ConvOpBase<Context> { ...@@ -184,7 +179,6 @@ class CuDNNConvOpBase : public ConvOpBase<Context> {
CUDNN_CROSS_CORRELATION, CUDNN_CROSS_CORRELATION,
compute_type_)); compute_type_));
} }
#if CUDNN_VERSION_MIN(7, 0, 0)
CUDNN_CHECK(cudnnSetConvolutionGroupCount(conv_desc_, group_)); CUDNN_CHECK(cudnnSetConvolutionGroupCount(conv_desc_, group_));
if (TENSOR_CORE_AVAILABLE()) { if (TENSOR_CORE_AVAILABLE()) {
cudnnMathType_t math_type; cudnnMathType_t math_type;
...@@ -192,15 +186,14 @@ class CuDNNConvOpBase : public ConvOpBase<Context> { ...@@ -192,15 +186,14 @@ class CuDNNConvOpBase : public ConvOpBase<Context> {
math_type = CUDNN_TENSOR_OP_MATH; math_type = CUDNN_TENSOR_OP_MATH;
} else { } else {
math_type = CUDNN_DEFAULT_MATH; math_type = CUDNN_DEFAULT_MATH;
#if CUDNN_VERSION_MIN(8, 0, 0)
if (!CUDAContext::objects().cudnn_allow_tf32_) { if (!CUDAContext::objects().cudnn_allow_tf32_) {
#if CUDNN_VERSION_MIN(8, 0, 0)
math_type = CUDNN_FMA_MATH; math_type = CUDNN_FMA_MATH;
}
#endif #endif
} }
}
CUDNN_CHECK(cudnnSetConvolutionMathType(conv_desc_, math_type)); CUDNN_CHECK(cudnnSetConvolutionMathType(conv_desc_, math_type));
} }
#endif
} }
template <typename T> template <typename T>
...@@ -210,13 +203,12 @@ class CuDNNConvOpBase : public ConvOpBase<Context> { ...@@ -210,13 +203,12 @@ class CuDNNConvOpBase : public ConvOpBase<Context> {
filter_desc_, filter_desc_,
CuDNNType<T>::type, CuDNNType<T>::type,
tensor_format_, tensor_format_,
conv_out_channels_ / group_v2_, conv_out_channels_,
conv_in_channels_ / group_, conv_in_channels_ / group_,
kshape_[0], kshape_[0],
num_axes_ == 1 ? 1 : kshape_[1])); num_axes_ == 1 ? 1 : kshape_[1]));
} else { } else {
vec64_t dims = {conv_out_channels_ / group_v2_, vec64_t dims = {conv_out_channels_, conv_in_channels_ / group_};
conv_in_channels_ / group_};
dims.insert(dims.end(), kshape_.begin(), kshape_.end()); dims.insert(dims.end(), kshape_.begin(), kshape_.end());
CUDNN_CHECK(cudnnSetFilterNdDescriptor( CUDNN_CHECK(cudnnSetFilterNdDescriptor(
filter_desc_, filter_desc_,
...@@ -227,7 +219,6 @@ class CuDNNConvOpBase : public ConvOpBase<Context> { ...@@ -227,7 +219,6 @@ class CuDNNConvOpBase : public ConvOpBase<Context> {
} }
} }
int64_t group_v2_;
cudnnConvolutionDescriptor_t conv_desc_; cudnnConvolutionDescriptor_t conv_desc_;
cudnnFilterDescriptor_t filter_desc_; cudnnFilterDescriptor_t filter_desc_;
cudnnDataType_t compute_type_; cudnnDataType_t compute_type_;
...@@ -240,8 +231,7 @@ class CuDNNConvOpBase : public ConvOpBase<Context> { ...@@ -240,8 +231,7 @@ class CuDNNConvOpBase : public ConvOpBase<Context> {
using CuDNNConvOpBase<Context>::conv_desc_; \ using CuDNNConvOpBase<Context>::conv_desc_; \
using CuDNNConvOpBase<Context>::filter_desc_; \ using CuDNNConvOpBase<Context>::filter_desc_; \
using CuDNNConvOpBase<Context>::compute_type_; \ using CuDNNConvOpBase<Context>::compute_type_; \
using CuDNNConvOpBase<Context>::tensor_format_; \ using CuDNNConvOpBase<Context>::tensor_format_
using CuDNNConvOpBase<Context>::group_v2_;
#endif // USE_CUDNN #endif // USE_CUDNN
......
...@@ -15,8 +15,8 @@ void CuDNNConvOp<Context>::ResetDesc() { ...@@ -15,8 +15,8 @@ void CuDNNConvOp<Context>::ResetDesc() {
if (input_changed || filter_changed) { if (input_changed || filter_changed) {
if (input_changed) { if (input_changed) {
input_dims_ = X.dims(); input_dims_ = X.dims();
CuDNNSetTensorDesc<T>(&input_desc_, X.dims(), data_format(), group_v2_); CuDNNSetTensorDesc<T>(&input_desc_, X.dims(), data_format());
CuDNNSetTensorDesc<T>(&output_desc_, Y->dims(), data_format(), group_v2_); CuDNNSetTensorDesc<T>(&output_desc_, Y->dims(), data_format());
if (HasBias()) { if (HasBias()) {
CuDNNSetTensorDesc<T>(&output_desc_for_bias_, Y->dims(), data_format()); CuDNNSetTensorDesc<T>(&output_desc_for_bias_, Y->dims(), data_format());
} }
...@@ -33,10 +33,10 @@ void CuDNNConvOp<Context>::ResetDesc() { ...@@ -33,10 +33,10 @@ void CuDNNConvOp<Context>::ResetDesc() {
// Get or search the appropriate algorithm // Get or search the appropriate algorithm
if (CUDAContext::objects().cudnn_deterministic_) { if (CUDAContext::objects().cudnn_deterministic_) {
fwd_algo_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; fwd_algo_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
algo_deterministic_ = true;
} else if (CUDAContext::objects().cudnn_benchmark_) { } else if (CUDAContext::objects().cudnn_benchmark_) {
exhaustive_search_ = true; exhaustive_search_ = true;
} else { } else {
#if CUDNN_VERSION_MIN(7, 0, 0)
int num_valid_algos; int num_valid_algos;
constexpr int num_algos = CUDNN_CONV_NUM_FWD_ALGOS; constexpr int num_algos = CUDNN_CONV_NUM_FWD_ALGOS;
cudnnConvolutionFwdAlgoPerf_t stats[num_algos]; cudnnConvolutionFwdAlgoPerf_t stats[num_algos];
...@@ -49,30 +49,21 @@ void CuDNNConvOp<Context>::ResetDesc() { ...@@ -49,30 +49,21 @@ void CuDNNConvOp<Context>::ResetDesc() {
num_algos, num_algos,
&num_valid_algos, &num_valid_algos,
stats)); stats));
bool algo_is_found = false; bool algo_found = false;
for (int i = 0; i < num_valid_algos; ++i) { for (int i = 0; i < num_valid_algos; ++i) {
if (stats[i].memory <= CUDNN_CONV_WORKSPACE_LIMIT_BYTES) { if (stats[i].memory <= CUDNN_CONV_WORKSPACE_LIMIT_BYTES) {
fwd_algo_ = stats[i].algo; fwd_algo_ = stats[i].algo;
algo_is_found = true; algo_found = true;
algo_deterministic_ = false;
break; break;
} }
} }
CHECK(algo_is_found) if (!algo_found) {
<< "\nNo algorithms available for <cudnnConvolutionForward> " fwd_algo_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
<< "under the current desc and workspace limit."; algo_deterministic_ = true;
#else }
CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm(
ctx()->cudnn_handle(),
input_desc_,
filter_desc_,
conv_desc_,
output_desc_,
CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
CUDNN_CONV_WORKSPACE_LIMIT_BYTES,
&fwd_algo_));
#endif // CUDNN_VERSION_MIN(7, 0, 0)
} }
cudnn_ws_nbytes_ = SIZE_MAX; // Request a new size cudnn_ws_size_ = SIZE_MAX; // Request a new size
} }
} }
...@@ -80,18 +71,15 @@ template <class Context> ...@@ -80,18 +71,15 @@ template <class Context>
template <typename T> template <typename T>
void CuDNNConvOp<Context>::DoRunWithType() { void CuDNNConvOp<Context>::DoRunWithType() {
auto &X = Input(0), &W = Input(1); auto &X = Input(0), &W = Input(1);
TENSOR_FILL(W, w_shape_); TENSOR_FILL(W, w_shape_);
if (HasBias()) { if (HasBias()) {
TENSOR_FILL(Input(2), b_shape_); TENSOR_FILL(Input(2), b_shape_);
} }
ResetDesc<T>(); ResetDesc<T>();
auto* x = X.template data<T, Context>(); auto* x = X.template data<T, Context>();
auto* w = W.template data<T, Context>(); auto* w = W.template data<T, Context>();
auto* y = Output(0)->template mutable_data<T, Context>(); auto* y = Output(0)->template mutable_data<T, Context>();
void* scratch = nullptr; // workspace buffer void* scratch = nullptr; // workspace buffer
// Find the appropriate algorithm if necessary // Find the appropriate algorithm if necessary
...@@ -120,10 +108,11 @@ void CuDNNConvOp<Context>::DoRunWithType() { ...@@ -120,10 +108,11 @@ void CuDNNConvOp<Context>::DoRunWithType() {
}); });
exhaustive_search_ = false; exhaustive_search_ = false;
fwd_algo_ = std::get<0>(algo); fwd_algo_ = std::get<0>(algo);
algo_deterministic_ = false;
} }
// Determine the workspace size for selected algorithm // Determine the workspace size for selected algorithm
if (cudnn_ws_nbytes_ == SIZE_MAX) { if (cudnn_ws_size_ == SIZE_MAX) {
auto algo_status = CUDNN_STATUS_SUCCESS; auto algo_status = CUDNN_STATUS_SUCCESS;
for (int step = 0; step < 2; ++step) { for (int step = 0; step < 2; ++step) {
algo_status = cudnnGetConvolutionForwardWorkspaceSize( algo_status = cudnnGetConvolutionForwardWorkspaceSize(
...@@ -133,9 +122,9 @@ void CuDNNConvOp<Context>::DoRunWithType() { ...@@ -133,9 +122,9 @@ void CuDNNConvOp<Context>::DoRunWithType() {
conv_desc_, conv_desc_,
output_desc_, output_desc_,
fwd_algo_, fwd_algo_,
&cudnn_ws_nbytes_); &cudnn_ws_size_);
if (algo_status != CUDNN_STATUS_SUCCESS && step == 0 && if (algo_status != CUDNN_STATUS_SUCCESS && step == 0 &&
CUDAContext::objects().cudnn_deterministic_) { algo_deterministic_) {
fwd_algo_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM; fwd_algo_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
} else { } else {
CUDNN_CHECK(algo_status); CUDNN_CHECK(algo_status);
...@@ -144,26 +133,24 @@ void CuDNNConvOp<Context>::DoRunWithType() { ...@@ -144,26 +133,24 @@ void CuDNNConvOp<Context>::DoRunWithType() {
} }
// Alloc the memory for workspace data // Alloc the memory for workspace data
if (cudnn_ws_nbytes_ > 0) { if (cudnn_ws_size_ > 0) {
scratch = ctx()->workspace()->template data<Context>({cudnn_ws_nbytes_})[0]; scratch = ctx()->workspace()->template data<Context>({cudnn_ws_size_})[0];
} }
for (int g = 0; g < group_v2_; g++) {
CUDNN_CHECK(cudnnConvolutionForward( CUDNN_CHECK(cudnnConvolutionForward(
ctx()->cudnn_handle(), ctx()->cudnn_handle(),
CuDNNType<T>::one, CuDNNType<T>::one,
input_desc_, input_desc_,
x + X_stride_ * g, x,
filter_desc_, filter_desc_,
w + W_stride_ * g, w,
conv_desc_, conv_desc_,
fwd_algo_, fwd_algo_,
scratch, scratch,
cudnn_ws_nbytes_, cudnn_ws_size_,
CuDNNType<T>::zero, CuDNNType<T>::zero,
output_desc_, output_desc_,
y + Y_stride_ * g)); y));
}
if (HasBias()) { if (HasBias()) {
auto* b = Input(2).template data<T, Context>(); auto* b = Input(2).template data<T, Context>();
...@@ -181,13 +168,6 @@ void CuDNNConvOp<Context>::DoRunWithType() { ...@@ -181,13 +168,6 @@ void CuDNNConvOp<Context>::DoRunWithType() {
template <class Context> template <class Context>
void CuDNNConvOp<Context>::RunOnDevice() { void CuDNNConvOp<Context>::RunOnDevice() {
ConvOpBase<Context>::Reshape(); ConvOpBase<Context>::Reshape();
if (data_format() == "NCHW") {
X_stride_ = Input(0).stride(0) / group_v2_;
Y_stride_ = Output(0)->stride(0) / group_v2_;
} else if (data_format() == "NHWC") {
X_stride_ = Input(0).dim(-1) / group_v2_;
Y_stride_ = Output(0)->dim(-1) / group_v2_;
}
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0)); DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
} }
...@@ -200,8 +180,8 @@ void CuDNNConvGradientOp<Context>::ResetDesc() { ...@@ -200,8 +180,8 @@ void CuDNNConvGradientOp<Context>::ResetDesc() {
if (input_changed || filter_changed) { if (input_changed || filter_changed) {
if (input_changed) { if (input_changed) {
input_dims_ = X.dims(); input_dims_ = X.dims();
CuDNNSetTensorDesc<T>(&input_desc_, dY.dims(), data_format(), group_v2_); CuDNNSetTensorDesc<T>(&input_desc_, dY.dims(), data_format());
CuDNNSetTensorDesc<T>(&output_desc_, X.dims(), data_format(), group_v2_); CuDNNSetTensorDesc<T>(&output_desc_, X.dims(), data_format());
if (HasBias()) { if (HasBias()) {
CuDNNSetTensorDesc<T>(&input_desc_for_bias_, dY.dims(), data_format()); CuDNNSetTensorDesc<T>(&input_desc_for_bias_, dY.dims(), data_format());
} }
...@@ -219,11 +199,12 @@ void CuDNNConvGradientOp<Context>::ResetDesc() { ...@@ -219,11 +199,12 @@ void CuDNNConvGradientOp<Context>::ResetDesc() {
if (CUDAContext::objects().cudnn_deterministic_) { if (CUDAContext::objects().cudnn_deterministic_) {
bwd_data_algo_ = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; bwd_data_algo_ = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
bwd_filter_algo_ = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1; bwd_filter_algo_ = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
data_algo_deterministic_ = true;
filter_algo_deterministic_ = true;
} else if (CUDAContext::objects().cudnn_benchmark_) { } else if (CUDAContext::objects().cudnn_benchmark_) {
exhaustive_search_data_ = true; exhaustive_search_data_ = true;
exhaustive_search_filter_ = true; exhaustive_search_filter_ = true;
} else { } else {
#if CUDNN_VERSION_MIN(7, 0, 0)
{ {
int num_valid_algos; int num_valid_algos;
constexpr int num_algos = CUDNN_CONV_NUM_BWD_FILTER_ALGOS; constexpr int num_algos = CUDNN_CONV_NUM_BWD_FILTER_ALGOS;
...@@ -237,17 +218,19 @@ void CuDNNConvGradientOp<Context>::ResetDesc() { ...@@ -237,17 +218,19 @@ void CuDNNConvGradientOp<Context>::ResetDesc() {
num_algos, num_algos,
&num_valid_algos, &num_valid_algos,
stats)); stats));
bool algo_is_found = false; bool algo_found = false;
for (int i = 0; i < num_valid_algos; ++i) { for (int i = 0; i < num_valid_algos; ++i) {
if (stats[i].memory <= CUDNN_CONV_WORKSPACE_LIMIT_BYTES) { if (stats[i].memory <= CUDNN_CONV_WORKSPACE_LIMIT_BYTES) {
algo_found = true;
bwd_filter_algo_ = stats[i].algo; bwd_filter_algo_ = stats[i].algo;
algo_is_found = true; filter_algo_deterministic_ = false;
break; break;
} }
} }
CHECK(algo_is_found) if (!algo_found) {
<< "\nNo algorithms available for <cudnnConvolutionBackwardFilter> " bwd_filter_algo_ = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
<< "under the current desc and workspace limit."; filter_algo_deterministic_ = true;
}
} }
{ {
int num_valid_algos; int num_valid_algos;
...@@ -262,40 +245,22 @@ void CuDNNConvGradientOp<Context>::ResetDesc() { ...@@ -262,40 +245,22 @@ void CuDNNConvGradientOp<Context>::ResetDesc() {
num_algos, num_algos,
&num_valid_algos, &num_valid_algos,
stats)); stats));
bool algo_is_found = false; bool algo_found = false;
for (int i = 0; i < num_valid_algos; ++i) { for (int i = 0; i < num_valid_algos; ++i) {
if (stats[i].memory <= CUDNN_CONV_WORKSPACE_LIMIT_BYTES) { if (stats[i].memory <= CUDNN_CONV_WORKSPACE_LIMIT_BYTES) {
bwd_data_algo_ = stats[i].algo; bwd_data_algo_ = stats[i].algo;
algo_is_found = true; algo_found = true;
data_algo_deterministic_ = false;
break; break;
} }
} }
CHECK(algo_is_found) if (!algo_found) {
<< "\nNo algorithms available for <cudnnConvolutionBackwardData> " bwd_data_algo_ = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
<< "under the current desc and workspace limit."; data_algo_deterministic_ = true;
}
} }
#else
CUDNN_CHECK(cudnnGetConvolutionBackwardFilterAlgorithm(
ctx()->cudnn_handle(),
output_desc_,
input_desc_,
conv_desc_,
filter_desc_,
CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
CUDNN_CONV_WORKSPACE_LIMIT_BYTES,
&bwd_filter_algo_));
CUDNN_CHECK(cudnnGetConvolutionBackwardDataAlgorithm(
ctx()->cudnn_handle(),
filter_desc_,
input_desc_,
conv_desc_,
output_desc_,
CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
CUDNN_CONV_WORKSPACE_LIMIT_BYTES,
&bwd_data_algo_));
#endif // CUDNN_VERSION_MIN(7, 0, 0)
} }
cudnn_ws_nbytes_ = SIZE_MAX; // Request a new size cudnn_ws_size_ = SIZE_MAX; // Request a new size
} }
} }
...@@ -340,6 +305,7 @@ void CuDNNConvGradientOp<Context>::DoRunWithType() { ...@@ -340,6 +305,7 @@ void CuDNNConvGradientOp<Context>::DoRunWithType() {
}); });
exhaustive_search_filter_ = false; exhaustive_search_filter_ = false;
bwd_filter_algo_ = std::get<0>(algo); bwd_filter_algo_ = std::get<0>(algo);
filter_algo_deterministic_ = false;
} }
if (dX->has_name() && exhaustive_search_data_) { if (dX->has_name() && exhaustive_search_data_) {
...@@ -369,10 +335,11 @@ void CuDNNConvGradientOp<Context>::DoRunWithType() { ...@@ -369,10 +335,11 @@ void CuDNNConvGradientOp<Context>::DoRunWithType() {
}); });
exhaustive_search_data_ = false; exhaustive_search_data_ = false;
bwd_data_algo_ = std::get<0>(algo); bwd_data_algo_ = std::get<0>(algo);
data_algo_deterministic_ = false;
} }
// Determine the workspace size for selected algorithm // Determine the workspace size for selected algorithm
if (cudnn_ws_nbytes_ == SIZE_MAX) { if (cudnn_ws_size_ == SIZE_MAX) {
auto algo_status = CUDNN_STATUS_SUCCESS; auto algo_status = CUDNN_STATUS_SUCCESS;
size_t bwd_filter_size = 0, bwd_data_size = 0; size_t bwd_filter_size = 0, bwd_data_size = 0;
for (int step = 0; step < 2; ++step) { for (int step = 0; step < 2; ++step) {
...@@ -385,7 +352,7 @@ void CuDNNConvGradientOp<Context>::DoRunWithType() { ...@@ -385,7 +352,7 @@ void CuDNNConvGradientOp<Context>::DoRunWithType() {
bwd_filter_algo_, bwd_filter_algo_,
&bwd_filter_size); &bwd_filter_size);
if (algo_status != CUDNN_STATUS_SUCCESS && step == 0 && if (algo_status != CUDNN_STATUS_SUCCESS && step == 0 &&
CUDAContext::objects().cudnn_deterministic_) { filter_algo_deterministic_) {
bwd_filter_algo_ = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0; bwd_filter_algo_ = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0;
} else { } else {
CUDNN_CHECK(algo_status); CUDNN_CHECK(algo_status);
...@@ -401,18 +368,18 @@ void CuDNNConvGradientOp<Context>::DoRunWithType() { ...@@ -401,18 +368,18 @@ void CuDNNConvGradientOp<Context>::DoRunWithType() {
bwd_data_algo_, bwd_data_algo_,
&bwd_data_size); &bwd_data_size);
if (algo_status != CUDNN_STATUS_SUCCESS && step == 0 && if (algo_status != CUDNN_STATUS_SUCCESS && step == 0 &&
CUDAContext::objects().cudnn_deterministic_) { data_algo_deterministic_) {
bwd_data_algo_ = CUDNN_CONVOLUTION_BWD_DATA_ALGO_0; bwd_data_algo_ = CUDNN_CONVOLUTION_BWD_DATA_ALGO_0;
} else { } else {
CUDNN_CHECK(algo_status); CUDNN_CHECK(algo_status);
} }
} }
cudnn_ws_nbytes_ = std::max(bwd_filter_size, bwd_data_size); cudnn_ws_size_ = std::max(bwd_filter_size, bwd_data_size);
} }
// Alloc the memory for workspace data // Alloc the memory for workspace data
if (cudnn_ws_nbytes_ > 0) { if (cudnn_ws_size_ > 0) {
scratch = ctx()->workspace()->template data<Context>({cudnn_ws_nbytes_})[0]; scratch = ctx()->workspace()->template data<Context>({cudnn_ws_size_})[0];
} }
if (Output(2)->has_name()) { if (Output(2)->has_name()) {
...@@ -430,56 +397,45 @@ void CuDNNConvGradientOp<Context>::DoRunWithType() { ...@@ -430,56 +397,45 @@ void CuDNNConvGradientOp<Context>::DoRunWithType() {
if (dW->has_name()) { if (dW->has_name()) {
x = X.template data<T, Context>(); x = X.template data<T, Context>();
dw = dW->template mutable_data<T, Context>(); dw = dW->template mutable_data<T, Context>();
for (int g = 0; g < group_v2_; g++) {
CUDNN_CHECK(cudnnConvolutionBackwardFilter( CUDNN_CHECK(cudnnConvolutionBackwardFilter(
ctx()->cudnn_handle(), ctx()->cudnn_handle(),
CuDNNType<T>::one, CuDNNType<T>::one,
output_desc_, output_desc_,
x + X_stride_ * g, x,
input_desc_, input_desc_,
dy + Y_stride_ * g, dy,
conv_desc_, conv_desc_,
bwd_filter_algo_, bwd_filter_algo_,
scratch, scratch,
cudnn_ws_nbytes_, cudnn_ws_size_,
CuDNNType<T>::zero, CuDNNType<T>::zero,
filter_desc_, filter_desc_,
dw + W_stride_ * g)); dw));
}
} }
if (dX->has_name()) { if (dX->has_name()) {
w = W.template data<T, Context>(); w = W.template data<T, Context>();
dx = dX->template mutable_data<T, Context>(); dx = dX->template mutable_data<T, Context>();
for (int g = 0; g < group_v2_; g++) {
CUDNN_CHECK(cudnnConvolutionBackwardData( CUDNN_CHECK(cudnnConvolutionBackwardData(
ctx()->cudnn_handle(), ctx()->cudnn_handle(),
CuDNNType<T>::one, CuDNNType<T>::one,
filter_desc_, filter_desc_,
w + W_stride_ * g, w,
input_desc_, input_desc_,
dy + Y_stride_ * g, dy,
conv_desc_, conv_desc_,
bwd_data_algo_, bwd_data_algo_,
scratch, scratch,
cudnn_ws_nbytes_, cudnn_ws_size_,
CuDNNType<T>::zero, CuDNNType<T>::zero,
output_desc_, output_desc_,
dx + X_stride_ * g)); dx));
}
} }
} }
template <class Context> template <class Context>
void CuDNNConvGradientOp<Context>::RunOnDevice() { void CuDNNConvGradientOp<Context>::RunOnDevice() {
ConvOpBase<Context>::Reshape(true); ConvOpBase<Context>::Reshape(true);
if (data_format() == "NCHW") {
X_stride_ = Input(0).stride(0) / group_v2_;
Y_stride_ = Input(-1).stride(0) / group_v2_;
} else if (data_format() == "NHWC") {
X_stride_ = Input(0).dim(-1) / group_v2_;
Y_stride_ = Input(-1).dim(-1) / group_v2_;
}
DispatchHelper<FloatingTensorTypes>::Call(this, Input(-1)); DispatchHelper<FloatingTensorTypes>::Call(this, Input(-1));
} }
......
...@@ -112,9 +112,10 @@ class CuDNNConvTransposeOp final : public CuDNNConvOpBase<Context> { ...@@ -112,9 +112,10 @@ class CuDNNConvTransposeOp final : public CuDNNConvOpBase<Context> {
template <typename T> template <typename T>
void ResetDesc(); void ResetDesc();
size_t cudnn_ws_nbytes_; size_t cudnn_ws_size_;
vec64_t input_dims_, filter_dims_; vec64_t input_dims_, filter_dims_;
bool exhaustive_search_ = false; bool exhaustive_search_ = false;
bool algo_deterministic_ = false;
cudnnConvolutionBwdDataAlgo_t fwd_algo_; cudnnConvolutionBwdDataAlgo_t fwd_algo_;
cudnnTensorDescriptor_t input_desc_, output_desc_; cudnnTensorDescriptor_t input_desc_, output_desc_;
cudnnTensorDescriptor_t bias_desc_, output_desc_for_bias_; cudnnTensorDescriptor_t bias_desc_, output_desc_for_bias_;
...@@ -164,10 +165,12 @@ class CuDNNConvTransposeGradientOp final : public CuDNNConvOpBase<Context> { ...@@ -164,10 +165,12 @@ class CuDNNConvTransposeGradientOp final : public CuDNNConvOpBase<Context> {
template <typename T> template <typename T>
void ResetDesc(); void ResetDesc();
size_t cudnn_ws_nbytes_; size_t cudnn_ws_size_;
vec64_t input_dims_, filter_dims_; vec64_t input_dims_, filter_dims_;
bool exhaustive_search_data_ = false; bool exhaustive_search_data_ = false;
bool exhaustive_search_filter_ = false; bool exhaustive_search_filter_ = false;
bool data_algo_deterministic_ = false;
bool filter_algo_deterministic_ = false;
cudnnConvolutionBwdFilterAlgo_t bwd_filter_algo_; cudnnConvolutionBwdFilterAlgo_t bwd_filter_algo_;
cudnnConvolutionFwdAlgo_t bwd_data_algo_; cudnnConvolutionFwdAlgo_t bwd_data_algo_;
cudnnTensorDescriptor_t input_desc_, output_desc_; cudnnTensorDescriptor_t input_desc_, output_desc_;
......
...@@ -15,8 +15,8 @@ void CuDNNConvTransposeOp<Context>::ResetDesc() { ...@@ -15,8 +15,8 @@ void CuDNNConvTransposeOp<Context>::ResetDesc() {
if (input_changed || filter_changed) { if (input_changed || filter_changed) {
if (input_changed) { if (input_changed) {
input_dims_ = X.dims(); input_dims_ = X.dims();
CuDNNSetTensorDesc<T>(&input_desc_, X.dims(), data_format(), group_v2_); CuDNNSetTensorDesc<T>(&input_desc_, X.dims(), data_format());
CuDNNSetTensorDesc<T>(&output_desc_, Y->dims(), data_format(), group_v2_); CuDNNSetTensorDesc<T>(&output_desc_, Y->dims(), data_format());
if (HasBias()) { if (HasBias()) {
CuDNNSetTensorDesc<T>(&output_desc_for_bias_, Y->dims(), data_format()); CuDNNSetTensorDesc<T>(&output_desc_for_bias_, Y->dims(), data_format());
} }
...@@ -33,10 +33,10 @@ void CuDNNConvTransposeOp<Context>::ResetDesc() { ...@@ -33,10 +33,10 @@ void CuDNNConvTransposeOp<Context>::ResetDesc() {
// Get or search the appropriate algorithm // Get or search the appropriate algorithm
if (CUDAContext::objects().cudnn_deterministic_) { if (CUDAContext::objects().cudnn_deterministic_) {
fwd_algo_ = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; fwd_algo_ = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
algo_deterministic_ = true;
} else if (CUDAContext::objects().cudnn_benchmark_) { } else if (CUDAContext::objects().cudnn_benchmark_) {
exhaustive_search_ = true; exhaustive_search_ = true;
} else { } else {
#if CUDNN_VERSION_MIN(7, 0, 0)
int num_valid_algos; int num_valid_algos;
constexpr int num_algos = CUDNN_CONV_NUM_BWD_DATA_ALGOS; constexpr int num_algos = CUDNN_CONV_NUM_BWD_DATA_ALGOS;
cudnnConvolutionBwdDataAlgoPerf_t stats[num_algos]; cudnnConvolutionBwdDataAlgoPerf_t stats[num_algos];
...@@ -49,30 +49,21 @@ void CuDNNConvTransposeOp<Context>::ResetDesc() { ...@@ -49,30 +49,21 @@ void CuDNNConvTransposeOp<Context>::ResetDesc() {
num_algos, num_algos,
&num_valid_algos, &num_valid_algos,
stats)); stats));
bool algo_is_found = false; bool algo_found = false;
for (int i = 0; i < num_valid_algos; ++i) { for (int i = 0; i < num_valid_algos; ++i) {
if (stats[i].memory <= CUDNN_CONV_WORKSPACE_LIMIT_BYTES) { if (stats[i].memory <= CUDNN_CONV_WORKSPACE_LIMIT_BYTES) {
fwd_algo_ = stats[i].algo; fwd_algo_ = stats[i].algo;
algo_is_found = true; algo_found = true;
algo_deterministic_ = false;
break; break;
} }
} }
CHECK(algo_is_found) if (!algo_found) {
<< "\nNo algorithms available for <cudnnConvolutionBackwardData> " fwd_algo_ = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
<< "under the current desc and workspace limit."; algo_deterministic_ = true;
#else }
CUDNN_CHECK(cudnnGetConvolutionBackwardDataAlgorithm(
ctx()->cudnn_handle(),
filter_desc_,
input_desc_,
conv_desc_,
output_desc_,
CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
CUDNN_CONV_WORKSPACE_LIMIT_BYTES,
&fwd_algo_));
#endif // CUDNN_VERSION_MIN(7, 0, 0)
} }
cudnn_ws_nbytes_ = SIZE_MAX; // Request a new size cudnn_ws_size_ = SIZE_MAX; // Request a new size
} }
} }
...@@ -80,12 +71,10 @@ template <class Context> ...@@ -80,12 +71,10 @@ template <class Context>
template <typename T> template <typename T>
void CuDNNConvTransposeOp<Context>::DoRunWithType() { void CuDNNConvTransposeOp<Context>::DoRunWithType() {
auto &X = Input(0), &W = Input(1); auto &X = Input(0), &W = Input(1);
TENSOR_FILL(W, w_shape_); TENSOR_FILL(W, w_shape_);
if (HasBias()) { if (HasBias()) {
TENSOR_FILL(Input(2), b_shape_); TENSOR_FILL(Input(2), b_shape_);
} }
ResetDesc<T>(); ResetDesc<T>();
auto* x = X.template data<T, Context>(); auto* x = X.template data<T, Context>();
...@@ -120,10 +109,11 @@ void CuDNNConvTransposeOp<Context>::DoRunWithType() { ...@@ -120,10 +109,11 @@ void CuDNNConvTransposeOp<Context>::DoRunWithType() {
}); });
exhaustive_search_ = false; exhaustive_search_ = false;
fwd_algo_ = std::get<0>(algo); fwd_algo_ = std::get<0>(algo);
algo_deterministic_ = false;
} }
// Determine the workspace size for selected algorithm // Determine the workspace size for selected algorithm
if (cudnn_ws_nbytes_ == SIZE_MAX) { if (cudnn_ws_size_ == SIZE_MAX) {
auto algo_status = CUDNN_STATUS_SUCCESS; auto algo_status = CUDNN_STATUS_SUCCESS;
for (int step = 0; step < 2; ++step) { for (int step = 0; step < 2; ++step) {
algo_status = cudnnGetConvolutionBackwardDataWorkspaceSize( algo_status = cudnnGetConvolutionBackwardDataWorkspaceSize(
...@@ -133,9 +123,9 @@ void CuDNNConvTransposeOp<Context>::DoRunWithType() { ...@@ -133,9 +123,9 @@ void CuDNNConvTransposeOp<Context>::DoRunWithType() {
conv_desc_, conv_desc_,
output_desc_, output_desc_,
fwd_algo_, fwd_algo_,
&cudnn_ws_nbytes_); &cudnn_ws_size_);
if (algo_status != CUDNN_STATUS_SUCCESS && step == 0 && if (algo_status != CUDNN_STATUS_SUCCESS && step == 0 &&
CUDAContext::objects().cudnn_deterministic_) { algo_deterministic_) {
fwd_algo_ = CUDNN_CONVOLUTION_BWD_DATA_ALGO_0; fwd_algo_ = CUDNN_CONVOLUTION_BWD_DATA_ALGO_0;
} else { } else {
CUDNN_CHECK(algo_status); CUDNN_CHECK(algo_status);
...@@ -144,26 +134,24 @@ void CuDNNConvTransposeOp<Context>::DoRunWithType() { ...@@ -144,26 +134,24 @@ void CuDNNConvTransposeOp<Context>::DoRunWithType() {
} }
// Alloc the memory for workspace data // Alloc the memory for workspace data
if (cudnn_ws_nbytes_ > 0) { if (cudnn_ws_size_ > 0) {
scratch = ctx()->workspace()->template data<Context>({cudnn_ws_nbytes_})[0]; scratch = ctx()->workspace()->template data<Context>({cudnn_ws_size_})[0];
} }
for (int g = 0; g < group_v2_; g++) {
CUDNN_CHECK(cudnnConvolutionBackwardData( CUDNN_CHECK(cudnnConvolutionBackwardData(
ctx()->cudnn_handle(), ctx()->cudnn_handle(),
CuDNNType<T>::one, CuDNNType<T>::one,
filter_desc_, filter_desc_,
w + W_stride_ * g, w,
input_desc_, input_desc_,
x + X_stride_ * g, x,
conv_desc_, conv_desc_,
fwd_algo_, fwd_algo_,
scratch, scratch,
cudnn_ws_nbytes_, cudnn_ws_size_,
CuDNNType<T>::zero, CuDNNType<T>::zero,
output_desc_, output_desc_,
y + Y_stride_ * g)); y));
}
if (HasBias()) { if (HasBias()) {
auto* b = Input(2).template data<T, Context>(); auto* b = Input(2).template data<T, Context>();
...@@ -181,13 +169,6 @@ void CuDNNConvTransposeOp<Context>::DoRunWithType() { ...@@ -181,13 +169,6 @@ void CuDNNConvTransposeOp<Context>::DoRunWithType() {
template <class Context> template <class Context>
void CuDNNConvTransposeOp<Context>::RunOnDevice() { void CuDNNConvTransposeOp<Context>::RunOnDevice() {
ConvOpBase<Context>::Reshape(); ConvOpBase<Context>::Reshape();
if (data_format() == "NCHW") {
X_stride_ = Input(0).stride(0) / group_v2_;
Y_stride_ = Output(0)->stride(0) / group_v2_;
} else if (data_format() == "NHWC") {
X_stride_ = Input(0).dim(-1) / group_v2_;
Y_stride_ = Output(0)->dim(-1) / group_v2_;
}
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0)); DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
} }
...@@ -200,8 +181,8 @@ void CuDNNConvTransposeGradientOp<Context>::ResetDesc() { ...@@ -200,8 +181,8 @@ void CuDNNConvTransposeGradientOp<Context>::ResetDesc() {
if (input_changed || filter_changed) { if (input_changed || filter_changed) {
if (input_changed) { if (input_changed) {
input_dims_ = X.dims(); input_dims_ = X.dims();
CuDNNSetTensorDesc<T>(&input_desc_, dY.dims(), data_format(), group_v2_); CuDNNSetTensorDesc<T>(&input_desc_, dY.dims(), data_format());
CuDNNSetTensorDesc<T>(&output_desc_, X.dims(), data_format(), group_v2_); CuDNNSetTensorDesc<T>(&output_desc_, X.dims(), data_format());
if (HasBias()) { if (HasBias()) {
CuDNNSetTensorDesc<T>(&input_desc_for_bias_, dY.dims(), data_format()); CuDNNSetTensorDesc<T>(&input_desc_for_bias_, dY.dims(), data_format());
} }
...@@ -219,11 +200,12 @@ void CuDNNConvTransposeGradientOp<Context>::ResetDesc() { ...@@ -219,11 +200,12 @@ void CuDNNConvTransposeGradientOp<Context>::ResetDesc() {
if (CUDAContext::objects().cudnn_deterministic_) { if (CUDAContext::objects().cudnn_deterministic_) {
bwd_data_algo_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; bwd_data_algo_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
bwd_filter_algo_ = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1; bwd_filter_algo_ = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
data_algo_deterministic_ = true;
filter_algo_deterministic_ = true;
} else if (CUDAContext::objects().cudnn_benchmark_) { } else if (CUDAContext::objects().cudnn_benchmark_) {
exhaustive_search_data_ = true; exhaustive_search_data_ = true;
exhaustive_search_filter_ = true; exhaustive_search_filter_ = true;
} else { } else {
#if CUDNN_VERSION_MIN(7, 0, 0)
{ {
int num_valid_algos; int num_valid_algos;
constexpr int num_algos = CUDNN_CONV_NUM_BWD_FILTER_ALGOS; constexpr int num_algos = CUDNN_CONV_NUM_BWD_FILTER_ALGOS;
...@@ -237,17 +219,19 @@ void CuDNNConvTransposeGradientOp<Context>::ResetDesc() { ...@@ -237,17 +219,19 @@ void CuDNNConvTransposeGradientOp<Context>::ResetDesc() {
num_algos, num_algos,
&num_valid_algos, &num_valid_algos,
stats)); stats));
bool algo_is_found = false; bool algo_found = false;
for (int i = 0; i < num_valid_algos; ++i) { for (int i = 0; i < num_valid_algos; ++i) {
if (stats[i].memory <= CUDNN_CONV_WORKSPACE_LIMIT_BYTES) { if (stats[i].memory <= CUDNN_CONV_WORKSPACE_LIMIT_BYTES) {
bwd_filter_algo_ = stats[i].algo; bwd_filter_algo_ = stats[i].algo;
algo_is_found = true; algo_found = true;
filter_algo_deterministic_ = false;
break; break;
} }
} }
CHECK(algo_is_found) if (!algo_found) {
<< "\nNo algorithms available for <cudnnConvolutionBackwardFilter> " bwd_filter_algo_ = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
<< "under the current desc and workspace limit."; filter_algo_deterministic_ = true;
}
} }
{ {
int num_valid_algos; int num_valid_algos;
...@@ -262,40 +246,22 @@ void CuDNNConvTransposeGradientOp<Context>::ResetDesc() { ...@@ -262,40 +246,22 @@ void CuDNNConvTransposeGradientOp<Context>::ResetDesc() {
num_algos, num_algos,
&num_valid_algos, &num_valid_algos,
stats)); stats));
bool algo_is_found = false; bool algo_found = false;
for (int i = 0; i < num_valid_algos; ++i) { for (int i = 0; i < num_valid_algos; ++i) {
if (stats[i].memory <= CUDNN_CONV_WORKSPACE_LIMIT_BYTES) { if (stats[i].memory <= CUDNN_CONV_WORKSPACE_LIMIT_BYTES) {
bwd_data_algo_ = stats[i].algo; bwd_data_algo_ = stats[i].algo;
algo_is_found = true; algo_found = true;
data_algo_deterministic_ = false;
break; break;
} }
} }
CHECK(algo_is_found) if (!algo_found) {
<< "\nNo algorithms available for <cudnnConvolutionForward> " bwd_data_algo_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
<< "under the current desc and workspace limit."; data_algo_deterministic_ = true;
} }
#else
CUDNN_CHECK(cudnnGetConvolutionBackwardFilterAlgorithm(
ctx()->cudnn_handle(),
input_desc_,
output_desc_,
conv_desc_,
filter_desc_,
CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
CUDNN_CONV_WORKSPACE_LIMIT_BYTES,
&bwd_filter_algo_));
CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm(
ctx()->cudnn_handle(),
input_desc_,
filter_desc_,
conv_desc_,
output_desc_,
CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
CUDNN_CONV_WORKSPACE_LIMIT_BYTES,
&bwd_data_algo_));
#endif // CUDNN_VERSION_MIN(7, 0, 0)
} }
cudnn_ws_nbytes_ = SIZE_MAX; // Request a new size }
cudnn_ws_size_ = SIZE_MAX; // Request a new size
} }
} }
...@@ -340,6 +306,7 @@ void CuDNNConvTransposeGradientOp<Context>::DoRunWithType() { ...@@ -340,6 +306,7 @@ void CuDNNConvTransposeGradientOp<Context>::DoRunWithType() {
}); });
exhaustive_search_filter_ = false; exhaustive_search_filter_ = false;
bwd_filter_algo_ = std::get<0>(algo); bwd_filter_algo_ = std::get<0>(algo);
filter_algo_deterministic_ = false;
} }
if (dX->has_name() && exhaustive_search_data_) { if (dX->has_name() && exhaustive_search_data_) {
...@@ -369,10 +336,11 @@ void CuDNNConvTransposeGradientOp<Context>::DoRunWithType() { ...@@ -369,10 +336,11 @@ void CuDNNConvTransposeGradientOp<Context>::DoRunWithType() {
}); });
exhaustive_search_data_ = false; exhaustive_search_data_ = false;
bwd_data_algo_ = std::get<0>(algo); bwd_data_algo_ = std::get<0>(algo);
data_algo_deterministic_ = false;
} }
// Determine the workspace size for selected algorithm // Determine the workspace size for selected algorithm
if (cudnn_ws_nbytes_ == SIZE_MAX) { if (cudnn_ws_size_ == SIZE_MAX) {
auto algo_status = CUDNN_STATUS_SUCCESS; auto algo_status = CUDNN_STATUS_SUCCESS;
size_t bwd_filter_size = 0, bwd_data_size = 0; size_t bwd_filter_size = 0, bwd_data_size = 0;
for (int step = 0; step < 2; ++step) { for (int step = 0; step < 2; ++step) {
...@@ -385,7 +353,7 @@ void CuDNNConvTransposeGradientOp<Context>::DoRunWithType() { ...@@ -385,7 +353,7 @@ void CuDNNConvTransposeGradientOp<Context>::DoRunWithType() {
bwd_filter_algo_, bwd_filter_algo_,
&bwd_filter_size); &bwd_filter_size);
if (algo_status != CUDNN_STATUS_SUCCESS && step == 0 && if (algo_status != CUDNN_STATUS_SUCCESS && step == 0 &&
CUDAContext::objects().cudnn_deterministic_) { filter_algo_deterministic_) {
bwd_filter_algo_ = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0; bwd_filter_algo_ = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0;
} else { } else {
CUDNN_CHECK(algo_status); CUDNN_CHECK(algo_status);
...@@ -401,18 +369,18 @@ void CuDNNConvTransposeGradientOp<Context>::DoRunWithType() { ...@@ -401,18 +369,18 @@ void CuDNNConvTransposeGradientOp<Context>::DoRunWithType() {
bwd_data_algo_, bwd_data_algo_,
&bwd_data_size); &bwd_data_size);
if (algo_status != CUDNN_STATUS_SUCCESS && step == 0 && if (algo_status != CUDNN_STATUS_SUCCESS && step == 0 &&
CUDAContext::objects().cudnn_deterministic_) { data_algo_deterministic_) {
bwd_data_algo_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM; bwd_data_algo_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
} else { } else {
CUDNN_CHECK(algo_status); CUDNN_CHECK(algo_status);
} }
} }
cudnn_ws_nbytes_ = std::max(bwd_filter_size, bwd_data_size); cudnn_ws_size_ = std::max(bwd_filter_size, bwd_data_size);
} }
// Alloc the memory for workspace data // Alloc the memory for workspace data
if (cudnn_ws_nbytes_ > 0) { if (cudnn_ws_size_ > 0) {
scratch = ctx()->workspace()->template data<Context>({cudnn_ws_nbytes_})[0]; scratch = ctx()->workspace()->template data<Context>({cudnn_ws_size_})[0];
} }
if (Output(2)->has_name()) { if (Output(2)->has_name()) {
...@@ -430,56 +398,45 @@ void CuDNNConvTransposeGradientOp<Context>::DoRunWithType() { ...@@ -430,56 +398,45 @@ void CuDNNConvTransposeGradientOp<Context>::DoRunWithType() {
if (dW->has_name()) { if (dW->has_name()) {
x = X.template data<T, Context>(); x = X.template data<T, Context>();
dw = dW->template mutable_data<T, Context>(); dw = dW->template mutable_data<T, Context>();
for (int g = 0; g < group_v2_; g++) {
CUDNN_CHECK(cudnnConvolutionBackwardFilter( CUDNN_CHECK(cudnnConvolutionBackwardFilter(
ctx()->cudnn_handle(), ctx()->cudnn_handle(),
CuDNNType<T>::one, CuDNNType<T>::one,
input_desc_, input_desc_,
dy + Y_stride_ * g, dy,
output_desc_, output_desc_,
x + X_stride_ * g, x,
conv_desc_, conv_desc_,
bwd_filter_algo_, bwd_filter_algo_,
scratch, scratch,
cudnn_ws_nbytes_, cudnn_ws_size_,
CuDNNType<T>::zero, CuDNNType<T>::zero,
filter_desc_, filter_desc_,
dw + W_stride_ * g)); dw));
}
} }
if (dX->has_name()) { if (dX->has_name()) {
auto* w = W.template data<T, Context>(); auto* w = W.template data<T, Context>();
auto* dx = dX->template mutable_data<T, Context>(); auto* dx = dX->template mutable_data<T, Context>();
for (int g = 0; g < group_v2_; g++) {
CUDNN_CHECK(cudnnConvolutionForward( CUDNN_CHECK(cudnnConvolutionForward(
ctx()->cudnn_handle(), ctx()->cudnn_handle(),
CuDNNType<T>::one, CuDNNType<T>::one,
input_desc_, input_desc_,
dy + Y_stride_ * g, dy,
filter_desc_, filter_desc_,
w + W_stride_ * g, w,
conv_desc_, conv_desc_,
bwd_data_algo_, bwd_data_algo_,
scratch, scratch,
cudnn_ws_nbytes_, cudnn_ws_size_,
CuDNNType<T>::zero, CuDNNType<T>::zero,
output_desc_, output_desc_,
dx + X_stride_ * g)); dx));
}
} }
} }
template <class Context> template <class Context>
void CuDNNConvTransposeGradientOp<Context>::RunOnDevice() { void CuDNNConvTransposeGradientOp<Context>::RunOnDevice() {
ConvOpBase<Context>::Reshape(true); ConvOpBase<Context>::Reshape(true);
if (data_format() == "NCHW") {
X_stride_ = Input(0).stride(0) / group_v2_;
Y_stride_ = Input(-1).stride(0) / group_v2_;
} else if (data_format() == "NHWC") {
X_stride_ = Input(0).dim(-1) / group_v2_;
Y_stride_ = Input(-1).dim(-1) / group_v2_;
}
DispatchHelper<FloatingTensorTypes>::Call(this, Input(-1)); DispatchHelper<FloatingTensorTypes>::Call(this, Input(-1));
} }
......
...@@ -83,10 +83,9 @@ template <typename T> ...@@ -83,10 +83,9 @@ template <typename T>
void CuDNNSetTensorDesc( void CuDNNSetTensorDesc(
cudnnTensorDescriptor_t* desc, cudnnTensorDescriptor_t* desc,
const vec64_t& dims, const vec64_t& dims,
const string& data_format, const string& data_format) {
const int64_t group) {
const int N = dims[0]; const int N = dims[0];
const int C = (data_format == "NCHW" ? dims[1] : dims.back()) / group; const int C = data_format == "NCHW" ? dims[1] : dims.back();
int D, H, W; int D, H, W;
if (dims.size() == 3) { if (dims.size() == 3) {
D = W = 1; D = W = 1;
...@@ -153,7 +152,7 @@ void CuDNNSetBiasDesc( ...@@ -153,7 +152,7 @@ void CuDNNSetBiasDesc(
template void CuDNNSetTensorDesc<T>( \ template void CuDNNSetTensorDesc<T>( \
cudnnTensorDescriptor_t*, const vec64_t&, const vec64_t&); \ cudnnTensorDescriptor_t*, const vec64_t&, const vec64_t&); \
template void CuDNNSetTensorDesc<T>( \ template void CuDNNSetTensorDesc<T>( \
cudnnTensorDescriptor_t*, const vec64_t&, const string&, const int64_t); \ cudnnTensorDescriptor_t*, const vec64_t&, const string&); \
template void CuDNNSetBiasDesc<T>( \ template void CuDNNSetBiasDesc<T>( \
cudnnTensorDescriptor_t*, const int, const int64_t, const string&); cudnnTensorDescriptor_t*, const int, const int64_t, const string&);
......
...@@ -93,8 +93,7 @@ template <typename T> ...@@ -93,8 +93,7 @@ template <typename T>
void CuDNNSetTensorDesc( void CuDNNSetTensorDesc(
cudnnTensorDescriptor_t* desc, cudnnTensorDescriptor_t* desc,
const vec64_t& dims, const vec64_t& dims,
const std::string& data_format, const std::string& data_format);
const int64_t group = 1);
/*! \brief Set a bias desc with expanding dimensions */ /*! \brief Set a bias desc with expanding dimensions */
template <typename T> template <typename T>
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!