Commit bbfecf22 by Ting PAN

Add sysconfig module

Summary:
This commit adds the sysconfig module to get the build information.
Build information is helpful to select tests or report issues.
1 parent 2c90589f
Showing with 379 additions and 92 deletions
dragon.sysconfig
================
.. only:: html
Functions
---------
`get_build_info(...) <sysconfig/get_build_info.html>`_
: Return the environment information of built binaries.
`get_include(...) <sysconfig/get_include.html>`_
: Return the directory of framework header files.
`get_lib(...) <sysconfig/get_lib.html>`_
: Return the directory of framework libraries.
.. toctree::
:hidden:
sysconfig/get_build_info
sysconfig/get_include
sysconfig/get_lib
.. raw:: html
<style>
h1:before {
content: "Module: ";
color: #103d3e;
}
</style>
get_build_info
==============
.. autofunction:: dragon.sysconfig.get_build_info
.. raw:: html
<style>
h1:before {
content: "dragon.sysconfig.";
color: #103d3e;
}
</style>
get_include
===========
.. autofunction:: dragon.sysconfig.get_include
.. raw:: html
<style>
h1:before {
content: "dragon.sysconfig.";
color: #103d3e;
}
</style>
get_lib
=======
.. autofunction:: dragon.sysconfig.get_lib
.. raw:: html
<style>
h1:before {
content: "dragon.sysconfig.";
color: #103d3e;
}
</style>
...@@ -41,6 +41,7 @@ Dragon ...@@ -41,6 +41,7 @@ Dragon
* `dragon.onnx <dragon/onnx.html>`_ * `dragon.onnx <dragon/onnx.html>`_
* `dragon.optimizers <dragon/optimizers.html>`_ * `dragon.optimizers <dragon/optimizers.html>`_
* `dragon.random <dragon/random.html>`_ * `dragon.random <dragon/random.html>`_
* `dragon.sysconfig <dragon/sysconfig.html>`_
* `dragon.vision <dragon/vision.html>`_ * `dragon.vision <dragon/vision.html>`_
Caffe Caffe
...@@ -210,6 +211,9 @@ Modules ...@@ -210,6 +211,9 @@ Modules
`Module random <dragon/random.html>`_ `Module random <dragon/random.html>`_
: Native API for ``dragon.random`` namespace. : Native API for ``dragon.random`` namespace.
`Module sysconfig <dragon/sysconfig.html>`_
: Native API for ``dragon.sysconfig`` namespace.
`Module vision <dragon/vision.html>`_ `Module vision <dragon/vision.html>`_
: Native API for ``dragon.vision`` namespace. : Native API for ``dragon.vision`` namespace.
...@@ -315,6 +319,7 @@ Modules ...@@ -315,6 +319,7 @@ Modules
dragon/onnx dragon/onnx
dragon/optimizers dragon/optimizers
dragon/random dragon/random
dragon/sysconfig
dragon/vision dragon/vision
caffe caffe
caffe/layers caffe/layers
......
...@@ -151,7 +151,10 @@ class CUDAObjects { ...@@ -151,7 +151,10 @@ class CUDAObjects {
/*! \brief The flag that allows cuDNN or not */ /*! \brief The flag that allows cuDNN or not */
bool cudnn_enabled_ = true; bool cudnn_enabled_ = true;
/*! \brief The flag that allows cuDNN benchmark or not */ /*! \brief The flag that enforces deterministic cuDNN algorithms or not */
bool cudnn_deterministic_ = false;
/*! \brief The flag that benchmarks fastest cuDNN algorithms or not */
bool cudnn_benchmark_ = false; bool cudnn_benchmark_ = false;
/*! \brief The flag that allows cuDNN TF32 math type or not */ /*! \brief The flag that allows cuDNN TF32 math type or not */
......
...@@ -94,10 +94,13 @@ void RegisterModule(py::module& m) { ...@@ -94,10 +94,13 @@ void RegisterModule(py::module& m) {
}); });
/*! \brief Activate the CuDNN engine */ /*! \brief Activate the CuDNN engine */
m.def("cudaEnableDNN", [](bool enabled, bool benchmark, bool allow_tf32) { m.def(
"cudaEnableDNN",
[](bool enabled, bool deterministic, bool benchmark, bool allow_tf32) {
#ifdef USE_CUDA #ifdef USE_CUDA
auto& cuda_objects = CUDAContext::objects(); auto& cuda_objects = CUDAContext::objects();
cuda_objects.cudnn_enabled_ = enabled; cuda_objects.cudnn_enabled_ = enabled;
cuda_objects.cudnn_deterministic_ = deterministic;
cuda_objects.cudnn_benchmark_ = benchmark; cuda_objects.cudnn_benchmark_ = benchmark;
cuda_objects.cudnn_allow_tf32_ = allow_tf32; cuda_objects.cudnn_allow_tf32_ = allow_tf32;
#endif #endif
......
#include "dragon/modules/python/autograd.h" #include "dragon/modules/python/autograd.h"
#include "dragon/modules/python/config.h"
#include "dragon/modules/python/cuda.h" #include "dragon/modules/python/cuda.h"
#include "dragon/modules/python/dlpack.h" #include "dragon/modules/python/dlpack.h"
#include "dragon/modules/python/mpi.h" #include "dragon/modules/python/mpi.h"
#include "dragon/modules/python/operator.h" #include "dragon/modules/python/operator.h"
#include "dragon/modules/python/proto.h" #include "dragon/modules/python/proto.h"
#include "dragon/modules/python/sysconfig.h"
#include "dragon/modules/python/tensor.h" #include "dragon/modules/python/tensor.h"
namespace dragon { namespace dragon {
...@@ -288,11 +288,11 @@ PYBIND11_MODULE(libdragon_python, m) { ...@@ -288,11 +288,11 @@ PYBIND11_MODULE(libdragon_python, m) {
[]() { import_array1(); }(); []() { import_array1(); }();
REGISTER_MODULE(autograd); REGISTER_MODULE(autograd);
REGISTER_MODULE(config);
REGISTER_MODULE(cuda); REGISTER_MODULE(cuda);
REGISTER_MODULE(mpi); REGISTER_MODULE(mpi);
REGISTER_MODULE(ops); REGISTER_MODULE(ops);
REGISTER_MODULE(proto); REGISTER_MODULE(proto);
REGISTER_MODULE(sysconfig);
REGISTER_MODULE(tensor); REGISTER_MODULE(tensor);
#undef REGISTER_MODULE #undef REGISTER_MODULE
} }
......
...@@ -10,8 +10,8 @@ ...@@ -10,8 +10,8 @@
* ------------------------------------------------------------ * ------------------------------------------------------------
*/ */
#ifndef DRAGON_MODULES_PYTHON_CONFIG_H_ #ifndef DRAGON_MODULES_PYTHON_SYSCONFIG_H_
#define DRAGON_MODULES_PYTHON_CONFIG_H_ #define DRAGON_MODULES_PYTHON_SYSCONFIG_H_
#include "dragon/modules/python/common.h" #include "dragon/modules/python/common.h"
#include "dragon/utils/device/common_eigen.h" #include "dragon/utils/device/common_eigen.h"
...@@ -20,7 +20,7 @@ namespace dragon { ...@@ -20,7 +20,7 @@ namespace dragon {
namespace python { namespace python {
namespace config { namespace sysconfig {
void RegisterModule(py::module& m) { void RegisterModule(py::module& m) {
/*! \brief Set the logging severity */ /*! \brief Set the logging severity */
...@@ -33,12 +33,53 @@ void RegisterModule(py::module& m) { ...@@ -33,12 +33,53 @@ void RegisterModule(py::module& m) {
/*! \brief Return the number of threads for cpu parallelism */ /*! \brief Return the number of threads for cpu parallelism */
m.def("GetNumThreads", []() { return Eigen::nbThreads(); }); m.def("GetNumThreads", []() { return Eigen::nbThreads(); });
m.def("GetBuildInformation", []() {
static string build_info;
if (!build_info.empty()) {
return build_info;
}
build_info += "cpu_features:";
#if defined(USE_AVX)
build_info += " AVX";
#endif
#if defined(USE_AVX2)
build_info += " AVX2";
#endif
#if defined(USE_FMA)
build_info += " FMA";
#endif
build_info += "\ncuda_version:";
#if defined(USE_CUDA)
build_info += " " + str::to(CUDA_VERSION / 1000) + "." +
str::to(CUDA_VERSION % 1000 / 10);
#endif
build_info += "\ncudnn_version:";
#if defined(USE_CUDNN)
build_info += " " + str::to(CUDNN_MAJOR) + "." + str::to(CUDNN_MINOR) +
"." + str::to(CUDNN_PATCHLEVEL);
#endif
build_info += "\nthird_party: eigen protobuf pybind11";
#if defined(USE_OPENMP)
build_info += " openmp";
#endif
#if defined(USE_MPI)
build_info += " mpi";
#endif
#if defined(USE_CUDA)
build_info += " cuda cub";
#endif
#if defined(USE_CUDNN)
build_info += " cudnn";
#endif
return build_info;
});
} }
} // namespace config } // namespace sysconfig
} // namespace python } // namespace python
} // namespace dragon } // namespace dragon
#endif // DRAGON_MODULES_PYTHON_CONFIG_H_ #endif // DRAGON_MODULES_PYTHON_SYSCONFIG_H_
...@@ -31,7 +31,9 @@ void CuDNNConvOp<Context>::ResetDesc() { ...@@ -31,7 +31,9 @@ void CuDNNConvOp<Context>::ResetDesc() {
} }
this->template SetConvDesc<T>(); this->template SetConvDesc<T>();
// Get or search the appropriate algorithm // Get or search the appropriate algorithm
if (CUDAContext::objects().cudnn_benchmark_) { if (CUDAContext::objects().cudnn_deterministic_) {
fwd_algo_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
} else if (CUDAContext::objects().cudnn_benchmark_) {
exhaustive_search_ = true; exhaustive_search_ = true;
} else { } else {
#if CUDNN_VERSION_MIN(7, 0, 0) #if CUDNN_VERSION_MIN(7, 0, 0)
...@@ -122,14 +124,23 @@ void CuDNNConvOp<Context>::DoRunWithType() { ...@@ -122,14 +124,23 @@ void CuDNNConvOp<Context>::DoRunWithType() {
// Determine the workspace size for selected algorithm // Determine the workspace size for selected algorithm
if (cudnn_ws_nbytes_ == SIZE_MAX) { if (cudnn_ws_nbytes_ == SIZE_MAX) {
CUDNN_CHECK(cudnnGetConvolutionForwardWorkspaceSize( auto algo_status = CUDNN_STATUS_SUCCESS;
for (int step = 0; step < 2; ++step) {
algo_status = cudnnGetConvolutionForwardWorkspaceSize(
ctx()->cudnn_handle(), ctx()->cudnn_handle(),
input_desc_, input_desc_,
filter_desc_, filter_desc_,
conv_desc_, conv_desc_,
output_desc_, output_desc_,
fwd_algo_, fwd_algo_,
&cudnn_ws_nbytes_)); &cudnn_ws_nbytes_);
if (algo_status != CUDNN_STATUS_SUCCESS && step == 0 &&
CUDAContext::objects().cudnn_deterministic_) {
fwd_algo_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
} else {
CUDNN_CHECK(algo_status);
}
}
} }
// Alloc the memory for workspace data // Alloc the memory for workspace data
...@@ -205,7 +216,10 @@ void CuDNNConvGradientOp<Context>::ResetDesc() { ...@@ -205,7 +216,10 @@ void CuDNNConvGradientOp<Context>::ResetDesc() {
} }
this->template SetConvDesc<T>(); this->template SetConvDesc<T>();
// Get or search the appropriate algorithm // Get or search the appropriate algorithm
if (CUDAContext::objects().cudnn_benchmark_) { if (CUDAContext::objects().cudnn_deterministic_) {
bwd_data_algo_ = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
bwd_filter_algo_ = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
} else if (CUDAContext::objects().cudnn_benchmark_) {
exhaustive_search_data_ = true; exhaustive_search_data_ = true;
exhaustive_search_filter_ = true; exhaustive_search_filter_ = true;
} else { } else {
...@@ -359,23 +373,40 @@ void CuDNNConvGradientOp<Context>::DoRunWithType() { ...@@ -359,23 +373,40 @@ void CuDNNConvGradientOp<Context>::DoRunWithType() {
// Determine the workspace size for selected algorithm // Determine the workspace size for selected algorithm
if (cudnn_ws_nbytes_ == SIZE_MAX) { if (cudnn_ws_nbytes_ == SIZE_MAX) {
auto algo_status = CUDNN_STATUS_SUCCESS;
size_t bwd_filter_size = 0, bwd_data_size = 0; size_t bwd_filter_size = 0, bwd_data_size = 0;
CUDNN_CHECK(cudnnGetConvolutionBackwardFilterWorkspaceSize( for (int step = 0; step < 2; ++step) {
algo_status = cudnnGetConvolutionBackwardFilterWorkspaceSize(
ctx()->cudnn_handle(), ctx()->cudnn_handle(),
output_desc_, output_desc_,
input_desc_, input_desc_,
conv_desc_, conv_desc_,
filter_desc_, filter_desc_,
bwd_filter_algo_, bwd_filter_algo_,
&bwd_filter_size)); &bwd_filter_size);
CUDNN_CHECK(cudnnGetConvolutionBackwardDataWorkspaceSize( if (algo_status != CUDNN_STATUS_SUCCESS && step == 0 &&
CUDAContext::objects().cudnn_deterministic_) {
bwd_filter_algo_ = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0;
} else {
CUDNN_CHECK(algo_status);
}
}
for (int step = 0; step < 2; ++step) {
algo_status = cudnnGetConvolutionBackwardDataWorkspaceSize(
ctx()->cudnn_handle(), ctx()->cudnn_handle(),
filter_desc_, filter_desc_,
input_desc_, input_desc_,
conv_desc_, conv_desc_,
output_desc_, output_desc_,
bwd_data_algo_, bwd_data_algo_,
&bwd_data_size)); &bwd_data_size);
if (algo_status != CUDNN_STATUS_SUCCESS && step == 0 &&
CUDAContext::objects().cudnn_deterministic_) {
bwd_data_algo_ = CUDNN_CONVOLUTION_BWD_DATA_ALGO_0;
} else {
CUDNN_CHECK(algo_status);
}
}
cudnn_ws_nbytes_ = std::max(bwd_filter_size, bwd_data_size); cudnn_ws_nbytes_ = std::max(bwd_filter_size, bwd_data_size);
} }
......
...@@ -31,7 +31,9 @@ void CuDNNConvTransposeOp<Context>::ResetDesc() { ...@@ -31,7 +31,9 @@ void CuDNNConvTransposeOp<Context>::ResetDesc() {
} }
this->template SetConvDesc<T>(); this->template SetConvDesc<T>();
// Get or search the appropriate algorithm // Get or search the appropriate algorithm
if (CUDAContext::objects().cudnn_benchmark_) { if (CUDAContext::objects().cudnn_deterministic_) {
fwd_algo_ = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
} else if (CUDAContext::objects().cudnn_benchmark_) {
exhaustive_search_ = true; exhaustive_search_ = true;
} else { } else {
#if CUDNN_VERSION_MIN(7, 0, 0) #if CUDNN_VERSION_MIN(7, 0, 0)
...@@ -122,14 +124,23 @@ void CuDNNConvTransposeOp<Context>::DoRunWithType() { ...@@ -122,14 +124,23 @@ void CuDNNConvTransposeOp<Context>::DoRunWithType() {
// Determine the workspace size for selected algorithm // Determine the workspace size for selected algorithm
if (cudnn_ws_nbytes_ == SIZE_MAX) { if (cudnn_ws_nbytes_ == SIZE_MAX) {
CUDNN_CHECK(cudnnGetConvolutionBackwardDataWorkspaceSize( auto algo_status = CUDNN_STATUS_SUCCESS;
for (int step = 0; step < 2; ++step) {
algo_status = cudnnGetConvolutionBackwardDataWorkspaceSize(
ctx()->cudnn_handle(), ctx()->cudnn_handle(),
filter_desc_, filter_desc_,
input_desc_, input_desc_,
conv_desc_, conv_desc_,
output_desc_, output_desc_,
fwd_algo_, fwd_algo_,
&cudnn_ws_nbytes_)); &cudnn_ws_nbytes_);
if (algo_status != CUDNN_STATUS_SUCCESS && step == 0 &&
CUDAContext::objects().cudnn_deterministic_) {
fwd_algo_ = CUDNN_CONVOLUTION_BWD_DATA_ALGO_0;
} else {
CUDNN_CHECK(algo_status);
}
}
} }
// Alloc the memory for workspace data // Alloc the memory for workspace data
...@@ -205,7 +216,10 @@ void CuDNNConvTransposeGradientOp<Context>::ResetDesc() { ...@@ -205,7 +216,10 @@ void CuDNNConvTransposeGradientOp<Context>::ResetDesc() {
} }
this->template SetConvDesc<T>(); this->template SetConvDesc<T>();
// Get the appropriate algorithm // Get the appropriate algorithm
if (CUDAContext::objects().cudnn_benchmark_) { if (CUDAContext::objects().cudnn_deterministic_) {
bwd_data_algo_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
bwd_filter_algo_ = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
} else if (CUDAContext::objects().cudnn_benchmark_) {
exhaustive_search_data_ = true; exhaustive_search_data_ = true;
exhaustive_search_filter_ = true; exhaustive_search_filter_ = true;
} else { } else {
...@@ -359,23 +373,40 @@ void CuDNNConvTransposeGradientOp<Context>::DoRunWithType() { ...@@ -359,23 +373,40 @@ void CuDNNConvTransposeGradientOp<Context>::DoRunWithType() {
// Determine the workspace size for selected algorithm // Determine the workspace size for selected algorithm
if (cudnn_ws_nbytes_ == SIZE_MAX) { if (cudnn_ws_nbytes_ == SIZE_MAX) {
auto algo_status = CUDNN_STATUS_SUCCESS;
size_t bwd_filter_size = 0, bwd_data_size = 0; size_t bwd_filter_size = 0, bwd_data_size = 0;
CUDNN_CHECK(cudnnGetConvolutionBackwardFilterWorkspaceSize( for (int step = 0; step < 2; ++step) {
algo_status = cudnnGetConvolutionBackwardFilterWorkspaceSize(
ctx()->cudnn_handle(), ctx()->cudnn_handle(),
input_desc_, input_desc_,
output_desc_, output_desc_,
conv_desc_, conv_desc_,
filter_desc_, filter_desc_,
bwd_filter_algo_, bwd_filter_algo_,
&bwd_filter_size)); &bwd_filter_size);
CUDNN_CHECK(cudnnGetConvolutionForwardWorkspaceSize( if (algo_status != CUDNN_STATUS_SUCCESS && step == 0 &&
CUDAContext::objects().cudnn_deterministic_) {
bwd_filter_algo_ = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0;
} else {
CUDNN_CHECK(algo_status);
}
}
for (int step = 0; step < 2; ++step) {
algo_status = cudnnGetConvolutionForwardWorkspaceSize(
ctx()->cudnn_handle(), ctx()->cudnn_handle(),
input_desc_, input_desc_,
filter_desc_, filter_desc_,
conv_desc_, conv_desc_,
output_desc_, output_desc_,
bwd_data_algo_, bwd_data_algo_,
&bwd_data_size)); &bwd_data_size);
if (algo_status != CUDNN_STATUS_SUCCESS && step == 0 &&
CUDAContext::objects().cudnn_deterministic_) {
bwd_data_algo_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
} else {
CUDNN_CHECK(algo_status);
}
}
cudnn_ws_nbytes_ = std::max(bwd_filter_size, bwd_data_size); cudnn_ws_nbytes_ = std::max(bwd_filter_size, bwd_data_size);
} }
......
...@@ -32,6 +32,7 @@ from dragon._api import nn ...@@ -32,6 +32,7 @@ from dragon._api import nn
from dragon._api import onnx from dragon._api import onnx
from dragon._api import optimizers from dragon._api import optimizers
from dragon._api import random from dragon._api import random
from dragon._api import sysconfig
from dragon._api import vision from dragon._api import vision
from dragon import vm from dragon import vm
...@@ -42,12 +43,12 @@ from dragon.core.eager.backprop import GradientTape ...@@ -42,12 +43,12 @@ from dragon.core.eager.backprop import GradientTape
from dragon.core.framework.workspace import Workspace from dragon.core.framework.workspace import Workspace
# Functions # Functions
from dragon.backend import load_library
from dragon.core.autograph.def_function import function from dragon.core.autograph.def_function import function
from dragon.core.autograph.function_lib import create_function from dragon.core.autograph.function_lib import create_function
from dragon.core.autograph.grad_impl import gradients from dragon.core.autograph.grad_impl import gradients
from dragon.core.eager.context import eager_mode from dragon.core.eager.context import eager_mode
from dragon.core.eager.context import graph_mode from dragon.core.eager.context import graph_mode
from dragon.core.framework.backend import load_library
from dragon.core.framework.config import get_num_threads from dragon.core.framework.config import get_num_threads
from dragon.core.framework.config import set_num_threads from dragon.core.framework.config import set_num_threads
from dragon.core.framework.context import device from dragon.core.framework.context import device
......
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
from __future__ import absolute_import as _absolute_import
from __future__ import division as _division
from __future__ import print_function as _print_function
from dragon.core.framework.sysconfig import get_build_info
from dragon.core.framework.sysconfig import get_include
from dragon.core.framework.sysconfig import get_lib
__all__ = [_s for _s in dir() if not _s.startswith('_')]
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""Autograph options.""" """Autograph configurations."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
...@@ -16,8 +16,8 @@ from __future__ import print_function ...@@ -16,8 +16,8 @@ from __future__ import print_function
from collections import defaultdict from collections import defaultdict
from dragon import backend
from dragon.core.autograph.op_def import OpDef from dragon.core.autograph.op_def import OpDef
from dragon.core.framework import backend
from dragon.core.framework import proto_util from dragon.core.framework import proto_util
from dragon.core.proto import dragon_pb2 from dragon.core.proto import dragon_pb2
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""The graph executing tensor.""" """Graph executing tensor."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
...@@ -14,7 +14,7 @@ from __future__ import absolute_import ...@@ -14,7 +14,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from dragon import backend from dragon.core.framework import backend
from dragon.core.framework import config from dragon.core.framework import config
from dragon.core.framework import workspace from dragon.core.framework import workspace
...@@ -62,20 +62,27 @@ def current_device(): ...@@ -62,20 +62,27 @@ def current_device():
return backend.cudaGetDevice() return backend.cudaGetDevice()
def enable_cudnn(enabled=True, benchmark=False, allow_tf32=False): def enable_cudnn(
enabled=True,
deterministic=False,
benchmark=False,
allow_tf32=False,
):
"""Enable backend to use the cuDNN library. """Enable backend to use the cuDNN library.
Parameters Parameters
---------- ----------
enabled : bool, optional, default=True enabled : bool, optional, default=True
Use cuDNN library or not. Use cuDNN library or not.
deterministic : bool, optional, default=False
Select deterministic algorithms instead of fastest.
benchmark : bool, optional, default=False benchmark : bool, optional, default=False
Select algorithms according to the benchmark or not. Select fastest algorithms via benchmark or heuristics.
allow_tf32 : bool, optional, default=False allow_tf32 : bool, optional, default=False
Allow TF32 Tensor core operation or not. Allow TF32 tensor core operation or not.
""" """
return backend.cudaEnableDNN(enabled, benchmark, allow_tf32) return backend.cudaEnableDNN(enabled, deterministic, benchmark, allow_tf32)
def get_device_capability(device_index=None): def get_device_capability(device_index=None):
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""Distributed utilities equipped with Python.""" """Distributed backend."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
...@@ -16,7 +16,7 @@ from __future__ import print_function ...@@ -16,7 +16,7 @@ from __future__ import print_function
import atexit import atexit
from dragon import backend as _b from dragon.core.framework import backend as _b
from dragon.core.util import nest from dragon.core.util import nest
from dragon.core.util import six from dragon.core.util import six
from dragon.core.util import tls from dragon.core.util import tls
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# <https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/eager/context.py> # <https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/eager/context.py>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""State management for eager execution.""" """Eager execution context."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""DLPack utilities."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""The eager executing tensor.""" """Eager executing tensor."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""List the exported C++ API.""" """Framework backend."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""Define the global configurations.""" """Framework configurations."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
...@@ -16,31 +16,29 @@ from __future__ import print_function ...@@ -16,31 +16,29 @@ from __future__ import print_function
import threading import threading
from dragon import backend from dragon.core.framework import backend
class Config(object): class Config(object):
"""Store the common configurations for frontend.""" """Framework configuration class."""
def __init__(self): def __init__(self):
# The type of device. # Device type.
# Enumeration in ('cpu', 'cuda', 'cnml'). # Enumeration in ('cpu', 'cuda', 'cnml').
self.device_type = 'cpu' self.device_type = 'cpu'
# The device index. # Device index.
self.device_index = 0 self.device_index = 0
# The global random seed. # Device random seed.
self.random_seed = 3 self.random_seed = 3
# Graph type for various scheduling.
# The graph type for various scheduling.
self.graph_type = '' self.graph_type = ''
# The graph optimization level. # Graph optimization level.
self.graph_optimization = 3 self.graph_optimization = 3
# The graph verbosity level. # Graph verbosity level.
self.graph_verbosity = 0 self.graph_verbosity = 0
# The execution mode for graph. # Graph execution mode.
self.graph_execution = 'EAGER_MODE' self.graph_execution = 'EAGER_MODE'
# Directory to store logging files.
# The directory to store logging files.
self.log_dir = None self.log_dir = None
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""Framework context."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
...@@ -20,7 +20,7 @@ import sys ...@@ -20,7 +20,7 @@ import sys
from google.protobuf.message import Message from google.protobuf.message import Message
import numpy import numpy
from dragon import backend from dragon.core.framework import backend
from dragon.core.framework import config from dragon.core.framework import config
from dragon.core.framework import context from dragon.core.framework import context
from dragon.core.proto import dragon_pb2 from dragon.core.proto import dragon_pb2
......
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""System configurations."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from dragon.core.framework import backend
def get_build_info():
"""Return the environment information of built binaries.
The return value is a dictionary with string keys:
* cpu_features
* cuda_version
* cudnn_version
* is_cuda_build
* third_party
Returns
-------
dict
The info dict.
"""
build_info = {}
build_info_str = backend.GetBuildInformation()
for entry in build_info_str.split('\n'):
k, v = entry.split(':')
if len(v) > 0:
build_info[k] = v[1:]
build_info['is_cuda_build'] = 'cuda_version' in build_info
return build_info
def get_include():
"""Return the directory of framework header files.
Returns
-------
str
The include directory.
"""
core_root = os.path.dirname(os.path.dirname(__file__))
return os.path.join(os.path.dirname(core_root), 'include')
def get_lib():
"""Return the directory of framework libraries.
Returns
-------
str
The library directory.
"""
core_root = os.path.dirname(os.path.dirname(__file__))
return os.path.join(os.path.dirname(core_root), 'lib')
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""Define the basic prototypes.""" """Basic prototypes."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""Generic interfaces of current default workspace.""" """Functions and helpers for workspace."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
...@@ -18,7 +18,7 @@ import collections ...@@ -18,7 +18,7 @@ import collections
import contextlib import contextlib
import numpy import numpy
from dragon import backend from dragon.core.framework import backend
from dragon.core.framework import config from dragon.core.framework import config
from dragon.core.framework import mapping from dragon.core.framework import mapping
from dragon.core.framework import proto_util from dragon.core.framework import proto_util
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""Utilities for KPLRecord.""" """KPLRecord utilities."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""Process to read the distributed data."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""Utilities for TFRecord.""" """TFRecord utilities."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
...@@ -22,10 +22,12 @@ import dragon ...@@ -22,10 +22,12 @@ import dragon
# The global argument parser # The global argument parser
parser = argparse.ArgumentParser(add_help=False) parser = argparse.ArgumentParser(add_help=False)
build_info = dragon.sysconfig.get_build_info()
# The optional testing flags # The optional testing flags
TEST_CUDA = dragon.cuda.is_available() TEST_CUDA = dragon.cuda.is_available()
TEST_MPI = dragon.distributed.is_mpi_available() TEST_MPI = dragon.distributed.is_mpi_available()
TEST_CUDNN_CONV3D_NHWC = build_info.get('cudnn_version', '0.0.0') > '8.0.0'
def run_tests(argv=None): def run_tests(argv=None):
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""The Adam optimizers.""" """Adam optimizers."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""The optimizer to update parameters.""" """Basic optimizers."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""The RMSprop optimizers.""" """RMSprop optimizers."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""The SGD optimizers.""" """SGD optimizers."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
...@@ -20,7 +20,7 @@ import os ...@@ -20,7 +20,7 @@ import os
import sys as _sys import sys as _sys
import threading import threading
from dragon import backend from dragon.core.framework import backend
from dragon.core.framework import config from dragon.core.framework import config
_logger = None _logger = None
......
...@@ -328,7 +328,7 @@ def _get_cuda_arch_flags(cflags=None): ...@@ -328,7 +328,7 @@ def _get_cuda_arch_flags(cflags=None):
'5.0', '5.2', '5.3', '5.0', '5.2', '5.3',
'6.0', '6.1', '6.2', '6.0', '6.1', '6.2',
'7.0', '7.2', '7.5', '7.0', '7.2', '7.5',
'8.0'] '8.0', '8.6']
valid_arch_strings = supported_arches + [s + "+PTX" for s in supported_arches] valid_arch_strings = supported_arches + [s + "+PTX" for s in supported_arches]
capability = _cuda.get_device_capability() capability = _cuda.get_device_capability()
arch_list = ['{}.{}'.format(capability[0], capability[1])] arch_list = ['{}.{}'.format(capability[0], capability[1])]
......
...@@ -144,7 +144,7 @@ CONVERSIONS_DECL half To<half, half>(half val) { ...@@ -144,7 +144,7 @@ CONVERSIONS_DECL half To<half, half>(half val) {
template <> template <>
CONVERSIONS_DECL half To<half, float>(float val) { CONVERSIONS_DECL half To<half, float>(float val) {
#if CUDA_VERSION_MIN(9, 2, 0) #if CUDA_VERSION_MIN(9, 2)
return __float2half(val); return __float2half(val);
#else #else
#if defined(__CUDA_ARCH__) #if defined(__CUDA_ARCH__)
...@@ -161,7 +161,7 @@ CONVERSIONS_DECL half To<half, float>(float val) { ...@@ -161,7 +161,7 @@ CONVERSIONS_DECL half To<half, float>(float val) {
template <> template <>
CONVERSIONS_DECL half2 To<half2, float>(float val) { CONVERSIONS_DECL half2 To<half2, float>(float val) {
#if CUDA_VERSION_MIN(9, 2, 0) #if CUDA_VERSION_MIN(9, 2)
return __float2half2_rn(val); return __float2half2_rn(val);
#else #else
#if defined(__CUDA_ARCH__) #if defined(__CUDA_ARCH__)
......
...@@ -39,11 +39,11 @@ constexpr int CUDA_MAX_DEVICES = 16; ...@@ -39,11 +39,11 @@ constexpr int CUDA_MAX_DEVICES = 16;
/*! \brief The maximum number of tensor dimsensions */ /*! \brief The maximum number of tensor dimsensions */
constexpr int CUDA_TENSOR_MAX_DIMS = 8; constexpr int CUDA_TENSOR_MAX_DIMS = 8;
#define CUDA_VERSION_MIN(major, minor, patch) \ #define CUDA_VERSION_MIN(major, minor) \
(CUDA_VERSION >= (major * 1000 + minor * 100 + patch)) (CUDA_VERSION >= (major * 1000 + minor * 10))
#define CUDA_VERSION_MAX(major, minor, patch) \ #define CUDA_VERSION_MAX(major, minor) \
(CUDA_VERSION < (major * 1000 + minor * 100 + patch)) (CUDA_VERSION < (major * 1000 + minor * 10))
#define CUDA_CHECK(condition) \ #define CUDA_CHECK(condition) \
do { \ do { \
...@@ -87,7 +87,7 @@ inline int CUDA_2D_BLOCKS(const int N) { ...@@ -87,7 +87,7 @@ inline int CUDA_2D_BLOCKS(const int N) {
return std::max(std::min(N, CUDA_MAX_BLOCKS), 1); return std::max(std::min(N, CUDA_MAX_BLOCKS), 1);
} }
#if CUDA_VERSION_MAX(9, 0, 0) #if CUDA_VERSION_MAX(9, 0)
#define __hdiv hdiv #define __hdiv hdiv
#endif #endif
......
...@@ -35,19 +35,9 @@ namespace dragon { ...@@ -35,19 +35,9 @@ namespace dragon {
} while (0) } while (0)
constexpr size_t CUDNN_CONV_WORKSPACE_LIMIT_BYTES = 64 * 1024 * 1024; constexpr size_t CUDNN_CONV_WORKSPACE_LIMIT_BYTES = 64 * 1024 * 1024;
#if CUDNN_VERSION_MIN(7, 0, 0)
constexpr size_t CUDNN_CONV_NUM_FWD_ALGOS =
2 * CUDNN_CONVOLUTION_FWD_ALGO_COUNT;
constexpr size_t CUDNN_CONV_NUM_BWD_FILTER_ALGOS =
2 * CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT;
constexpr size_t CUDNN_CONV_NUM_BWD_DATA_ALGOS =
2 * CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT;
#else
constexpr size_t CUDNN_CONV_NUM_FWD_ALGOS = 7; constexpr size_t CUDNN_CONV_NUM_FWD_ALGOS = 7;
constexpr size_t CUDNN_CONV_NUM_BWD_FILTER_ALGOS = 4; constexpr size_t CUDNN_CONV_NUM_BWD_FILTER_ALGOS = 4;
constexpr size_t CUDNN_CONV_NUM_BWD_DATA_ALGOS = 5; constexpr size_t CUDNN_CONV_NUM_BWD_DATA_ALGOS = 5;
#endif
class Tensor; class Tensor;
......
...@@ -25,6 +25,7 @@ from dragon.core.eager.context import context as execution_context ...@@ -25,6 +25,7 @@ from dragon.core.eager.context import context as execution_context
from dragon.core.util import nest from dragon.core.util import nest
from dragon.core.testing.unittest.common_utils import run_tests from dragon.core.testing.unittest.common_utils import run_tests
from dragon.core.testing.unittest.common_utils import TEST_CUDA from dragon.core.testing.unittest.common_utils import TEST_CUDA
from dragon.core.testing.unittest.common_utils import TEST_CUDNN_CONV3D_NHWC
# Fix the duplicate linked omp runtime # Fix the duplicate linked omp runtime
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
...@@ -3240,10 +3241,11 @@ class TestVisionOps(OpTestCase): ...@@ -3240,10 +3241,11 @@ class TestVisionOps(OpTestCase):
with dragon.device('cuda'), self.cudnn_ws.as_default(): with dragon.device('cuda'), self.cudnn_ws.as_default():
self.test_conv2d() self.test_conv2d()
def test_conv3d(self, prec=1e-3): def test_conv3d(self, prec=1e-3, test_nhwc=True):
entries = [((2, 2, 2, 2, 2), (3, 2, 1, 1, 1), (3,), 1, 1, 0, 1, 1, 'NCHW'), entries = [((2, 2, 2, 2, 2), (3, 2, 1, 1, 1), (3,), 1, 1, 0, 1, 1, 'NCHW'),
((2, 2, 2, 2, 2), (3, 2, 3, 3, 3), (3,), 3, 1, 1, 1, 1, 'NCHW'), ((2, 2, 2, 2, 2), (3, 2, 3, 3, 3), (3,), 3, 1, 1, 1, 1, 'NCHW')]
((2, 2, 2, 2, 2), (3, 2, 1, 1, 1), (3,), 1, 1, 0, 1, 1, 'NHWC'), if test_nhwc:
entries += [((2, 2, 2, 2, 2), (3, 2, 1, 1, 1), (3,), 1, 1, 0, 1, 1, 'NHWC'),
((2, 2, 2, 2, 2), (3, 2, 3, 3, 3), (3,), 3, 1, 1, 1, 1, 'NHWC')] ((2, 2, 2, 2, 2), (3, 2, 3, 3, 3), (3,), 3, 1, 1, 1, 1, 'NHWC')]
results = [[[[[[0.08, 0.09], [0.1, 0.11]], [[0.12, 0.13], [0.14, 0.15]]], results = [[[[[[0.08, 0.09], [0.1, 0.11]], [[0.12, 0.13], [0.14, 0.15]]],
[[[0.34, 0.39], [0.44, 0.49]], [[0.54, 0.59], [0.64, 0.69]]], [[[0.34, 0.39], [0.44, 0.49]], [[0.54, 0.59], [0.64, 0.69]]],
...@@ -3362,7 +3364,7 @@ class TestVisionOps(OpTestCase): ...@@ -3362,7 +3364,7 @@ class TestVisionOps(OpTestCase):
def test_conv3d_cudnn(self): def test_conv3d_cudnn(self):
dragon.cuda.enable_cudnn(True) dragon.cuda.enable_cudnn(True)
with dragon.device('cuda'), self.cudnn_ws.as_default(): with dragon.device('cuda'), self.cudnn_ws.as_default():
self.test_conv3d() self.test_conv3d(test_nhwc=TEST_CUDNN_CONV3D_NHWC)
def test_conv1d_transpose(self, prec=1e-3): def test_conv1d_transpose(self, prec=1e-3):
entries = [((2, 2, 2), (2, 3, 1), (3,), 1, 1, 0, 1, 1, 'NCHW'), entries = [((2, 2, 2), (2, 3, 1), (3,), 1, 1, 0, 1, 1, 'NCHW'),
...@@ -3508,10 +3510,11 @@ class TestVisionOps(OpTestCase): ...@@ -3508,10 +3510,11 @@ class TestVisionOps(OpTestCase):
with dragon.device('cuda'), self.cudnn_ws.as_default(): with dragon.device('cuda'), self.cudnn_ws.as_default():
self.test_conv2d_transpose() self.test_conv2d_transpose()
def test_conv3d_transpose(self, prec=1e-3): def test_conv3d_transpose(self, prec=1e-3, test_nhwc=True):
entries = [((2, 2, 2, 2, 2), (2, 3, 1, 1, 1), (3,), 1, 1, 0, 1, 1, 'NCHW'), entries = [((2, 2, 2, 2, 2), (2, 3, 1, 1, 1), (3,), 1, 1, 0, 1, 1, 'NCHW'),
((2, 2, 2, 2, 2), (2, 3, 3, 3, 3), (3,), 3, 1, 1, 1, 1, 'NCHW'), ((2, 2, 2, 2, 2), (2, 3, 3, 3, 3), (3,), 3, 1, 1, 1, 1, 'NCHW')]
((2, 2, 2, 2, 2), (2, 3, 1, 1, 1), (3,), 1, 1, 0, 1, 1, 'NHWC'), if test_nhwc:
entries += [((2, 2, 2, 2, 2), (2, 3, 1, 1, 1), (3,), 1, 1, 0, 1, 1, 'NHWC'),
((2, 2, 2, 2, 2), (2, 3, 3, 3, 3), (3,), 3, 1, 1, 1, 1, 'NHWC')] ((2, 2, 2, 2, 2), (2, 3, 3, 3, 3), (3,), 3, 1, 1, 1, 1, 'NHWC')]
results = [[[[[[0.24, 0.27], [0.3, 0.33]], [[0.36, 0.39], [0.42, 0.45]]], results = [[[[[[0.24, 0.27], [0.3, 0.33]], [[0.36, 0.39], [0.42, 0.45]]],
[[[0.42, 0.47], [0.52, 0.57]], [[0.62, 0.67], [0.72, 0.77]]], [[[0.42, 0.47], [0.52, 0.57]], [[0.62, 0.67], [0.72, 0.77]]],
...@@ -3631,7 +3634,7 @@ class TestVisionOps(OpTestCase): ...@@ -3631,7 +3634,7 @@ class TestVisionOps(OpTestCase):
def test_conv3d_transpose_cudnn(self): def test_conv3d_transpose_cudnn(self):
dragon.cuda.enable_cudnn(True) dragon.cuda.enable_cudnn(True)
with dragon.device('cuda'), self.cudnn_ws.as_default(): with dragon.device('cuda'), self.cudnn_ws.as_default():
self.test_conv3d_transpose() self.test_conv3d_transpose(test_nhwc=TEST_CUDNN_CONV3D_NHWC)
def test_depthwise_conv2d(self, test_grad=False): def test_depthwise_conv2d(self, test_grad=False):
entries = [((2, 2, 2, 2), (2, 1, 1, 1), (2,), 1, 1, 0, 1, 'NCHW'), entries = [((2, 2, 2, 2), (2, 1, 1, 1), (2,), 1, 1, 0, 1, 'NCHW'),
......
...@@ -113,13 +113,13 @@ def dirac_(tensor, groups=1): ...@@ -113,13 +113,13 @@ def dirac_(tensor, groups=1):
sizes = tensor.size() sizes = tensor.size()
if sizes[0] % groups != 0: if sizes[0] % groups != 0:
raise ValueError('Dimension 0 should be divisible by groups.') raise ValueError('Dimension 0 should be divisible by groups.')
group_dim = sizes[0] // groups out_channels_per_grp = sizes[0] // groups
min_dim = min(group_dim, sizes[1]) min_dim = min(out_channels_per_grp, sizes[1])
with grad_mode.no_grad(): with grad_mode.no_grad():
tensor.zero_() tensor.zero_()
for g in range(groups): for g in range(groups):
for d in range(min_dim): for d in range(min_dim):
item = [g * group_dim + d, d] item = [g * out_channels_per_grp + d, d]
for i in range(2, dimensions): for i in range(2, dimensions):
item.append(sizes[i] // 2) item.append(sizes[i] // 2)
tensor[tuple(item)] = 1 tensor[tuple(item)] = 1
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# <https://github.com/pytorch/pytorch/blob/master/torch/optim/adam.py> # <https://github.com/pytorch/pytorch/blob/master/torch/optim/adam.py>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""Adam optimizers."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# <https://github.com/pytorch/pytorch/blob/master/torch/optim/optimizer.py> # <https://github.com/pytorch/pytorch/blob/master/torch/optim/optimizer.py>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""Basic optimizers."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# <https://github.com/pytorch/pytorch/blob/master/torch/optim/rmsprop.py> # <https://github.com/pytorch/pytorch/blob/master/torch/optim/rmsprop.py>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""RMSprop optimizers."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# <https://github.com/pytorch/pytorch/blob/master/torch/optim/sgd.py> # <https://github.com/pytorch/pytorch/blob/master/torch/optim/sgd.py>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""SGD optimizers."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""DLPack utilities."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# <https://github.com/pytorch/pytorch/blob/master/torch/utils/hooks.py> # <https://github.com/pytorch/pytorch/blob/master/torch/utils/hooks.py>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""Hook utilities."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!