Commit 654febe3 by Ting PAN

Implement softmax kernels via warp reduce

Summary:
This commit adds extra CUDA softmax kernels using warp reduce.
Warp reduce leads to better performance when dimension <= 256,
which is preferred for the recent vision transformers.
1 parent 3a97dc22
...@@ -34,6 +34,7 @@ from dragon.vm.dali.core.ops.image_ops import RandomResizedCrop ...@@ -34,6 +34,7 @@ from dragon.vm.dali.core.ops.image_ops import RandomResizedCrop
from dragon.vm.dali.core.ops.image_ops import Resize from dragon.vm.dali.core.ops.image_ops import Resize
from dragon.vm.dali.core.ops.image_ops import Rotate from dragon.vm.dali.core.ops.image_ops import Rotate
from dragon.vm.dali.core.ops.image_ops import WarpAffine from dragon.vm.dali.core.ops.image_ops import WarpAffine
from dragon.vm.dali.core.ops.math_ops import Normalize
from dragon.vm.dali.core.ops.random_ops import CoinFlip from dragon.vm.dali.core.ops.random_ops import CoinFlip
from dragon.vm.dali.core.ops.random_ops import Uniform from dragon.vm.dali.core.ops.random_ops import Uniform
from dragon.vm.dali.core.ops.reader_ops import KPLRecordReader from dragon.vm.dali.core.ops.reader_ops import KPLRecordReader
......
...@@ -118,13 +118,13 @@ class CropMirrorNormalize(object): ...@@ -118,13 +118,13 @@ class CropMirrorNormalize(object):
# (H, W) for 2d input # (H, W) for 2d input
# (D, H, W) for 3d input # (D, H, W) for 3d input
crop=(224, 224), crop=(224, 224),
# Historical values to normalize input # Historical BGR values to normalize input
mean=(102., 115., 122.), mean=(103.53, 116.28, 123.675),
std=(1., 1., 1.), std=(57.375, 57.12, 58.395),
# Or ``float16`` for fp16 training # Or ``float16`` for fp16 training
dtype='float32', dtype='float32',
# Or ``NHWC`` # Or ``HWC``
output_layout='NCHW' output_layout='CHW',
) )
y = cmn(inputs['x'], mirror=flip_rng()) y = cmn(inputs['x'], mirror=flip_rng())
``` ```
...@@ -138,7 +138,7 @@ class CropMirrorNormalize(object): ...@@ -138,7 +138,7 @@ class CropMirrorNormalize(object):
mean=0., mean=0.,
std=1., std=1.,
dtype='float32', dtype='float32',
output_layout='NCHW', output_layout='CHW',
**kwargs **kwargs
): ):
"""Create a ``CropMirrorNormalize`` operator. """Create a ``CropMirrorNormalize`` operator.
...@@ -153,10 +153,10 @@ class CropMirrorNormalize(object): ...@@ -153,10 +153,10 @@ class CropMirrorNormalize(object):
The values to subtract. The values to subtract.
std : Union[float, Sequence[float]], optional std : Union[float, Sequence[float]], optional
The values to divide after subtraction. The values to divide after subtraction.
dtype : {'float16', 'float32'}, optional dtype : str, optional, default='float32'
The data type of output. The output data type.
output_layout : {'NCHW', 'NHWC'}, optional output_layout : str, optional
The data format of output. The output data layout.
Returns Returns
------- -------
...@@ -167,7 +167,7 @@ class CropMirrorNormalize(object): ...@@ -167,7 +167,7 @@ class CropMirrorNormalize(object):
if isinstance(dtype, six.string_types): if isinstance(dtype, six.string_types):
dtype = getattr(types, dtype.upper()) dtype = getattr(types, dtype.upper())
if isinstance(output_layout, six.string_types): if isinstance(output_layout, six.string_types):
output_layout = getattr(types, output_layout.upper()) output_layout = output_layout.upper()
return ops.CropMirrorNormalize( return ops.CropMirrorNormalize(
crop=crop, crop=crop,
mirror=mirror, mirror=mirror,
......
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Math ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
try:
from nvidia.dali import ops
except ImportError:
from dragon.core.util import deprecation
ops = deprecation.not_installed('nvidia.dali')
from dragon.core.util import six
from dragon.vm.dali.core.framework import context
from dragon.vm.dali.core.framework import types
class Normalize(object):
"""Normalize input.
Examples:
```python
norm = dali.ops.Normalize(
# Batch normalization case of HWC layout
axes=(0, 1), batch=True, epsilon=1e-5,
)
y = norm(inputs['x'])
```
"""
def __new__(
cls,
axes=(0, 1),
mean=None,
stddev=None,
scale=1.0,
shift=0.0,
batch=False,
epsilon=0,
dtype='float32',
**kwargs
):
"""Create a ``Normalize`` operator.
Parameters
----------
axes : Sequence[int], optional
The axes to normalize.
mean : float, optional
The value to subtract.
stddev : float, optional
The value to divide after subtraction.
scale : float, optional, default=1.0
The scale factor after normalization.
shift : float, optional, default=0.0
The shift factor after normalization.
batch : bool, optional, default=False
Whether to compute mean and stddev across the batch.
epsilon : float, optional, default=0
The value added to the computed variance.
dtype : str, optional, default='float32'
The output data type.
Returns
-------
nvidia.dali.ops.Normalize
The operator.
"""
if isinstance(dtype, six.string_types):
dtype = getattr(types, dtype.upper())
return ops.Normalize(
axes=axes,
mean=mean,
stddev=stddev,
scale=scale,
shift=shift,
batch=batch,
epsilon=epsilon,
dtype=dtype,
device=context.get_device_type(),
**kwargs
)
...@@ -17,7 +17,7 @@ Public Functions ...@@ -17,7 +17,7 @@ Public Functions
CopyFrom CopyFrom
######## ########
.. doxygenfunction:: dragon::Tensor::CopyFrom(const Tensor &other, Context *ctx) .. doxygenfunction:: dragon::Tensor::CopyFrom(Tensor &other, Context *ctx)
CopyFrom CopyFrom
######## ########
...@@ -39,6 +39,10 @@ IsType ...@@ -39,6 +39,10 @@ IsType
###### ######
.. doxygenfunction:: dragon::Tensor::IsType .. doxygenfunction:: dragon::Tensor::IsType
MapFrom
#######
.. doxygenfunction:: dragon::Tensor::MapFrom
Reset Reset
##### #####
.. doxygenfunction:: dragon::Tensor::Reset .. doxygenfunction:: dragon::Tensor::Reset
...@@ -51,10 +55,6 @@ ReshapeLike ...@@ -51,10 +55,6 @@ ReshapeLike
########### ###########
.. doxygenfunction:: dragon::Tensor::ReshapeLike .. doxygenfunction:: dragon::Tensor::ReshapeLike
Share
#####
.. doxygenfunction:: dragon::Tensor::Share
axis axis
#### ####
.. doxygenfunction:: dragon::Tensor::axis .. doxygenfunction:: dragon::Tensor::axis
...@@ -135,10 +135,6 @@ raw_mutable_data ...@@ -135,10 +135,6 @@ raw_mutable_data
################ ################
.. doxygenfunction:: dragon::Tensor::raw_mutable_data() .. doxygenfunction:: dragon::Tensor::raw_mutable_data()
raw_mutable_data
################
.. doxygenfunction:: dragon::Tensor::raw_mutable_data(const TypeMeta &meta)
size size
#### ####
.. doxygenfunction:: dragon::Tensor::size .. doxygenfunction:: dragon::Tensor::size
......
...@@ -45,6 +45,9 @@ vm.dali.ops ...@@ -45,6 +45,9 @@ vm.dali.ops
`class ImageDecoderRandomCrop <ops/ImageDecoderRandomCrop.html>`_ `class ImageDecoderRandomCrop <ops/ImageDecoderRandomCrop.html>`_
: Decode image and return a random crop. : Decode image and return a random crop.
`class Normalize <ops/Normalize.html>`_
: Normalize input.
`class Pad <ops/Pad.html>`_ `class Pad <ops/Pad.html>`_
: Pad input to have the same dimensions. : Pad input to have the same dimensions.
...@@ -97,6 +100,7 @@ vm.dali.ops ...@@ -97,6 +100,7 @@ vm.dali.ops
ops/Hsv ops/Hsv
ops/ImageDecoder ops/ImageDecoder
ops/ImageDecoderRandomCrop ops/ImageDecoderRandomCrop
ops/Normalize
ops/Pad ops/Pad
ops/Paste ops/Paste
ops/RandomBBoxCrop ops/RandomBBoxCrop
......
Normalize
=========
.. autoclass:: dragon.vm.dali.ops.Normalize
__new__
--------
.. automethod:: dragon.vm.dali.ops.Normalize.__new__
.. raw:: html
<style>
h1:before {
content: "dali.ops.";
color: #103d3e;
}
</style>
...@@ -143,6 +143,9 @@ class CUDAObjects { ...@@ -143,6 +143,9 @@ class CUDAObjects {
#ifdef USE_CUDNN #ifdef USE_CUDNN
/*! \brief The cached cuDNN handles of each device */ /*! \brief The cached cuDNN handles of each device */
vector<cudnnHandle_t> cudnn_handles_[CUDA_MAX_DEVICES]; vector<cudnnHandle_t> cudnn_handles_[CUDA_MAX_DEVICES];
/*! \brief The disabled cuDNN operators */
Set<string> cudnn_disabled_ops_;
#endif #endif
#ifdef USE_NCCL #ifdef USE_NCCL
......
...@@ -75,11 +75,6 @@ class DRAGON_API UnifiedMemory { ...@@ -75,11 +75,6 @@ class DRAGON_API UnifiedMemory {
return size_t(0); return size_t(0);
} }
/*! \brief Return the number of batch chunks */
size_t num_chunks() const {
return num_chunks_;
}
/*! \brief Return the storage order */ /*! \brief Return the storage order */
StorageOrder order() const { StorageOrder order() const {
return order_; return order_;
...@@ -105,11 +100,6 @@ class DRAGON_API UnifiedMemory { ...@@ -105,11 +100,6 @@ class DRAGON_API UnifiedMemory {
/*! \brief Return the mutable cuda data */ /*! \brief Return the mutable cuda data */
void* mutable_cuda_data(size_t size = 0); void* mutable_cuda_data(size_t size = 0);
/*! \brief Set the number of data chunks */
void set_num_chunks(size_t num_chunks) {
num_chunks_ = num_chunks;
}
/*! \brief Set the storage order */ /*! \brief Set the storage order */
void set_order(StorageOrder order) { void set_order(StorageOrder order) {
order_ = order; order_ = order;
...@@ -125,8 +115,8 @@ class DRAGON_API UnifiedMemory { ...@@ -125,8 +115,8 @@ class DRAGON_API UnifiedMemory {
/*! \brief The data state */ /*! \brief The data state */
State state_ = UNINITIALIZED; State state_ = UNINITIALIZED;
/*! \brief The size and number of chunks */ /*! \brief The data size */
size_t size_ = 0, num_chunks_ = 1; size_t size_ = 0;
/*! \brief The type meta */ /*! \brief The type meta */
TypeMeta meta_; TypeMeta meta_;
...@@ -140,12 +130,12 @@ class DRAGON_API UnifiedMemory { ...@@ -140,12 +130,12 @@ class DRAGON_API UnifiedMemory {
/*! \brief The cpu data pointer */ /*! \brief The cpu data pointer */
void* cpu_ptr_ = nullptr; void* cpu_ptr_ = nullptr;
/*! \brief The ownership of cpu data pointer */
bool own_cpu_ptr_ = true;
/*! \brief The cuda data pointer */ /*! \brief The cuda data pointer */
void* cuda_ptr_ = nullptr; void* cuda_ptr_ = nullptr;
/*! \brief The ownership of cpu data pointer */
bool own_cpu_ptr_ = true;
/*! \brief The ownership of cuda data pointer */ /*! \brief The ownership of cuda data pointer */
bool own_cuda_ptr_ = true; bool own_cuda_ptr_ = true;
......
...@@ -60,15 +60,13 @@ Tensor* OperatorBase::Output(int i, const vec32_t& inputs) { ...@@ -60,15 +60,13 @@ Tensor* OperatorBase::Output(int i, const vec32_t& inputs) {
auto* Y = Output(i); auto* Y = Output(i);
if (i < output_aliases_.size()) { if (i < output_aliases_.size()) {
for (auto j : inputs) { for (auto j : inputs) {
const auto& X = Input(j); auto& X = Input(j);
if (output_aliases_[i].count(X.name())) { if (output_aliases_[i].count(X.name())) {
Output(i)->ReshapeLike(X)->Share(X.memory()); return Y->ReshapeLike(X)->MapFrom(&X);
return Y;
} }
} }
} }
Y->Share(nullptr); return Y->MapFrom(nullptr);
return Y;
} }
Tensor* OperatorBase::Buffer(const string& name) { Tensor* OperatorBase::Buffer(const string& name) {
...@@ -85,7 +83,8 @@ OperatorBase* OperatorBase::New(const OperatorDef& def, Workspace* ws) { ...@@ -85,7 +83,8 @@ OperatorBase* OperatorBase::New(const OperatorDef& def, Workspace* ws) {
case PROTO_CUDA: case PROTO_CUDA:
#ifdef USE_CUDNN #ifdef USE_CUDNN
if (CUDNNOperatorRegistry()->Has(op_type) && if (CUDNNOperatorRegistry()->Has(op_type) &&
CUDAContext::objects().cudnn_enabled_) { CUDAContext::objects().cudnn_enabled_ &&
!CUDAContext::objects().cudnn_disabled_ops_.count(op_type)) {
return CUDNNOperatorRegistry()->Create(op_type, def, ws); return CUDNNOperatorRegistry()->Create(op_type, def, ws);
} }
#endif #endif
......
...@@ -79,7 +79,7 @@ class DRAGON_API Tensor { ...@@ -79,7 +79,7 @@ class DRAGON_API Tensor {
} }
} }
/*! \brief Change the tensor dimensions */ /*! \brief Change the dimensions */
Tensor* Reshape(const vec64_t& dims) { Tensor* Reshape(const vec64_t& dims) {
dims_ = dims; dims_ = dims;
strides_.resize(dims.size()); strides_.resize(dims.size());
...@@ -91,31 +91,23 @@ class DRAGON_API Tensor { ...@@ -91,31 +91,23 @@ class DRAGON_API Tensor {
CHECK_GE(d, 0); CHECK_GE(d, 0);
if (d > 0) new_size *= d; if (d > 0) new_size *= d;
} }
if (capacity_ < new_size * meta_.itemsize()) {
if (own_memory_ptr_) {
memory_.reset();
} else {
mapped_memory_ = nullptr;
own_memory_ptr_ = true;
}
capacity_ = 0;
}
size_ = new_size; size_ = new_size;
return this; return this;
} }
/*! \brief Change the tensor dimensions as the other */ /*! \brief Change the dimensions as the other */
Tensor* ReshapeLike(const Tensor& other) { Tensor* ReshapeLike(const Tensor& other) {
return Reshape(other.dims_); return Reshape(other.dims_);
} }
/*! \brief Copy memory from a tensor with context */ /*! \brief Copy memory from a tensor */
template <class Context> template <class Context>
Tensor* CopyFrom(const Tensor& other, Context* ctx) { Tensor* CopyFrom(Tensor& other, Context* ctx) {
if ((void*)&other == (void*)this) return this; if ((void*)&other == (void*)this) return this;
CHECK_EQ(size_, other.size_); CHECK_EQ(size_, other.size_);
meta_ = other.meta_;
auto* src = other.template raw_data<Context>(); auto* src = other.template raw_data<Context>();
auto* dst = raw_mutable_data<Context>(other.meta_); auto* dst = raw_mutable_data<Context>();
if (dst == src) return this; if (dst == src) return this;
ctx->template MemcpyAsync<Context, Context>(nbytes(), dst, src); ctx->template MemcpyAsync<Context, Context>(nbytes(), dst, src);
return this; return this;
...@@ -144,20 +136,22 @@ class DRAGON_API Tensor { ...@@ -144,20 +136,22 @@ class DRAGON_API Tensor {
} }
} }
/*! \brief Share an external memory */ /*! \brief Map memory from a tensor */
void Share(UnifiedMemory* memory) { Tensor* MapFrom(Tensor* other) {
if (memory != nullptr) { if (other == nullptr) {
CHECK_LE(size_, memory->size()) mapped_memory_ = nullptr;
<< "\nShare an external memory with smaller capacity."; capacity_ = (memory_ != nullptr ? memory_->size() : 0);
memory_.reset();
capacity_ = memory->size();
} else { } else {
if (memory_) { auto* new_memory = other->memory();
capacity_ = memory_->size(); if (new_memory != nullptr) {
CHECK_LE(size_, new_memory->size())
<< "\nMap from a memory with smaller capacity.";
memory_.reset();
mapped_memory_ = new_memory;
capacity_ = new_memory->size();
} }
} }
mapped_memory_ = memory; return this;
own_memory_ptr_ = (memory == nullptr);
} }
/*! \brief Reset tensor to release all resources */ /*! \brief Reset tensor to release all resources */
...@@ -167,7 +161,6 @@ class DRAGON_API Tensor { ...@@ -167,7 +161,6 @@ class DRAGON_API Tensor {
memory_.reset(); memory_.reset();
meta_ = TypeMeta(); meta_ = TypeMeta();
size_ = capacity_ = 0; size_ = capacity_ = 0;
own_memory_ptr_ = true;
mapped_memory_ = nullptr; mapped_memory_ = nullptr;
if (ExternalDeleter != nullptr) { if (ExternalDeleter != nullptr) {
ExternalDeleter(); ExternalDeleter();
...@@ -216,17 +209,17 @@ class DRAGON_API Tensor { ...@@ -216,17 +209,17 @@ class DRAGON_API Tensor {
return version_; return version_;
} }
/*! \brief Return the total number of elements */ /*! \brief Return the number of elements */
size_t size() const { size_t size() const {
return size_; return size_;
} }
/*! \brief Return the memory capacity */ /*! \brief Return the byte length of memory */
size_t capacity() const { size_t capacity() const {
return capacity_; return capacity_;
} }
/*! \brief Return the total number of data bytes */ /*! \brief Return the byte length of all elements */
size_t nbytes() const { size_t nbytes() const {
return size_ * meta_.itemsize(); return size_ * meta_.itemsize();
} }
...@@ -260,22 +253,22 @@ class DRAGON_API Tensor { ...@@ -260,22 +253,22 @@ class DRAGON_API Tensor {
return strides_[axis(i)]; return strides_[axis(i)];
} }
/*! \brief Return the tensor dimensions */ /*! \brief Return the dimensions */
const vec64_t& dims() const { const vec64_t& dims() const {
return dims_; return dims_;
} }
/*! \brief Return the tensor strides */ /*! \brief Return the strides */
const vec64_t& strides() const { const vec64_t& strides() const {
return strides_; return strides_;
} }
/*! \brief Return the total number of elements */ /*! \brief Return the number of elements counting along all axes */
int64_t count() const { int64_t count() const {
return (int64_t)size_; return (int64_t)size_;
} }
/*! \brief Return the number of elements counting along the given axes */ /*! \brief Return the number of elements counting along given axes */
int64_t count(int64_t start, int64_t end) const { int64_t count(int64_t start, int64_t end) const {
int64_t num = 1; int64_t num = 1;
for (int64_t i = start; i < end; i++) { for (int64_t i = start; i < end; i++) {
...@@ -284,12 +277,12 @@ class DRAGON_API Tensor { ...@@ -284,12 +277,12 @@ class DRAGON_API Tensor {
return num; return num;
} }
/*! \brief Return the number of elements counting from the given axis */ /*! \brief Return the number of elements counting from given axis */
int64_t count(int64_t start) const { int64_t count(int64_t start) const {
return count(start, ndim()); return count(start, ndim());
} }
/*! \brief Return whether the total number of elements is zero */ /*! \brief Return whether the number of elements is zero */
bool empty() const { bool empty() const {
return size_ == 0; return size_ == 0;
} }
...@@ -299,22 +292,30 @@ class DRAGON_API Tensor { ...@@ -299,22 +292,30 @@ class DRAGON_API Tensor {
return memory_ != nullptr || mapped_memory_ != nullptr; return memory_ != nullptr || mapped_memory_ != nullptr;
} }
/*! \brief Return the memory pointer */ /*! \brief Return the memory */
UnifiedMemory* memory(bool required = false, bool owned = false) const { UnifiedMemory* memory(bool required = false, bool owned = false) {
auto* ptr = own_memory_ptr_ || owned ? memory_.get() : mapped_memory_; if (capacity_ < size_ * meta_.itemsize()) {
if (mapped_memory_ != nullptr) {
mapped_memory_ = nullptr;
} else {
memory_.reset();
}
capacity_ = 0;
}
auto* ptr = (owned || !mapped_memory_ ? memory_.get() : mapped_memory_);
if (required) CHECK(ptr) << "\nAccess the empty memory."; if (required) CHECK(ptr) << "\nAccess the empty memory.";
return ptr; return ptr;
} }
/*! \brief Return the memory state */ /*! \brief Return the memory state */
UnifiedMemory::State memory_state() const { UnifiedMemory::State memory_state() {
return memory(true)->state(); return memory(true)->state();
} }
/*! \brief Try to return the raw const data pointer */ /*! \brief Return the raw data pointer */
template <class Context> template <class Context>
const void* const_data_ptr() const { const void* raw_data() {
auto context_type = TypeMeta::Id<Context>(); const auto context_type = TypeMeta::Id<Context>();
if (context_type == TypeMeta::Id<CPUContext>()) { if (context_type == TypeMeta::Id<CPUContext>()) {
return memory(true)->cpu_data(nbytes()); return memory(true)->cpu_data(nbytes());
} else if (context_type == TypeMeta::Id<CUDAContext>()) { } else if (context_type == TypeMeta::Id<CUDAContext>()) {
...@@ -325,14 +326,14 @@ class DRAGON_API Tensor { ...@@ -325,14 +326,14 @@ class DRAGON_API Tensor {
} }
} }
/*! \brief Try to return the raw mutable data pointer */ /*! \brief Get the raw mutable data pointer */
template <class Context> template <class Context>
void mutable_data_ptr(void** data_ptr) { void raw_mutable_data(void** data_ptr) {
auto* memory_ptr = memory(); auto* memory_ptr = memory();
if (!memory_ptr) { if (!memory_ptr) {
*data_ptr = nullptr; *data_ptr = nullptr;
} else { } else {
auto context_type = TypeMeta::Id<Context>(); const auto context_type = TypeMeta::Id<Context>();
if (context_type == TypeMeta::Id<CPUContext>()) { if (context_type == TypeMeta::Id<CPUContext>()) {
*data_ptr = memory_ptr->mutable_cpu_data(nbytes()); *data_ptr = memory_ptr->mutable_cpu_data(nbytes());
} else if (context_type == TypeMeta::Id<CUDAContext>()) { } else if (context_type == TypeMeta::Id<CUDAContext>()) {
...@@ -343,67 +344,38 @@ class DRAGON_API Tensor { ...@@ -343,67 +344,38 @@ class DRAGON_API Tensor {
} }
} }
/*!
* \brief Return the raw mutable data pointer.
*
* If memory is not set, create to manage it with the given meta.
*/
template <class Context>
void* raw_mutable_data(const TypeMeta& meta) {
void* data_ptr;
mutable_data_ptr<Context>(&data_ptr);
// Return the data pointer directly
if (meta_ == meta && data_ptr) return data_ptr;
// Create a new memory created with size and meta
CHECK_GT(size_, 0) << "\nInvalid tensor size.";
meta_ = meta;
capacity_ = size_ * meta.itemsize();
memory_.reset(new UnifiedMemory(meta_, capacity_));
mutable_data_ptr<Context>(&data_ptr);
if (meta_.ctor()) meta_.ctor()(data_ptr, size_);
return data_ptr;
}
/*! \brief Return the raw mutable data pointer */ /*! \brief Return the raw mutable data pointer */
template <class Context> template <class Context>
void* raw_mutable_data() { void* raw_mutable_data() {
CHECK_NE(meta_.id(), 0) << "\nTensor(" << name_ << "): unknown type, " CHECK_NE(meta_.id(), 0) << "\nTensor(" << name_ << "): unknown type, "
<< "or does not have a type."; << "or does not have a type.";
return raw_mutable_data<Context>(meta_);
}
/*! \brief Return the raw const data pointer */
template <class Context>
const void* raw_data() const {
return const_data_ptr<Context>();
}
/*! \brief Return the typed mutable data pointer */
template <typename T, class Context>
T* mutable_data() {
void* data_ptr; void* data_ptr;
mutable_data_ptr<Context>(&data_ptr); raw_mutable_data<Context>(&data_ptr);
if (data_ptr) { if (data_ptr) return data_ptr;
auto meta = TypeMeta::Make<T>(); CHECK_GT(size_, 0) << "\nInvalid tensor size.";
if (meta_ == meta) { capacity_ = size_ * meta_.itemsize();
return static_cast<T*>(data_ptr); memory_.reset(new UnifiedMemory(meta_, capacity_));
} else if (capacity_ >= size_ * meta.itemsize()) { raw_mutable_data<Context>(&data_ptr);
meta_ = meta; if (meta_.ctor()) meta_.ctor()(data_ptr, size_);
return static_cast<T*>(data_ptr); return data_ptr;
}
}
return static_cast<T*>(raw_mutable_data<Context>(TypeMeta::Make<T>()));
} }
/*! \brief Return the typed const data pointer */ /*! \brief Return the typed data pointer */
template <typename T, class Context> template <typename T, class Context>
const T* data() const { const T* data() {
CHECK(meta_.Match<T>()) << "\nThe type of Tensor(" << name() << ") is " CHECK(meta_.Match<T>()) << "\nThe type of Tensor(" << name() << ") is "
<< dtypes::to_string(meta_) << ", while requesting " << dtypes::to_string(meta_) << ", while requesting "
<< dtypes::to_string(TypeMeta::Make<T>()) << "."; << dtypes::to_string(TypeMeta::Make<T>()) << ".";
return static_cast<const T*>(raw_data<Context>()); return static_cast<const T*>(raw_data<Context>());
} }
/*! \brief Return the typed mutable data pointer */
template <typename T, class Context>
T* mutable_data() {
meta_ = TypeMeta::Make<T>();
return static_cast<T*>(raw_mutable_data<Context>());
}
/*! \brief Set the tensor version */ /*! \brief Set the tensor version */
void set_version(int version) { void set_version(int version) {
version_ = version; version_ = version;
...@@ -415,12 +387,14 @@ class DRAGON_API Tensor { ...@@ -415,12 +387,14 @@ class DRAGON_API Tensor {
return this; return this;
} }
/*! \brief Set to manage the memory */ /*! \brief Set the managed memory */
void set_memory(UnifiedMemory* memory_ptr) { void set_memory(UnifiedMemory* memory) {
if (memory_ptr != memory_.get()) { if (memory != nullptr) {
memory_.reset(memory_ptr); if (memory != memory_.get()) {
memory_.reset(memory);
}
capacity_ = memory->size();
} }
capacity_ = memory_ptr->size();
} }
private: private:
...@@ -430,14 +404,20 @@ class DRAGON_API Tensor { ...@@ -430,14 +404,20 @@ class DRAGON_API Tensor {
/*! \brief The type meta */ /*! \brief The type meta */
TypeMeta meta_; TypeMeta meta_;
/*! \brief The size and capacity */ /*! \brief The number of elements */
size_t size_ = 0, capacity_ = 0; size_t size_ = 0;
/*! \brief The byte length of memory */
size_t capacity_ = 0;
/*! \brief The tensor version */ /*! \brief The tensor version */
int version_ = -1; int version_ = -1;
/*! \brief The dimensions and strides */ /*! \brief The dimensions */
vec64_t dims_, strides_; vec64_t dims_;
/*! \brief The strides */
vec64_t strides_;
/*! \brief The managed memory */ /*! \brief The managed memory */
unique_ptr<UnifiedMemory> memory_; unique_ptr<UnifiedMemory> memory_;
...@@ -445,9 +425,6 @@ class DRAGON_API Tensor { ...@@ -445,9 +425,6 @@ class DRAGON_API Tensor {
/*! \brief The mapped memory */ /*! \brief The mapped memory */
UnifiedMemory* mapped_memory_ = nullptr; UnifiedMemory* mapped_memory_ = nullptr;
/*! \brief The ownership of memory pointer */
bool own_memory_ptr_ = true;
DISABLE_COPY_AND_ASSIGN(Tensor); DISABLE_COPY_AND_ASSIGN(Tensor);
}; };
......
...@@ -91,16 +91,24 @@ void RegisterModule_cuda(py::module& m) { ...@@ -91,16 +91,24 @@ void RegisterModule_cuda(py::module& m) {
#endif #endif
}); });
/*! \brief Activate the CuDNN engine */ /*! \brief Enable using the cuDNN library */
m.def( m.def(
"cudaEnableDNN", "cudaEnableDNN",
[](bool enabled, bool deterministic, bool benchmark, bool allow_tf32) { [](bool enabled,
bool deterministic,
bool benchmark,
bool allow_tf32,
const vector<string>& disabled_ops) {
#ifdef USE_CUDA #ifdef USE_CUDA
auto& cuda_objects = CUDAContext::objects(); auto& cuda_objects = CUDAContext::objects();
cuda_objects.cudnn_enabled_ = enabled; cuda_objects.cudnn_enabled_ = enabled;
cuda_objects.cudnn_deterministic_ = deterministic; cuda_objects.cudnn_deterministic_ = deterministic;
cuda_objects.cudnn_benchmark_ = benchmark; cuda_objects.cudnn_benchmark_ = benchmark;
cuda_objects.cudnn_allow_tf32_ = allow_tf32; cuda_objects.cudnn_allow_tf32_ = allow_tf32;
cuda_objects.cudnn_disabled_ops_.clear();
for (const auto& op_type : disabled_ops) {
cuda_objects.cudnn_disabled_ops_.insert(op_type);
}
#endif #endif
}); });
......
...@@ -29,15 +29,22 @@ void RegisterModule_tensor(py::module& m) { ...@@ -29,15 +29,22 @@ void RegisterModule_tensor(py::module& m) {
/*! \brief Return the number of dimensions */ /*! \brief Return the number of dimensions */
.def_property_readonly("ndim", &Tensor::ndim) .def_property_readonly("ndim", &Tensor::ndim)
/*! \brief Return all the dimensions */ /*! \brief Return the dimensions */
.def_property_readonly("dims", &Tensor::dims) .def_property_readonly("dims", &Tensor::dims)
/*! \brief Return the total number of elements */ /*! \brief Return the number of elements */
.def_property_readonly("size", &Tensor::size) .def_property_readonly("size", &Tensor::size)
/*! \brief Return the total number of bytes */ /*! \brief Return the byte length of one element */
.def_property_readonly(
"itemsize", [](Tensor* self) { return self->meta().itemsize(); })
/*! \brief Return the byte length of all elements */
.def_property_readonly("nbytes", &Tensor::nbytes) .def_property_readonly("nbytes", &Tensor::nbytes)
/*! \brief Return the byte length of allocated memory */
.def_property_readonly("capacity", &Tensor::capacity)
/*! \brief Return the data type */ /*! \brief Return the data type */
.def_property_readonly( .def_property_readonly(
"dtype", "dtype",
......
...@@ -60,7 +60,9 @@ class CuDNNSoftmaxOp final : public SoftmaxOp<Context> { ...@@ -60,7 +60,9 @@ class CuDNNSoftmaxOp final : public SoftmaxOp<Context> {
CuDNNDestroyTensorDesc(&input_desc_); CuDNNDestroyTensorDesc(&input_desc_);
} }
void RunOnDevice() override; void RunOnDevice() override {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
template <typename T> template <typename T>
void DoRunWithType(); void DoRunWithType();
...@@ -82,7 +84,9 @@ class CuDNNSoftmaxGradientOp final : public SoftmaxGradientOp<Context> { ...@@ -82,7 +84,9 @@ class CuDNNSoftmaxGradientOp final : public SoftmaxGradientOp<Context> {
CuDNNDestroyTensorDesc(&input_desc_); CuDNNDestroyTensorDesc(&input_desc_);
} }
void RunOnDevice() override; void RunOnDevice() override {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
template <typename T> template <typename T>
void DoRunWithType(); void DoRunWithType();
......
#ifdef USE_CUDNN #ifdef USE_CUDNN
#include "dragon/operators/activation/softmax_op.h" #include "dragon/operators/activation/softmax_op.h"
#include "dragon/utils/op_kernels.h"
namespace dragon { namespace dragon {
...@@ -9,8 +10,19 @@ template <typename T> ...@@ -9,8 +10,19 @@ template <typename T>
void CuDNNSoftmaxOp<Context>::DoRunWithType() { void CuDNNSoftmaxOp<Context>::DoRunWithType() {
auto &X = Input(0), *Y = Output(0, {0}); auto &X = Input(0), *Y = Output(0, {0});
GET_OP_AXIS_ARG(axis, X.ndim(), -1); GET_OP_AXIS_ARG(axis, X.ndim(), -1);
auto S = X.count(axis + 1); const auto C = X.dim(axis);
CuDNNSetTensorDesc<T>(&input_desc_, {X.count(0, axis), X.dim(axis), S, 1}); const auto N = X.count(0, axis), S = X.count(axis + 1);
if (C < 384 && S == 1) {
kernels::Softmax(
N,
S,
C,
X.template data<T, Context>(),
Y->ReshapeLike(X)->template mutable_data<T, Context>(),
ctx());
return;
}
CuDNNSetTensorDesc<T>(&input_desc_, {N, C, S, 1});
CUDNN_CHECK(cudnnSoftmaxForward( CUDNN_CHECK(cudnnSoftmaxForward(
ctx()->cudnn_handle(), ctx()->cudnn_handle(),
CUDNN_SOFTMAX_ACCURATE, CUDNN_SOFTMAX_ACCURATE,
...@@ -24,17 +36,24 @@ void CuDNNSoftmaxOp<Context>::DoRunWithType() { ...@@ -24,17 +36,24 @@ void CuDNNSoftmaxOp<Context>::DoRunWithType() {
} }
template <class Context> template <class Context>
void CuDNNSoftmaxOp<Context>::RunOnDevice() {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
template <class Context>
template <typename T> template <typename T>
void CuDNNSoftmaxGradientOp<Context>::DoRunWithType() { void CuDNNSoftmaxGradientOp<Context>::DoRunWithType() {
auto &Y = Input(0), &dY = Input(1), *dX = Output(0); auto &Y = Input(0), &dY = Input(1), *dX = Output(0);
GET_OP_AXIS_ARG(axis, Y.ndim(), -1); GET_OP_AXIS_ARG(axis, Y.ndim(), -1);
auto S = Y.count(axis + 1); const auto C = Y.dim(axis);
CuDNNSetTensorDesc<T>(&input_desc_, {Y.count(0, axis), Y.dim(axis), S, 1}); const auto N = Y.count(0, axis), S = Y.count(axis + 1);
if (C < 256 && S == 1) {
kernels::SoftmaxGrad(
Y.count(0, axis),
Y.count(axis + 1),
Y.dim(axis),
dY.template data<T, Context>(),
Y.template data<T, Context>(),
dX->ReshapeLike(Y)->template mutable_data<T, Context>(),
ctx());
return;
}
CuDNNSetTensorDesc<T>(&input_desc_, {N, C, S, 1});
CUDNN_CHECK(cudnnSoftmaxBackward( CUDNN_CHECK(cudnnSoftmaxBackward(
ctx()->cudnn_handle(), ctx()->cudnn_handle(),
CUDNN_SOFTMAX_ACCURATE, CUDNN_SOFTMAX_ACCURATE,
...@@ -49,11 +68,6 @@ void CuDNNSoftmaxGradientOp<Context>::DoRunWithType() { ...@@ -49,11 +68,6 @@ void CuDNNSoftmaxGradientOp<Context>::DoRunWithType() {
dX->ReshapeLike(Y)->template mutable_data<T, Context>())); dX->ReshapeLike(Y)->template mutable_data<T, Context>()));
} }
template <class Context>
void CuDNNSoftmaxGradientOp<Context>::RunOnDevice() {
DispatchHelper<dtypes::Floating>::Call(this, Input(0));
}
DEPLOY_CUDNN_OPERATOR(Softmax); DEPLOY_CUDNN_OPERATOR(Softmax);
DEPLOY_CUDNN_OPERATOR(SoftmaxGradient); DEPLOY_CUDNN_OPERATOR(SoftmaxGradient);
......
...@@ -29,7 +29,7 @@ void ConcatOp<Context>::DoRunWithType() { ...@@ -29,7 +29,7 @@ void ConcatOp<Context>::DoRunWithType() {
int64_t output_offset = 0; int64_t output_offset = 0;
for (int i = 0; i < InputSize(); ++i) { for (int i = 0; i < InputSize(); ++i) {
const auto& Xi = Input(i); auto& Xi = Input(i);
math::CopyMatrix( math::CopyMatrix(
Xi.count(0, axis), Xi.count(0, axis),
Xi.count(axis), Xi.count(axis),
......
...@@ -28,7 +28,7 @@ void StackOp<Context>::DoRunWithType() { ...@@ -28,7 +28,7 @@ void StackOp<Context>::DoRunWithType() {
int64_t output_offset = 0; int64_t output_offset = 0;
for (int i = 0; i < num_stacks; i++) { for (int i = 0; i < num_stacks; i++) {
const auto& Xi = Input(i); auto& Xi = Input(i);
math::CopyMatrix( math::CopyMatrix(
Xi.count(0, axis), Xi.count(0, axis),
Xi.count(axis), Xi.count(axis),
......
...@@ -67,6 +67,7 @@ def enable_cudnn( ...@@ -67,6 +67,7 @@ def enable_cudnn(
deterministic=False, deterministic=False,
benchmark=False, benchmark=False,
allow_tf32=False, allow_tf32=False,
disabled_ops=None,
): ):
"""Enable backend to use the cuDNN library. """Enable backend to use the cuDNN library.
...@@ -80,9 +81,12 @@ def enable_cudnn( ...@@ -80,9 +81,12 @@ def enable_cudnn(
Select fastest algorithms via benchmark or heuristics. Select fastest algorithms via benchmark or heuristics.
allow_tf32 : bool, optional, default=False allow_tf32 : bool, optional, default=False
Allow TF32 tensor core operation or not. Allow TF32 tensor core operation or not.
disabled_ops : Sequence[str], optional
The operator types to disable using cuDNN.
""" """
return backend.cudaEnableDNN(enabled, deterministic, benchmark, allow_tf32) return backend.cudaEnableDNN(enabled, deterministic, benchmark,
allow_tf32, disabled_ops or [])
def get_device_capability(device_index=None): def get_device_capability(device_index=None):
......
...@@ -143,6 +143,11 @@ CONVERSIONS_DECL half To<half, half>(half val) { ...@@ -143,6 +143,11 @@ CONVERSIONS_DECL half To<half, half>(half val) {
} }
template <> template <>
CONVERSIONS_DECL float To<float, half>(half val) {
return __half2float(val);
}
template <>
CONVERSIONS_DECL half To<half, float>(float val) { CONVERSIONS_DECL half To<half, float>(float val) {
#if CUDA_VERSION_MIN(9, 2) #if CUDA_VERSION_MIN(9, 2)
return __float2half(val); return __float2half(val);
......
...@@ -20,12 +20,35 @@ ...@@ -20,12 +20,35 @@
#include <cub/device/device_reduce.cuh> #include <cub/device/device_reduce.cuh>
#include <cub/device/device_select.cuh> #include <cub/device/device_select.cuh>
#include <cub/iterator/counting_input_iterator.cuh> #include <cub/iterator/counting_input_iterator.cuh>
#include <cub/warp/warp_reduce.cuh>
namespace dragon { namespace dragon {
template <typename T> template <typename T>
using WarpReduce = cub::BlockReduce<T, CUDA_WARP_SIZE>;
template <typename T>
using BlockReduce = cub::BlockReduce<T, CUDA_THREADS>; using BlockReduce = cub::BlockReduce<T, CUDA_THREADS>;
template <typename T, typename Reducer>
__inline__ __device__ T WarpAllReduce(T val) {
for (int mask = CUDA_WARP_SIZE / 2; mask > 0; mask /= 2) {
val = Reducer()(val, __shfl_xor_sync(0xffffffff, val, mask));
}
return val;
}
template <typename T, typename Reducer, int kThreadsPerBlock>
__inline__ __device__ T BlockAllReduce(T val) {
typedef cub::BlockReduce<T, kThreadsPerBlock> BlockReduce;
__shared__ T block_val;
__shared__ typename BlockReduce::TempStorage storage;
val = BlockReduce(storage).Reduce(val, Reducer());
if (threadIdx.x == 0) block_val = val;
__syncthreads();
return block_val;
}
} // namespace dragon } // namespace dragon
#endif // USE_CUDA #endif // USE_CUDA
......
...@@ -27,10 +27,13 @@ namespace dragon { ...@@ -27,10 +27,13 @@ namespace dragon {
#ifdef USE_CUDA #ifdef USE_CUDA
/*! \brief The number of cuda threads to use */ /*! \brief The number of cuda threads in a warp */
constexpr int CUDA_WARP_SIZE = 32;
/*! \brief The number of cuda threads in a block */
constexpr int CUDA_THREADS = 256; constexpr int CUDA_THREADS = 256;
/*! \brief The maximum number of blocks to use in the default kernel call */ /*! \brief The maximum number of blocks to use in a default kernel call */
constexpr int CUDA_MAX_BLOCKS = 4096; constexpr int CUDA_MAX_BLOCKS = 4096;
/*! \brief The maximum number of devices in a single machine */ /*! \brief The maximum number of devices in a single machine */
......
...@@ -29,11 +29,9 @@ void _Transpose( ...@@ -29,11 +29,9 @@ void _Transpose(
} // namespace } // namespace
/* ------------------- Launcher Separator ------------------- */
#define DEFINE_TRANSPOSE_FUNC(T) \ #define DEFINE_TRANSPOSE_FUNC(T) \
template <> \ template <> \
void Transpose<T, CPUContext>( \ DRAGON_API void Transpose<T, CPUContext>( \
const int num_dims, \ const int num_dims, \
const int64_t* dims, \ const int64_t* dims, \
const int64_t* axes, \ const int64_t* axes, \
......
...@@ -82,11 +82,9 @@ void _TransposeImpl( ...@@ -82,11 +82,9 @@ void _TransposeImpl(
} // namespace } // namespace
/* ------------------- Launcher Separator ------------------- */
#define DEFINE_TRANSPOSE_FUNC(T) \ #define DEFINE_TRANSPOSE_FUNC(T) \
template <> \ template <> \
void Transpose<T, CUDAContext>( \ DRAGON_API void Transpose<T, CUDAContext>( \
const int num_dims, \ const int num_dims, \
const int64_t* dims, \ const int64_t* dims, \
const int64_t* axes, \ const int64_t* axes, \
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!