Commit c40eaf7b by Ting PAN

Add support for building Ampere GPU, CUDA11 and CUDNN8

Summary:
This commit fixes the issue on building with CUDA11 and CUDNN8.
Besides, C++14 is enabled by default instead of C++11 to support CUB 1.9+,
and for this reason, the compiler is required to be gcc5/clang6/msvc141 or higher.
1 parent d8f612c8
......@@ -20,6 +20,8 @@ if (USE_CUDA)
if (MSVC)
# Suppress all warnings for msvc compiler
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -w")
else()
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -std=c++14")
endif()
endif()
if (USE_TENSORRT)
......
......@@ -49,11 +49,6 @@ foreach(_proto ${ARGN})
-I=${_proto_dir}
--cpp_out=${PROTOBUF_DLLEXPORT_STRING}${_proto_dir}
${_proto})
if (MSVC)
string(REPLACE ".proto" ".pb.h" _pb_h "${_proto}")
string(REPLACE ".proto" ".pb.cc" _pb_cc "${_proto}")
protobuf_remove_constexpr(${_pb_h} ${_pb_cc})
endif()
endforeach()
endfunction()
......@@ -69,11 +64,6 @@ foreach(_proto ${ARGN})
-I=${_proto_dir}
--cpp_out=${PROTOBUF_DLLEXPORT_STRING}${_proto_dir}
${_proto})
if (MSVC)
string(REPLACE ".proto" ".pb.h" _pb_h "${_proto}")
string(REPLACE ".proto" ".pb.cc" _pb_cc "${_proto}")
protobuf_remove_constexpr(${_pb_h} ${_pb_cc})
endif()
endforeach()
endfunction()
......
include(CheckCXXCompilerFlag)
# ---[ Check if CXX11 is supported
set(CMAKE_CXX_STANDARD 11)
# ---[ Check if CXX14 is supported
set(CMAKE_CXX_STANDARD 14)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
# ---[ Use ``-fPIC`` for all compilers
......@@ -30,7 +30,7 @@ if (MSVC)
endif()
else() # GNU, Clang, AppleClang
set(CMAKE_ORIGIN)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O2 -std=c++11")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O2 -std=c++14")
if (USE_NATIVE_ARCH)
check_cxx_compiler_flag("-march=native" COMPILER_SUPPORTS_MARCH_NATIVE)
if (COMPILER_SUPPORTS_MARCH_NATIVE)
......
......@@ -5,9 +5,9 @@
# - "Auto" detects local machine GPU compute arch at runtime.
# - "Common" and "All" cover common and entire subsets of architectures
# ARCH_AND_PTX : NAME | NUM.NUM | NUM.NUM(NUM.NUM) | NUM.NUM+PTX
# NAME: Kepler Maxwell Kepler+Tesla Maxwell+Tegra Pascal Volta Turing
# NAME: Kepler Maxwell Kepler+Tesla Maxwell+Tegra Pascal Volta Turing Ampere
# NUM: Any number. Only those pairs are currently accepted by NVCC though:
# 3.5 3.7 5.0 5.2 5.3 6.0 6.1 6.2 7.0 7.2 7.5
# 3.5 3.7 5.0 5.2 5.3 6.0 6.1 6.2 7.0 7.2 7.5 8.0
# Returns LIST of flags to be added to CUDA_NVCC_FLAGS in ${out_variable}
# Additionally, sets ${out_variable}_readable to the resulting numeric list
# Example:
......@@ -55,27 +55,39 @@ if(CUDA_VERSION VERSION_GREATER "7.5")
list(APPEND CUDA_ALL_GPU_ARCHITECTURES "6.0" "6.1" "6.2")
if(CUDA_VERSION VERSION_LESS "9.0")
list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "6.1+PTX")
list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "6.2+PTX")
set(CUDA_LIMIT_GPU_ARCHITECTURE "7.0")
endif()
endif ()
if(CUDA_VERSION VERSION_GREATER "8.5")
list(APPEND CUDA_KNOWN_GPU_ARCHITECTURES "Volta")
list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "7.0" "7.0+PTX")
list(APPEND CUDA_ALL_GPU_ARCHITECTURES "7.0" "7.0+PTX" "7.2" "7.2+PTX")
list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "7.0")
list(APPEND CUDA_ALL_GPU_ARCHITECTURES "7.0" "7.2")
if(CUDA_VERSION VERSION_LESS "10.0")
list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "7.2+PTX")
set(CUDA_LIMIT_GPU_ARCHITECTURE "8.0")
endif()
endif()
if(CUDA_VERSION VERSION_GREATER "9.5")
list(APPEND CUDA_KNOWN_GPU_ARCHITECTURES "Turing")
list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "7.5" "7.5+PTX")
list(APPEND CUDA_ALL_GPU_ARCHITECTURES "7.5" "7.5+PTX")
list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "7.5")
list(APPEND CUDA_ALL_GPU_ARCHITECTURES "7.5")
if(CUDA_VERSION VERSION_LESS "11.0")
list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "7.5+PTX")
set(CUDA_LIMIT_GPU_ARCHITECTURE "8.0")
endif()
endif()
if(CUDA_VERSION VERSION_GREATER "10.5")
list(APPEND CUDA_KNOWN_GPU_ARCHITECTURES "Ampere")
list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "8.0" "8.0+PTX")
list(APPEND CUDA_ALL_GPU_ARCHITECTURES "8.0")
if(CUDA_VERSION VERSION_LESS "12.0")
set(CUDA_LIMIT_GPU_ARCHITECTURE "9.0")
endif()
endif()
......@@ -211,6 +223,9 @@ function(CUDA_SELECT_NVCC_ARCH_FLAGS out_variable)
elseif(${arch_name} STREQUAL "Turing")
set(arch_bin 7.5)
set(arch_ptx 7.5)
elseif(${arch_name} STREQUAL "Ampere")
set(arch_bin 8.0)
set(arch_ptx 8.0)
else()
message(SEND_ERROR "Unknown CUDA Architecture Name ${arch_name} in CUDA_SELECT_NVCC_ARCH_FLAGS")
endif()
......
......@@ -9,10 +9,9 @@ std::mutex& CUDAContext::mutex() {
return m;
}
CUDAObject* CUDAContext::object() {
static TLS_OBJECT CUDAObject* cuda_object_;
if (!cuda_object_) cuda_object_ = new CUDAObject();
return cuda_object_;
CUDAObjects& CUDAContext::objects() {
static thread_local CUDAObjects cuda_objects_;
return cuda_objects_;
}
#endif // USE_CUDA
......
......@@ -21,10 +21,10 @@ namespace dragon {
#ifdef USE_CUDA
class CUDAObject {
class CUDAObjects {
public:
/*! \brief Default Constructor */
CUDAObject() {
CUDAObjects() {
for (int i = 0; i < CUDA_MAX_DEVICES; i++) {
cuda_streams_[i] = vector<cudaStream_t>();
cublas_handles_[i] = vector<cublasHandle_t>();
......@@ -38,7 +38,7 @@ class CUDAObject {
}
/*! \brief Destructor */
~CUDAObject() {
~CUDAObjects() {
for (int i = 0; i < CUDA_MAX_DEVICES; i++) {
for (int j = 0; j < cuda_streams_[i].size(); j++) {
auto& stream = cuda_streams_[i][j];
......@@ -164,7 +164,7 @@ class CUDAObject {
bool cudnn_benchmark_ = false;
private:
DISABLE_COPY_AND_ASSIGN(CUDAObject);
DISABLE_COPY_AND_ASSIGN(CUDAObjects);
};
/*!
......@@ -197,7 +197,7 @@ class DRAGON_API CUDAContext {
/*! \brief Set a memory block to the given value */
static void Memset(size_t n, void* ptr, int value = 0) {
auto stream = object()->default_stream();
auto stream = objects().default_stream();
CUDA_CHECK(cudaMemsetAsync(ptr, value, n, stream));
SynchronizeStream(stream);
}
......@@ -216,7 +216,7 @@ class DRAGON_API CUDAContext {
/*! \brief Copy a memory block to the destination using given device */
template <class DestContext, class SrcContext>
static void Memcpy(size_t n, void* dest, const void* src, int device) {
auto stream = object()->default_stream(device);
auto stream = objects().default_stream(device);
CUDA_CHECK(cudaMemcpyAsync(dest, src, n, cudaMemcpyDefault, stream));
SynchronizeStream(stream);
}
......@@ -269,12 +269,12 @@ class DRAGON_API CUDAContext {
/*! \brief Return the specified cuda stream */
cudaStream_t cuda_stream(int device, int stream) {
return object()->stream(device, stream);
return objects().stream(device, stream);
}
/*! \brief Return the cublas handle */
cublasHandle_t cublas_handle() {
return object()->cublas_handle(device_id_, stream_id_);
return objects().cublas_handle(device_id_, stream_id_);
}
/*! \brief Return the curand generator */
......@@ -293,7 +293,7 @@ class DRAGON_API CUDAContext {
/*! \brief Return the cudnn handle */
#ifdef USE_CUDNN
cudnnHandle_t cudnn_handle() {
return object()->cudnn_handle(device_id_, stream_id_);
return objects().cudnn_handle(device_id_, stream_id_);
}
#endif
......@@ -315,8 +315,8 @@ class DRAGON_API CUDAContext {
/*! \brief Return the shared context mutex */
static std::mutex& mutex();
/*! \brief Return the thread-local cuda object */
static CUDAObject* object();
/*! \brief Return the thread-local cuda objects */
static CUDAObjects& objects();
/*! \brief Return the random generator */
std::mt19937* rand_generator() {
......
......@@ -158,7 +158,7 @@ TryCreateOperator(const string& key, const OperatorDef& def, Workspace* ws) {
case PROTO_CUDA:
#ifdef USE_CUDNN
if (CUDNNOperatorRegistry()->Has(key) &&
CUDAContext::object()->cudnn_enabled_) {
CUDAContext::objects().cudnn_enabled_) {
return CUDNNOperatorRegistry()->Create(key, def, ws);
}
#endif
......
......@@ -98,9 +98,9 @@ void RegisterModule(py::module& m) {
/*! \brief Activate the CuDNN engine */
m.def("cudaEnableDNN", [](bool enabled, bool benchmark) {
#ifdef USE_CUDA
auto* cuda_object = CUDAContext::object();
cuda_object->cudnn_enabled_ = enabled;
cuda_object->cudnn_benchmark_ = benchmark;
auto& cuda_objects = CUDAContext::objects();
cuda_objects.cudnn_enabled_ = enabled;
cuda_objects.cudnn_benchmark_ = benchmark;
#endif
});
......@@ -129,7 +129,7 @@ void RegisterModule(py::module& m) {
m.def("cudaStreamSynchronize", [](int device_id, int stream_id) {
#ifdef USE_CUDA
if (device_id < 0) device_id = CUDAContext::current_device();
auto stream = CUDAContext::object()->stream(device_id, stream_id);
auto stream = CUDAContext::objects().stream(device_id, stream_id);
CUDAContext::SynchronizeStream(stream);
#endif
});
......
......@@ -9,7 +9,6 @@ template <typename T>
void CuDNNReluOp<Context>::DoRunWithType() {
auto &X = Input(0), *Y = Output(0, {0});
CuDNNSetTensorDesc<T>(&input_desc_, X.dims());
#if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnActivationForward(
ctx()->cudnn_handle(),
act_desc_,
......@@ -19,17 +18,6 @@ void CuDNNReluOp<Context>::DoRunWithType() {
CuDNNType<T>::zero,
input_desc_,
Y->ReshapeLike(X)->template mutable_data<T, Context>()));
#else
CUDNN_CHECK(cudnnActivationForward_v4(
ctx()->cudnn_handle(),
act_desc_,
CuDNNType<Dtype>::one,
input_desc_,
X.template data<T, Context>(),
CuDNNType<Dtype>::zero,
input_desc_,
Y->ReshapeLike(X)->template mutable_data<T, Context>()));
#endif
}
template <class Context>
......@@ -46,7 +34,6 @@ template <typename T>
void CuDNNReluGradientOp<Context>::DoRunWithType() {
auto &Y = Input(0), &dY = Input(1), *dX = Output(0);
CuDNNSetTensorDesc<T>(&input_desc_, Y.dims());
#if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnActivationBackward(
ctx()->cudnn_handle(),
act_desc_,
......@@ -60,21 +47,6 @@ void CuDNNReluGradientOp<Context>::DoRunWithType() {
CuDNNType<T>::zero,
input_desc_,
dX->ReshapeLike(Y)->template mutable_data<T, Context>()));
#else
CUDNN_CHECK(cudnnActivationBackward_v4(
ctx()->cudnn_handle(),
act_desc_,
CuDNNType<T>::one,
input_desc_,
Y.template data<T, Context>(),
input_desc_,
dY.template data<T, Context>(),
input_desc_,
Y.template data<T, Context>(),
CuDNNType<T>::zero,
input_desc_,
dX->ReshapeLike(Y)->template mutable_data<T, Context>()));
#endif
}
template <class Context>
......
......@@ -9,7 +9,6 @@ template <typename T>
void CuDNNSigmoidOp<Context>::DoRunWithType() {
auto &X = Input(0), *Y = Output(0, {0});
CuDNNSetTensorDesc<T>(&input_desc_, X.dims());
#if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnActivationForward(
ctx()->cudnn_handle(),
act_desc_,
......@@ -19,17 +18,6 @@ void CuDNNSigmoidOp<Context>::DoRunWithType() {
CuDNNType<T>::zero,
input_desc_,
Y->ReshapeLike(X)->template mutable_data<T, Context>()));
#else
CUDNN_CHECK(cudnnActivationForward_v4(
ctx()->cudnn_handle(),
act_desc_,
CuDNNType<T>::one,
input_desc_,
X.template data<T, Context>(),
CuDNNType<T>::zero,
input_desc_,
Y->ReshapeLike(X)->template mutable_data<T, Context>()));
#endif
}
template <class Context>
......@@ -42,7 +30,6 @@ template <typename T>
void CuDNNSigmoidGradientOp<Context>::DoRunWithType() {
auto &Y = Input(0), &dY = Input(1), *dX = Output(0);
CuDNNSetTensorDesc<T>(&input_desc_, Y.dims());
#if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnActivationBackward(
ctx()->cudnn_handle(),
act_desc_,
......@@ -56,21 +43,6 @@ void CuDNNSigmoidGradientOp<Context>::DoRunWithType() {
CuDNNType<T>::zero,
input_desc_,
dX->ReshapeLike(Y)->template mutable_data<T, Context>()));
#else
CUDNN_CHECK(cudnnActivationBackward_v4(
ctx()->cudnn_handle(),
act_desc_,
CuDNNType<T>::one,
input_desc_,
Y.template data<T, Context>(),
input_desc_,
dY.template data<T, Context>(),
input_desc_,
y,
CuDNNType<T>::zero,
input_desc_,
dX->ReshapeLike(Y)->template mutable_data<T, Context>()));
#endif
}
template <class Context>
......
......@@ -9,7 +9,6 @@ template <typename T>
void CuDNNTanhOp<Context>::DoRunWithType() {
auto &X = Input(0), *Y = Output(0, {0});
CuDNNSetTensorDesc<T>(&input_desc_, X.dims());
#if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnActivationForward(
ctx()->cudnn_handle(),
act_desc_,
......@@ -19,17 +18,6 @@ void CuDNNTanhOp<Context>::DoRunWithType() {
CuDNNType<T>::zero,
input_desc_,
Y->ReshapeLike(X)->template mutable_data<T, Context>()));
#else
CUDNN_CHECK(cudnnActivationForward_v4(
ctx()->cudnn_handle(),
act_desc_,
CuDNNType<T>::one,
input_desc_,
X.template data<T, Context>(),
CuDNNType<T>::zero,
output_desc_,
Y->ReshapeLike(X)->template mutable_data<T, Context>()));
#endif
}
template <class Context>
......@@ -42,7 +30,6 @@ template <typename T>
void CuDNNTanhGradientOp<Context>::DoRunWithType() {
auto &Y = Input(0), &dY = Input(1), *dX = Output(0);
CuDNNSetTensorDesc<T>(&input_desc_, Y.dims());
#if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnActivationBackward(
ctx()->cudnn_handle(),
act_desc_,
......@@ -56,21 +43,6 @@ void CuDNNTanhGradientOp<Context>::DoRunWithType() {
CuDNNType<T>::zero,
input_desc_,
dX->ReshapeLike(Y)->template mutable_data<T, Context>()));
#else
CUDNN_CHECK(cudnnActivationBackward_v4(
ctx()->cudnn_handle(),
act_desc_,
CuDNNType<T>::one,
input_desc_,
Y.template data<T, Context>(),
input_desc_,
dY.template data<T, Context>(),
input_desc_,
Y.template data<T, Context>(),
CuDNNType<T>::zero,
input_desc_,
dX->ReshapeLike(Y)->template mutable_data<T, Context>()));
#endif
}
template <class Context>
......
......@@ -148,7 +148,7 @@ class CollectiveOpBase : public Operator<Context> {
}
ncclComm_t nccl_comm() {
auto ret = CUDAContext::object()->nccl_comm(
auto ret = CUDAContext::objects().nccl_comm(
this->ctx()->template device(),
group_str_,
nullptr,
......@@ -161,7 +161,7 @@ class CollectiveOpBase : public Operator<Context> {
NCCL_CHECK(ncclGetUniqueId(&comm_uuid));
}
Broadcast((uint8_t*)&comm_uuid, sizeof(comm_uuid));
ret = CUDAContext::object()->nccl_comm(
ret = CUDAContext::objects().nccl_comm(
this->ctx()->template device(),
group_str_,
&comm_uuid,
......
......@@ -144,8 +144,6 @@ class SyncBatchNormGradientOp : public BatchNormGradientOp<Context> {
#ifdef USE_CUDNN
#if CUDNN_VERSION_MIN(5, 0, 0)
template <class Context>
class CuDNNBatchNormOp final : public BatchNormOpBase<Context> {
public:
......@@ -211,8 +209,6 @@ class CuDNNBatchNormGradientOp final : public BatchNormGradientOp<Context> {
cudnnBatchNormMode_t bn_mode_;
};
#endif // CUDNN_VERSION_MIN(5, 0, 0)
#endif // USE_CUDNN
} // namespace dragon
......
......@@ -4,8 +4,6 @@
#ifdef USE_CUDNN
#if CUDNN_VERSION_MIN(5, 0, 0)
namespace dragon {
template <class Context>
......@@ -171,6 +169,4 @@ DEPLOY_CUDNN(BatchNormGradient);
} // namespace dragon
#endif // CUDNN_VERSION_MIN(5, 0, 0)
#endif // USE_CUDNN
......@@ -4,8 +4,6 @@
#include "dragon/core/workspace.h"
#include "dragon/utils/filler.h"
#if CUDNN_VERSION_MIN(5, 0, 0)
namespace dragon {
template <class Context>
......@@ -56,7 +54,7 @@ void CuDNNRecurrentOpBase<Context>::ResetDesc() {
// Setup RNN
#if CUDNN_VERSION_MIN(7, 0, 0)
CUDNN_CHECK(cudnnSetRNNDescriptor(
CUDNN_CHECK(cudnnSetRNNDescriptor_v6(
ctx()->cudnn_handle(),
rnn_desc_,
hidden_size_,
......@@ -323,6 +321,4 @@ DEPLOY_CUDNN(RecurrentGradient);
} // namespace dragon
#endif // CUDNN_VERSION_MIN(5, 0, 0)
#endif // USE_CUDNN
......@@ -19,8 +19,6 @@ namespace dragon {
#ifdef USE_CUDNN
#if CUDNN_VERSION_MIN(5, 0, 0)
class CuDNNTensorDescs {
public:
CuDNNTensorDescs(int num_descs) {
......@@ -174,8 +172,6 @@ class CuDNNRecurrentGradientOp final : public CuDNNRecurrentOpBase<Context> {
void DoRunWithType();
};
#endif // CUDNN_VERSION_MIN(5, 0, 0)
#endif // USE_CUDNN
} // namespace dragon
......
......@@ -66,7 +66,6 @@ void CuDNNConvTranspose2dOp<Context>::ResetDesc() {
}
}
if (filter_changed) {
#if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnSetFilter4dDescriptor(
filter_desc_,
CuDNNType<T>::type,
......@@ -75,16 +74,6 @@ void CuDNNConvTranspose2dOp<Context>::ResetDesc() {
out_channels_ / group_,
kshape_[0],
kshape_[1]));
#else
CUDNN_CHECK(cudnnSetFilter4dDescriptor_v4(
filter_desc_,
CuDNNType<T>::type,
format_,
in_channels_ / cudnn_group_,
out_channels_ / group_,
kshape_[0],
kshape_[1]));
#endif
// Determine the bias shape
if (HasBias()) {
CuDNNSetBiasDesc<T>(
......@@ -94,9 +83,34 @@ void CuDNNConvTranspose2dOp<Context>::ResetDesc() {
// Set the conv configuration
SetConvDesc<T>();
// Get or search the appropriate algorithm
if (CUDAContext::object()->cudnn_benchmark_) {
if (CUDAContext::objects().cudnn_benchmark_) {
exhaustive_search_ = true;
} else {
#if CUDNN_VERSION_MIN(7, 0, 0)
int num_valid_algos;
constexpr int num_algos = CUDNN_CONV_NUM_BWD_DATA_ALGOS;
cudnnConvolutionBwdDataAlgoPerf_t stats[num_algos];
CUDNN_CHECK(cudnnGetConvolutionBackwardDataAlgorithm_v7(
ctx()->cudnn_handle(),
filter_desc_,
input_desc_,
conv_desc_,
output_desc_,
num_algos,
&num_valid_algos,
stats));
bool algo_is_found = false;
for (int i = 0; i < num_valid_algos; ++i) {
if (stats[i].memory <= CUDNN_CONV_WORKSPACE_LIMIT_BYTES) {
fwd_algo_ = stats[i].algo;
algo_is_found = true;
break;
}
}
CHECK(algo_is_found)
<< "\nNo algorithms available for <cudnnConvolutionBackwardData> "
<< "under the current desc and workspace limit.";
#else
CUDNN_CHECK(cudnnGetConvolutionBackwardDataAlgorithm(
ctx()->cudnn_handle(),
filter_desc_,
......@@ -106,6 +120,7 @@ void CuDNNConvTranspose2dOp<Context>::ResetDesc() {
CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
CUDNN_CONV_WORKSPACE_LIMIT_BYTES,
&fwd_algo_));
#endif // CUDNN_VERSION_MIN(7, 0, 0)
}
cudnn_ws_nbytes_ = SIZE_MAX; // Request a new size
}
......@@ -134,11 +149,9 @@ void CuDNNConvTranspose2dOp<Context>::DoRunWithType() {
scratch =
ws()->template data<Context>({CUDNN_CONV_WORKSPACE_LIMIT_BYTES})[0];
auto algo = algo_cache_.get(X.dims(), W.dims(), compute_type_, [&]() {
int returned_algo_count;
std::array<
cudnnConvolutionBwdDataAlgoPerf_t,
CUDNN_CONV_NUM_BWD_DATA_ALGOS>
stat;
int num_valid_algos;
constexpr int num_algos = CUDNN_CONV_NUM_BWD_DATA_ALGOS;
cudnnConvolutionBwdDataAlgoPerf_t stats[num_algos];
CUDNN_CHECK(cudnnFindConvolutionBackwardDataAlgorithmEx(
ctx()->cudnn_handle(),
filter_desc_,
......@@ -148,12 +161,12 @@ void CuDNNConvTranspose2dOp<Context>::DoRunWithType() {
conv_desc_,
output_desc_,
y,
CUDNN_CONV_NUM_BWD_DATA_ALGOS,
&returned_algo_count,
stat.data(),
num_algos,
&num_valid_algos,
stats,
scratch,
CUDNN_CONV_WORKSPACE_LIMIT_BYTES));
return FwdAlgoWithCost(stat[0].algo, stat[0].time);
return FwdAlgoWithCost(stats[0].algo, stats[0].time);
});
exhaustive_search_ = false;
fwd_algo_ = std::get<0>(algo);
......@@ -274,6 +287,7 @@ template <class Context>
template <typename T>
void CuDNNConvTranspose2dGradientOp<Context>::ResetDesc() {
auto &X = Input(0), &W = Input(1), &dY = Input(-1);
auto *dX = Output(0), *dW = Output(1);
bool input_changed = (X.dims() != input_dims_);
bool filter_changed = (W.dims() != filter_dims_);
if (input_changed || filter_changed) {
......@@ -289,7 +303,6 @@ void CuDNNConvTranspose2dGradientOp<Context>::ResetDesc() {
}
}
if (filter_changed) {
#if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnSetFilter4dDescriptor(
filter_desc_,
CuDNNType<T>::type,
......@@ -298,16 +311,6 @@ void CuDNNConvTranspose2dGradientOp<Context>::ResetDesc() {
out_channels_ / group_,
kshape_[0],
kshape_[1]));
#else
CUDNN_CHECK(cudnnSetFilter4dDescriptor_v4(
filter_desc,
CuDNNType<T>::type,
format_,
in_channels_ / cudnn_group_,
out_channels_ / group_,
kshape_[0],
kshape_[1]));
#endif
// Determine the bias shape
if (HasBias()) {
CuDNNSetBiasDesc<T>(
......@@ -317,28 +320,84 @@ void CuDNNConvTranspose2dGradientOp<Context>::ResetDesc() {
// Set the conv configuration
SetConvDesc<T>();
// Get the appropriate algorithm
if (CUDAContext::object()->cudnn_benchmark_) {
if (CUDAContext::objects().cudnn_benchmark_) {
exhaustive_search_data_ = true;
exhaustive_search_filter_ = true;
} else {
CUDNN_CHECK(cudnnGetConvolutionBackwardFilterAlgorithm(
ctx()->cudnn_handle(),
input_desc_,
output_desc_,
conv_desc_,
filter_desc_,
CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
CUDNN_CONV_WORKSPACE_LIMIT_BYTES,
&bwd_filter_algo_));
CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm(
ctx()->cudnn_handle(),
input_desc_,
filter_desc_,
conv_desc_,
output_desc_,
CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
CUDNN_CONV_WORKSPACE_LIMIT_BYTES,
&bwd_data_algo_));
if (dW->has_name()) {
#if CUDNN_VERSION_MIN(7, 0, 0)
int num_valid_algos;
constexpr int num_algos = CUDNN_CONV_NUM_BWD_FILTER_ALGOS;
cudnnConvolutionBwdFilterAlgoPerf_t stats[num_algos];
CUDNN_CHECK(cudnnGetConvolutionBackwardFilterAlgorithm_v7(
ctx()->cudnn_handle(),
input_desc_,
output_desc_,
conv_desc_,
filter_desc_,
num_algos,
&num_valid_algos,
stats));
bool algo_is_found = false;
for (int i = 0; i < num_valid_algos; ++i) {
if (stats[i].memory <= CUDNN_CONV_WORKSPACE_LIMIT_BYTES) {
bwd_filter_algo_ = stats[i].algo;
algo_is_found = true;
break;
}
}
CHECK(algo_is_found)
<< "\nNo algorithms available for <cudnnConvolutionBackwardFilter> "
<< "under the current desc and workspace limit.";
#else
CUDNN_CHECK(cudnnGetConvolutionBackwardFilterAlgorithm(
ctx()->cudnn_handle(),
input_desc_,
output_desc_,
conv_desc_,
filter_desc_,
CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
CUDNN_CONV_WORKSPACE_LIMIT_BYTES,
&bwd_filter_algo_));
#endif // CUDNN_VERSION_MIN(7, 0, 0)
}
if (dX->has_name()) {
#if CUDNN_VERSION_MIN(7, 0, 0)
int num_valid_algos;
constexpr int num_algos = CUDNN_CONV_NUM_FWD_ALGOS;
cudnnConvolutionFwdAlgoPerf_t stats[num_algos];
CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm_v7(
ctx()->cudnn_handle(),
input_desc_,
filter_desc_,
conv_desc_,
output_desc_,
num_algos,
&num_valid_algos,
stats));
bool algo_is_found = false;
for (int i = 0; i < num_valid_algos; ++i) {
if (stats[i].memory <= CUDNN_CONV_WORKSPACE_LIMIT_BYTES) {
bwd_data_algo_ = stats[i].algo;
algo_is_found = true;
break;
}
}
CHECK(algo_is_found)
<< "\nNo algorithms available for <cudnnConvolutionForward> "
<< "under the current desc and workspace limit.";
#else
CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm(
ctx()->cudnn_handle(),
input_desc_,
filter_desc_,
conv_desc_,
output_desc_,
CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
CUDNN_CONV_WORKSPACE_LIMIT_BYTES,
&bwd_data_algo_));
#endif // CUDNN_VERSION_MIN(7, 0, 0)
}
}
cudnn_ws_nbytes_ = SIZE_MAX; // Request a new size
}
......@@ -364,11 +423,9 @@ void CuDNNConvTranspose2dGradientOp<Context>::DoRunWithType() {
dw = dW->template mutable_data<T, Context>();
auto algo =
filter_algo_cache_.get(X.dims(), W.dims(), compute_type_, [&]() {
int returned_algo_count;
std::array<
cudnnConvolutionBwdFilterAlgoPerf_t,
CUDNN_CONV_NUM_BWD_FILTER_ALGOS>
stat;
int num_valid_algos;
constexpr int num_algos = CUDNN_CONV_NUM_BWD_FILTER_ALGOS;
cudnnConvolutionBwdFilterAlgoPerf_t stats[num_algos];
CUDNN_CHECK(cudnnFindConvolutionBackwardFilterAlgorithmEx(
ctx()->cudnn_handle(),
input_desc_,
......@@ -378,12 +435,12 @@ void CuDNNConvTranspose2dGradientOp<Context>::DoRunWithType() {
conv_desc_,
filter_desc_,
dw,
CUDNN_CONV_NUM_BWD_FILTER_ALGOS,
&returned_algo_count,
stat.data(),
num_algos,
&num_valid_algos,
stats,
scratch,
CUDNN_CONV_WORKSPACE_LIMIT_BYTES));
return BwdFilterAlgoWithCost(stat[0].algo, stat[0].time);
return BwdFilterAlgoWithCost(stats[0].algo, stats[0].time);
});
exhaustive_search_filter_ = false;
bwd_filter_algo_ = std::get<0>(algo);
......@@ -395,8 +452,9 @@ void CuDNNConvTranspose2dGradientOp<Context>::DoRunWithType() {
w = W.template data<T, Context>();
dx = dX->template mutable_data<T, Context>();
auto algo = data_algo_cache_.get(X.dims(), W.dims(), compute_type_, [&]() {
int returned_algo_count;
std::array<cudnnConvolutionFwdAlgoPerf_t, CUDNN_CONV_NUM_FWD_ALGOS> stat;
int num_valid_algos;
constexpr int num_algos = CUDNN_CONV_NUM_FWD_ALGOS;
cudnnConvolutionFwdAlgoPerf_t stats[num_algos];
CUDNN_CHECK(cudnnFindConvolutionForwardAlgorithmEx(
ctx()->cudnn_handle(),
input_desc_,
......@@ -406,12 +464,12 @@ void CuDNNConvTranspose2dGradientOp<Context>::DoRunWithType() {
conv_desc_,
output_desc_,
dx,
CUDNN_CONV_NUM_FWD_ALGOS,
&returned_algo_count,
stat.data(),
num_algos,
&num_valid_algos,
stats,
scratch,
CUDNN_CONV_WORKSPACE_LIMIT_BYTES));
return BwdDataAlgoWithCost(stat[0].algo, stat[0].time);
return BwdDataAlgoWithCost(stats[0].algo, stats[0].time);
});
exhaustive_search_data_ = false;
bwd_data_algo_ = std::get<0>(algo);
......
......@@ -13,7 +13,6 @@ void CuDNNPool2dOp<Context>::DoRunWithType() {
CuDNNSetTensorDesc<T>(&input_desc_, X.dims(), data_format());
CuDNNSetTensorDesc<T>(&output_desc_, out_shape_, data_format());
#if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnSetPooling2dDescriptor(
pool_desc_,
pool_mode_,
......@@ -24,18 +23,6 @@ void CuDNNPool2dOp<Context>::DoRunWithType() {
pad_l_[1],
stride_[0],
stride_[1]));
#else
CUDNN_CHECK(cudnnSetPooling2dDescriptor_v4(
pool_desc_,
pool_mode_,
CUDNN_PROPAGATE_NAN,
kshape_[0],
kshape_[1],
pad_l_[0],
pad_l_[1],
stride_[0],
stride_[1]));
#endif
CUDNN_CHECK(cudnnPoolingForward(
ctx()->cudnn_handle(),
......@@ -63,7 +50,6 @@ void CuDNNPool2dGradientOp<Context>::DoRunWithType() {
CuDNNSetTensorDesc<T>(&input_desc_, dY.dims(), data_format());
CuDNNSetTensorDesc<T>(&output_desc_, X.dims(), data_format());
#if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnSetPooling2dDescriptor(
pool_desc_,
pool_mode_,
......@@ -74,18 +60,6 @@ void CuDNNPool2dGradientOp<Context>::DoRunWithType() {
pad_l_[1],
stride_[0],
stride_[1]));
#else
CUDNN_CHECK(cudnnSetPooling2dDescriptor_v4(
pool_desc_,
pool_mode_,
CUDNN_PROPAGATE_NAN,
kshape_[0],
kshape_[1],
pad_l_[0],
pad_l_[1],
stride_[0],
stride_[1]));
#endif
CUDNN_CHECK(cudnnPoolingBackward(
ctx()->cudnn_handle(),
......
......@@ -138,16 +138,15 @@ class BuildExtension(_build_ext):
self.compiler.set_executable('compiler_so', nvcc)
if isinstance(cflags, dict):
cflags = cflags['nvcc']
cflags = \
COMMON_NVCC_FLAGS + \
['--compiler-options', "'-fPIC'"] + \
cflags + _get_cuda_arch_flags(cflags)
cflags = (COMMON_NVCC_FLAGS +
['--compiler-options', "'-fPIC'"] +
cflags + _get_cuda_arch_flags(cflags))
else:
if isinstance(cflags, dict):
cflags = cflags['cxx']
cflags += COMMON_CC_FLAGS
if not any(flag.startswith('-std=') for flag in cflags):
cflags.append('-std=c++11')
cflags.append('-std=c++14')
original_compile(obj, src, ext, cc_args, cflags, pp_opts)
finally:
self.compiler.set_executable('compiler_so', original_compiler)
......@@ -328,6 +327,7 @@ def _get_cuda_arch_flags(cflags=None):
'5.0', '5.2', '5.3',
'6.0', '6.1', '6.2',
'7.0', '7.2', '7.5',
'8.0',
]
valid_arch_strings = supported_arches + [s + "+PTX" for s in supported_arches]
capability = _cuda.get_device_capability()
......@@ -365,6 +365,6 @@ CUDA_HOME = _find_cuda()
CUDNN_HOME = _os.environ.get('CUDNN_HOME') or _os.environ.get('CUDNN_PATH')
COMMON_CC_FLAGS = ['-Wno-sign-compare', '-Wno-unused-variable', '-Wno-reorder']
COMMON_MSVC_FLAGS = ['/EHsc', '/wd4819', '/wd4244', '/wd4251', '/wd4275', '/wd4800', '/wd4996']
COMMON_NVCC_FLAGS = ['-w'] if IS_WINDOWS else []
COMMON_NVCC_FLAGS = ['-w'] if IS_WINDOWS else ['-std=c++14']
COMMON_LINK_LIBRARIES = ['protobuf'] if IS_WINDOWS else []
DLLIMPORT_STR = '__declspec(dllimport)' if IS_WINDOWS else ''
......@@ -32,16 +32,16 @@ namespace dragon {
#ifdef USE_CUDA
/*! \brief The number of cuda threads to use */
const int CUDA_THREADS = 256;
constexpr int CUDA_THREADS = 256;
/*! \brief The maximum number of blocks to use in the default kernel call */
const int CUDA_MAX_BLOCKS = 4096;
constexpr int CUDA_MAX_BLOCKS = 4096;
/*! \brief The maximum number of devices in a single machine */
const int CUDA_MAX_DEVICES = 16;
constexpr int CUDA_MAX_DEVICES = 16;
/*! \brief The maximum number of tensor dimsensions */
const int CUDA_TENSOR_MAX_DIMS = 8;
constexpr int CUDA_TENSOR_MAX_DIMS = 8;
#define CUDA_VERSION_MIN(major, minor, patch) \
(CUDA_VERSION >= (major * 1000 + minor * 100 + patch))
......
......@@ -34,19 +34,19 @@ namespace dragon {
<< cudnnGetErrorString(status); \
} while (0)
static const size_t CUDNN_CONV_WORKSPACE_LIMIT_BYTES = 64 * 1024 * 1024;
constexpr size_t CUDNN_CONV_WORKSPACE_LIMIT_BYTES = 64 * 1024 * 1024;
#if CUDNN_VERSION_MIN(7, 0, 0)
static const size_t CUDNN_CONV_NUM_FWD_ALGOS =
constexpr size_t CUDNN_CONV_NUM_FWD_ALGOS =
2 * CUDNN_CONVOLUTION_FWD_ALGO_COUNT;
static const size_t CUDNN_CONV_NUM_BWD_FILTER_ALGOS =
constexpr size_t CUDNN_CONV_NUM_BWD_FILTER_ALGOS =
2 * CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT;
static const size_t CUDNN_CONV_NUM_BWD_DATA_ALGOS =
constexpr size_t CUDNN_CONV_NUM_BWD_DATA_ALGOS =
2 * CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT;
#else
static const size_t CUDNN_CONV_NUM_FWD_ALGOS = 7;
static const size_t CUDNN_CONV_NUM_BWD_FILTER_ALGOS = 4;
static const size_t CUDNN_CONV_NUM_BWD_DATA_ALGOS = 5;
constexpr size_t CUDNN_CONV_NUM_FWD_ALGOS = 7;
constexpr size_t CUDNN_CONV_NUM_BWD_FILTER_ALGOS = 4;
constexpr size_t CUDNN_CONV_NUM_BWD_DATA_ALGOS = 5;
#endif
class Tensor;
......
......@@ -17,13 +17,6 @@
#define DRAGON_API
#endif
// Avoid using of "thread_local" for VS2013 or older Xcode
#if defined(__clang__) || defined(__GNUC__)
#define TLS_OBJECT __thread
#else
#define TLS_OBJECT __declspec(thread)
#endif
// Disable the copy and assignment operator for a class
#define DISABLE_COPY_AND_ASSIGN(classname) \
classname(const classname&) = delete; \
......
:: ##############################################################
:: Command file to build on Windows for Visual Studio 2013 (VC12)
:: ##############################################################
@echo off
setlocal
:: Build variables
set ORIGINAL_DIR=%cd%
set REPO_ROOT=%~dp0%..
set DRAGON_ROOT=%REPO_ROOT%\dragon
set THIRD_PARTY_DIR=%REPO_ROOT%\third_party
set CMAKE_GENERATOR="Visual Studio 12 2013 Win64"
:: Build options
set BUILD_PYTHON=ON
set BUILD_RUNTIME=OFF
:: Optional libraries
set USE_CUDA=ON
set USE_CUDNN=ON
set USE_OPENMP=ON
set USE_AVX=ON
set USE_AVX2=ON
set USE_FMA=ON
:: Protobuf SDK options
set PROTOBUF_SDK_ROOT_DIR=%THIRD_PARTY_DIR%\protobuf
:: Protobuf Compiler options
:: Set the protobuf compiler(i.e., protoc) if necessary.
:: If not, a compiler in the sdk or environment will be used.
set PROTOBUF_PROTOC_EXECUTABLE=%PROTOBUF_SDK_ROOT_DIR%\bin\protoc
:: Python options
:: Set your python "interpreter" if necessary.
:: If not, a default interpreter will be used.
:: set PYTHON_EXECUTABLE=X:/Anaconda3/python
if %BUILD_PYTHON% == ON (
if NOT DEFINED PYTHON_EXECUTABLE (
for /F %%i in ('python -c "import sys;print(sys.executable)"') do (set PYTHON_EXECUTABLE=%%i)
)
)
echo=
echo ------------------------- BUILDING CONFIGS -------------------------
echo=
echo -- DRAGON_ROOT=%DRAGON_ROOT%
echo -- CMAKE_GENERATOR=%CMAKE_GENERATOR%
if not exist %DRAGON_ROOT%\build mkdir %DRAGON_ROOT%\build
cd %DRAGON_ROOT%\build
cmake .. ^
-G%CMAKE_GENERATOR% ^
-DBUILD_PYTHON=%BUILD_PYTHON% ^
-DBUILD_RUNTIME=%BUILD_RUNTIME% ^
-USE_CUDA==%USE_CUDA% ^
-USE_CUDNN==%USE_CUDNN% ^
-USE_OPENMP==%USE_OPENMP% ^
-USE_AVX==%USE_AVX% ^
-USE_AVX2==%USE_AVX2% ^
-USE_FMA==%USE_FMA% ^
-DTHIRD_PARTY_DIR=%THIRD_PARTY_DIR% ^
-DPROTOBUF_SDK_ROOT_DIR=%PROTOBUF_SDK_ROOT_DIR% ^
-DPROTOBUF_PROTOC_EXECUTABLE=%PROTOBUF_PROTOC_EXECUTABLE% ^
-DPYTHON_EXECUTABLE=%PYTHON_EXECUTABLE% ^
|| goto :label_error
echo=
echo ------------------------- BUILDING CONFIGS -------------------------
echo=
cmake --build . --target INSTALL --config Release -- /maxcpucount:%NUMBER_OF_PROCESSORS% || goto :label_error
cd %DRAGON_ROOT%
%PYTHON_EXECUTABLE% setup.py install || goto :label_error
echo=
echo Built successfully
cd %ORIGINAL_DIR%
endlocal
pause
exit /b 0
:label_error
echo=
echo Building failed
cd %ORIGINAL_DIR%
endlocal
pause
exit /b 1
:: ##############################################################
:: Command file to build on Windows for Visual Studio 2015 (VC14)
:: ##############################################################
@echo off
setlocal
:: Build variables
set ORIGINAL_DIR=%cd%
set REPO_ROOT=%~dp0%..
set DRAGON_ROOT=%REPO_ROOT%\dragon
set THIRD_PARTY_DIR=%REPO_ROOT%\third_party
set CMAKE_GENERATOR="Visual Studio 14 2015 Win64"
:: Build options
set BUILD_PYTHON=ON
set BUILD_RUNTIME=OFF
:: Optional libraries
set USE_CUDA=ON
set USE_CUDNN=ON
set USE_OPENMP=ON
set USE_AVX=ON
set USE_AVX2=ON
set USE_FMA=ON
:: Protobuf SDK options
set PROTOBUF_SDK_ROOT_DIR=%THIRD_PARTY_DIR%\protobuf
:: Protobuf Compiler options
:: Set the protobuf compiler(i.e., protoc) if necessary.
:: If not, a compiler in the sdk or environment will be used.
set PROTOBUF_PROTOC_EXECUTABLE=%PROTOBUF_SDK_ROOT_DIR%\bin\protoc
:: Python options
:: Set your python "interpreter" if necessary.
:: If not, a default interpreter will be used.
:: set PYTHON_EXECUTABLE=X:/Anaconda3/python
if %BUILD_PYTHON% == ON (
if NOT DEFINED PYTHON_EXECUTABLE (
for /F %%i in ('python -c "import sys;print(sys.executable)"') do (set PYTHON_EXECUTABLE=%%i)
)
)
echo=
echo ------------------------- BUILDING CONFIGS -------------------------
echo=
echo -- DRAGON_ROOT=%DRAGON_ROOT%
echo -- CMAKE_GENERATOR=%CMAKE_GENERATOR%
if not exist %DRAGON_ROOT%\build mkdir %DRAGON_ROOT%\build
cd %DRAGON_ROOT%\build
cmake .. ^
-G%CMAKE_GENERATOR% ^
-DBUILD_PYTHON=%BUILD_PYTHON% ^
-DBUILD_RUNTIME=%BUILD_RUNTIME% ^
-USE_CUDA==%USE_CUDA% ^
-USE_CUDNN==%USE_CUDNN% ^
-USE_OPENMP==%USE_OPENMP% ^
-USE_AVX==%USE_AVX% ^
-USE_AVX2==%USE_AVX2% ^
-USE_FMA==%USE_FMA% ^
-DTHIRD_PARTY_DIR=%THIRD_PARTY_DIR% ^
-DPROTOBUF_SDK_ROOT_DIR=%PROTOBUF_SDK_ROOT_DIR% ^
-DPROTOBUF_PROTOC_EXECUTABLE=%PROTOBUF_PROTOC_EXECUTABLE% ^
-DPYTHON_EXECUTABLE=%PYTHON_EXECUTABLE% ^
|| goto :label_error
echo=
echo ------------------------- BUILDING CONFIGS -------------------------
echo=
cmake --build . --target INSTALL --config Release -- /maxcpucount:%NUMBER_OF_PROCESSORS% || goto :label_error
cd %DRAGON_ROOT%
%PYTHON_EXECUTABLE% setup.py install || goto :label_error
echo=
echo Built successfully
cd %ORIGINAL_DIR%
endlocal
pause
exit /b 0
:label_error
echo=
echo Building failed
cd %ORIGINAL_DIR%
endlocal
pause
exit /b 1
......@@ -8,6 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Command line to run tests."""
from __future__ import absolute_import
from __future__ import division
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!