Commit bbfecf22 by Ting PAN

Add sysconfig module

Summary:
This commit adds the sysconfig module to get the build information.
Build information is helpful to select tests or report issues.
1 parent 2c90589f
Showing with 422 additions and 135 deletions
dragon.sysconfig
================
.. only:: html
Functions
---------
`get_build_info(...) <sysconfig/get_build_info.html>`_
: Return the environment information of built binaries.
`get_include(...) <sysconfig/get_include.html>`_
: Return the directory of framework header files.
`get_lib(...) <sysconfig/get_lib.html>`_
: Return the directory of framework libraries.
.. toctree::
:hidden:
sysconfig/get_build_info
sysconfig/get_include
sysconfig/get_lib
.. raw:: html
<style>
h1:before {
content: "Module: ";
color: #103d3e;
}
</style>
get_build_info
==============
.. autofunction:: dragon.sysconfig.get_build_info
.. raw:: html
<style>
h1:before {
content: "dragon.sysconfig.";
color: #103d3e;
}
</style>
get_include
===========
.. autofunction:: dragon.sysconfig.get_include
.. raw:: html
<style>
h1:before {
content: "dragon.sysconfig.";
color: #103d3e;
}
</style>
get_lib
=======
.. autofunction:: dragon.sysconfig.get_lib
.. raw:: html
<style>
h1:before {
content: "dragon.sysconfig.";
color: #103d3e;
}
</style>
......@@ -41,6 +41,7 @@ Dragon
* `dragon.onnx <dragon/onnx.html>`_
* `dragon.optimizers <dragon/optimizers.html>`_
* `dragon.random <dragon/random.html>`_
* `dragon.sysconfig <dragon/sysconfig.html>`_
* `dragon.vision <dragon/vision.html>`_
Caffe
......@@ -210,6 +211,9 @@ Modules
`Module random <dragon/random.html>`_
: Native API for ``dragon.random`` namespace.
`Module sysconfig <dragon/sysconfig.html>`_
: Native API for ``dragon.sysconfig`` namespace.
`Module vision <dragon/vision.html>`_
: Native API for ``dragon.vision`` namespace.
......@@ -315,6 +319,7 @@ Modules
dragon/onnx
dragon/optimizers
dragon/random
dragon/sysconfig
dragon/vision
caffe
caffe/layers
......
......@@ -151,7 +151,10 @@ class CUDAObjects {
/*! \brief The flag that allows cuDNN or not */
bool cudnn_enabled_ = true;
/*! \brief The flag that allows cuDNN benchmark or not */
/*! \brief The flag that enforces deterministic cuDNN algorithms or not */
bool cudnn_deterministic_ = false;
/*! \brief The flag that benchmarks fastest cuDNN algorithms or not */
bool cudnn_benchmark_ = false;
/*! \brief The flag that allows cuDNN TF32 math type or not */
......
......@@ -94,14 +94,17 @@ void RegisterModule(py::module& m) {
});
/*! \brief Activate the CuDNN engine */
m.def("cudaEnableDNN", [](bool enabled, bool benchmark, bool allow_tf32) {
m.def(
"cudaEnableDNN",
[](bool enabled, bool deterministic, bool benchmark, bool allow_tf32) {
#ifdef USE_CUDA
auto& cuda_objects = CUDAContext::objects();
cuda_objects.cudnn_enabled_ = enabled;
cuda_objects.cudnn_benchmark_ = benchmark;
cuda_objects.cudnn_allow_tf32_ = allow_tf32;
auto& cuda_objects = CUDAContext::objects();
cuda_objects.cudnn_enabled_ = enabled;
cuda_objects.cudnn_deterministic_ = deterministic;
cuda_objects.cudnn_benchmark_ = benchmark;
cuda_objects.cudnn_allow_tf32_ = allow_tf32;
#endif
});
});
/*! \brief Return the index of current device */
m.def("cudaGetDevice", []() { return CUDAContext::current_device(); });
......
#include "dragon/modules/python/autograd.h"
#include "dragon/modules/python/config.h"
#include "dragon/modules/python/cuda.h"
#include "dragon/modules/python/dlpack.h"
#include "dragon/modules/python/mpi.h"
#include "dragon/modules/python/operator.h"
#include "dragon/modules/python/proto.h"
#include "dragon/modules/python/sysconfig.h"
#include "dragon/modules/python/tensor.h"
namespace dragon {
......@@ -288,11 +288,11 @@ PYBIND11_MODULE(libdragon_python, m) {
[]() { import_array1(); }();
REGISTER_MODULE(autograd);
REGISTER_MODULE(config);
REGISTER_MODULE(cuda);
REGISTER_MODULE(mpi);
REGISTER_MODULE(ops);
REGISTER_MODULE(proto);
REGISTER_MODULE(sysconfig);
REGISTER_MODULE(tensor);
#undef REGISTER_MODULE
}
......
......@@ -10,8 +10,8 @@
* ------------------------------------------------------------
*/
#ifndef DRAGON_MODULES_PYTHON_CONFIG_H_
#define DRAGON_MODULES_PYTHON_CONFIG_H_
#ifndef DRAGON_MODULES_PYTHON_SYSCONFIG_H_
#define DRAGON_MODULES_PYTHON_SYSCONFIG_H_
#include "dragon/modules/python/common.h"
#include "dragon/utils/device/common_eigen.h"
......@@ -20,7 +20,7 @@ namespace dragon {
namespace python {
namespace config {
namespace sysconfig {
void RegisterModule(py::module& m) {
/*! \brief Set the logging severity */
......@@ -33,12 +33,53 @@ void RegisterModule(py::module& m) {
/*! \brief Return the number of threads for cpu parallelism */
m.def("GetNumThreads", []() { return Eigen::nbThreads(); });
m.def("GetBuildInformation", []() {
static string build_info;
if (!build_info.empty()) {
return build_info;
}
build_info += "cpu_features:";
#if defined(USE_AVX)
build_info += " AVX";
#endif
#if defined(USE_AVX2)
build_info += " AVX2";
#endif
#if defined(USE_FMA)
build_info += " FMA";
#endif
build_info += "\ncuda_version:";
#if defined(USE_CUDA)
build_info += " " + str::to(CUDA_VERSION / 1000) + "." +
str::to(CUDA_VERSION % 1000 / 10);
#endif
build_info += "\ncudnn_version:";
#if defined(USE_CUDNN)
build_info += " " + str::to(CUDNN_MAJOR) + "." + str::to(CUDNN_MINOR) +
"." + str::to(CUDNN_PATCHLEVEL);
#endif
build_info += "\nthird_party: eigen protobuf pybind11";
#if defined(USE_OPENMP)
build_info += " openmp";
#endif
#if defined(USE_MPI)
build_info += " mpi";
#endif
#if defined(USE_CUDA)
build_info += " cuda cub";
#endif
#if defined(USE_CUDNN)
build_info += " cudnn";
#endif
return build_info;
});
}
} // namespace config
} // namespace sysconfig
} // namespace python
} // namespace dragon
#endif // DRAGON_MODULES_PYTHON_CONFIG_H_
#endif // DRAGON_MODULES_PYTHON_SYSCONFIG_H_
......@@ -31,7 +31,9 @@ void CuDNNConvOp<Context>::ResetDesc() {
}
this->template SetConvDesc<T>();
// Get or search the appropriate algorithm
if (CUDAContext::objects().cudnn_benchmark_) {
if (CUDAContext::objects().cudnn_deterministic_) {
fwd_algo_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
} else if (CUDAContext::objects().cudnn_benchmark_) {
exhaustive_search_ = true;
} else {
#if CUDNN_VERSION_MIN(7, 0, 0)
......@@ -122,14 +124,23 @@ void CuDNNConvOp<Context>::DoRunWithType() {
// Determine the workspace size for selected algorithm
if (cudnn_ws_nbytes_ == SIZE_MAX) {
CUDNN_CHECK(cudnnGetConvolutionForwardWorkspaceSize(
ctx()->cudnn_handle(),
input_desc_,
filter_desc_,
conv_desc_,
output_desc_,
fwd_algo_,
&cudnn_ws_nbytes_));
auto algo_status = CUDNN_STATUS_SUCCESS;
for (int step = 0; step < 2; ++step) {
algo_status = cudnnGetConvolutionForwardWorkspaceSize(
ctx()->cudnn_handle(),
input_desc_,
filter_desc_,
conv_desc_,
output_desc_,
fwd_algo_,
&cudnn_ws_nbytes_);
if (algo_status != CUDNN_STATUS_SUCCESS && step == 0 &&
CUDAContext::objects().cudnn_deterministic_) {
fwd_algo_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
} else {
CUDNN_CHECK(algo_status);
}
}
}
// Alloc the memory for workspace data
......@@ -205,7 +216,10 @@ void CuDNNConvGradientOp<Context>::ResetDesc() {
}
this->template SetConvDesc<T>();
// Get or search the appropriate algorithm
if (CUDAContext::objects().cudnn_benchmark_) {
if (CUDAContext::objects().cudnn_deterministic_) {
bwd_data_algo_ = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
bwd_filter_algo_ = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
} else if (CUDAContext::objects().cudnn_benchmark_) {
exhaustive_search_data_ = true;
exhaustive_search_filter_ = true;
} else {
......@@ -359,23 +373,40 @@ void CuDNNConvGradientOp<Context>::DoRunWithType() {
// Determine the workspace size for selected algorithm
if (cudnn_ws_nbytes_ == SIZE_MAX) {
auto algo_status = CUDNN_STATUS_SUCCESS;
size_t bwd_filter_size = 0, bwd_data_size = 0;
CUDNN_CHECK(cudnnGetConvolutionBackwardFilterWorkspaceSize(
ctx()->cudnn_handle(),
output_desc_,
input_desc_,
conv_desc_,
filter_desc_,
bwd_filter_algo_,
&bwd_filter_size));
CUDNN_CHECK(cudnnGetConvolutionBackwardDataWorkspaceSize(
ctx()->cudnn_handle(),
filter_desc_,
input_desc_,
conv_desc_,
output_desc_,
bwd_data_algo_,
&bwd_data_size));
for (int step = 0; step < 2; ++step) {
algo_status = cudnnGetConvolutionBackwardFilterWorkspaceSize(
ctx()->cudnn_handle(),
output_desc_,
input_desc_,
conv_desc_,
filter_desc_,
bwd_filter_algo_,
&bwd_filter_size);
if (algo_status != CUDNN_STATUS_SUCCESS && step == 0 &&
CUDAContext::objects().cudnn_deterministic_) {
bwd_filter_algo_ = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0;
} else {
CUDNN_CHECK(algo_status);
}
}
for (int step = 0; step < 2; ++step) {
algo_status = cudnnGetConvolutionBackwardDataWorkspaceSize(
ctx()->cudnn_handle(),
filter_desc_,
input_desc_,
conv_desc_,
output_desc_,
bwd_data_algo_,
&bwd_data_size);
if (algo_status != CUDNN_STATUS_SUCCESS && step == 0 &&
CUDAContext::objects().cudnn_deterministic_) {
bwd_data_algo_ = CUDNN_CONVOLUTION_BWD_DATA_ALGO_0;
} else {
CUDNN_CHECK(algo_status);
}
}
cudnn_ws_nbytes_ = std::max(bwd_filter_size, bwd_data_size);
}
......
......@@ -31,7 +31,9 @@ void CuDNNConvTransposeOp<Context>::ResetDesc() {
}
this->template SetConvDesc<T>();
// Get or search the appropriate algorithm
if (CUDAContext::objects().cudnn_benchmark_) {
if (CUDAContext::objects().cudnn_deterministic_) {
fwd_algo_ = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
} else if (CUDAContext::objects().cudnn_benchmark_) {
exhaustive_search_ = true;
} else {
#if CUDNN_VERSION_MIN(7, 0, 0)
......@@ -122,14 +124,23 @@ void CuDNNConvTransposeOp<Context>::DoRunWithType() {
// Determine the workspace size for selected algorithm
if (cudnn_ws_nbytes_ == SIZE_MAX) {
CUDNN_CHECK(cudnnGetConvolutionBackwardDataWorkspaceSize(
ctx()->cudnn_handle(),
filter_desc_,
input_desc_,
conv_desc_,
output_desc_,
fwd_algo_,
&cudnn_ws_nbytes_));
auto algo_status = CUDNN_STATUS_SUCCESS;
for (int step = 0; step < 2; ++step) {
algo_status = cudnnGetConvolutionBackwardDataWorkspaceSize(
ctx()->cudnn_handle(),
filter_desc_,
input_desc_,
conv_desc_,
output_desc_,
fwd_algo_,
&cudnn_ws_nbytes_);
if (algo_status != CUDNN_STATUS_SUCCESS && step == 0 &&
CUDAContext::objects().cudnn_deterministic_) {
fwd_algo_ = CUDNN_CONVOLUTION_BWD_DATA_ALGO_0;
} else {
CUDNN_CHECK(algo_status);
}
}
}
// Alloc the memory for workspace data
......@@ -205,7 +216,10 @@ void CuDNNConvTransposeGradientOp<Context>::ResetDesc() {
}
this->template SetConvDesc<T>();
// Get the appropriate algorithm
if (CUDAContext::objects().cudnn_benchmark_) {
if (CUDAContext::objects().cudnn_deterministic_) {
bwd_data_algo_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
bwd_filter_algo_ = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
} else if (CUDAContext::objects().cudnn_benchmark_) {
exhaustive_search_data_ = true;
exhaustive_search_filter_ = true;
} else {
......@@ -359,23 +373,40 @@ void CuDNNConvTransposeGradientOp<Context>::DoRunWithType() {
// Determine the workspace size for selected algorithm
if (cudnn_ws_nbytes_ == SIZE_MAX) {
auto algo_status = CUDNN_STATUS_SUCCESS;
size_t bwd_filter_size = 0, bwd_data_size = 0;
CUDNN_CHECK(cudnnGetConvolutionBackwardFilterWorkspaceSize(
ctx()->cudnn_handle(),
input_desc_,
output_desc_,
conv_desc_,
filter_desc_,
bwd_filter_algo_,
&bwd_filter_size));
CUDNN_CHECK(cudnnGetConvolutionForwardWorkspaceSize(
ctx()->cudnn_handle(),
input_desc_,
filter_desc_,
conv_desc_,
output_desc_,
bwd_data_algo_,
&bwd_data_size));
for (int step = 0; step < 2; ++step) {
algo_status = cudnnGetConvolutionBackwardFilterWorkspaceSize(
ctx()->cudnn_handle(),
input_desc_,
output_desc_,
conv_desc_,
filter_desc_,
bwd_filter_algo_,
&bwd_filter_size);
if (algo_status != CUDNN_STATUS_SUCCESS && step == 0 &&
CUDAContext::objects().cudnn_deterministic_) {
bwd_filter_algo_ = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0;
} else {
CUDNN_CHECK(algo_status);
}
}
for (int step = 0; step < 2; ++step) {
algo_status = cudnnGetConvolutionForwardWorkspaceSize(
ctx()->cudnn_handle(),
input_desc_,
filter_desc_,
conv_desc_,
output_desc_,
bwd_data_algo_,
&bwd_data_size);
if (algo_status != CUDNN_STATUS_SUCCESS && step == 0 &&
CUDAContext::objects().cudnn_deterministic_) {
bwd_data_algo_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
} else {
CUDNN_CHECK(algo_status);
}
}
cudnn_ws_nbytes_ = std::max(bwd_filter_size, bwd_data_size);
}
......
......@@ -32,6 +32,7 @@ from dragon._api import nn
from dragon._api import onnx
from dragon._api import optimizers
from dragon._api import random
from dragon._api import sysconfig
from dragon._api import vision
from dragon import vm
......@@ -42,12 +43,12 @@ from dragon.core.eager.backprop import GradientTape
from dragon.core.framework.workspace import Workspace
# Functions
from dragon.backend import load_library
from dragon.core.autograph.def_function import function
from dragon.core.autograph.function_lib import create_function
from dragon.core.autograph.grad_impl import gradients
from dragon.core.eager.context import eager_mode
from dragon.core.eager.context import graph_mode
from dragon.core.framework.backend import load_library
from dragon.core.framework.config import get_num_threads
from dragon.core.framework.config import set_num_threads
from dragon.core.framework.context import device
......
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
from __future__ import absolute_import as _absolute_import
from __future__ import division as _division
from __future__ import print_function as _print_function
from dragon.core.framework.sysconfig import get_build_info
from dragon.core.framework.sysconfig import get_include
from dragon.core.framework.sysconfig import get_lib
__all__ = [_s for _s in dir() if not _s.startswith('_')]
......@@ -8,7 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Autograph options."""
"""Autograph configurations."""
from __future__ import absolute_import
from __future__ import division
......
......@@ -16,8 +16,8 @@ from __future__ import print_function
from collections import defaultdict
from dragon import backend
from dragon.core.autograph.op_def import OpDef
from dragon.core.framework import backend
from dragon.core.framework import proto_util
from dragon.core.proto import dragon_pb2
......
......@@ -8,7 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""The graph executing tensor."""
"""Graph executing tensor."""
from __future__ import absolute_import
from __future__ import division
......
......@@ -14,7 +14,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from dragon import backend
from dragon.core.framework import backend
from dragon.core.framework import config
from dragon.core.framework import workspace
......@@ -62,20 +62,27 @@ def current_device():
return backend.cudaGetDevice()
def enable_cudnn(enabled=True, benchmark=False, allow_tf32=False):
def enable_cudnn(
enabled=True,
deterministic=False,
benchmark=False,
allow_tf32=False,
):
"""Enable backend to use the cuDNN library.
Parameters
----------
enabled : bool, optional, default=True
Use cuDNN library or not.
deterministic : bool, optional, default=False
Select deterministic algorithms instead of fastest.
benchmark : bool, optional, default=False
Select algorithms according to the benchmark or not.
Select fastest algorithms via benchmark or heuristics.
allow_tf32 : bool, optional, default=False
Allow TF32 Tensor core operation or not.
Allow TF32 tensor core operation or not.
"""
return backend.cudaEnableDNN(enabled, benchmark, allow_tf32)
return backend.cudaEnableDNN(enabled, deterministic, benchmark, allow_tf32)
def get_device_capability(device_index=None):
......
......@@ -8,7 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Distributed utilities equipped with Python."""
"""Distributed backend."""
from __future__ import absolute_import
from __future__ import division
......@@ -16,7 +16,7 @@ from __future__ import print_function
import atexit
from dragon import backend as _b
from dragon.core.framework import backend as _b
from dragon.core.util import nest
from dragon.core.util import six
from dragon.core.util import tls
......
......@@ -12,7 +12,7 @@
# <https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/eager/context.py>
#
# ------------------------------------------------------------
"""State management for eager execution."""
"""Eager execution context."""
from __future__ import absolute_import
from __future__ import division
......
......@@ -8,6 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""DLPack utilities."""
from __future__ import absolute_import
from __future__ import division
......
......@@ -8,7 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""The eager executing tensor."""
"""Eager executing tensor."""
from __future__ import absolute_import
from __future__ import division
......
......@@ -8,7 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""List the exported C++ API."""
"""Framework backend."""
from __future__ import absolute_import
from __future__ import division
......
......@@ -8,7 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Define the global configurations."""
"""Framework configurations."""
from __future__ import absolute_import
from __future__ import division
......@@ -16,31 +16,29 @@ from __future__ import print_function
import threading
from dragon import backend
from dragon.core.framework import backend
class Config(object):
"""Store the common configurations for frontend."""
"""Framework configuration class."""
def __init__(self):
# The type of device.
# Device type.
# Enumeration in ('cpu', 'cuda', 'cnml').
self.device_type = 'cpu'
# The device index.
# Device index.
self.device_index = 0
# The global random seed.
# Device random seed.
self.random_seed = 3
# The graph type for various scheduling.
# Graph type for various scheduling.
self.graph_type = ''
# The graph optimization level.
# Graph optimization level.
self.graph_optimization = 3
# The graph verbosity level.
# Graph verbosity level.
self.graph_verbosity = 0
# The execution mode for graph.
# Graph execution mode.
self.graph_execution = 'EAGER_MODE'
# The directory to store logging files.
# Directory to store logging files.
self.log_dir = None
......
......@@ -8,6 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Framework context."""
from __future__ import absolute_import
from __future__ import division
......
......@@ -20,7 +20,7 @@ import sys
from google.protobuf.message import Message
import numpy
from dragon import backend
from dragon.core.framework import backend
from dragon.core.framework import config
from dragon.core.framework import context
from dragon.core.proto import dragon_pb2
......
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""System configurations."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from dragon.core.framework import backend
def get_build_info():
"""Return the environment information of built binaries.
The return value is a dictionary with string keys:
* cpu_features
* cuda_version
* cudnn_version
* is_cuda_build
* third_party
Returns
-------
dict
The info dict.
"""
build_info = {}
build_info_str = backend.GetBuildInformation()
for entry in build_info_str.split('\n'):
k, v = entry.split(':')
if len(v) > 0:
build_info[k] = v[1:]
build_info['is_cuda_build'] = 'cuda_version' in build_info
return build_info
def get_include():
"""Return the directory of framework header files.
Returns
-------
str
The include directory.
"""
core_root = os.path.dirname(os.path.dirname(__file__))
return os.path.join(os.path.dirname(core_root), 'include')
def get_lib():
"""Return the directory of framework libraries.
Returns
-------
str
The library directory.
"""
core_root = os.path.dirname(os.path.dirname(__file__))
return os.path.join(os.path.dirname(core_root), 'lib')
......@@ -8,7 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Define the basic prototypes."""
"""Basic prototypes."""
from __future__ import absolute_import
from __future__ import division
......
......@@ -8,7 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Generic interfaces of current default workspace."""
"""Functions and helpers for workspace."""
from __future__ import absolute_import
from __future__ import division
......@@ -18,7 +18,7 @@ import collections
import contextlib
import numpy
from dragon import backend
from dragon.core.framework import backend
from dragon.core.framework import config
from dragon.core.framework import mapping
from dragon.core.framework import proto_util
......
......@@ -8,7 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Utilities for KPLRecord."""
"""KPLRecord utilities."""
from __future__ import absolute_import
from __future__ import division
......
......@@ -8,6 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Process to read the distributed data."""
from __future__ import absolute_import
from __future__ import division
......
......@@ -8,7 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Utilities for TFRecord."""
"""TFRecord utilities."""
from __future__ import absolute_import
from __future__ import division
......
......@@ -22,10 +22,12 @@ import dragon
# The global argument parser
parser = argparse.ArgumentParser(add_help=False)
build_info = dragon.sysconfig.get_build_info()
# The optional testing flags
TEST_CUDA = dragon.cuda.is_available()
TEST_MPI = dragon.distributed.is_mpi_available()
TEST_CUDNN_CONV3D_NHWC = build_info.get('cudnn_version', '0.0.0') > '8.0.0'
def run_tests(argv=None):
......
......@@ -8,7 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""The Adam optimizers."""
"""Adam optimizers."""
from __future__ import absolute_import
from __future__ import division
......
......@@ -8,7 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""The optimizer to update parameters."""
"""Basic optimizers."""
from __future__ import absolute_import
from __future__ import division
......
......@@ -8,7 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""The RMSprop optimizers."""
"""RMSprop optimizers."""
from __future__ import absolute_import
from __future__ import division
......
......@@ -8,7 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""The SGD optimizers."""
"""SGD optimizers."""
from __future__ import absolute_import
from __future__ import division
......
......@@ -20,7 +20,7 @@ import os
import sys as _sys
import threading
from dragon import backend
from dragon.core.framework import backend
from dragon.core.framework import config
_logger = None
......
......@@ -328,7 +328,7 @@ def _get_cuda_arch_flags(cflags=None):
'5.0', '5.2', '5.3',
'6.0', '6.1', '6.2',
'7.0', '7.2', '7.5',
'8.0']
'8.0', '8.6']
valid_arch_strings = supported_arches + [s + "+PTX" for s in supported_arches]
capability = _cuda.get_device_capability()
arch_list = ['{}.{}'.format(capability[0], capability[1])]
......
......@@ -144,7 +144,7 @@ CONVERSIONS_DECL half To<half, half>(half val) {
template <>
CONVERSIONS_DECL half To<half, float>(float val) {
#if CUDA_VERSION_MIN(9, 2, 0)
#if CUDA_VERSION_MIN(9, 2)
return __float2half(val);
#else
#if defined(__CUDA_ARCH__)
......@@ -161,7 +161,7 @@ CONVERSIONS_DECL half To<half, float>(float val) {
template <>
CONVERSIONS_DECL half2 To<half2, float>(float val) {
#if CUDA_VERSION_MIN(9, 2, 0)
#if CUDA_VERSION_MIN(9, 2)
return __float2half2_rn(val);
#else
#if defined(__CUDA_ARCH__)
......
......@@ -39,11 +39,11 @@ constexpr int CUDA_MAX_DEVICES = 16;
/*! \brief The maximum number of tensor dimsensions */
constexpr int CUDA_TENSOR_MAX_DIMS = 8;
#define CUDA_VERSION_MIN(major, minor, patch) \
(CUDA_VERSION >= (major * 1000 + minor * 100 + patch))
#define CUDA_VERSION_MIN(major, minor) \
(CUDA_VERSION >= (major * 1000 + minor * 10))
#define CUDA_VERSION_MAX(major, minor, patch) \
(CUDA_VERSION < (major * 1000 + minor * 100 + patch))
#define CUDA_VERSION_MAX(major, minor) \
(CUDA_VERSION < (major * 1000 + minor * 10))
#define CUDA_CHECK(condition) \
do { \
......@@ -87,7 +87,7 @@ inline int CUDA_2D_BLOCKS(const int N) {
return std::max(std::min(N, CUDA_MAX_BLOCKS), 1);
}
#if CUDA_VERSION_MAX(9, 0, 0)
#if CUDA_VERSION_MAX(9, 0)
#define __hdiv hdiv
#endif
......
......@@ -35,19 +35,9 @@ namespace dragon {
} while (0)
constexpr size_t CUDNN_CONV_WORKSPACE_LIMIT_BYTES = 64 * 1024 * 1024;
#if CUDNN_VERSION_MIN(7, 0, 0)
constexpr size_t CUDNN_CONV_NUM_FWD_ALGOS =
2 * CUDNN_CONVOLUTION_FWD_ALGO_COUNT;
constexpr size_t CUDNN_CONV_NUM_BWD_FILTER_ALGOS =
2 * CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT;
constexpr size_t CUDNN_CONV_NUM_BWD_DATA_ALGOS =
2 * CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT;
#else
constexpr size_t CUDNN_CONV_NUM_FWD_ALGOS = 7;
constexpr size_t CUDNN_CONV_NUM_BWD_FILTER_ALGOS = 4;
constexpr size_t CUDNN_CONV_NUM_BWD_DATA_ALGOS = 5;
#endif
class Tensor;
......
......@@ -25,6 +25,7 @@ from dragon.core.eager.context import context as execution_context
from dragon.core.util import nest
from dragon.core.testing.unittest.common_utils import run_tests
from dragon.core.testing.unittest.common_utils import TEST_CUDA
from dragon.core.testing.unittest.common_utils import TEST_CUDNN_CONV3D_NHWC
# Fix the duplicate linked omp runtime
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
......@@ -3240,11 +3241,12 @@ class TestVisionOps(OpTestCase):
with dragon.device('cuda'), self.cudnn_ws.as_default():
self.test_conv2d()
def test_conv3d(self, prec=1e-3):
def test_conv3d(self, prec=1e-3, test_nhwc=True):
entries = [((2, 2, 2, 2, 2), (3, 2, 1, 1, 1), (3,), 1, 1, 0, 1, 1, 'NCHW'),
((2, 2, 2, 2, 2), (3, 2, 3, 3, 3), (3,), 3, 1, 1, 1, 1, 'NCHW'),
((2, 2, 2, 2, 2), (3, 2, 1, 1, 1), (3,), 1, 1, 0, 1, 1, 'NHWC'),
((2, 2, 2, 2, 2), (3, 2, 3, 3, 3), (3,), 3, 1, 1, 1, 1, 'NHWC')]
((2, 2, 2, 2, 2), (3, 2, 3, 3, 3), (3,), 3, 1, 1, 1, 1, 'NCHW')]
if test_nhwc:
entries += [((2, 2, 2, 2, 2), (3, 2, 1, 1, 1), (3,), 1, 1, 0, 1, 1, 'NHWC'),
((2, 2, 2, 2, 2), (3, 2, 3, 3, 3), (3,), 3, 1, 1, 1, 1, 'NHWC')]
results = [[[[[[0.08, 0.09], [0.1, 0.11]], [[0.12, 0.13], [0.14, 0.15]]],
[[[0.34, 0.39], [0.44, 0.49]], [[0.54, 0.59], [0.64, 0.69]]],
[[[0.6, 0.69], [0.78, 0.87]], [[0.96, 1.05], [1.14, 1.23]]]],
......@@ -3362,7 +3364,7 @@ class TestVisionOps(OpTestCase):
def test_conv3d_cudnn(self):
dragon.cuda.enable_cudnn(True)
with dragon.device('cuda'), self.cudnn_ws.as_default():
self.test_conv3d()
self.test_conv3d(test_nhwc=TEST_CUDNN_CONV3D_NHWC)
def test_conv1d_transpose(self, prec=1e-3):
entries = [((2, 2, 2), (2, 3, 1), (3,), 1, 1, 0, 1, 1, 'NCHW'),
......@@ -3508,11 +3510,12 @@ class TestVisionOps(OpTestCase):
with dragon.device('cuda'), self.cudnn_ws.as_default():
self.test_conv2d_transpose()
def test_conv3d_transpose(self, prec=1e-3):
def test_conv3d_transpose(self, prec=1e-3, test_nhwc=True):
entries = [((2, 2, 2, 2, 2), (2, 3, 1, 1, 1), (3,), 1, 1, 0, 1, 1, 'NCHW'),
((2, 2, 2, 2, 2), (2, 3, 3, 3, 3), (3,), 3, 1, 1, 1, 1, 'NCHW'),
((2, 2, 2, 2, 2), (2, 3, 1, 1, 1), (3,), 1, 1, 0, 1, 1, 'NHWC'),
((2, 2, 2, 2, 2), (2, 3, 3, 3, 3), (3,), 3, 1, 1, 1, 1, 'NHWC')]
((2, 2, 2, 2, 2), (2, 3, 3, 3, 3), (3,), 3, 1, 1, 1, 1, 'NCHW')]
if test_nhwc:
entries += [((2, 2, 2, 2, 2), (2, 3, 1, 1, 1), (3,), 1, 1, 0, 1, 1, 'NHWC'),
((2, 2, 2, 2, 2), (2, 3, 3, 3, 3), (3,), 3, 1, 1, 1, 1, 'NHWC')]
results = [[[[[[0.24, 0.27], [0.3, 0.33]], [[0.36, 0.39], [0.42, 0.45]]],
[[[0.42, 0.47], [0.52, 0.57]], [[0.62, 0.67], [0.72, 0.77]]],
[[[0.6, 0.67], [0.74, 0.81]], [[0.88, 0.95], [1.02, 1.09]]]],
......@@ -3631,7 +3634,7 @@ class TestVisionOps(OpTestCase):
def test_conv3d_transpose_cudnn(self):
dragon.cuda.enable_cudnn(True)
with dragon.device('cuda'), self.cudnn_ws.as_default():
self.test_conv3d_transpose()
self.test_conv3d_transpose(test_nhwc=TEST_CUDNN_CONV3D_NHWC)
def test_depthwise_conv2d(self, test_grad=False):
entries = [((2, 2, 2, 2), (2, 1, 1, 1), (2,), 1, 1, 0, 1, 'NCHW'),
......
......@@ -113,13 +113,13 @@ def dirac_(tensor, groups=1):
sizes = tensor.size()
if sizes[0] % groups != 0:
raise ValueError('Dimension 0 should be divisible by groups.')
group_dim = sizes[0] // groups
min_dim = min(group_dim, sizes[1])
out_channels_per_grp = sizes[0] // groups
min_dim = min(out_channels_per_grp, sizes[1])
with grad_mode.no_grad():
tensor.zero_()
for g in range(groups):
for d in range(min_dim):
item = [g * group_dim + d, d]
item = [g * out_channels_per_grp + d, d]
for i in range(2, dimensions):
item.append(sizes[i] // 2)
tensor[tuple(item)] = 1
......
......@@ -12,6 +12,7 @@
# <https://github.com/pytorch/pytorch/blob/master/torch/optim/adam.py>
#
# ------------------------------------------------------------
"""Adam optimizers."""
from __future__ import absolute_import
from __future__ import division
......
......@@ -12,6 +12,7 @@
# <https://github.com/pytorch/pytorch/blob/master/torch/optim/optimizer.py>
#
# ------------------------------------------------------------
"""Basic optimizers."""
from __future__ import absolute_import
from __future__ import division
......
......@@ -12,6 +12,7 @@
# <https://github.com/pytorch/pytorch/blob/master/torch/optim/rmsprop.py>
#
# ------------------------------------------------------------
"""RMSprop optimizers."""
from __future__ import absolute_import
from __future__ import division
......
......@@ -12,6 +12,7 @@
# <https://github.com/pytorch/pytorch/blob/master/torch/optim/sgd.py>
#
# ------------------------------------------------------------
"""SGD optimizers."""
from __future__ import absolute_import
from __future__ import division
......
......@@ -8,6 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""DLPack utilities."""
from __future__ import absolute_import
from __future__ import division
......
......@@ -12,6 +12,7 @@
# <https://github.com/pytorch/pytorch/blob/master/torch/utils/hooks.py>
#
# ------------------------------------------------------------
"""Hook utilities."""
from __future__ import absolute_import
from __future__ import division
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!