Commit 73ed1b96 by Ting PAN

Remove support for CUDNN v6

Summary:
For the purpose of consistency on getting CUDNN convolution algorithms,
CUDNN v6 (mainly relied by CUDA 8.0) is now dropped.
1 parent bbfecf22
......@@ -58,20 +58,20 @@ endif()
# ---[ Library directories
list(APPEND THIRD_PARTY_LIBRARY_DIRS ${PROTOBUF_SDK_ROOT_DIR}/lib)
if (USE_CUDA)
list(APPEND THIRD_PARTY_LIBRARY_DIRS ${CUDA_TOOLKIT_ROOT_DIR}/lib64)
list(APPEND THIRD_PARTY_LIBRARY_DIRS ${CUDA_TOOLKIT_ROOT_DIR}/lib/x64)
if (USE_MPI)
list(APPEND THIRD_PARTY_LIBRARY_DIRS ${THIRD_PARTY_DIR}/mpi/lib)
endif()
if (USE_CUDNN)
list(APPEND THIRD_PARTY_LIBRARY_DIRS ${THIRD_PARTY_DIR}/cudnn/lib64)
list(APPEND THIRD_PARTY_LIBRARY_DIRS ${THIRD_PARTY_DIR}/cudnn/lib/x64)
endif()
if (USE_MPI)
list(APPEND THIRD_PARTY_LIBRARY_DIRS ${THIRD_PARTY_DIR}/mpi/lib)
endif()
if (USE_TENSORRT)
list(APPEND THIRD_PARTY_LIBRARY_DIRS ${TENSORRT_SDK_ROOT_DIR}/lib)
endif()
if (USE_CUDA)
list(APPEND THIRD_PARTY_LIBRARY_DIRS ${CUDA_TOOLKIT_ROOT_DIR}/lib64)
list(APPEND THIRD_PARTY_LIBRARY_DIRS ${CUDA_TOOLKIT_ROOT_DIR}/lib/x64)
endif()
# ---[ Defines
if (BUILD_PYTHON)
......
......@@ -35,6 +35,12 @@ else()
endif()
foreach(_plain_name ${ARGN})
# Filter the linker option
string(FIND "${_plain_name}" "-Wl" _is_linker_option)
if (${_is_linker_option} GREATER -1)
list(APPEND _libraries ${_plain_name})
continue()
endif()
# Firstly, search in the third party
set(LIB_VAR "${_target}/lib/${_plain_name}")
find_library(
......
......@@ -55,10 +55,6 @@ Share
#####
.. doxygenfunction:: dragon::Tensor::Share
SwitchToDevice
##############
.. doxygenfunction:: dragon::Tensor::SwitchToDevice
axis
####
.. doxygenfunction:: dragon::Tensor::axis
......
......@@ -19,10 +19,6 @@ State
Public Functions
----------------
SwitchToDevice
##############
.. doxygenfunction:: dragon::UnifiedMemory::SwitchToDevice
SwitchToCUDADevice
##################
.. doxygenfunction:: dragon::UnifiedMemory::SwitchToCUDADevice
......
......@@ -55,6 +55,7 @@ const void* UnifiedMemory::cpu_data(size_t size) {
}
const void* UnifiedMemory::cuda_data(size_t size) {
SwitchToCUDADevice(CUDAContext::current_device());
ToCUDA(size);
return (const void*)cuda_ptr_;
}
......@@ -70,6 +71,7 @@ void* UnifiedMemory::mutable_cpu_data(size_t size) {
}
void* UnifiedMemory::mutable_cuda_data(size_t size) {
SwitchToCUDADevice(CUDAContext::current_device());
ToCUDA(size);
state_ = STATE_AT_CUDA;
return cuda_ptr_;
......@@ -116,10 +118,6 @@ UnifiedMemory::~UnifiedMemory() {
}
}
void UnifiedMemory::SwitchToDevice(int device_id) {
if (cuda_ptr_) SwitchToCUDADevice(device_id);
}
void UnifiedMemory::SwitchToCUDADevice(int device_id) {
#ifdef USE_CUDA
if (cuda_ptr_) {
......@@ -129,7 +127,9 @@ void UnifiedMemory::SwitchToCUDADevice(int device_id) {
new_ptr_ = CUDAContext::New(size_);
CUDAContext::Memcpy<CUDAContext, CUDAContext>(
size_, new_ptr_, cuda_ptr_, device_id_);
if (own_cuda_ptr_) CUDAContext::Delete(cuda_ptr_);
if (own_cuda_ptr_) {
CUDAContext::Delete(cuda_ptr_);
}
cuda_ptr_ = new_ptr_;
device_id_ = device_id;
}
......
......@@ -49,9 +49,6 @@ class DRAGON_API UnifiedMemory {
/*! \brief Destructor */
~UnifiedMemory();
/*! \brief Switch to the given device */
void SwitchToDevice(int device);
/*! \brief Switch to the given cuda device */
void SwitchToCUDADevice(int device);
......
......@@ -136,20 +136,6 @@ void Operator<Context>::Release() {
}
}
template <class Context>
void Operator<Context>::SwitchToDevice() {
for (auto* tensor : inputs_) {
if (tensor->has_name()) {
tensor->SwitchToDevice(ctx()->device());
}
}
for (auto* tensor : outputs_) {
if (tensor->has_name()) {
tensor->SwitchToDevice(ctx()->device());
}
}
}
OperatorBase*
TryCreateOperator(const string& key, const OperatorDef& def, Workspace* ws) {
switch (def.device_option().device_type()) {
......
......@@ -207,9 +207,6 @@ class DRAGON_API Operator : public OperatorBase {
/*! \brief Release the ownership of inputs */
virtual void Release();
/*! \brief Coordinate the context of inputs and outputs */
virtual void SwitchToDevice();
/*! \brief The detailed execution on device */
virtual void RunOnDevice() = 0;
......@@ -217,7 +214,6 @@ class DRAGON_API Operator : public OperatorBase {
void Run(int stream = 0) final {
Prepare();
ctx()->SwitchToDevice(stream);
SwitchToDevice();
RunOnDevice();
if (do_sync_) {
ctx()->FinishDeviceComputation();
......
......@@ -109,12 +109,6 @@ class DRAGON_API Tensor {
return Reshape(other.dims_);
}
/*! \brief Switch memory to the specific device */
void SwitchToDevice(int device_id) {
UnifiedMemory* mem = memory();
if (mem) mem->SwitchToDevice(device_id);
}
/*! \brief Copy memory from a tensor with context */
template <class Context>
Tensor* CopyFrom(const Tensor& other, Context* ctx) {
......
......@@ -71,7 +71,7 @@ if (USE_CUDNN)
if (USE_SHARED_LIBS)
target_link_libraries_v2(dragon cudnn)
else()
target_link_libraries_v2(dragon cudnn_static)
target_link_libraries_v2(dragon_python -Wl,--whole-archive cudnn_static -Wl,--no-whole-archive)
endif()
endif()
if (USE_NCCL)
......
......@@ -79,7 +79,7 @@ if (USE_CUDNN)
if (USE_SHARED_LIBS)
target_link_libraries_v2(dragonrt cudnn)
else()
target_link_libraries_v2(dragonrt cudnn_static)
target_link_libraries_v2(dragonrt -Wl,--whole-archive cudnn_static -Wl,--no-whole-archive)
endif()
endif()
if (USE_CUDA AND (NOT USE_SHARED_LIBS))
......
......@@ -58,8 +58,6 @@ DEFINE_OP_SINGLE_ARG_WITH_DESC(float, DropoutGradientOp, ratio);
#ifdef USE_CUDNN
#if CUDNN_VERSION_MIN(7, 0, 0)
template <class Context>
class CuDNNDropoutOp final : public DropoutOp<Context> {
public:
......@@ -118,8 +116,6 @@ class CuDNNDropoutGradientOp final : public DropoutGradientOp<Context> {
unsigned long long rng_seed_;
};
#endif // CUDNN_VERSION_MIN(7, 0, 0)
#endif // USE_CUDNN
} // namespace dragon
......
......@@ -3,8 +3,6 @@
#include "dragon/core/workspace.h"
#include "dragon/operators/activation/dropout_op.h"
#if CUDNN_VERSION_MIN(7, 0, 0)
namespace dragon {
template <class Context>
......@@ -124,6 +122,4 @@ DEPLOY_CUDNN_OPERATOR(DropoutGradient);
} // namespace dragon
#endif // CUDNN_VERSION_MIN(7, 0, 0)
#endif // USE_CUDNN
......@@ -23,8 +23,6 @@ class ShapeOp final : public Operator<Context> {
SIMPLE_CTOR_DTOR(ShapeOp);
USE_OPERATOR_FUNCTIONS;
void SwitchToDevice() override {}
void RunOnDevice() override;
};
......
......@@ -44,8 +44,6 @@ class CTCLossGradientOp final : public Operator<Context> {
#ifdef USE_CUDNN
#if CUDNN_VERSION_MIN(7, 0, 0)
template <class Context>
class CuDNNCTCLossOp final : public Operator<Context> {
public:
......@@ -81,8 +79,6 @@ class CuDNNCTCLossOp final : public Operator<Context> {
vec32_t packed_labels_, label_lengths_, input_lengths_;
};
#endif // CUDNN_VERSION_MIN(7, 0, 0)
#endif // USE_CUDNN
} // namespace dragon
......
......@@ -3,8 +3,6 @@
#include "dragon/core/workspace.h"
#include "dragon/operators/loss/ctc_loss_op.h"
#if CUDNN_VERSION_MIN(7, 0, 0)
#define CUDNN_LABEL_LENGTH_LIMIT 256
namespace dragon {
......@@ -95,6 +93,4 @@ DEPLOY_CUDNN_OPERATOR(CTCLoss);
} // namespace dragon
#endif // CUDNN_VERSION_MIN(7, 0, 0)
#endif // USE_CUDNN
......@@ -19,11 +19,10 @@ void CuDNNRecurrentOpBase<Context>::ResetDesc() {
// Setup Dropout
if (dropout_ratio_ < 1.f) {
#if CUDNN_VERSION_MIN(7, 0, 0)
if (!states_initialized_) {
states_initialized_ = 1;
CUDNN_CHECK(
cudnnDropoutGetStatesSize(ctx()->cudnn_handle(), &states_size_));
auto cudnn_handle = ctx()->cudnn_handle();
CUDNN_CHECK(cudnnDropoutGetStatesSize(cudnn_handle, &states_size_));
std::lock_guard<std::mutex> lk(CUDAContext::mutex());
auto* states_tensor = workspace()->CreateTensor(
"/share/cudnn/dropout:" + str::to(rng_seed_) + "/states");
......@@ -31,7 +30,7 @@ void CuDNNRecurrentOpBase<Context>::ResetDesc() {
auto* states = states_tensor->template mutable_data<uint8_t, Context>();
CUDNN_CHECK(cudnnRestoreDropoutDescriptor(
dropout_desc_,
ctx()->cudnn_handle(),
cudnn_handle,
dropout_ratio_,
states,
states_size_,
......@@ -41,16 +40,13 @@ void CuDNNRecurrentOpBase<Context>::ResetDesc() {
->template mutable_data<uint8_t, Context>();
CUDNN_CHECK(cudnnSetDropoutDescriptor(
dropout_desc_,
ctx()->cudnn_handle(),
cudnn_handle,
dropout_ratio_,
states,
states_size_,
rng_seed_));
}
}
#else
LOG(FATAL) << "Dropout has been supported since CuDNN 7.0";
#endif
}
// Setup RNN
......@@ -61,7 +57,6 @@ void CuDNNRecurrentOpBase<Context>::ResetDesc() {
} else if (input_type == TypeMeta::Id<double>()) {
compute_type_ = CUDNN_DATA_DOUBLE;
}
#if CUDNN_VERSION_MIN(7, 0, 0)
CUDNN_CHECK(cudnnSetRNNDescriptor_v6(
ctx()->cudnn_handle(),
rnn_desc_,
......@@ -73,35 +68,22 @@ void CuDNNRecurrentOpBase<Context>::ResetDesc() {
rnn_mode_,
CUDNN_RNN_ALGO_STANDARD,
compute_type_));
#else
CUDNN_CHECK(cudnnSetRNNDescriptor(
rnn_desc_,
hidden_size_,
num_layers_,
dropout_desc_,
rnn_input_mode_,
rnn_direction_,
rnn_mode_,
compute_type_));
#endif
// Setup TensorCore
#if CUDNN_VERSION_MIN(7, 0, 0)
if (enable_tensor_core_ > 0) {
if (TENSOR_CORE_AVAILABLE()) {
cudnnMathType_t math_type;
if (input_type == TypeMeta::Id<float16>()) {
math_type = CUDNN_TENSOR_OP_MATH;
} else {
math_type = CUDNN_DEFAULT_MATH;
#if CUDNN_VERSION_MIN(8, 0, 0)
if (!CUDAContext::objects().cudnn_allow_tf32_) {
#if CUDNN_VERSION_MIN(8, 0, 0)
math_type = CUDNN_FMA_MATH;
}
#endif
}
}
CUDNN_CHECK(cudnnSetRNNMatrixMathType(rnn_desc_, math_type));
}
#endif
// Setup X and Y
output_dims_ = {seq_length_, batch_size, y_dim};
......
......@@ -100,9 +100,10 @@ class CuDNNConvOp final : public CuDNNConvOpBase<Context> {
template <typename T>
void ResetDesc();
size_t cudnn_ws_nbytes_;
size_t cudnn_ws_size_;
vec64_t input_dims_, filter_dims_;
bool exhaustive_search_ = false;
bool algo_deterministic_ = false;
cudnnConvolutionFwdAlgo_t fwd_algo_;
cudnnTensorDescriptor_t input_desc_, output_desc_;
cudnnTensorDescriptor_t bias_desc_, output_desc_for_bias_;
......@@ -148,8 +149,10 @@ class CuDNNConvGradientOp final : public CuDNNConvOpBase<Context> {
template <typename T>
void ResetDesc();
size_t cudnn_ws_nbytes_;
size_t cudnn_ws_size_;
vec64_t input_dims_, filter_dims_;
bool data_algo_deterministic_ = false;
bool filter_algo_deterministic_ = false;
bool exhaustive_search_data_ = false;
bool exhaustive_search_filter_ = false;
cudnnConvolutionBwdFilterAlgo_t bwd_filter_algo_;
......
......@@ -136,11 +136,6 @@ class CuDNNConvOpBase : public ConvOpBase<Context> {
CuDNNConvOpBase(const OperatorDef& def, Workspace* ws)
: ConvOpBase<Context>(def, ws) {
GetBaseArguments();
#if CUDNN_VERSION_MIN(7, 0, 0)
group_v2_ = 1;
#else
group_v2_ = group_;
#endif
if (data_format() == "NCHW") {
tensor_format_ = CUDNN_TENSOR_NCHW;
} else if (data_format() == "NHWC") {
......@@ -184,7 +179,6 @@ class CuDNNConvOpBase : public ConvOpBase<Context> {
CUDNN_CROSS_CORRELATION,
compute_type_));
}
#if CUDNN_VERSION_MIN(7, 0, 0)
CUDNN_CHECK(cudnnSetConvolutionGroupCount(conv_desc_, group_));
if (TENSOR_CORE_AVAILABLE()) {
cudnnMathType_t math_type;
......@@ -192,15 +186,14 @@ class CuDNNConvOpBase : public ConvOpBase<Context> {
math_type = CUDNN_TENSOR_OP_MATH;
} else {
math_type = CUDNN_DEFAULT_MATH;
#if CUDNN_VERSION_MIN(8, 0, 0)
if (!CUDAContext::objects().cudnn_allow_tf32_) {
#if CUDNN_VERSION_MIN(8, 0, 0)
math_type = CUDNN_FMA_MATH;
}
#endif
}
}
CUDNN_CHECK(cudnnSetConvolutionMathType(conv_desc_, math_type));
}
#endif
}
template <typename T>
......@@ -210,13 +203,12 @@ class CuDNNConvOpBase : public ConvOpBase<Context> {
filter_desc_,
CuDNNType<T>::type,
tensor_format_,
conv_out_channels_ / group_v2_,
conv_out_channels_,
conv_in_channels_ / group_,
kshape_[0],
num_axes_ == 1 ? 1 : kshape_[1]));
} else {
vec64_t dims = {conv_out_channels_ / group_v2_,
conv_in_channels_ / group_};
vec64_t dims = {conv_out_channels_, conv_in_channels_ / group_};
dims.insert(dims.end(), kshape_.begin(), kshape_.end());
CUDNN_CHECK(cudnnSetFilterNdDescriptor(
filter_desc_,
......@@ -227,7 +219,6 @@ class CuDNNConvOpBase : public ConvOpBase<Context> {
}
}
int64_t group_v2_;
cudnnConvolutionDescriptor_t conv_desc_;
cudnnFilterDescriptor_t filter_desc_;
cudnnDataType_t compute_type_;
......@@ -240,8 +231,7 @@ class CuDNNConvOpBase : public ConvOpBase<Context> {
using CuDNNConvOpBase<Context>::conv_desc_; \
using CuDNNConvOpBase<Context>::filter_desc_; \
using CuDNNConvOpBase<Context>::compute_type_; \
using CuDNNConvOpBase<Context>::tensor_format_; \
using CuDNNConvOpBase<Context>::group_v2_;
using CuDNNConvOpBase<Context>::tensor_format_
#endif // USE_CUDNN
......
......@@ -112,9 +112,10 @@ class CuDNNConvTransposeOp final : public CuDNNConvOpBase<Context> {
template <typename T>
void ResetDesc();
size_t cudnn_ws_nbytes_;
size_t cudnn_ws_size_;
vec64_t input_dims_, filter_dims_;
bool exhaustive_search_ = false;
bool algo_deterministic_ = false;
cudnnConvolutionBwdDataAlgo_t fwd_algo_;
cudnnTensorDescriptor_t input_desc_, output_desc_;
cudnnTensorDescriptor_t bias_desc_, output_desc_for_bias_;
......@@ -164,10 +165,12 @@ class CuDNNConvTransposeGradientOp final : public CuDNNConvOpBase<Context> {
template <typename T>
void ResetDesc();
size_t cudnn_ws_nbytes_;
size_t cudnn_ws_size_;
vec64_t input_dims_, filter_dims_;
bool exhaustive_search_data_ = false;
bool exhaustive_search_filter_ = false;
bool data_algo_deterministic_ = false;
bool filter_algo_deterministic_ = false;
cudnnConvolutionBwdFilterAlgo_t bwd_filter_algo_;
cudnnConvolutionFwdAlgo_t bwd_data_algo_;
cudnnTensorDescriptor_t input_desc_, output_desc_;
......
......@@ -83,10 +83,9 @@ template <typename T>
void CuDNNSetTensorDesc(
cudnnTensorDescriptor_t* desc,
const vec64_t& dims,
const string& data_format,
const int64_t group) {
const string& data_format) {
const int N = dims[0];
const int C = (data_format == "NCHW" ? dims[1] : dims.back()) / group;
const int C = data_format == "NCHW" ? dims[1] : dims.back();
int D, H, W;
if (dims.size() == 3) {
D = W = 1;
......@@ -153,7 +152,7 @@ void CuDNNSetBiasDesc(
template void CuDNNSetTensorDesc<T>( \
cudnnTensorDescriptor_t*, const vec64_t&, const vec64_t&); \
template void CuDNNSetTensorDesc<T>( \
cudnnTensorDescriptor_t*, const vec64_t&, const string&, const int64_t); \
cudnnTensorDescriptor_t*, const vec64_t&, const string&); \
template void CuDNNSetBiasDesc<T>( \
cudnnTensorDescriptor_t*, const int, const int64_t, const string&);
......
......@@ -93,8 +93,7 @@ template <typename T>
void CuDNNSetTensorDesc(
cudnnTensorDescriptor_t* desc,
const vec64_t& dims,
const std::string& data_format,
const int64_t group = 1);
const std::string& data_format);
/*! \brief Set a bias desc with expanding dimensions */
template <typename T>
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!