Commit 58708021 by Ting PAN

Fix the stream issue with NCCL2 on CUDA 9.2 and later

Summary:
This commit enforces the stream synchronization before dispatching NCCL collectives.
Otherwise, data corruption will happen due to the default value of ``NCCL_GROUP_CUDA_STREAM``
changed to 0 since CUDA 9.2, i.e., no explicit event waiting for unfinished kernels.
1 parent 58c5371e
Showing with 571 additions and 591 deletions
......@@ -118,7 +118,7 @@ class BatchNorm(Layer):
'use_stats': int(param.use_global_stats)
if param.HasField('use_global_stats') else -1,
'momentum': param.moving_average_fraction,
'eps': param.eps,
'epsilon': param.eps,
'axis': 1,
}
self.add_blob(value=0, no_grad=True) # running_mean
......@@ -398,7 +398,7 @@ class Normalize(Layer):
self.l2norm_arguments = {
'axis': 1,
'num_axes': -1 if param.across_spatial else 1,
'eps': param.eps,
'epsilon': param.eps,
}
self.affine_arguments = {
'axis': 1,
......
......@@ -510,12 +510,22 @@ message ConcatParameter {
}
message BatchNormParameter {
// If false, accumulate global mean/variance values via a moving average. If
// true, use those accumulated values instead of computing mean/variance
// across the batch.
// If false, normalization is performed over the current mini-batch
// and global statistics are accumulated (but not yet used) by a moving
// average.
// If true, those accumulated mean and variance values are used for the
// normalization.
// By default, it is set to false when the network is in the training
// phase and true when the network is in the testing phase.
optional bool use_global_stats = 1;
// How much does the moving average decay each iteration?
optional float moving_average_fraction = 2 [default = 0.9];
// What fraction of the moving average remains each iteration?
// Smaller values make the moving average decay faster, giving more
// weight to the recent values.
// Each iteration updates the moving average @f$S_{t-1}@f$ with the
// current mean @f$ Y_t @f$ by
// @f$ S_t = (1-\beta)Y_t + \beta \cdot S_{t-1} @f$, where @f$ \beta @f$
// is the moving_average_fraction parameter.
optional float moving_average_fraction = 2 [default = .999];
// Small value to add to the variance estimate so that we don't divide by
// zero.
optional float eps = 3 [default = 1e-5];
......
......@@ -22,11 +22,14 @@ from dragon.vm.dali.core.ops.generic_ops import Cast
from dragon.vm.dali.core.ops.generic_ops import Pad
from dragon.vm.dali.core.ops.generic_ops import Reshape
from dragon.vm.dali.core.ops.generic_ops import Slice
from dragon.vm.dali.core.ops.image_ops import Brightness
from dragon.vm.dali.core.ops.image_ops import BrightnessContrast
from dragon.vm.dali.core.ops.image_ops import Contrast
from dragon.vm.dali.core.ops.image_ops import CropMirrorNormalize
from dragon.vm.dali.core.ops.image_ops import Hsv
from dragon.vm.dali.core.ops.image_ops import Paste
from dragon.vm.dali.core.ops.image_ops import RandomBBoxCrop
from dragon.vm.dali.core.ops.image_ops import RandomResizedCrop
from dragon.vm.dali.core.ops.image_ops import Resize
from dragon.vm.dali.core.ops.random_ops import CoinFlip
from dragon.vm.dali.core.ops.random_ops import Uniform
......
......@@ -54,7 +54,7 @@ try:
batch_size=batch_size,
num_threads=num_threads,
device_id=device_id,
seed=seed + device_id,
seed=seed,
prefetch_queue_depth=prefetch_queue_depth,
)
......
......@@ -24,6 +24,33 @@ from dragon.vm.dali.core.framework import context
from dragon.vm.dali.core.framework import types
class Brightness(object):
"""Adjust the brightness of image.
Examples:
```python
# Historical jitter range for brightness
twist_rng = dali.ops.Uniform(range=[0.6, 1.4])
brightness = dali.ops.Brightness()
y = brightness(inputs['x'], brightness=twist_rng())
```
"""
def __new__(cls, **kwargs):
"""Create a ``Brightness`` operator.
Returns
-------
nvidia.dali.ops.Brightness
The operator.
"""
return ops.Brightness(device=context.get_device_type(), **kwargs)
class BrightnessContrast(object):
"""Adjust the brightness and contrast of image.
......@@ -40,7 +67,7 @@ class BrightnessContrast(object):
"""
def __new__(cls, **kwargs):
"""Create a ``BrightnessContrastBrightnessContrast`` operator.
"""Create a ``BrightnessContrast`` operator.
Returns
-------
......@@ -51,6 +78,33 @@ class BrightnessContrast(object):
return ops.BrightnessContrast(device=context.get_device_type(), **kwargs)
class Contrast(object):
"""Adjust the contrast of image.
Examples:
```python
# Historical jitter range for contrast
twist_rng = dali.ops.Uniform(range=[0.6, 1.4])
contrast = dali.ops.Contrast()
y = contrast(inputs['x'], contrast=twist_rng())
```
"""
def __new__(cls, **kwargs):
"""Create a ``Contrast`` operator.
Returns
-------
nvidia.dali.ops.Contrast
The operator.
"""
return ops.Contrast(device=context.get_device_type(), **kwargs)
class CropMirrorNormalize(object):
"""Crop and normalize image with the horizontal flip.
......@@ -285,6 +339,66 @@ class RandomBBoxCrop(object):
)
class RandomResizedCrop(object):
"""Return a resized random crop of image.
Examples:
```python
resize = dali.ops.RandomResizedCrop(
size=(224, 224),
# Inception sampling policy for image classification
random_area=(0.08, 1.00),
random_aspect_ratio=(0.75, 1.33),
)
y = resize(inputs['x'])
```
"""
def __new__(
cls,
size,
interp_type='LINEAR',
random_area=(0.08, 1.),
random_aspect_ratio=(0.75, 1.33),
num_attempts=10,
**kwargs
):
"""Create a ``ImageDecoderRandomCrop`` operator.
Parameters
----------
size : Union[int, Sequence[int]]
The output image size.
interp_type : {'NN', 'LINEAR', 'TRIANGULAR', 'CUBIC', 'GAUSSIAN', 'LANCZOS3'}, optional
The interpolation method.
random_area : Sequence[float], optional, default=(0.08, 1.)
The range of scale for sampling.
random_aspect_ratio : Sequence[float], optional, default=(0.75, 1.33)
The range of aspect ratio for sampling.
num_attempts : int, optional, default=10
The max number of sampling trails.
Returns
-------
nvidia.dali.ops.ImageDecoderRandomCrop
The operator.
"""
if isinstance(interp_type, six.string_types):
interp_type = getattr(types, 'INTERP_' + interp_type.upper())
return ops.RandomResizedCrop(
size=size,
interp_type=interp_type,
random_area=random_area,
random_aspect_ratio=random_aspect_ratio,
num_attempts=num_attempts,
device=context.get_device_type(),
**kwargs
)
class Resize(object):
"""Resize the image.
......@@ -310,7 +424,8 @@ class Resize(object):
resize_shorter=None,
resize_longer=None,
max_size=None,
interp_type='TRIANGULAR',
interp_type='LINEAR',
**kwargs
):
"""Create a ``Resize`` operator.
......@@ -340,4 +455,5 @@ class Resize(object):
max_size=max_size,
interp_type=interp_type,
device=context.get_device_type(),
**kwargs
)
......@@ -30,7 +30,6 @@ def path_to(href, index=False):
# Basic
html_static_path = ['../_static']
exclude_patterns = ['../_build']
master_doc = 'index'
source_suffix = '.rst'
......@@ -78,7 +77,7 @@ html_theme_options = {
'breadcrumb_links': [
('Dragon', path_to('../..', 1)),
('API', path_to('../../versions', 1)),
('Dragon v{}'.format(version.replace('a0', '-a0')), path_to('../../api', 1)),
('Dragon v{}'.format(version.replace('a0', '')), path_to('../../api', 1)),
('C++', path_to('', 1)),
],
}
......
......@@ -26,7 +26,6 @@ def path_to(href, index=False):
# Basic
html_static_path = ['../_static']
exclude_patterns = ['../_build']
master_doc = 'index'
source_suffix = '.rst'
......@@ -72,7 +71,7 @@ html_theme_options = {
'breadcrumb_links': [
('Dragon', path_to('../..', 1)),
('API', path_to('../../versions', 1)),
('Dragon v{}'.format(version.replace('a0', '-a0')), path_to('../../api', 1)),
('Dragon v{}'.format(version.replace('a0', '')), path_to('../../api', 1)),
('Python', path_to('', 1)),
],
}
......
......@@ -12,6 +12,9 @@ vm.dali.ops
`class BBoxPaste <ops/BBoxPaste.html>`_
: Transform bounding boxes to match the ``Paste`` operator.
`class Brightness <ops/Brightness.html>`_
: Adjust the brightness of image.
`class BrightnessContrast <ops/BrightnessContrast.html>`_
: Adjust the brightness and contrast of image.
......@@ -21,6 +24,9 @@ vm.dali.ops
`class CoinFlip <ops/CoinFlip.html>`_
: Sample values from a bernoulli distribution.
`class Contrast <ops/Contrast.html>`_
: Adjust the contrast of image.
`class CropMirrorNormalize <ops/CropMirrorNormalize.html>`_
: Crop and normalize image with the horizontal flip.
......@@ -45,6 +51,9 @@ vm.dali.ops
`class RandomBBoxCrop <ops/RandomBBoxCrop.html>`_
: Return an valid image crop restricted by bounding boxes.
`class RandomResizedCrop <ops/RandomResizedCrop.html>`_
: Return a resized random crop of image.
`class Reshape <ops/Reshape.html>`_
: Change the dimensions of input.
......@@ -68,9 +77,11 @@ vm.dali.ops
ops/BbFlip
ops/BBoxPaste
ops/Brightness
ops/BrightnessContrast
ops/Cast
ops/CoinFlip
ops/Contrast
ops/CropMirrorNormalize
ops/ExternalSource
ops/Hsv
......@@ -79,6 +90,7 @@ vm.dali.ops
ops/Pad
ops/Paste
ops/RandomBBoxCrop
ops/RandomResizedCrop
ops/Reshape
ops/Resize
ops/Slice
......
Brightness
==========
.. autoclass:: dragon.vm.dali.ops.Brightness
__new__
--------
.. automethod:: dragon.vm.dali.ops.Brightness.__new__
.. raw:: html
<style>
h1:before {
content: "dali.ops.";
color: #103d3e;
}
</style>
Contrast
========
.. autoclass:: dragon.vm.dali.ops.Contrast
__new__
--------
.. automethod:: dragon.vm.dali.ops.Contrast.__new__
.. raw:: html
<style>
h1:before {
content: "dali.ops.";
color: #103d3e;
}
</style>
RandomResizedCrop
=================
.. autoclass:: dragon.vm.dali.ops.RandomResizedCrop
__new__
--------
.. automethod:: dragon.vm.dali.ops.RandomResizedCrop.__new__
.. raw:: html
<style>
h1:before {
content: "dali.ops.";
color: #103d3e;
}
</style>
......@@ -13,7 +13,7 @@ dragon.random
: Return a tensor initialized from the glorot uniform distribution.
`multinomial(...) <random/multinomial.html>`_
: Return a tensor with index sampled from the multinomial distribution.
: Return a tensor with index sampled from multinomial distribution.
`normal(...) <random/normal.html>`_
: Return a tensor initialized from the normal distribution.
......
......@@ -77,7 +77,7 @@ Name Supported Reference
`Log`_ |v| :func:`dragon.math.log`
`LogSoftmax`_ |v| :func:`dragon.nn.log_softmax`
`Loop`_
`LpNormalization`_ |v| :func:`dragon.math.l2_normalize`
`LpNormalization`_ |v| :func:`dragon.math.lp_normalize`
`LpPool`_
`MatMul`_ |v| :func:`dragon.math.matmul`
`MatMulInteger`_
......
......@@ -152,7 +152,7 @@ vm.torch
: Compute the element-wise multiplication.
`multinomial(...) <torch/multinomial.html>`_
: Return a tensor where each row sampled from the multinomial distribution.
: Return a tensor with index sampled from multinomial distribution.
`narrow(...) <torch/narrow.html>`_
: Return a new tensor that is a narrowed version of input tensor.
......
......@@ -14,8 +14,9 @@
#define DRAGON_CORE_CONTEXT_CUDA_H_
#include "dragon/core/common.h"
#include "dragon/utils/cuda_device.h"
#include "dragon/utils/cudnn_device.h"
#include "dragon/utils/device/common_cuda.h"
#include "dragon/utils/device/common_cudnn.h"
#include "dragon/utils/device/common_nccl.h"
namespace dragon {
......@@ -130,7 +131,7 @@ class CUDAObjects {
/*! \brief Return the default cuda stream of current device */
cudaStream_t default_stream() {
return stream(CUDA_GET_DEVICE(), 0);
return stream(GetCUDADevice(), 0);
}
/*! \brief Return the default cuda stream of given device */
......@@ -311,7 +312,7 @@ class DRAGON_API CUDAContext {
/*! \brief Return the device index of current thread */
static int current_device() {
return CUDA_GET_DEVICE();
return GetCUDADevice();
}
/*! \brief Return the shared context mutex */
......
#include "dragon/core/memory.h"
#include "dragon/utils/cuda_device.h"
#include "dragon/utils/device/common_cuda.h"
namespace dragon {
......
......@@ -266,6 +266,7 @@ DEFINE_REGISTRY(
}
INSTANTIATE_GET_SINGLE_ARGUMENT(float, f)
INSTANTIATE_GET_SINGLE_ARGUMENT(double, f)
INSTANTIATE_GET_SINGLE_ARGUMENT(int, i)
INSTANTIATE_GET_SINGLE_ARGUMENT(bool, i)
INSTANTIATE_GET_SINGLE_ARGUMENT(int64_t, i)
......
......@@ -320,7 +320,7 @@ class DRAGON_API Tensor {
/*! \brief Try to return the raw const data pointer */
template <class Context>
const void* const_data_ptr() const {
TypeId context_type = TypeMeta::Id<Context>();
auto context_type = TypeMeta::Id<Context>();
if (context_type == TypeMeta::Id<CPUContext>()) {
return memory(true)->cpu_data(nbytes());
} else if (context_type == TypeMeta::Id<CUDAContext>()) {
......@@ -340,7 +340,7 @@ class DRAGON_API Tensor {
if (!memory_ptr) {
*data_ptr = nullptr;
} else {
TypeId context_type = TypeMeta::Id<Context>();
auto context_type = TypeMeta::Id<Context>();
if (context_type == TypeMeta::Id<CPUContext>()) {
*data_ptr = memory_ptr->mutable_cpu_data(nbytes());
} else if (context_type == TypeMeta::Id<CUDAContext>()) {
......
......@@ -9,7 +9,6 @@ add_subdirectory(vision)
# ---[ Extended sources
if (NOT BUILD_RUNTIME)
add_subdirectory(framework)
add_subdirectory(loss)
add_subdirectory(training)
endif()
......
#ifdef USE_CUDA
#include "dragon/core/context_cuda.h"
#include "dragon/utils/cub_device.h"
#include "dragon/utils/device/common_cub.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
......
......@@ -2,7 +2,7 @@
#include "dragon/core/context_cuda.h"
#include "dragon/utils/cast.h"
#include "dragon/utils/cub_device.h"
#include "dragon/utils/device/common_cub.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h"
......
#ifdef USE_CUDA
#include "dragon/core/context_cuda.h"
#include "dragon/utils/cub_device.h"
#include "dragon/utils/device/common_cub.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h"
......
#include "dragon/utils/eigen_utils.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
namespace kernel {
namespace {
template <typename T>
void _GradientTwoSum(const int count, const T* dy1, const T* dy2, T* dx) {
EigenVectorArrayMap<T>(dx, count) +=
(ConstEigenVectorArrayMap<T>(dy1, count) +
ConstEigenVectorArrayMap<T>(dy2, count));
}
} // namespace
/* ------------------- Launcher Separator ------------------- */
template <>
void GradientTwoSum<float16, CPUContext>(
const int count,
const float16* dy1,
const float16* dy2,
float16* dx,
CPUContext* ctx) {
CPU_FP16_NOT_SUPPORTED;
} // TwoSumGrad
#define DEFINE_GRAD_KERNEL_LAUNCHER(T) \
template <> \
void GradientTwoSum<T, CPUContext>( \
const int count, const T* dy1, const T* dy2, T* dx, CPUContext* ctx) { \
_GradientTwoSum(count, dy1, dy2, dx); \
}
DEFINE_GRAD_KERNEL_LAUNCHER(float);
DEFINE_GRAD_KERNEL_LAUNCHER(double);
#undef DEFINE_KERNEL_LAUNCHER
} // namespace kernel
} // namespace dragon
#ifdef USE_CUDA
#include "dragon/core/context_cuda.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
namespace kernel {
namespace {
template <typename T>
__global__ void
_GradientTwoSum(const int nthreads, const T* dy1, const T* dy2, T* dx) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
dx[i] += (dy1[i] + dy2[i]);
}
}
template <>
__global__ void _GradientTwoSum<half>(
const int nthreads,
const half* dy1,
const half* dy2,
half* dx) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 530
dx[i] = __hadd(dx[i], __hadd(dy1[i], dy2[i]));
#endif
}
}
template <>
__global__ void _GradientTwoSum<half2>(
const int nthreads,
const half2* dy1,
const half2* dy2,
half2* dx) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
#if __CUDA_ARCH__ >= 530
dx[i] = __hadd2(dx[i], __hadd2(dy1[i], dy2[i]));
#endif
}
}
} // namespace
/* ------------------- Launcher Separator ------------------- */
template <>
void GradientTwoSum<float16, CUDAContext>(
const int count,
const float16* dy1,
const float16* dy2,
float16* dx,
CUDAContext* ctx) {
if ((count & 1) == 0) {
_GradientTwoSum<<<
CUDA_BLOCKS(count >> 2),
CUDA_THREADS,
0,
ctx->cuda_stream()>>>(
count >> 2,
reinterpret_cast<const half2*>(dy1),
reinterpret_cast<const half2*>(dy2),
reinterpret_cast<half2*>(dx));
} else {
_GradientTwoSum<<<
CUDA_BLOCKS(count),
CUDA_THREADS,
0,
ctx->cuda_stream()>>>(
count,
reinterpret_cast<const half*>(dy1),
reinterpret_cast<const half*>(dy2),
reinterpret_cast<half*>(dx));
}
} // TwoSumGrad
#define DEFINE_GRAD_KERNEL_LAUNCHER(T) \
template <> \
void GradientTwoSum<T, CUDAContext>( \
const int count, const T* dy1, const T* dy2, T* dx, CUDAContext* ctx) { \
_GradientTwoSum<<< \
CUDA_BLOCKS(count), \
CUDA_THREADS, \
0, \
ctx->cuda_stream()>>>(count, dy1, dy2, dx); \
}
DEFINE_GRAD_KERNEL_LAUNCHER(float);
DEFINE_GRAD_KERNEL_LAUNCHER(double);
#undef DEFINE_GRAD_KERNEL_LAUNCHER
} // namespace kernel
} // namespace dragon
#endif // USE_CUDA
......@@ -2,7 +2,7 @@
#include "dragon/core/context_cuda.h"
#include "dragon/utils/cast.h"
#include "dragon/utils/cub_device.h"
#include "dragon/utils/device/common_cub.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h"
......
#ifdef USE_CUDA
#include "dragon/core/memory.h"
#include "dragon/utils/cub_device.h"
#include "dragon/utils/device/common_cub.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h"
......
#ifdef USE_CUDA
#include "dragon/core/memory.h"
#include "dragon/utils/cub_device.h"
#include "dragon/utils/device/common_cub.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h"
......
#ifdef USE_CUDA
#include "dragon/core/context_cuda.h"
#include "dragon/utils/cub_device.h"
#include "dragon/utils/device/common_cub.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
......
#ifdef USE_CUDA
#include "dragon/core/context_cuda.h"
#include "dragon/utils/cub_device.h"
#include "dragon/utils/device/common_cub.h"
#include "dragon/utils/op_kernels.h"
namespace dragon {
......
......@@ -21,16 +21,16 @@ namespace python {
namespace cuda {
class CudaStream {
class CUDAStream {
public:
explicit CudaStream(int device_id) : device_id_(device_id) {
explicit CUDAStream(int device_id) : device_id_(device_id) {
#ifdef USE_CUDA
CUDADeviceGuard guard(device_id);
CUDA_CHECK(cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking));
#endif
}
~CudaStream() {
~CUDAStream() {
#ifdef USE_CUDA
CUDA_CHECK(cudaStreamDestroy(stream_));
#endif
......@@ -132,19 +132,19 @@ void RegisterModule(py::module& m) {
#endif
});
/*! \brief Export the Stream class */
py::class_<CudaStream>(m, "CudaStream")
/*! \brief Export the stream class */
py::class_<CUDAStream>(m, "CUDAStream")
/*! \brief Default constructor */
.def(py::init<int>())
/*! \brief Return the device index */
.def_property_readonly("device_id", &CudaStream::device_id)
.def_property_readonly("device_id", &CUDAStream::device_id)
/*! \brief Return the stream pointer */
.def_property_readonly("ptr", &CudaStream::ptr)
.def_property_readonly("ptr", &CUDAStream::ptr)
/*! \brief Synchronize the stream */
.def("Synchronize", &CudaStream::Synchronize);
.def("Synchronize", &CUDAStream::Synchronize);
}
} // namespace cuda
......
......@@ -184,7 +184,6 @@ ONNXBackend::get_special_nodes() const {
const Map<string, Map<string, string>>& ONNXBackend::get_node_renamed_attrs()
const {
const static Map<string, Map<string, string>> kPerNodeRenamedAttrs = {
{"BatchNormalization", {{"epsilon", "eps"}}},
{"DepthToSpace", {{"blocksize", "block_size"}}},
{"Gemm", {{"transB", "transW"}}},
{"RoiAlign",
......
......@@ -2,6 +2,7 @@
add_subdirectory(activation)
add_subdirectory(array)
add_subdirectory(control_flow)
add_subdirectory(generic)
add_subdirectory(math)
add_subdirectory(normalization)
add_subdirectory(recurrent)
......@@ -10,7 +11,6 @@ add_subdirectory(vision)
# ---[ Extended sources
if (NOT BUILD_RUNTIME)
add_subdirectory(distributed)
add_subdirectory(framework)
add_subdirectory(loss)
add_subdirectory(metric)
add_subdirectory(training)
......
......@@ -31,11 +31,11 @@ void MultinomialOp<Context>::DoRunWithType() {
}
auto* rng = ctx()->rand_generator();
std::uniform_real_distribution<float> eps_dist;
std::uniform_real_distribution<double> epsilon_dist;
for (int i = 0; i < X.count(0, axis); ++i) {
running_total = 0.;
if (eps_ > 0.f && eps_dist(*rng) < eps_) {
if (epsilon_ > 0. && epsilon_dist(*rng) < epsilon_) {
for (int j = 0; j < num_classes; ++j) {
running_total += uniform_p;
cdf[j] = running_total;
......
......@@ -22,7 +22,7 @@ class MultinomialOp final : public Operator<Context> {
public:
MultinomialOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
eps_(OpArg<float>("eps", 0.f)),
epsilon_(OpArg<double>("epsilon", 0.)),
normalize_(OpArg<int64_t>("normalize", 0)),
num_samples_(OpArg<int64_t>("num_samples", 1)) {}
USE_OPERATOR_FUNCTIONS;
......@@ -33,7 +33,7 @@ class MultinomialOp final : public Operator<Context> {
void DoRunWithType();
protected:
float eps_;
double epsilon_;
int64_t normalize_, num_samples_;
};
......
......@@ -8,101 +8,84 @@ namespace dragon {
template <class Context>
template <typename T>
void CollectiveOp<Context>::AllReduceMPI(Tensor* tensor) {
MPI_Request recv_req;
int64_t count = tensor->count();
void CollectiveOp<Context>::AllReduceMPI() {
int64_t count = src_tensor_->count();
int64_t seg_size = count / comm_size_;
int64_t residual = count % comm_size_;
vec64_t sizes(comm_size_, seg_size);
for (int i = 0; i < residual; i++)
for (int i = 0; i < residual; ++i) {
sizes[i]++;
}
vec64_t ends(comm_size_);
ends[0] = sizes[0];
for (int i = 1; i < ends.size(); i++)
vec64_t ends(comm_size_, sizes[0]);
for (int i = 1; i < ends.size(); ++i) {
ends[i] = sizes[i] + ends[i - 1];
}
auto to = (comm_rank_ + 1) % comm_size_;
auto from = (comm_rank_ - 1 + comm_size_) % comm_size_;
auto* x = tensor->template mutable_data<T, Context>();
auto* data = src_tensor_->template mutable_data<T, Context>();
auto* scratch = ws()->template data<T, Context>({sizes[0]})[0];
// Scatter-Reduce
MPI_Request recv_req;
for (int i = 0; i < comm_size_ - 1; i++) {
auto base_id = comm_rank_ - i + comm_size_;
auto recv_i = (base_id - 1) % comm_size_;
auto send_i = base_id % comm_size_;
auto* send = &(x[ends[send_i] - sizes[send_i]]);
auto* update = &(x[ends[recv_i] - sizes[recv_i]]);
IRecv(scratch, sizes[recv_i], from, &recv_req);
Send(send, sizes[send_i], to);
auto base_idx = comm_rank_ - i + comm_size_;
auto recv_idx = (base_idx - 1) % comm_size_;
auto send_idx = base_idx % comm_size_;
auto* send_buf = &(data[ends[send_idx] - sizes[send_idx]]);
auto* recv_buf = &(data[ends[recv_idx] - sizes[recv_idx]]);
IRecv(scratch, sizes[recv_idx], from, &recv_req);
Send(send_buf, sizes[send_idx], to);
MPI_Wait(&recv_req, MPI_STATUS_IGNORE);
math::Axpy(sizes[recv_i], 1.f, scratch, update, ctx());
math::Axpy(sizes[recv_idx], 1.f, scratch, recv_buf, ctx());
// Wait stream to finish the local reduce before next sending
ctx()->FinishDeviceComputation();
}
// Allgather
for (int i = 0; i < comm_size_ - 1; i++) {
auto base_id = comm_rank_ - i + comm_size_;
auto send_i = (base_id + 1) % comm_size_;
auto recv_i = base_id % comm_size_;
auto* send = &(x[ends[send_i] - sizes[send_i]]);
auto* recv = &(x[ends[recv_i] - sizes[recv_i]]);
SendRecv(send, sizes[send_i], to, recv, sizes[recv_i], from);
}
// Normalization
if (comm_size_ > 1 && operation_ == "MEAN") {
math::Scale(count, 1.f / comm_size_, x, x, ctx());
auto base_idx = comm_rank_ - i + comm_size_;
auto send_idx = (base_idx + 1) % comm_size_;
auto recv_idx = base_idx % comm_size_;
auto* send_buf = &(data[ends[send_idx] - sizes[send_idx]]);
auto* recv_buf = &(data[ends[recv_idx] - sizes[recv_idx]]);
SendRecv(send_buf, sizes[send_idx], to, recv_buf, sizes[recv_idx], from);
}
}
template <class Context>
template <typename T>
void CollectiveOp<Context>::AllReduceNCCL(Tensor* tensor) {
void CollectiveOp<Context>::AllReduceNCCL() {
#ifdef USE_NCCL
auto* x = tensor->template mutable_data<T, Context>();
auto* data = src_tensor_->template mutable_data<T, Context>();
NCCL_CHECK(ncclAllReduce(
(const void*)x,
(void*)x,
tensor->count(),
(const void*)data,
(void*)data,
src_tensor_->count(),
this->template nccl_dtype<T>(),
ncclSum,
this->nccl_comm(),
((CUDAContext*)ctx())->cuda_stream()));
if (comm_size_ > 1 && operation_ == "MEAN") {
math::Scale(tensor->count(), 1.f / comm_size_, x, x, ctx());
}
#endif // USE_NCCL
}
template <class Context>
template <typename T>
void CollectiveOp<Context>::AllReduceDispatcher(Tensor* tensor) {
if (enable_nccl_) {
AllReduceNCCL<T>(tensor);
} else {
AllReduceMPI<T>(tensor);
}
}
template <class Context>
template <typename T>
void CollectiveOp<Context>::BroadcastMPI(Tensor* tensor) {
auto* x = tensor->template mutable_data<T, Context>();
Broadcast(x, tensor->count());
void CollectiveOp<Context>::BroadcastMPI() {
auto* data = src_tensor_->template mutable_data<T, Context>();
Broadcast(data, src_tensor_->count());
}
template <class Context>
template <typename T>
void CollectiveOp<Context>::BroadcastNCCL(Tensor* tensor) {
void CollectiveOp<Context>::BroadcastNCCL() {
#ifdef USE_NCCL
auto* x = tensor->template mutable_data<T, Context>();
NCCL_CHECK(ncclBcast(
(void*)x,
tensor->count(),
(void*)src_tensor_->template mutable_data<T, Context>(),
src_tensor_->count(),
this->template nccl_dtype<T>(),
comm_root_,
this->nccl_comm(),
......@@ -112,79 +95,63 @@ void CollectiveOp<Context>::BroadcastNCCL(Tensor* tensor) {
template <class Context>
template <typename T>
void CollectiveOp<Context>::BroadcastDispatcher(Tensor* tensor) {
if (enable_nccl_) {
BroadcastNCCL<T>(tensor);
} else {
BroadcastMPI<T>(tensor);
}
}
template <class Context>
void CollectiveOp<Context>::RunOnDevice() {
if (communication_ == "ALLREDUCE") {
for (int i = 0; i < InputSize(); i++) {
auto& X = Input(i);
if (XIsType(X, int8_t)) {
AllReduceDispatcher<int8_t>(&Input(i));
} else if (XIsType(X, uint8_t)) {
AllReduceDispatcher<uint8_t>(&Input(i));
} else if (XIsType(X, int)) {
AllReduceDispatcher<int>(&Input(i));
} else if (XIsType(X, int64_t)) {
AllReduceDispatcher<int64_t>(&Input(i));
} else if (XIsType(X, float16)) {
AllReduceDispatcher<float16>(&Input(i));
} else if (XIsType(X, float)) {
AllReduceDispatcher<float>(&Input(i));
} else if (XIsType(X, double)) {
AllReduceDispatcher<double>(&Input(i));
void CollectiveOp<Context>::DoRunWithType() {
if (src_tensor_ != nullptr) {
// Dispatch collective communication
if (communication_ == "ALLREDUCE") {
if (enable_nccl_) {
AllReduceNCCL<T>();
} else {
LOG(FATAL) << MessageForUnsupported(
types::to_string(X.meta()),
{"int8",
"uint8",
"int32",
"int64",
"float16",
"float32",
"float64"});
AllReduceMPI<T>();
}
}
} else if (communication_ == "BROADCAST") {
for (int i = 0; i < InputSize(); i++) {
auto& X = Input(i);
if (XIsType(X, bool)) {
BroadcastDispatcher<bool>(&Input(i));
} else if (XIsType(X, int8_t)) {
BroadcastDispatcher<int8_t>(&Input(i));
} else if (XIsType(X, uint8_t)) {
BroadcastDispatcher<uint8_t>(&Input(i));
} else if (XIsType(X, int)) {
BroadcastDispatcher<int>(&Input(i));
} else if (XIsType(X, int64_t)) {
BroadcastDispatcher<int64_t>(&Input(i));
} else if (XIsType(X, float16)) {
BroadcastDispatcher<float16>(&Input(i));
} else if (XIsType(X, float)) {
BroadcastDispatcher<float>(&Input(i));
} else if (XIsType(X, double)) {
BroadcastDispatcher<double>(&Input(i));
} else if (communication_ == "BROADCAST") {
if (enable_nccl_) {
BroadcastNCCL<T>();
} else {
LOG(FATAL) << MessageForUnsupported(
types::to_string(X.meta()),
{"bool",
"int8",
"uint8",
"int32",
"int64",
"float16",
"float32",
"float64"});
BroadcastMPI<T>();
}
} else {
LOG(FATAL) << "Unknown communication: " << communication_;
}
} else {
LOG(FATAL) << "Unknown communication: " << communication_;
// Dispatch other computation
if (communication_ == "ALLREDUCE" && operation_ == "MEAN") {
auto* data = dest_tensor_->template mutable_data<T, Context>();
math::Scale(dest_tensor_->count(), 1.f / comm_size_, data, data, ctx());
}
}
}
template <class Context>
void CollectiveOp<Context>::RunOnDevice() {
if (comm_size_ <= 1) return;
// Wait stream to finish the enqueued kernels.
// Otherwise, data corruption will happen through GPUDirect(UVA)
// during executing collectives asynchronously.
ctx()->FinishDeviceComputation();
#ifdef USE_NCCL
#if NCCL_VERSION_MIN(2, 2, 0)
if (enable_nccl_ && InputSize() <= 2048) {
this->nccl_comm(); // Ensure the comm created
NCCL_CHECK(ncclGroupStart());
}
#endif
#endif
for (int i = 0; i < InputSize(); i++) {
src_tensor_ = &Input(i);
DispatchHelper<MathTensorTypes>::Call(this, *src_tensor_);
}
#ifdef USE_NCCL
#if NCCL_VERSION_MIN(2, 2, 0)
if (enable_nccl_ && InputSize() <= 2048) {
NCCL_CHECK(ncclGroupEnd());
}
#endif
#endif
src_tensor_ = nullptr;
for (int i = 0; i < InputSize(); i++) {
dest_tensor_ = &Input(i);
DispatchHelper<MathTensorTypes>::Call(this, *dest_tensor_);
}
}
......
......@@ -32,26 +32,23 @@ class CollectiveOp final : public CollectiveOpBase<Context> {
void RunOnDevice() override;
template <typename T>
void AllReduceMPI(Tensor*);
void AllReduceMPI();
template <typename T>
void AllReduceNCCL(Tensor*);
void AllReduceNCCL();
template <typename T>
void AllReduceDispatcher(Tensor*);
void BroadcastMPI();
template <typename T>
void BroadcastMPI(Tensor*);
void BroadcastNCCL();
template <typename T>
void BroadcastNCCL(Tensor*);
template <typename T>
void BroadcastDispatcher(Tensor*);
void DoRunWithType();
protected:
string communication_;
string operation_;
string communication_, operation_;
Tensor *src_tensor_, *dest_tensor_;
};
} // namespace dragon
......
# ---[ General sources
file(GLOB INCLUDES *.h)
file(GLOB SOURCES *.cc)
set(MODULE_INCLUDES ${MODULE_INCLUDES} ${INCLUDES})
set(MODULE_SOURCES ${MODULE_SOURCES} ${SOURCES})
# ---[ CUDA sources
if (USE_CUDA)
file(GLOB CUDA_SOURCES *.cu)
set(KERNEL_CUDA_SOURCES ${KERNEL_CUDA_SOURCES} ${CUDA_SOURCES})
if (BUILD_RUNTIME)
# Remove gradient ops
list(REMOVE_ITEM MODULE_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/gradient_ops.cc)
endif()
# ---[ Submit to the parent scope
set(MODULE_INCLUDES ${MODULE_INCLUDES} PARENT_SCOPE)
set(MODULE_SOURCES ${MODULE_SOURCES} PARENT_SCOPE)
set(KERNEL_CUDA_SOURCES ${KERNEL_CUDA_SOURCES} PARENT_SCOPE)
#include "dragon/operators/framework/gradient_ops.h"
#include "dragon/operators/generic/gradient_ops.h"
#include "dragon/utils/math_functions.h"
#include "dragon/utils/op_kernels.h"
......@@ -8,57 +8,62 @@ template <class Context>
template <typename T>
void GradientGenerateOp<Context>::DoRunWithType() {
for (int i = 0; i < OutputSize(); i++) {
if (!Output(i)->has_name()) continue;
Output(i)->ReshapeLike(Input(i));
auto value = cast::to<T>(defaults[i]);
auto* y = Output(i)->template mutable_data<T, Context>();
math::Set(Output(i)->count(), value, y, ctx());
auto* Y = Output(i);
if (!Y->has_name()) continue;
Y->ReshapeLike(Input(i));
math::Set(
Y->count(),
cast::to<T>(defaults_[i]),
Y->template mutable_data<T, Context>(),
ctx());
}
}
template <class Context>
void GradientGenerateOp<Context>::RunOnDevice() {
CHECK_EQ(InputSize(), OutputSize());
CHECK_EQ(defaults_.size(), OutputSize());
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
}
template <class Context>
template <typename T>
void GradientGatherOp<Context>::DoRunWithType() {
int64_t count = Output(0)->count();
auto* y = Output(0)->template mutable_data<T, Context>();
if (indices.size() == 1) {
auto* x = Input(indices[0]).template data<T, Context>();
math::Copy(count, x, y, ctx());
} else if (indices.size() == 2) {
CHECK_EQ(count, Input(indices[1]).count());
auto* a = Input(indices[0]).template data<T, Context>();
auto* b = Input(indices[1]).template data<T, Context>();
math::Add(count, a, b, y, ctx());
auto* Y = Output(0)->ReshapeLike(*grads_[0]);
if (grads_.size() == 1) {
math::Copy(
Y->count(),
grads_[0]->template data<T, Context>(),
Y->template mutable_data<T, Context>(),
ctx());
} else {
size_t i = 1;
auto* x = Input(indices[0]).template data<T, Context>();
math::Copy(count, x, y, ctx());
while (i < indices.size()) {
if (indices.size() - i >= 2) {
auto* a = Input(indices[i]).template data<T, Context>();
auto* b = Input(indices[i + 1]).template data<T, Context>();
kernel::GradientTwoSum(count, a, b, y, ctx());
i += 2;
} else {
x = Input(indices[i]).template data<T, Context>();
math::Add(count, y, x, y, ctx());
break;
}
CHECK_EQ(Y->count(), grads_[1]->count());
auto* y = Y->template mutable_data<T, Context>();
math::Add(
Y->count(),
grads_[0]->template data<T, Context>(),
grads_[1]->template data<T, Context>(),
y,
ctx());
for (int i = 2; i < grads_.size(); ++i) {
CHECK_EQ(Y->count(), grads_[i]->count());
math::Add(
Y->count(), y, grads_[i]->template data<T, Context>(), y, ctx());
}
}
}
template <class Context>
void GradientGatherOp<Context>::RunOnDevice() {
if (indices.size() == 0) return;
auto& Xi = Input(indices[0]);
Output(0)->ReshapeLike(Xi);
DispatchHelper<FloatingTensorTypes>::Call(this, Xi);
grads_.clear();
for (int i = 0; i < InputSize(); i++) {
auto* X = &Input(i);
if (X->has_name()) {
grads_.push_back(X);
}
}
if (grads_.empty() || !Output(0)->has_name()) return;
DispatchHelper<FloatingTensorTypes>::Call(this, *grads_[0]);
}
template <class Context>
......@@ -72,7 +77,7 @@ void GradientAddOp<Context>::DoRunWithType() {
template <class Context>
void GradientAddOp<Context>::RunOnDevice() {
CHECK_EQ(Input(0).name(), Output(0)->name())
<< "\nRequires Input(0) == Output(0).";
<< "\nExcepted Input(0) and Output(0) are the same tensor.";
DispatchHelper<FloatingTensorTypes>::Call(this, Input(0));
}
......@@ -117,11 +122,11 @@ OPERATOR_SCHEMA(GradientGather)
.NumOutputs(1);
OPERATOR_SCHEMA(GradientAdd)
/* X1, X2 */
/* A, B */
.NumInputs(2)
/* Y */
.NumOutputs(1)
/* X1 => Y */
/* A => Y */
.AllowInplace({{0, 0}});
OPERATOR_SCHEMA(StopGradient)
......
......@@ -10,8 +10,8 @@
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_FRAMEWORK_GRADIENT_OP_H_
#define DRAGON_OPERATORS_FRAMEWORK_GRADIENT_OP_H_
#ifndef DRAGON_OPERATORS_GENERIC_GRADIENT_OPS_H_
#define DRAGON_OPERATORS_GENERIC_GRADIENT_OPS_H_
#include "dragon/core/operator.h"
......@@ -21,10 +21,7 @@ template <class Context>
class GradientGenerateOp final : public Operator<Context> {
public:
GradientGenerateOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), defaults(OpArgs<float>("defaults")) {
CHECK_EQ(InputSize(), OutputSize());
CHECK_EQ(defaults.size(), OutputSize());
}
: Operator<Context>(def, ws), defaults_(OpArgs<float>("defaults")) {}
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......@@ -33,20 +30,13 @@ class GradientGenerateOp final : public Operator<Context> {
void DoRunWithType();
protected:
vector<float> defaults;
vector<float> defaults_;
};
template <class Context>
class GradientGatherOp final : public Operator<Context> {
public:
GradientGatherOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws) {
for (int i = 0; i < InputSize(); i++) {
if (Input(i).has_name()) {
indices.push_back(i);
}
}
}
SIMPLE_CTOR_DTOR(GradientGatherOp);
USE_OPERATOR_FUNCTIONS;
void RunOnDevice() override;
......@@ -55,7 +45,7 @@ class GradientGatherOp final : public Operator<Context> {
void DoRunWithType();
protected:
vec32_t indices;
vector<Tensor*> grads_;
};
template <class Context>
......@@ -81,4 +71,4 @@ class StopGradientOp final : public Operator<Context> {
} // namespace dragon
#endif // DRAGON_OPERATORS_FRAMEWORK_GRADIENT_OP_H_
#endif // DRAGON_OPERATORS_GENERIC_GRADIENT_OPS_H_
......@@ -49,7 +49,7 @@ void BatchNormOp<Context>::TrainingImpl() {
// Fuse parameters along channel axis
// [mu, rsig, alpha, beta] => [scale, bias]
math::InvStd(C_, eps_, rsig, rsig, ctx());
math::InvStd(C_, epsilon_, rsig, rsig, ctx());
math::Mul(C_, gamma, rsig, scale, ctx());
math::Mul(C_, scale, mu, bias, ctx());
math::Sub(C_, beta, bias, bias, ctx());
......@@ -84,7 +84,7 @@ void BatchNormOp<Context>::InferenceImpl() {
// Fuse parameters along channel axis
// [mu, rsig, alpha, beta] => [scale, bias]
math::InvStd(C_, eps_, rv, bias, ctx());
math::InvStd(C_, epsilon_, rv, bias, ctx());
math::Mul(C_, gamma, bias, scale, ctx());
math::Mul(C_, scale, rm, bias, ctx());
math::Sub(C_, beta, bias, bias, ctx());
......@@ -103,7 +103,7 @@ void BatchNormOp<Context>::RunOnDevice() {
// Get the recomputing flag
auto* flag = ws()->GetTensor("/share/flag/recomputing");
is_recomputing_ = flag->template data<bool, CPUContext>()[0];
is_recomputing_ = flag->template data<bool, CPUContext>()[0] ? 1 : 0;
// Dispatch the training or inference impl
Output(0)->ReshapeLike(Input(0));
......@@ -159,7 +159,7 @@ void BatchNormGradientOp<Context>::InferenceImpl() {
}
// Restore inverse stddev from variance
math::InvStd(C_, eps_, rv, rsig, ctx());
math::InvStd(C_, epsilon_, rv, rsig, ctx());
// Gradient w.r.t. gamma, beta and input
kernel::BatchNormBackwardInference(
......
......@@ -21,20 +21,20 @@
namespace dragon {
// Multiple inheritance is forbidden by the registry.
// So, we should inherit the collective base as the meta.
// So, we should inherit the collective op base if mpi available.
#ifdef USE_MPI
#define BatchNormOpBaseMeta CollectiveOpBase
#define GenericOpBase CollectiveOpBase
#else
#define BatchNormOpBaseMeta Operator
#define GenericOpBase Operator
#endif
template <class Context>
class BatchNormOpBase : public BatchNormOpBaseMeta<Context> {
class BatchNormOpBase : public GenericOpBase<Context> {
public:
BatchNormOpBase(const OperatorDef& def, Workspace* ws)
: BatchNormOpBaseMeta<Context>(def, ws),
: GenericOpBase<Context>(def, ws),
momentum_(OpArg<float>("momentum", 0.9f)),
eps_(OpArg<float>("eps", 1e-5f)),
epsilon_(OpArg<double>("epsilon", 1e-5)),
use_stats_(OpArg<int64_t>("use_stats", -1)) {}
USE_OPERATOR_FUNCTIONS;
......@@ -56,17 +56,18 @@ class BatchNormOpBase : public BatchNormOpBaseMeta<Context> {
}
protected:
float momentum_, eps_;
float momentum_;
double epsilon_;
int64_t use_stats_, N_, C_, S_;
int64_t is_training_, is_recomputing_;
};
#undef BatchNormOpBaseMeta
#undef GenericOpBase
#define USE_BATCHNORM_FUNCTIONS \
using BatchNormOpBase<Context>::DetermineBaseArguments; \
using BatchNormOpBase<Context>::momentum_; \
using BatchNormOpBase<Context>::eps_; \
using BatchNormOpBase<Context>::epsilon_; \
using BatchNormOpBase<Context>::use_stats_; \
using BatchNormOpBase<Context>::N_; \
using BatchNormOpBase<Context>::C_; \
......@@ -148,14 +149,15 @@ template <class Context>
class CuDNNBatchNormOp final : public BatchNormOpBase<Context> {
public:
CuDNNBatchNormOp(const OperatorDef& def, Workspace* ws)
: BatchNormOpBase<Context>(def, ws), eps64_(OpArg<float>("eps", 1e-5f)) {
: BatchNormOpBase<Context>(def, ws) {
CuDNNCreateTensorDesc(&bn_desc_);
CuDNNCreateTensorDesc(&input_desc_);
if (eps64_ <= CUDNN_BN_MIN_EPSILON - FLT_EPSILON)
LOG(FATAL) << "Provided epsilon is smaller than "
if (epsilon_ <= CUDNN_BN_MIN_EPSILON) {
LOG(ERROR) << "Provided epsilon is smaller than "
<< "CUDNN_BN_MIN_EPSILON. \nSet it to "
<< "CUDNN_BN_MIN_EPSILON instead.";
eps64_ = std::max(eps64_, CUDNN_BN_MIN_EPSILON);
epsilon_ = CUDNN_BN_MIN_EPSILON;
}
}
USE_OPERATOR_FUNCTIONS;
USE_BATCHNORM_FUNCTIONS;
......@@ -171,7 +173,6 @@ class CuDNNBatchNormOp final : public BatchNormOpBase<Context> {
void DoRunWithType();
protected:
double eps64_;
cudnnTensorDescriptor_t input_desc_, bn_desc_;
cudnnBatchNormMode_t bn_mode_;
};
......@@ -180,15 +181,15 @@ template <class Context>
class CuDNNBatchNormGradientOp final : public BatchNormGradientOp<Context> {
public:
CuDNNBatchNormGradientOp(const OperatorDef& def, Workspace* ws)
: BatchNormGradientOp<Context>(def, ws),
eps64_(OpArg<float>("eps", 1e-5f)) {
: BatchNormGradientOp<Context>(def, ws) {
CuDNNCreateTensorDesc(&bn_desc_);
CuDNNCreateTensorDesc(&input_desc_);
if (eps64_ <= CUDNN_BN_MIN_EPSILON - FLT_EPSILON)
LOG(FATAL) << "Provided epsilon is smaller than "
if (epsilon_ <= CUDNN_BN_MIN_EPSILON) {
LOG(ERROR) << "Provided epsilon is smaller than "
<< "CUDNN_BN_MIN_EPSILON. \nSet it to "
<< "CUDNN_BN_MIN_EPSILON instead.";
eps64_ = std::max(eps64_, CUDNN_BN_MIN_EPSILON);
epsilon_ = CUDNN_BN_MIN_EPSILON;
}
}
USE_OPERATOR_FUNCTIONS;
USE_BATCHNORM_FUNCTIONS;
......@@ -204,7 +205,6 @@ class CuDNNBatchNormGradientOp final : public BatchNormGradientOp<Context> {
void TrainingImpl();
protected:
double eps64_;
cudnnTensorDescriptor_t input_desc_, bn_desc_;
cudnnBatchNormMode_t bn_mode_;
};
......
......@@ -17,11 +17,6 @@ void CuDNNBatchNormOp<Context>::DoRunWithType() {
CuDNNSetTensorDesc<T>(&input_desc_, vec64_t({N_, C_, 1, 1}));
} else {
bn_mode_ = CUDNN_BATCHNORM_SPATIAL;
#if CUDNN_VERSION_MIN(7, 0, 0)
if (is_training_ > 0) {
bn_mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
}
#endif
CuDNNSetTensorDesc<T>(&input_desc_, Input(0).dims(), data_format());
}
......@@ -48,10 +43,10 @@ void CuDNNBatchNormOp<Context>::DoRunWithType() {
bn_desc_,
Input(1).template data<ParamType, Context>(), // gamma
Input(2).template data<ParamType, Context>(), // beta
is_recomputing_ ? 0.f : 1.f - this->momentum_,
is_recomputing_ > 0 ? 0.f : 1.f - this->momentum_,
Input(3).template mutable_data<ParamType, Context>(), // rm
Input(4).template mutable_data<ParamType, Context>(), // rv
eps64_,
epsilon_,
X_mu->template mutable_data<ParamType, Context>(), // sm
X_rsig->template mutable_data<ParamType, Context>())); // sv
} else {
......@@ -69,7 +64,7 @@ void CuDNNBatchNormOp<Context>::DoRunWithType() {
Input(2).template data<ParamType, Context>(), // beta
Input(3).template data<ParamType, Context>(), // rm
Input(4).template data<ParamType, Context>(), // rv
eps64_));
epsilon_));
}
}
......@@ -79,7 +74,7 @@ void CuDNNBatchNormOp<Context>::RunOnDevice() {
// Get the recomputing flag
auto* flag = ws()->GetTensor("/share/flag/recomputing");
is_recomputing_ = flag->template data<bool, CPUContext>()[0];
is_recomputing_ = flag->template data<bool, CPUContext>()[0] ? 1 : 0;
// Dispatch the training or inference impl
Output(0)->ReshapeLike(Input(0));
......@@ -106,9 +101,6 @@ void CuDNNBatchNormGradientOp<Context>::TrainingImpl() {
CuDNNSetTensorDesc<T>(&input_desc_, vec64_t({N_, C_, 1, 1}));
} else {
bn_mode_ = CUDNN_BATCHNORM_SPATIAL;
#if CUDNN_VERSION_MIN(7, 0, 0)
bn_mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
#endif
CuDNNSetTensorDesc<T>(&input_desc_, Input(0).dims(), data_format());
}
......@@ -133,7 +125,7 @@ void CuDNNBatchNormGradientOp<Context>::TrainingImpl() {
Input(1).template data<ParamType, Context>(), // gamma
dW->Reshape({C_})->template mutable_data<ParamType, Context>(), // dw
dB->Reshape({C_})->template mutable_data<ParamType, Context>(), // db
eps64_,
epsilon_,
X_mu->template data<ParamType, Context>(), // mu
X_rsig->template data<ParamType, Context>())); // rsig
}
......@@ -154,7 +146,6 @@ void CuDNNBatchNormGradientOp<Context>::RunOnDevice() {
if (is_training_ > 0) {
TrainingImpl<float16>();
} else {
// We will support it some day -:)
LOG(FATAL) << MessageForUnsupported(
types::to_string(Input(0).meta()), {"float32"});
}
......
......@@ -81,7 +81,7 @@ void SyncBatchNormOp<Context>::TrainingImpl() {
// Fuse parameters along channel axis
// [mu, rsig, alpha, beta] => [scale, bias]
math::InvStd(C_, eps_, rsig, rsig, ctx());
math::InvStd(C_, epsilon_, rsig, rsig, ctx());
math::Mul(C_, gamma, rsig, scale, ctx());
math::Mul(C_, scale, mu, bias, ctx());
math::Sub(C_, beta, bias, bias, ctx());
......@@ -100,7 +100,7 @@ void SyncBatchNormOp<Context>::RunOnDevice() {
// Get the recomputing flag
auto* flag = ws()->GetTensor("/share/flag/recomputing");
is_recomputing_ = flag->template data<bool, CPUContext>()[0];
is_recomputing_ = flag->template data<bool, CPUContext>()[0] ? 1 : 0;
// Dispatch the training or inference impl
Output(0)->ReshapeLike(Input(0));
......
......@@ -31,7 +31,7 @@ void GroupNormOp<Context>::DoRunWithType() {
kernel::Moments(4, dims.data(), 2, axes.data(), x, mu, rsig, ctx());
}
math::InvStd(N_ * G_, eps_, rsig, rsig, ctx());
math::InvStd(N_ * G_, epsilon_, rsig, rsig, ctx());
kernel::GroupNormForward(
N_,
......
......@@ -23,7 +23,7 @@ class GroupNormOpBase : public Operator<Context> {
GroupNormOpBase(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
group_(OpArg<int64_t>("group", 0)),
eps_(OpArg<float>("eps", 1e-5f)) {}
epsilon_(OpArg<double>("epsilon", 1e-5)) {}
USE_OPERATOR_FUNCTIONS;
void DetermineBaseArguments() {
......@@ -45,7 +45,7 @@ class GroupNormOpBase : public Operator<Context> {
}
protected:
float eps_;
double epsilon_;
int64_t group_;
int64_t N_, C_, G_, D_, S_;
};
......@@ -53,7 +53,7 @@ class GroupNormOpBase : public Operator<Context> {
#define USE_GROUPNORM_FUNCTIONS \
using GroupNormOpBase<Context>::DetermineBaseArguments; \
using GroupNormOpBase<Context>::group_; \
using GroupNormOpBase<Context>::eps_; \
using GroupNormOpBase<Context>::epsilon_; \
using GroupNormOpBase<Context>::N_; \
using GroupNormOpBase<Context>::C_; \
using GroupNormOpBase<Context>::G_; \
......
......@@ -31,7 +31,7 @@ void LpNormalizeOp<Context>::DoRunWithType() {
reduce_dim,
X.count(axis + num_axes),
reduction_ == "MEAN" ? 1.f / (float)reduce_dim : 1.f,
eps_,
epsilon_,
X.template data<T, Context>(),
Y->ReshapeLike(X)->template mutable_data<T, Context>(),
ctx());
......@@ -41,7 +41,7 @@ void LpNormalizeOp<Context>::DoRunWithType() {
reduce_dim,
X.count(axis + num_axes),
reduction_ == "MEAN" ? 1.f / (float)reduce_dim : 1.f,
eps_,
epsilon_,
X.template data<T, Context>(),
Y->ReshapeLike(X)->template mutable_data<T, Context>(),
ctx());
......@@ -68,7 +68,7 @@ void LpNormalizeGradientOp<Context>::DoRunWithType() {
reduce_dim,
X.count(axis + num_axes),
reduction_ == "MEAN" ? 1.f / (float)reduce_dim : 1.f,
eps_,
epsilon_,
dY.template data<T, Context>(),
X.template data<T, Context>(),
dX->ReshapeLike(X)->template mutable_data<T, Context>(),
......@@ -79,7 +79,7 @@ void LpNormalizeGradientOp<Context>::DoRunWithType() {
reduce_dim,
X.count(axis + num_axes),
reduction_ == "MEAN" ? 1.f / (float)reduce_dim : 1.f,
eps_,
epsilon_,
dY.template data<T, Context>(),
X.template data<T, Context>(),
dX->ReshapeLike(X)->template mutable_data<T, Context>(),
......
......@@ -23,7 +23,7 @@ class LpNormalizeOp final : public Operator<Context> {
LpNormalizeOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
p_(OpArg<int64_t>("p", 2)),
eps_(OpArg<float>("eps", 1e-12f)),
epsilon_(OpArg<double>("epsilon", 1e-12)),
reduction_(OpArg<string>("reduction", "SUM")) {}
USE_OPERATOR_FUNCTIONS;
......@@ -33,9 +33,9 @@ class LpNormalizeOp final : public Operator<Context> {
void DoRunWithType();
protected:
float eps_;
string reduction_;
int64_t p_;
double epsilon_;
string reduction_;
};
template <class Context>
......@@ -44,7 +44,7 @@ class LpNormalizeGradientOp final : public Operator<Context> {
LpNormalizeGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
p_(OpArg<int64_t>("p", 2)),
eps_(OpArg<float>("eps", 1e-12f)),
epsilon_(OpArg<double>("epsilon", 1e-12)),
reduction_(OpArg<string>("reduction", "SUM")) {}
USE_OPERATOR_FUNCTIONS;
......@@ -54,9 +54,9 @@ class LpNormalizeGradientOp final : public Operator<Context> {
void DoRunWithType();
protected:
float eps_;
string reduction_;
int64_t p_;
float epsilon_;
string reduction_;
};
} // namespace dragon
......
......@@ -19,7 +19,7 @@ from dragon.core.framework import config
from dragon.core.framework import workspace
class Stream(backend.CudaStream):
class Stream(backend.CUDAStream):
"""The CUDA stream wrapper."""
def __init__(self, device_index):
......
......@@ -850,11 +850,10 @@ def moments(inputs, axis=None, keep_dims=False, **kwargs):
@OpSchema.num_inputs(1)
def multinomial(inputs, num_samples=1, eps=0., normalize=False, **kwargs):
def multinomial(inputs, num_samples=1, epsilon=0, normalize=False, **kwargs):
"""Return a tensor with index sampled from multinomial distribution.
If ``normalize`` is **True**, negative input is accepted,
and will be normalized by a **Softmax** function.
If ``normalize`` is **True**, negative input is accepted.
Otherwise, input should be non-negative.
......@@ -864,8 +863,8 @@ def multinomial(inputs, num_samples=1, eps=0., normalize=False, **kwargs):
The input tensor.
num_samples : int, optional, default=1
The number of samples.
eps : float, optional, default=0.
The prob to a uniform sampling.
epsilon : float, optional, default=0
The epsilon value to apply e-greedy strategy.
normalize : bool, optional, default=False
Whether to normalize the input.
......@@ -876,13 +875,13 @@ def multinomial(inputs, num_samples=1, eps=0., normalize=False, **kwargs):
"""
args = parse_args(locals())
args['eps'] = float(eps)
args['epsilon'] = float(epsilon)
op_lib = array_ops_lib.Multinomial
if context.executing_eagerly():
return op_lib \
.instantiate(
num_samples=num_samples,
eps=args['eps'],
epsilon=args['epsilon'],
normalize=normalize,
).apply([inputs])
else:
......
......@@ -305,7 +305,7 @@ class Moments(Operator):
class Multinomial(Operator):
def __init__(self, key, dev, **kwargs):
super(Multinomial, self).__init__(key, dev, **kwargs)
self.eps = kwargs.get('eps', 0.)
self.epsilon = kwargs.get('epsilon', 0.)
self.normalize = kwargs.get('normalize', False)
self.num_samples = kwargs.get('num_samples', 1)
......@@ -313,7 +313,7 @@ class Multinomial(Operator):
return {
'op_type': 'Multinomial',
'arguments': {
'eps': self.eps,
'epsilon': self.epsilon,
'normalize': self.normalize,
'num_samples': self.num_samples,
}
......
......@@ -27,7 +27,7 @@ def batch_norm(
inputs,
axis=-1,
momentum=0.9,
eps=1e-5,
epsilon=1e-5,
use_stats=-1,
**kwargs
):
......@@ -58,7 +58,7 @@ def batch_norm(
The channel axis.
momentum : float, optional, default=0.9
The momentum for running average.
eps : float, optional, default=1e-5
epsilon : float, optional, default=1e-5
The value to :math:`\epsilon`.
use_stats : int, optional, default=-1
Whether to use estimated statistics or not.
......@@ -70,14 +70,14 @@ def batch_norm(
"""
args = parse_args(locals())
args['momentum'], args['eps'] = float(momentum), float(eps)
args['momentum'], args['epsilon'] = float(momentum), float(epsilon)
op_lib = normalization_ops_lib.BatchNorm
if context.executing_eagerly():
return op_lib \
.instantiate(
axis=axis,
momentum=args['momentum'],
eps=args['eps'],
epsilon=args['epsilon'],
use_stats=use_stats,
).apply(inputs)
else:
......@@ -85,7 +85,7 @@ def batch_norm(
@OpSchema.num_inputs(3)
def group_norm(inputs, axis=-1, group=32, eps=1e-5, **kwargs):
def group_norm(inputs, axis=-1, group=32, epsilon=1e-5, **kwargs):
r"""Apply the group normalization.
`[Wu & He, 2018] <https://arxiv.org/abs/1803.08494>`_.
......@@ -111,7 +111,7 @@ def group_norm(inputs, axis=-1, group=32, eps=1e-5, **kwargs):
The channel axis.
group : int, optional, default=32
The group size.
eps : float, optional, default=1e-5
epsilon : float, optional, default=1e-5
The value to :math:`\epsilon`.
Returns
......@@ -121,21 +121,21 @@ def group_norm(inputs, axis=-1, group=32, eps=1e-5, **kwargs):
"""
args = parse_args(locals())
args['eps'] = float(eps)
args['epsilon'] = float(epsilon)
op_lib = normalization_ops_lib.GroupNorm
if context.executing_eagerly():
return op_lib \
.instantiate(
axis=axis,
group=group,
eps=args['eps'],
epsilon=args['epsilon'],
).apply(inputs)
else:
return op_lib.blend(**args)
@OpSchema.num_inputs(3)
def instance_norm(inputs, axis=-1, eps=1e-5, **kwargs):
def instance_norm(inputs, axis=-1, epsilon=1e-5, **kwargs):
r"""Apply the instance normalization.
`[Ulyanov et.al, 2016] <https://arxiv.org/abs/1607.08022>`_
......@@ -156,7 +156,7 @@ def instance_norm(inputs, axis=-1, eps=1e-5, **kwargs):
The tensor ``x``, ``gamma`` and ``beta``.
axis : int, optional, default=-1
The channel axis.
eps : float, optional, default=1e-5
epsilon : float, optional, default=1e-5
The value to :math:`\epsilon`.
Returns
......@@ -165,11 +165,11 @@ def instance_norm(inputs, axis=-1, eps=1e-5, **kwargs):
The output tensor.
"""
return group_norm(inputs, axis=axis, group=0, eps=eps, **kwargs)
return group_norm(inputs, axis=axis, group=0, epsilon=epsilon, **kwargs)
@OpSchema.num_inputs(1)
def lp_normalize(inputs, axis=None, p=2, eps=1e-12, reduction='sum', **kwargs):
def lp_normalize(inputs, axis=None, p=2, epsilon=1e-12, reduction='sum', **kwargs):
r"""Apply the lp normalization.
The **Lp-Normalization** is defined as:
......@@ -201,7 +201,7 @@ def lp_normalize(inputs, axis=None, p=2, eps=1e-12, reduction='sum', **kwargs):
The order of the normalization.
axis : Union[int, Sequence[int]], optional
The axis to compute the norm.
eps : float, optional, default=1e-12
epsilon : float, optional, default=1e-12
The value to :math:`\epsilon`.
reduction : {'sum', 'mean'}, optional
The reduction method for norm.
......@@ -222,7 +222,7 @@ def lp_normalize(inputs, axis=None, p=2, eps=1e-12, reduction='sum', **kwargs):
raise ValueError('The <axis> should be a continuous sequence.')
args['axis'], args['num_axes'] = axes[0], len(axes)
args['num_axes'] = kwargs.get('num_axes', args['num_axes'])
args['eps'] = float(eps)
args['epsilon'] = float(epsilon)
args['reduction'] = reduction.upper()
op_lib = normalization_ops_lib.LpNormalize
if context.executing_eagerly():
......@@ -231,7 +231,7 @@ def lp_normalize(inputs, axis=None, p=2, eps=1e-12, reduction='sum', **kwargs):
p=p,
axis=args['axis'],
num_axes=args['num_axes'],
eps=args['eps'],
epsilon=args['epsilon'],
reduction=args['reduction'],
).apply([inputs])
else:
......@@ -239,7 +239,7 @@ def lp_normalize(inputs, axis=None, p=2, eps=1e-12, reduction='sum', **kwargs):
@OpSchema.num_inputs(3)
def layer_norm(inputs, axis=-1, eps=1e-5, **kwargs):
def layer_norm(inputs, axis=-1, epsilon=1e-5, **kwargs):
r"""Apply the layer normalization.
`[Ba et.al, 2016] <https://arxiv.org/abs/1607.06450>`_
......@@ -260,7 +260,7 @@ def layer_norm(inputs, axis=-1, eps=1e-5, **kwargs):
The tensor ``x``, ``gamma`` and ``beta``.
axis : int, optional, default=-1
The channel axis.
eps : float, optional, default=1e-5
epsilon : float, optional, default=1e-5
The value to :math:`\epsilon`.
Returns
......@@ -269,7 +269,7 @@ def layer_norm(inputs, axis=-1, eps=1e-5, **kwargs):
The output tensor.
"""
return group_norm(inputs, axis=axis, group=1, eps=eps, **kwargs)
return group_norm(inputs, axis=axis, group=1, epsilon=epsilon, **kwargs)
@OpSchema.num_inputs(1)
......@@ -337,7 +337,7 @@ def sync_batch_norm(
inputs,
axis=-1,
momentum=0.9,
eps=1e-5,
epsilon=1e-5,
use_stats=-1,
process_group=None,
**kwargs
......@@ -369,7 +369,7 @@ def sync_batch_norm(
The channel axis.
momentum : float, optional, default=0.9
The momentum for average.
eps : float, optional, default=1e-5
epsilon : float, optional, default=1e-5
The value to :math:`\epsilon`.
use_stats : int, optional, default=-1
Whether to use estimated statistics or not.
......@@ -383,7 +383,7 @@ def sync_batch_norm(
"""
args = parse_args(locals())
args['momentum'], args['eps'] = float(momentum), float(eps)
args['momentum'], args['epsilon'] = float(momentum), float(epsilon)
if process_group is None:
process_group = distributed.get_group()
if process_group is None:
......@@ -394,7 +394,7 @@ def sync_batch_norm(
.instantiate(
axis=axis,
momentum=args['momentum'],
eps=args['eps'],
epsilon=args['epsilon'],
use_stats=use_stats,
process_group=process_group,
).apply(inputs)
......
......@@ -22,7 +22,7 @@ class BatchNorm(Operator):
super(BatchNorm, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', -1)
self.momentum = kwargs.get('momentum', 0.9)
self.eps = kwargs.get('eps', 1e-5)
self.epsilon = kwargs.get('epsilon', 1e-5)
self.use_stats = kwargs.get('use_stats', 0)
if self.use_stats not in (0, 1):
raise ValueError('Excepted determined stats mode.')
......@@ -33,7 +33,7 @@ class BatchNorm(Operator):
'arguments': {
'axis': self.axis,
'momentum': self.momentum,
'eps': self.eps,
'epsilon': self.epsilon,
'use_stats': self.use_stats,
}
}
......@@ -47,7 +47,7 @@ class GroupNorm(Operator):
super(GroupNorm, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', -1)
self.group = kwargs.get('group', 32)
self.eps = kwargs.get('eps', 1e-5)
self.epsilon = kwargs.get('epsilon', 1e-5)
def attributes(self):
return {
......@@ -55,7 +55,7 @@ class GroupNorm(Operator):
'arguments': {
'axis': self.axis,
'group': self.group,
'eps': self.eps,
'epsilon': self.epsilon,
}
}
......@@ -69,7 +69,7 @@ class LpNormalize(Operator):
self.p = kwargs.get('p', 2)
self.axis = kwargs.get('axis', 0)
self.num_axes = kwargs.get('num_axes', -1)
self.eps = kwargs.get('eps', 1e-5)
self.epsilon = kwargs.get('epsilon', 1e-12)
self.reduction = kwargs.get('reduction', 'SUM')
def attributes(self):
......@@ -79,7 +79,7 @@ class LpNormalize(Operator):
'p': self.p,
'axis': self.axis,
'num_axes': self.num_axes,
'eps': self.eps,
'epsilon': self.epsilon,
'reduction': self.reduction,
}
}
......
......@@ -22,7 +22,7 @@ def batch_norm_exporter(op_def, shape_dict, ws):
node, const_tensors = exporter.translate(**locals())
node.op_type = 'BatchNormalization'
for arg in op_def.arg:
if arg.name == 'eps':
if arg.name == 'epsilon':
helper.add_attribute(node, 'epsilon', arg.f)
elif arg.name == 'momentum':
helper.add_attribute(node, 'momentum', arg.f)
......@@ -36,7 +36,7 @@ def group_norm_exporter(op_def, shape_dict, ws):
node, const_tensors = exporter.translate(**locals())
node.op_type = 'ATen' # Currently not supported in ai.onnx
for arg in op_def.arg:
if arg.name == 'eps':
if arg.name == 'epsilon':
helper.add_attribute(node, 'epsilon', arg.f)
elif arg.name == 'group':
if arg.i == 0:
......
......@@ -11,6 +11,7 @@ if (USE_CUDA)
endif()
# ---[ Subdirectory sources
add_subdirectory(device)
add_subdirectory(math)
# ---[ Submit to the parent scope
......
......@@ -16,7 +16,7 @@
#include <cstring>
#include "dragon/core/types.h"
#include "dragon/utils/cuda_device.h"
#include "dragon/utils/device/common_cuda.h"
namespace dragon {
......
#ifndef DRAGON_UTILS_CUB_DEVICE_H_
#define DRAGON_UTILS_CUB_DEVICE_H_
#ifndef DRAGON_UTILS_DEVICE_COMMON_CUB_H_
#define DRAGON_UTILS_DEVICE_COMMON_CUB_H_
#ifdef USE_CUDA
......@@ -7,7 +7,7 @@
#include <cub/device/device_select.cuh>
#include <cub/iterator/counting_input_iterator.cuh>
#include "dragon/utils/cuda_device.h"
#include "dragon/utils/device/common_cuda.h"
namespace dragon {
......@@ -18,4 +18,4 @@ using BlockReduce = cub::BlockReduce<T, CUDA_THREADS>;
#endif // USE_CUDA
#endif // DRAGON_UTILS_CUB_DEVICE_H_
#endif // DRAGON_UTILS_DEVICE_COMMON_CUB_H_
......@@ -10,8 +10,8 @@
* ------------------------------------------------------------
*/
#ifndef DRAGON_UTILS_CUDA_DEVICE_H_
#define DRAGON_UTILS_CUDA_DEVICE_H_
#ifndef DRAGON_UTILS_DEVICE_COMMON_CUDA_H_
#define DRAGON_UTILS_DEVICE_COMMON_CUDA_H_
#ifdef USE_CUDA
#include <cublas_v2.h>
......@@ -21,10 +21,6 @@
#include <device_launch_parameters.h>
#endif
#ifdef USE_NCCL
#include <nccl.h>
#endif
#include "dragon/core/common.h"
namespace dragon {
......@@ -67,14 +63,6 @@ constexpr int CUDA_TENSOR_MAX_DIMS = 8;
CHECK_EQ(status, CURAND_STATUS_SUCCESS); \
} while (0)
#ifdef USE_NCCL
#define NCCL_CHECK(condition) \
do { \
ncclResult_t status = condition; \
CHECK_EQ(status, ncclSuccess) << "\n" << ncclGetErrorString(status); \
} while (0)
#endif // USE_NCCL
#define CUDA_TENSOR_DIMS_CHECK(num_dims) \
CHECK_LE(num_dims, CUDA_TENSOR_MAX_DIMS) \
<< "Too many (> " << CUDA_TENSOR_MAX_DIMS \
......@@ -114,7 +102,7 @@ inline int CUDA_NUM_DEVICES() {
return count;
}
inline int CUDA_GET_DEVICE() {
inline int GetCUDADevice() {
int device_id;
cudaGetDevice(&device_id);
return device_id;
......@@ -138,7 +126,7 @@ inline const cudaDeviceProp& GetCUDADeviceProp(int device_id) {
}
inline bool CUDA_TRUE_FP16_AVAILABLE() {
int device = CUDA_GET_DEVICE();
int device = GetCUDADevice();
auto& prop = GetCUDADeviceProp(device);
return prop.major >= 6;
}
......@@ -147,7 +135,7 @@ inline bool TENSOR_CORE_AVAILABLE() {
#if CUDA_VERSION < 9000
return false;
#else
int device = CUDA_GET_DEVICE();
int device = GetCUDADevice();
auto& prop = GetCUDADeviceProp(device);
return prop.major >= 7;
#endif
......@@ -172,7 +160,7 @@ class CUDADeviceGuard {
#else
#define CUDA_NOT_COMPILED LOG(FATAL) << "CUDA was not compiled."
#define CUDA_NOT_COMPILED LOG(FATAL) << "CUDA library is not compiled with."
class CUDADeviceGuard {
public:
......@@ -185,4 +173,4 @@ class CUDADeviceGuard {
} // namespace dragon
#endif // DRAGON_UTILS_CUDA_DEVICE_H_
#endif // DRAGON_UTILS_DEVICE_COMMON_CUDA_H_
#ifdef USE_CUDNN
#include "dragon/utils/cudnn_device.h"
#include "dragon/utils/device/common_cudnn.h"
#include "dragon/core/tensor.h"
#include "dragon/core/types.h"
......
......@@ -10,8 +10,8 @@
* ------------------------------------------------------------
*/
#ifndef DRAGON_UTILS_CUDNN_DEVICE_H_
#define DRAGON_UTILS_CUDNN_DEVICE_H_
#ifndef DRAGON_UTILS_DEVICE_COMMON_CUDNN_H_
#define DRAGON_UTILS_DEVICE_COMMON_CUDNN_H_
#ifdef USE_CUDNN
......@@ -118,4 +118,4 @@ void CuDNNSetBiasDesc(
#endif // USE_CUDNN
#endif // DRAGON_UTILS_CUDNN_DEVICE_H_
#endif // DRAGON_UTILS_DEVICE_COMMON_CUDNN_H_
#ifndef DRAGON_UTILS_DEVICE_COMMON_NCCL_H_
#define DRAGON_UTILS_DEVICE_COMMON_NCCL_H_
#ifdef USE_NCCL
#include <nccl.h>
#define NCCL_VERSION_MIN(major, minor, patch) \
(NCCL_VERSION_CODE >= NCCL_VERSION(major, minor, patch))
#define NCCL_CHECK(condition) \
do { \
ncclResult_t status = condition; \
CHECK_EQ(status, ncclSuccess) << "\n" << ncclGetErrorString(status); \
} while (0)
#endif // USE_NCCL
#endif // DRAGON_UTILS_DEVICE_COMMON_NCCL_H_
......@@ -336,8 +336,7 @@ DRAGON_API void Gemv<float16, CUDAContext>(
float16* y,
CUDAContext* ctx,
const string math_type) {
cublasOperation_t cuTransA =
TransA == CblasNoTrans ? CUBLAS_OP_T : CUBLAS_OP_N;
auto cuTransA = TransA == CblasNoTrans ? CUBLAS_OP_T : CUBLAS_OP_N;
int m = cuTransA == CUBLAS_OP_N ? N : M;
int k = cuTransA == CUBLAS_OP_N ? M : N;
int LDA = cuTransA == CUBLAS_OP_N ? m : k;
......@@ -487,8 +486,7 @@ DRAGON_API void Gemv<float, CUDAContext>(
float* y,
CUDAContext* ctx,
const string math_type) {
cublasOperation_t cuTransA =
TransA == CblasNoTrans ? CUBLAS_OP_T : CUBLAS_OP_N;
auto cuTransA = TransA == CblasNoTrans ? CUBLAS_OP_T : CUBLAS_OP_N;
CUBLAS_CHECK(
cublasSetPointerMode(ctx->cublas_handle(), CUBLAS_POINTER_MODE_HOST));
CUBLAS_CHECK(cublasSgemv(
......@@ -507,8 +505,7 @@ DRAGON_API void Gemv<double, CUDAContext>(
double* y,
CUDAContext* ctx,
const string math_type) {
cublasOperation_t cuTransA =
TransA == CblasNoTrans ? CUBLAS_OP_T : CUBLAS_OP_N;
auto cuTransA = TransA == CblasNoTrans ? CUBLAS_OP_T : CUBLAS_OP_N;
const auto alpha64 = static_cast<double>(alpha);
const auto beta64 = static_cast<double>(beta);
CUBLAS_CHECK(
......@@ -544,10 +541,8 @@ DRAGON_API void Gemm<float16, CUDAContext>(
const std::string math_type) {
int lda = (TransA == CblasNoTrans) ? K : M;
int ldb = (TransB == CblasNoTrans) ? N : K;
cublasOperation_t cuTransA =
TransA == CblasNoTrans ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
TransB == CblasNoTrans ? CUBLAS_OP_N : CUBLAS_OP_T;
auto cuTransA = TransA == CblasNoTrans ? CUBLAS_OP_N : CUBLAS_OP_T;
auto cuTransB = TransB == CblasNoTrans ? CUBLAS_OP_N : CUBLAS_OP_T;
CUBLAS_CHECK(
cublasSetPointerMode(ctx->cublas_handle(), CUBLAS_POINTER_MODE_HOST));
if (math_type == "float32") {
......@@ -697,10 +692,8 @@ DRAGON_API void Gemm<float, CUDAContext>(
const string math_type) {
int lda = TransA == CblasNoTrans ? K : M;
int ldb = TransB == CblasNoTrans ? N : K;
cublasOperation_t cuTransA =
TransA == CblasNoTrans ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
TransB == CblasNoTrans ? CUBLAS_OP_N : CUBLAS_OP_T;
auto cuTransA = TransA == CblasNoTrans ? CUBLAS_OP_N : CUBLAS_OP_T;
auto cuTransB = TransB == CblasNoTrans ? CUBLAS_OP_N : CUBLAS_OP_T;
CUBLAS_CHECK(
cublasSetPointerMode(ctx->cublas_handle(), CUBLAS_POINTER_MODE_HOST));
CUBLAS_CHECK(cublasSgemm(
......@@ -736,10 +729,8 @@ DRAGON_API void Gemm<double, CUDAContext>(
const string math_type) {
int lda = (TransA == CblasNoTrans) ? K : M;
int ldb = (TransB == CblasNoTrans) ? N : K;
cublasOperation_t cuTransA =
TransA == CblasNoTrans ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
TransB == CblasNoTrans ? CUBLAS_OP_N : CUBLAS_OP_T;
auto cuTransA = TransA == CblasNoTrans ? CUBLAS_OP_N : CUBLAS_OP_T;
auto cuTransB = TransB == CblasNoTrans ? CUBLAS_OP_N : CUBLAS_OP_T;
const auto alpha64 = static_cast<double>(alpha);
const auto beta64 = static_cast<double>(beta);
CUBLAS_CHECK(
......
#ifdef USE_CUDA
#include "dragon/utils/cub_device.h"
#include "dragon/utils/device/common_cub.h"
#include "dragon/utils/math/reduce.h"
#include "dragon/utils/math/utils.h"
......
......@@ -551,16 +551,6 @@ void Assign(
T* y,
Context* ctx);
/* framework.gradient */
template <typename T, class Context>
void GradientTwoSum(
const int count,
const T* dy1,
const T* dy2,
T* dx,
Context* ctx);
/* loss.generic_loss */
template <typename T, class Context>
......
......@@ -57,12 +57,12 @@ cmake .. ^
-G%CMAKE_GENERATOR% ^
-DBUILD_PYTHON=%BUILD_PYTHON% ^
-DBUILD_RUNTIME=%BUILD_RUNTIME% ^
-USE_CUDA==%USE_CUDA% ^
-USE_CUDNN==%USE_CUDNN% ^
-USE_OPENMP==%USE_OPENMP% ^
-USE_AVX==%USE_AVX% ^
-USE_AVX2==%USE_AVX2% ^
-USE_FMA==%USE_FMA% ^
-DUSE_CUDA=%USE_CUDA% ^
-DUSE_CUDNN=%USE_CUDNN% ^
-DUSE_OPENMP=%USE_OPENMP% ^
-DUSE_AVX=%USE_AVX% ^
-DUSE_AVX2=%USE_AVX2% ^
-DUSE_FMA=%USE_FMA% ^
-DTHIRD_PARTY_DIR=%THIRD_PARTY_DIR% ^
-DPROTOBUF_SDK_ROOT_DIR=%PROTOBUF_SDK_ROOT_DIR% ^
-DPROTOBUF_PROTOC_EXECUTABLE=%PROTOBUF_PROTOC_EXECUTABLE% ^
......
......@@ -58,12 +58,12 @@ cmake .. ^
-Ax64 ^
-DBUILD_PYTHON=%BUILD_PYTHON% ^
-DBUILD_RUNTIME=%BUILD_RUNTIME% ^
-USE_CUDA==%USE_CUDA% ^
-USE_CUDNN==%USE_CUDNN% ^
-USE_OPENMP==%USE_OPENMP% ^
-USE_AVX==%USE_AVX% ^
-USE_AVX2==%USE_AVX2% ^
-USE_FMA==%USE_FMA% ^
-DUSE_CUDA=%USE_CUDA% ^
-DUSE_CUDNN=%USE_CUDNN% ^
-DUSE_OPENMP=%USE_OPENMP% ^
-DUSE_AVX=%USE_AVX% ^
-DUSE_AVX2=%USE_AVX2% ^
-DUSE_FMA=%USE_FMA% ^
-DTHIRD_PARTY_DIR=%THIRD_PARTY_DIR% ^
-DPROTOBUF_SDK_ROOT_DIR=%PROTOBUF_SDK_ROOT_DIR% ^
-DPROTOBUF_PROTOC_EXECUTABLE=%PROTOBUF_PROTOC_EXECUTABLE% ^
......
......@@ -80,7 +80,7 @@ def batch_normalization(
moving_variance],
axis=axis,
momentum=momentum,
eps=variance_epsilon,
epsilon=variance_epsilon,
use_stats=not trainable,
name=name,
)
......@@ -131,7 +131,7 @@ def l2_normalize(x, axis=None, epsilon=1e-12, name=None):
x,
p=2,
axis=axis,
eps=epsilon,
epsilon=epsilon,
name=name,
)
......
......@@ -157,7 +157,7 @@ class BatchNorm(layer.Layer):
self.moving_var],
axis=self.axis,
momentum=self.decay,
eps=self.epsilon,
epsilon=self.epsilon,
use_stats=0 if self.training else 1,
)
if self.act:
......
......@@ -153,7 +153,7 @@ def batch_norm(
input.device,
training=training,
momentum=momentum,
eps=eps,
epsilon=eps,
).apply(input, running_mean, running_var, weight, bias)
......@@ -670,11 +670,8 @@ def group_norm(input, weight, bias, groups=32, eps=1e-5):
"""
return _functions.GroupNorm \
.instantiate(
input.device,
group=groups,
eps=eps,
).apply(input, weight, bias)
.instantiate(input.device, group=groups, epsilon=eps) \
.apply(input, weight, bias)
def interpolate(
......@@ -1126,7 +1123,7 @@ def normalize(input, p=2, dim=1, eps=1e-12, out=None):
input.device,
p=p,
axis=dim,
eps=eps,
epsilon=eps,
).apply(input, out)
......@@ -1558,7 +1555,7 @@ def sync_batch_norm(
input.device,
training=training,
momentum=momentum,
eps=eps,
epsilon=eps,
process_group=process_group,
).apply(input, running_mean, running_var, weight, bias)
......
......@@ -120,8 +120,8 @@ class BatchNorm(function.Function):
def __init__(self, key, dev, **kwargs):
super(BatchNorm, self).__init__(key, dev, **kwargs)
self.momentum = kwargs.get('momentum', 0.1)
self.eps = kwargs.get('eps', 1e-5)
self.training = kwargs.get('training', 'False')
self.epsilon = kwargs.get('epsilon', 1e-5)
self.training = kwargs.get('training', False)
def attributes(self):
return {
......@@ -129,7 +129,7 @@ class BatchNorm(function.Function):
'arguments': {
'axis': 1,
'momentum': 1. - self.momentum,
'eps': self.eps,
'epsilon': self.epsilon,
'use_stats': int(not self.training),
}
}
......@@ -276,15 +276,15 @@ class GroupNorm(function.Function):
def __init__(self, key, dev, **kwargs):
super(GroupNorm, self).__init__(key, dev, **kwargs)
self.group = kwargs.get('group', 32)
self.eps = kwargs.get('eps', 1e-5)
self.epsilon = kwargs.get('epsilon', 1e-5)
def attributes(self):
return {
'op_type': 'GroupNorm',
'arguments': {
'group': self.group,
'axis': 1,
'eps': self.eps,
'group': self.group,
'epsilon': self.epsilon,
}
}
......@@ -325,7 +325,7 @@ class LpNormalize(function.Function):
super(LpNormalize, self).__init__(key, dev, **kwargs)
self.p = kwargs.get('p', 2)
self.axis = kwargs.get('axis', 0)
self.eps = kwargs.get('eps', 1e-12)
self.epsilon = kwargs.get('epsilon', 1e-12)
def attributes(self):
return {
......@@ -333,7 +333,7 @@ class LpNormalize(function.Function):
'arguments': {
'p': self.p,
'axis': self.axis,
'eps': self.eps,
'epsilon': self.epsilon,
'num_axes': 1,
'reduction': 'SUM',
}
......
......@@ -47,8 +47,7 @@ class _BatchNorm(Module):
self.register_buffer('bias', init.zeros(num_features))
self.register_buffer('running_mean', init.zeros(num_features))
self.register_buffer('running_var', init.ones(num_features))
self.inputs = [self.running_mean, self.running_var,
self.weight, self.bias]
self.inputs = [self.running_mean, self.running_var, self.weight, self.bias]
self.reset_parameters()
def reset_parameters(self):
......@@ -56,6 +55,11 @@ class _BatchNorm(Module):
self.weight.data.one_()
self.bias.data.zero_()
def reset_running_stats(self):
if self.track_running_stats:
self.running_mean.zero_()
self.running_var.fill_(1)
def extra_repr(self):
return '{num_features}, ' \
'eps={eps}, ' \
......@@ -65,10 +69,9 @@ class _BatchNorm(Module):
.format(**self.__dict__)
def forward(self, input):
training = self.training or not self.track_running_stats
return F.batch_norm(
input, *self.inputs,
training=training,
training=self.training,
momentum=self.momentum,
eps=self.eps
)
......@@ -293,12 +296,10 @@ class SyncBatchNorm(_BatchNorm):
self.process_group = process_group
def forward(self, input):
training = self.training or \
not self.track_running_stats
if training:
if self.training:
return F.sync_batch_norm(
input, *self.inputs,
training=training,
training=self.training,
momentum=self.momentum,
eps=self.eps,
process_group=self.process_group
......@@ -306,7 +307,7 @@ class SyncBatchNorm(_BatchNorm):
else:
return F.batch_norm(
input, *self.inputs,
training=training,
training=self.training,
momentum=self.momentum,
eps=self.eps
)
......@@ -260,14 +260,14 @@ class MaskedSelect(function.Function):
class Multinomial(function.Function):
def __init__(self, key, dev, **kwargs):
super(Multinomial, self).__init__(key, dev, **kwargs)
self.eps = kwargs.get('eps', 0.)
self.epsilon = kwargs.get('epsilon', 0.)
self.num_samples = kwargs.get('num_samples', 1)
def attributes(self):
return {
'op_type': 'Multinomial',
'arguments': {
'eps': self.eps,
'epsilon': self.epsilon,
'normalize': False,
'num_samples': self.num_samples,
},
......
......@@ -544,8 +544,8 @@ def min(input, dim=None, keepdim=False, out=None):
return _reduce(input, 'Min', dim, keepdim, out)
def multinomial(input, num_samples, eps=0., out=None):
"""Return a tensor where each row sampled from the multinomial distribution.
def multinomial(input, num_samples, epsilon=0, out=None):
"""Return a tensor with index sampled from multinomial distribution.
Parameters
----------
......@@ -553,8 +553,8 @@ def multinomial(input, num_samples, eps=0., out=None):
The input tensor.
num_samples : int
The number of samples in each row.
eps : float, optional, default=0.
The prob to a uniform sampling.
epsilon : float, optional, default=0
The epsilon value to apply epsilon-greedy strategy.
out : dragon.vm.torch.Tensor, optional
The output tensor.
......@@ -568,7 +568,7 @@ def multinomial(input, num_samples, eps=0., out=None):
.instantiate(
input.device,
num_samples=num_samples,
eps=float(eps),
epsilon=float(epsilon),
).apply(input, out)
......
......@@ -1035,16 +1035,15 @@ def mul_(self, other):
return math_funcs.mul(self, other, self)
def multinomial(self, num_samples, eps=0.):
"""Return a tensor where each row contains ``num_samples``,
sampled from the multinomial distribution.
def multinomial(self, num_samples, epsilon=0):
"""Return a tensor with index sampled from multinomial distribution.
Parameters
----------
num_samples : int
The number of samples.
eps : float, optional, default=0.
The prob to a uniform sampling.
epsilon : float, optional, default=0
The epsilon value to apply e-greedy strategy.
Returns
-------
......@@ -1052,7 +1051,7 @@ def multinomial(self, num_samples, eps=0.):
The output tensor.
"""
return array_funcs.multinomial(self, num_samples, eps)
return array_funcs.multinomial(self, num_samples, epsilon)
def narrow(self, dimension, start, length):
......
......@@ -1200,16 +1200,15 @@ class Tensor(object):
"""
def multinomial(self, num_samples, eps=0.):
"""Return a tensor where each row contains ``num_samples``,
sampled from the multinomial distribution.
def multinomial(self, num_samples, epsilon=0):
"""Return a tensor with index sampled from multinomial distribution.
Parameters
----------
num_samples : int
The number of samples.
eps : float, optional, default=0.
The prob to a uniform sampling.
The number of samples in each row.
epsilon : float, optional, default=0
The epsilon value to apply e-greedy strategy.
Returns
-------
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!