Commit 58c5371e by Ting PAN

Add support to query the size for the allocated memory

Summary:
This commit adds the ``memory_allocated`` API for ``dragon.Workspace``
to query the size of allocated memory (and optionally on a specified device).
1 parent a7a7e4fc
...@@ -45,11 +45,9 @@ class Dropout(Layer): ...@@ -45,11 +45,9 @@ class Dropout(Layer):
def __init__(self, layer_param): def __init__(self, layer_param):
super(Dropout, self).__init__(layer_param) super(Dropout, self).__init__(layer_param)
param = layer_param.dropout_param param = layer_param.dropout_param
self.arguments = { if not param.scale_train:
'prob': param.dropout_ratio, raise ValueError('Unscaled dropout is not supported.')
'scale': param.scale_train self.arguments = {'prob': param.dropout_ratio}
if hasattr(param, 'scale_train') else True,
}
def __call__(self, bottom): def __call__(self, bottom):
return activation_ops.dropout(bottom, **self.arguments) return activation_ops.dropout(bottom, **self.arguments)
......
...@@ -69,7 +69,11 @@ set_cuda_data ...@@ -69,7 +69,11 @@ set_cuda_data
size size
#### ####
.. doxygenfunction:: dragon::UnifiedMemory::size .. doxygenfunction:: dragon::UnifiedMemory::size() const
size
####
.. doxygenfunction:: dragon::UnifiedMemory::size(const string &device_type, int device_id) const
state state
##### #####
......
...@@ -30,6 +30,10 @@ has_tensor ...@@ -30,6 +30,10 @@ has_tensor
########## ##########
.. automethod:: dragon.Workspace.has_tensor .. automethod:: dragon.Workspace.has_tensor
memory_allocated
################
.. automethod:: dragon.Workspace.memory_allocated
merge_from merge_from
########## ##########
.. automethod:: dragon.Workspace.merge_from .. automethod:: dragon.Workspace.merge_from
......
...@@ -16,7 +16,7 @@ dragon.cuda ...@@ -16,7 +16,7 @@ dragon.cuda
: Return the index of current selected device. : Return the index of current selected device.
`enable_cudnn(...) <cuda/enable_cudnn.html>`_ `enable_cudnn(...) <cuda/enable_cudnn.html>`_
: Enable the CuDNN library. : Enable backend to use the cuDNN library.
`get_device_capability(...) <cuda/get_device_capability.html>`_ `get_device_capability(...) <cuda/get_device_capability.html>`_
: Return the capability of specified device. : Return the capability of specified device.
...@@ -24,6 +24,9 @@ dragon.cuda ...@@ -24,6 +24,9 @@ dragon.cuda
`is_available(...) <cuda/is_available.html>`_ `is_available(...) <cuda/is_available.html>`_
: Return a bool reporting if runtime is available. : Return a bool reporting if runtime is available.
`memory_allocated(...) <cuda/memory_allocated.html>`_
: Return the size of memory used by tensors in current workspace.
`set_default_device(...) <cuda/set_default_device.html>`_ `set_default_device(...) <cuda/set_default_device.html>`_
: Set the default device. : Set the default device.
...@@ -40,6 +43,7 @@ dragon.cuda ...@@ -40,6 +43,7 @@ dragon.cuda
cuda/enable_cudnn cuda/enable_cudnn
cuda/get_device_capability cuda/get_device_capability
cuda/is_available cuda/is_available
cuda/memory_allocated
cuda/set_default_device cuda/set_default_device
cuda/set_device cuda/set_device
cuda/Stream cuda/Stream
......
memory_allocated
================
.. autofunction:: dragon.cuda.memory_allocated
.. _dragon.Workspace.memory_allocated(...): ../Workspace.html#memory-allocated
.. raw:: html
<style>
h1:before {
content: "dragon.cuda.";
color: #103d3e;
}
</style>
...@@ -41,16 +41,20 @@ class CUDAObjects { ...@@ -41,16 +41,20 @@ class CUDAObjects {
~CUDAObjects() { ~CUDAObjects() {
for (int i = 0; i < CUDA_MAX_DEVICES; i++) { for (int i = 0; i < CUDA_MAX_DEVICES; i++) {
#ifdef USE_NCCL #ifdef USE_NCCL
for (auto& comm : nccl_comms_[i]) { for (auto& comm_iter : nccl_comms_[i]) {
/*! if (comm_iter.second) {
* Temporarily disable the comm destroying, NCCL_CHECK(ncclCommDestroy(comm_iter.second));
* to avoid an unhandled error. }
*/
} }
#endif #endif
#ifdef USE_CUDNN #ifdef USE_CUDNN
for (auto& handle : cudnn_handles_[i]) { for (auto& handle : cudnn_handles_[i]) {
if (handle) CUDNN_CHECK(cudnnDestroy(handle)); /*!
* Temporarily disable the handle destroying,
* to avoid the segmentation fault in CUDNN v8.
*
* if (handle) CUDNN_CHECK(cudnnDestroy(handle));
*/
} }
#endif #endif
for (auto& handle : cublas_handles_[i]) { for (auto& handle : cublas_handles_[i]) {
......
...@@ -76,6 +76,16 @@ class DRAGON_API UnifiedMemory { ...@@ -76,6 +76,16 @@ class DRAGON_API UnifiedMemory {
return size_; return size_;
} }
/*! \brief Return the total number of bytes on given device */
size_t size(const string& device_type, int device_id) const {
if (device_type == "cuda") {
if (own_cuda_ptr_ && cuda_ptr_ && device_id_ == device_id) {
return size_;
}
}
return size_t(0);
}
/*! \brief Return the number of memory chunks */ /*! \brief Return the number of memory chunks */
size_t num_chunks() const { size_t num_chunks() const {
return num_chunks_; return num_chunks_;
...@@ -159,15 +169,18 @@ class DRAGON_API UnifiedMemory { ...@@ -159,15 +169,18 @@ class DRAGON_API UnifiedMemory {
/*! \brief The cpu data pointer */ /*! \brief The cpu data pointer */
void* cpu_ptr_ = nullptr; void* cpu_ptr_ = nullptr;
/*! \brief The ownership of cpu data pointer */
bool own_cpu_ptr_ = true;
/*! \brief The cuda data pointer */ /*! \brief The cuda data pointer */
void* cuda_ptr_ = nullptr; void* cuda_ptr_ = nullptr;
/*! \brief The ownership of cuda data pointer */
bool own_cuda_ptr_ = true;
/*! \brief The cnml data pointer */ /*! \brief The cnml data pointer */
void* cnml_ptr_ = nullptr; void* cnml_ptr_ = nullptr;
/*! \brief The ownership of data pointers */
int own_cpu_ptr_ = 1, own_cuda_ptr_ = 1;
/*! \brief The binding cpu tensor for cnml */ /*! \brief The binding cpu tensor for cnml */
cnmlCpuTensor_t cnml_cpu_tensor_ = nullptr; cnmlCpuTensor_t cnml_cpu_tensor_ = nullptr;
......
...@@ -92,11 +92,11 @@ class DRAGON_API Tensor { ...@@ -92,11 +92,11 @@ class DRAGON_API Tensor {
if (d > 0) new_size *= d; if (d > 0) new_size *= d;
} }
if (capacity_ < new_size * meta_.itemsize()) { if (capacity_ < new_size * meta_.itemsize()) {
if (own_memory_) { if (own_memory_ptr_) {
internal_memory_.reset(); memory_.reset();
} else { } else {
external_memory_ = nullptr; mapped_memory_ = nullptr;
own_memory_ = true; own_memory_ptr_ = true;
} }
capacity_ = 0; capacity_ = 0;
} }
...@@ -155,26 +155,26 @@ class DRAGON_API Tensor { ...@@ -155,26 +155,26 @@ class DRAGON_API Tensor {
if (memory != nullptr) { if (memory != nullptr) {
CHECK_LE(size_, memory->size()) CHECK_LE(size_, memory->size())
<< "\nShare an external memory with smaller capacity."; << "\nShare an external memory with smaller capacity.";
internal_memory_.reset(); memory_.reset();
capacity_ = memory->size(); capacity_ = memory->size();
} else { } else {
if (internal_memory_) { if (memory_) {
capacity_ = internal_memory_->size(); capacity_ = memory_->size();
} }
} }
external_memory_ = memory; mapped_memory_ = memory;
own_memory_ = (memory == nullptr); own_memory_ptr_ = (memory == nullptr);
} }
/*! \brief Reset tensor to release all resources */ /*! \brief Reset tensor to release all resources */
void Reset() { void Reset() {
dims_.clear(); dims_.clear();
strides_.clear(); strides_.clear();
internal_memory_.reset(); memory_.reset();
meta_ = TypeMeta(); meta_ = TypeMeta();
size_ = capacity_ = 0; size_ = capacity_ = 0;
own_memory_ = true; own_memory_ptr_ = true;
external_memory_ = nullptr; mapped_memory_ = nullptr;
if (ExternalDeleter != nullptr) { if (ExternalDeleter != nullptr) {
ExternalDeleter(); ExternalDeleter();
ExternalDeleter = nullptr; ExternalDeleter = nullptr;
...@@ -283,11 +283,11 @@ class DRAGON_API Tensor { ...@@ -283,11 +283,11 @@ class DRAGON_API Tensor {
/*! \brief Return the number of elements counting along the given axes */ /*! \brief Return the number of elements counting along the given axes */
int64_t count(int64_t start, int64_t end) const { int64_t count(int64_t start, int64_t end) const {
int64_t nelements = 1; int64_t num = 1;
for (int64_t i = start; i < end; i++) { for (int64_t i = start; i < end; i++) {
nelements *= dim(i); num *= dim(i);
} }
return nelements; return num;
} }
/*! \brief Return the number of elements counting from the given axis */ /*! \brief Return the number of elements counting from the given axis */
...@@ -302,12 +302,12 @@ class DRAGON_API Tensor { ...@@ -302,12 +302,12 @@ class DRAGON_API Tensor {
/*! \brief Return whether the memory is set */ /*! \brief Return whether the memory is set */
bool has_memory() const { bool has_memory() const {
return internal_memory_ != nullptr || external_memory_ != nullptr; return memory_ != nullptr || mapped_memory_ != nullptr;
} }
/*! \brief Return the memory pointer */ /*! \brief Return the memory pointer */
UnifiedMemory* memory(bool required = false) const { UnifiedMemory* memory(bool required = false, bool owned = false) const {
auto* ptr = own_memory_ ? internal_memory_.get() : external_memory_; auto* ptr = own_memory_ptr_ || owned ? memory_.get() : mapped_memory_;
if (required) CHECK(ptr) << "\nAccess the empty memory."; if (required) CHECK(ptr) << "\nAccess the empty memory.";
return ptr; return ptr;
} }
...@@ -320,15 +320,15 @@ class DRAGON_API Tensor { ...@@ -320,15 +320,15 @@ class DRAGON_API Tensor {
/*! \brief Try to return the raw const data pointer */ /*! \brief Try to return the raw const data pointer */
template <class Context> template <class Context>
const void* const_data_ptr() const { const void* const_data_ptr() const {
TypeId ctx_type = TypeMeta::Id<Context>(); TypeId context_type = TypeMeta::Id<Context>();
if (ctx_type == TypeMeta::Id<CPUContext>()) { if (context_type == TypeMeta::Id<CPUContext>()) {
return memory(true)->cpu_data(nbytes()); return memory(true)->cpu_data(nbytes());
} else if (ctx_type == TypeMeta::Id<CUDAContext>()) { } else if (context_type == TypeMeta::Id<CUDAContext>()) {
return memory(true)->cuda_data(nbytes()); return memory(true)->cuda_data(nbytes());
} else if (ctx_type == TypeMeta::Id<CNMLContext>()) { } else if (context_type == TypeMeta::Id<CNMLContext>()) {
return memory(true)->cnml_data(); return memory(true)->cnml_data();
} else { } else {
LOG(FATAL) << "Unknown memory type."; LOG(FATAL) << "Unsupported context type.";
return nullptr; return nullptr;
} }
} }
...@@ -336,19 +336,19 @@ class DRAGON_API Tensor { ...@@ -336,19 +336,19 @@ class DRAGON_API Tensor {
/*! \brief Try to return the raw mutable data pointer */ /*! \brief Try to return the raw mutable data pointer */
template <class Context> template <class Context>
void mutable_data_ptr(void** data_ptr) { void mutable_data_ptr(void** data_ptr) {
auto* mem = memory(); auto* memory_ptr = memory();
if (!mem) { if (!memory_ptr) {
*data_ptr = nullptr; *data_ptr = nullptr;
} else { } else {
TypeId ctx_type = TypeMeta::Id<Context>(); TypeId context_type = TypeMeta::Id<Context>();
if (ctx_type == TypeMeta::Id<CPUContext>()) { if (context_type == TypeMeta::Id<CPUContext>()) {
*data_ptr = mem->mutable_cpu_data(nbytes()); *data_ptr = memory_ptr->mutable_cpu_data(nbytes());
} else if (ctx_type == TypeMeta::Id<CUDAContext>()) { } else if (context_type == TypeMeta::Id<CUDAContext>()) {
*data_ptr = mem->mutable_cuda_data(nbytes()); *data_ptr = memory_ptr->mutable_cuda_data(nbytes());
} else if (ctx_type == TypeMeta::Id<CNMLContext>()) { } else if (context_type == TypeMeta::Id<CNMLContext>()) {
*data_ptr = mem->mutable_cnml_data(); *data_ptr = memory_ptr->mutable_cnml_data();
} else { } else {
LOG(FATAL) << "Unknown memory type."; LOG(FATAL) << "Unsupported context type.";
} }
} }
} }
...@@ -368,7 +368,7 @@ class DRAGON_API Tensor { ...@@ -368,7 +368,7 @@ class DRAGON_API Tensor {
CHECK_GT(size_, 0) << "\nInvalid tensor size."; CHECK_GT(size_, 0) << "\nInvalid tensor size.";
meta_ = meta; meta_ = meta;
capacity_ = size_ * meta.itemsize(); capacity_ = size_ * meta.itemsize();
internal_memory_.reset(new UnifiedMemory(meta_, capacity_)); memory_.reset(new UnifiedMemory(meta_, capacity_));
mutable_data_ptr<Context>(&data_ptr); mutable_data_ptr<Context>(&data_ptr);
if (meta_.ctor()) meta_.ctor()(data_ptr, size_); if (meta_.ctor()) meta_.ctor()(data_ptr, size_);
return data_ptr; return data_ptr;
...@@ -426,11 +426,11 @@ class DRAGON_API Tensor { ...@@ -426,11 +426,11 @@ class DRAGON_API Tensor {
} }
/*! \brief Set to manage the memory */ /*! \brief Set to manage the memory */
void set_memory(UnifiedMemory* memory) { void set_memory(UnifiedMemory* memory_ptr) {
if (memory != internal_memory_.get()) { if (memory_ptr != memory_.get()) {
internal_memory_.reset(memory); memory_.reset(memory_ptr);
} }
capacity_ = memory->size(); capacity_ = memory_ptr->size();
} }
private: private:
...@@ -449,14 +449,14 @@ class DRAGON_API Tensor { ...@@ -449,14 +449,14 @@ class DRAGON_API Tensor {
/*! \brief The dimensions and strides */ /*! \brief The dimensions and strides */
vec64_t dims_, strides_; vec64_t dims_, strides_;
/*! \brief The internal memory */ /*! \brief The managed memory */
unique_ptr<UnifiedMemory> internal_memory_; unique_ptr<UnifiedMemory> memory_;
/*! \brief The external memory */ /*! \brief The mapped memory */
UnifiedMemory* external_memory_ = nullptr; UnifiedMemory* mapped_memory_ = nullptr;
/*! \brief The external memory indicator */ /*! \brief The ownership of memory pointer */
bool own_memory_ = true; bool own_memory_ptr_ = true;
DISABLE_COPY_AND_ASSIGN(Tensor); DISABLE_COPY_AND_ASSIGN(Tensor);
}; };
......
...@@ -148,14 +148,16 @@ string Workspace::UniqueName( ...@@ -148,14 +148,16 @@ string Workspace::UniqueName(
return name + "_" + str::to(index_map[required_name]++) + suffix; return name + "_" + str::to(index_map[required_name]++) + suffix;
} }
vector<string> Workspace::tensors() const { vector<string> Workspace::tensors(bool external) const {
vector<string> names; vector<string> names;
for (const auto& it : tensor_map_) { for (const auto& it : tensor_map_) {
names.push_back(it.first); names.push_back(it.first);
} }
if (external) {
for (const auto& it : external_tensor_map_) { for (const auto& it : external_tensor_map_) {
names.push_back(it.first); names.push_back(it.first);
} }
}
return names; return names;
} }
......
...@@ -80,7 +80,7 @@ class DRAGON_API Workspace { ...@@ -80,7 +80,7 @@ class DRAGON_API Workspace {
} }
/*! \brief Return the name of cached tensors */ /*! \brief Return the name of cached tensors */
vector<string> tensors() const; vector<string> tensors(bool external = true) const;
/*! \brief Return the name of cached graphs */ /*! \brief Return the name of cached graphs */
vector<string> graphs() const; vector<string> graphs() const;
......
...@@ -27,14 +27,12 @@ class CudaStream { ...@@ -27,14 +27,12 @@ class CudaStream {
#ifdef USE_CUDA #ifdef USE_CUDA
CUDADeviceGuard guard(device_id); CUDADeviceGuard guard(device_id);
CUDA_CHECK(cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking)); CUDA_CHECK(cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking));
#else
CUDA_NOT_COMPILED;
#endif #endif
} }
~CudaStream() { ~CudaStream() {
#ifdef USE_CUDA #ifdef USE_CUDA
cudaStreamDestroy(stream_); CUDA_CHECK(cudaStreamDestroy(stream_));
#endif #endif
} }
......
...@@ -135,6 +135,24 @@ PYBIND11_MODULE(libdragon_python, m) { ...@@ -135,6 +135,24 @@ PYBIND11_MODULE(libdragon_python, m) {
} }
}) })
/*! \brief Return the size of memory used by tensors on given device */
.def(
"MemoryAllocated",
[](Workspace* self, const string& device_type, int device_id) {
size_t size = 0;
for (const auto& name : self->tensors(false)) {
auto* memory = self->GetTensor(name)->memory(false, true);
if (memory) {
if (device_type == "cpu") {
size += memory->size();
} else {
size += memory->size(device_type, device_id);
}
}
}
return size;
})
/*! \brief Run the operator */ /*! \brief Run the operator */
.def( .def(
"RunOperator", "RunOperator",
...@@ -200,7 +218,7 @@ PYBIND11_MODULE(libdragon_python, m) { ...@@ -200,7 +218,7 @@ PYBIND11_MODULE(libdragon_python, m) {
if (could_be_serialized) { if (could_be_serialized) {
auto msg = string("\n") + def.DebugString(); auto msg = string("\n") + def.DebugString();
msg.pop_back(); msg.pop_back();
PRINT(INFO) << "graph {" << str::replace_all(msg, "\n", "\n ") LOG(INFO) << "\ngraph {" << str::replace_all(msg, "\n", "\n ")
<< "\n}\n"; << "\n}\n";
} }
} }
......
...@@ -9,12 +9,10 @@ template <class Context> ...@@ -9,12 +9,10 @@ template <class Context>
template <typename T> template <typename T>
void DropBlock2dOp<Context>::DoRunWithType() { void DropBlock2dOp<Context>::DoRunWithType() {
auto &X = Input(0), *Y = Output(0, {0}); auto &X = Input(0), *Y = Output(0, {0});
if (phase() == "TEST") { if (phase() == "TEST") {
Y->ReshapeLike(X)->CopyFrom(X, ctx()); Y->ReshapeLike(X)->CopyFrom(X, ctx());
} else if (phase() == "TRAIN") { } else if (phase() == "TRAIN") {
int64_t feat_h, feat_w, seed_h, seed_w; int64_t feat_h, feat_w, seed_h, seed_w;
if (data_format() == "NCHW") { if (data_format() == "NCHW") {
feat_h = X.dim(2), feat_w = X.dim(3); feat_h = X.dim(2), feat_w = X.dim(3);
} else if (data_format() == "NHWC") { } else if (data_format() == "NHWC") {
...@@ -22,11 +20,9 @@ void DropBlock2dOp<Context>::DoRunWithType() { ...@@ -22,11 +20,9 @@ void DropBlock2dOp<Context>::DoRunWithType() {
} else { } else {
LOG(FATAL) << "Unknown DataFormat: " << data_format(); LOG(FATAL) << "Unknown DataFormat: " << data_format();
} }
seed_h = feat_h - block_size_ + 1; seed_h = feat_h - block_size_ + 1;
seed_w = feat_w - block_size_ + 1; seed_w = feat_w - block_size_ + 1;
CHECK(seed_h > 0 && seed_w > 0) << "\nExcepted block_size <= feat_size."; CHECK(seed_h > 0 && seed_w > 0) << "\nExcepted block_size <= feat_size.";
// Schedule the keep ratio // Schedule the keep ratio
auto kp = keep_prob(); auto kp = keep_prob();
if (decrement_ > 0.f && prob_ > kp) { if (decrement_ > 0.f && prob_ > kp) {
...@@ -34,27 +30,23 @@ void DropBlock2dOp<Context>::DoRunWithType() { ...@@ -34,27 +30,23 @@ void DropBlock2dOp<Context>::DoRunWithType() {
} else { } else {
prob_ = kp; // Fixed to the limit value prob_ = kp; // Fixed to the limit value
} }
// Compute the drop ratio
float gamma = (1.f - prob_) / std::pow(block_size_, 2); float gamma = (1.f - prob_) / std::pow(block_size_, 2);
gamma *= (alpha_ * (feat_h * feat_w) / (seed_h * seed_w)); gamma *= (alpha_ * (feat_h * feat_w) / (seed_h * seed_w));
// Prepare buffers
auto* mask = Buffer("mask") auto* mask = Buffer("mask")
->ReshapeLike(X) ->ReshapeLike(X)
->template mutable_data<uint8_t, Context>(); ->template mutable_data<uint8_t, Context>();
auto* scale = Buffer("scale") auto* scale = Buffer("scale")
->Reshape({}) ->Reshape({})
->template mutable_data<float, CPUContext>(); ->template mutable_data<float, CPUContext>();
auto scratches = ws()->template data<Context>({ auto scratches = ws()->template data<Context>({
X.dim(0) * seed_h * seed_w * sizeof(uint32_t), // seed points X.dim(0) * seed_h * seed_w * sizeof(uint32_t), // seed points
X.count() * sizeof(int), // int32 mask for seed growing X.count() * sizeof(int), // int32 mask for seed growing
}); });
// Fill mask with ones
// Fill the mask with ones
math::Set(X.count(), 1, (int*)scratches[1], ctx()); math::Set(X.count(), 1, (int*)scratches[1], ctx());
// Generate 2d mask from the seed region
// Generate 2d mask from seed region
kernel::DropBlock2d( kernel::DropBlock2d(
X.dim(0), X.dim(0),
data_format() == "NCHW" ? X.dim(1) : X.dim(-1), data_format() == "NCHW" ? X.dim(1) : X.dim(-1),
...@@ -68,13 +60,12 @@ void DropBlock2dOp<Context>::DoRunWithType() { ...@@ -68,13 +60,12 @@ void DropBlock2dOp<Context>::DoRunWithType() {
(uint32_t*)scratches[0], (uint32_t*)scratches[0],
(int*)scratches[1], (int*)scratches[1],
ctx()); ctx());
// Convert to uint8 mask
// Convert to uint8 mask for applying
kernel::Cast(X.count(), (int*)scratches[1], mask, ctx()); kernel::Cast(X.count(), (int*)scratches[1], mask, ctx());
// Count the number of zeros to compute scale factor
// Count && Apply
float normalizer = math::Sum(X.count(), 1.f, (int*)scratches[1], ctx()); float normalizer = math::Sum(X.count(), 1.f, (int*)scratches[1], ctx());
scale[0] = (float)X.count() / std::max(normalizer, 1.f); scale[0] = (float)X.count() / std::max(normalizer, 1.f);
// Apply mask to the feature
kernel::ApplyMask( kernel::ApplyMask(
X.count(), X.count(),
scale[0], scale[0],
...@@ -97,7 +88,6 @@ template <class Context> ...@@ -97,7 +88,6 @@ template <class Context>
template <typename T> template <typename T>
void DropBlock2dGradientOp<Context>::DoRunWithType() { void DropBlock2dGradientOp<Context>::DoRunWithType() {
auto &dY = Input(0), *dX = Output(0); auto &dY = Input(0), *dX = Output(0);
if (phase() == "TEST") { if (phase() == "TEST") {
NOT_IMPLEMENTED; NOT_IMPLEMENTED;
} else if (phase() == "TRAIN") { } else if (phase() == "TRAIN") {
......
...@@ -9,7 +9,6 @@ template <class Context> ...@@ -9,7 +9,6 @@ template <class Context>
template <typename T> template <typename T>
void DropPathOp<Context>::DoRunWithType() { void DropPathOp<Context>::DoRunWithType() {
auto &X = Input(0), *Y = Output(0, {0}); auto &X = Input(0), *Y = Output(0, {0});
if (phase() == "TEST") { if (phase() == "TEST") {
Y->ReshapeLike(X)->CopyFrom(X, ctx()); Y->ReshapeLike(X)->CopyFrom(X, ctx());
} else if (phase() == "TRAIN") { } else if (phase() == "TRAIN") {
...@@ -20,20 +19,15 @@ void DropPathOp<Context>::DoRunWithType() { ...@@ -20,20 +19,15 @@ void DropPathOp<Context>::DoRunWithType() {
} else { } else {
drop_prob_ = dp; // Fixed to the limit value drop_prob_ = dp; // Fixed to the limit value
} }
auto* mask = Buffer("mask") auto* mask = Buffer("mask")
->Reshape({X.dim(0)}) ->Reshape({X.dim(0)})
->template mutable_data<float, Context>(); ->template mutable_data<float, Context>();
auto* scale = Buffer("scale") auto* scale = Buffer("scale")
->Reshape({}) ->Reshape({})
->template mutable_data<float, CPUContext>(); ->template mutable_data<float, CPUContext>();
scale[0] = 1.f / (1.f - drop_prob_); scale[0] = 1.f / (1.f - drop_prob_);
// Generate mask for each example // Generate mask for each example
math::RandomUniform(X.dim(0), 0.f, 1.f, mask, ctx()); math::RandomUniform(X.dim(0), 0.f, 1.f, mask, ctx());
// Apply mask to the feature // Apply mask to the feature
kernel::DropPath( kernel::DropPath(
X.dim(0), X.dim(0),
...@@ -57,7 +51,6 @@ template <class Context> ...@@ -57,7 +51,6 @@ template <class Context>
template <typename T> template <typename T>
void DropPathGradientOp<Context>::DoRunWithType() { void DropPathGradientOp<Context>::DoRunWithType() {
auto &dY = Input(0), *dX = Output(0); auto &dY = Input(0), *dX = Output(0);
if (phase() == "TEST") { if (phase() == "TEST") {
NOT_IMPLEMENTED; NOT_IMPLEMENTED;
} else if (phase() == "TRAIN") { } else if (phase() == "TRAIN") {
......
...@@ -9,24 +9,14 @@ template <class Context> ...@@ -9,24 +9,14 @@ template <class Context>
template <typename T> template <typename T>
void DropoutOp<Context>::DoRunWithType() { void DropoutOp<Context>::DoRunWithType() {
auto &X = Input(0), *Y = Output(0, {0}); auto &X = Input(0), *Y = Output(0, {0});
auto scale = use_scale_ ? 1.f / (1.f - prob()) : 1.f;
if (phase() == "TEST") { if (phase() == "TEST") {
Y->ReshapeLike(X)->CopyFrom(X, ctx()); Y->ReshapeLike(X)->CopyFrom(X, ctx());
if (!use_scale_) {
math::Scale(
X.count(),
1.f - prob(),
Y->template data<T, Context>(),
Y->template mutable_data<T, Context>(),
ctx());
}
} else if (phase() == "TRAIN") { } else if (phase() == "TRAIN") {
Buffer("mask")->ReshapeLike(X); Buffer("mask")->ReshapeLike(X);
kernel::Dropout( kernel::Dropout(
X.count(), X.count(),
prob(), prob(),
scale, 1.f / (1.f - prob()),
X.template data<T, Context>(), X.template data<T, Context>(),
Buffer("mask")->template mutable_data<uint8_t, Context>(), Buffer("mask")->template mutable_data<uint8_t, Context>(),
Y->ReshapeLike(X)->template mutable_data<T, Context>(), Y->ReshapeLike(X)->template mutable_data<T, Context>(),
...@@ -46,14 +36,12 @@ template <class Context> ...@@ -46,14 +36,12 @@ template <class Context>
template <typename T> template <typename T>
void DropoutGradientOp<Context>::DoRunWithType() { void DropoutGradientOp<Context>::DoRunWithType() {
auto &dY = Input(0), *dX = Output(0); auto &dY = Input(0), *dX = Output(0);
auto scale = use_scale_ ? 1.f / (1.f - prob()) : 1.f;
if (phase() == "TEST") { if (phase() == "TEST") {
NOT_IMPLEMENTED; NOT_IMPLEMENTED;
} else if (phase() == "TRAIN") { } else if (phase() == "TRAIN") {
kernel::ApplyMask( kernel::ApplyMask(
dY.count(), dY.count(),
scale, 1.f / (1.f - prob()),
dY.template data<T, Context>(), dY.template data<T, Context>(),
Buffer("mask")->template data<uint8_t, Context>(), Buffer("mask")->template data<uint8_t, Context>(),
dX->ReshapeLike(dY)->template mutable_data<T, Context>(), dX->ReshapeLike(dY)->template mutable_data<T, Context>(),
......
...@@ -21,7 +21,7 @@ template <class Context> ...@@ -21,7 +21,7 @@ template <class Context>
class DropoutOp : public Operator<Context> { class DropoutOp : public Operator<Context> {
public: public:
DropoutOp(const OperatorDef& def, Workspace* ws) DropoutOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), use_scale_(OpArg<bool>("scale", true)) { : Operator<Context>(def, ws) {
GET_ARG_WITH_DESC(float, prob, 0.5f); GET_ARG_WITH_DESC(float, prob, 0.5f);
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
...@@ -32,7 +32,6 @@ class DropoutOp : public Operator<Context> { ...@@ -32,7 +32,6 @@ class DropoutOp : public Operator<Context> {
void DoRunWithType(); void DoRunWithType();
protected: protected:
bool use_scale_;
DECLARE_ARG_WITH_DESC(float, prob); DECLARE_ARG_WITH_DESC(float, prob);
}; };
...@@ -40,7 +39,7 @@ template <class Context> ...@@ -40,7 +39,7 @@ template <class Context>
class DropoutGradientOp : public Operator<Context> { class DropoutGradientOp : public Operator<Context> {
public: public:
DropoutGradientOp(const OperatorDef& def, Workspace* ws) DropoutGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws), use_scale_(OpArg<bool>("scale", true)) { : Operator<Context>(def, ws) {
GET_ARG_WITH_DESC(float, prob, 0.5f); GET_ARG_WITH_DESC(float, prob, 0.5f);
} }
USE_OPERATOR_FUNCTIONS; USE_OPERATOR_FUNCTIONS;
...@@ -51,7 +50,6 @@ class DropoutGradientOp : public Operator<Context> { ...@@ -51,7 +50,6 @@ class DropoutGradientOp : public Operator<Context> {
void DoRunWithType(); void DoRunWithType();
protected: protected:
bool use_scale_;
DECLARE_ARG_WITH_DESC(float, prob); DECLARE_ARG_WITH_DESC(float, prob);
}; };
......
...@@ -11,10 +11,7 @@ template <class Context> ...@@ -11,10 +11,7 @@ template <class Context>
template <typename T> template <typename T>
void CuDNNDropoutOp<Context>::DoRunWithType() { void CuDNNDropoutOp<Context>::DoRunWithType() {
auto &X = Input(0), *Y = Output(0, {0}); auto &X = Input(0), *Y = Output(0, {0});
CHECK(this->use_scale_) << "\nCuDNN only supports the scaled dropout.";
CuDNNSetTensorDesc<T>(&input_desc_, {X.count(), 1, 1, 1}); CuDNNSetTensorDesc<T>(&input_desc_, {X.count(), 1, 1, 1});
if (phase() == "TEST") { if (phase() == "TEST") {
Y->ReshapeLike(X)->CopyFrom(X, ctx()); Y->ReshapeLike(X)->CopyFrom(X, ctx());
} else if (phase() == "TRAIN") { } else if (phase() == "TRAIN") {
...@@ -46,12 +43,9 @@ void CuDNNDropoutOp<Context>::DoRunWithType() { ...@@ -46,12 +43,9 @@ void CuDNNDropoutOp<Context>::DoRunWithType() {
rng_seed_)); rng_seed_));
} }
} }
// Allocate for the reserve space
size_t reserve_size; size_t reserve_size;
CUDNN_CHECK(cudnnDropoutGetReserveSpaceSize(input_desc_, &reserve_size)); CUDNN_CHECK(cudnnDropoutGetReserveSpaceSize(input_desc_, &reserve_size));
auto* X_mask = Buffer("X_mask")->Reshape({(int64_t)reserve_size}); auto* X_mask = Buffer("X_mask")->Reshape({(int64_t)reserve_size});
CUDNN_CHECK(cudnnDropoutForward( CUDNN_CHECK(cudnnDropoutForward(
ctx()->cudnn_handle(), ctx()->cudnn_handle(),
dropout_desc_, dropout_desc_,
...@@ -76,11 +70,9 @@ template <typename T> ...@@ -76,11 +70,9 @@ template <typename T>
void CuDNNDropoutGradientOp<Context>::DoRunWithType() { void CuDNNDropoutGradientOp<Context>::DoRunWithType() {
auto &dY = Input(0), *dX = Output(0); auto &dY = Input(0), *dX = Output(0);
CuDNNSetTensorDesc<T>(&input_desc_, {dY.count(), 1, 1, 1}); CuDNNSetTensorDesc<T>(&input_desc_, {dY.count(), 1, 1, 1});
if (phase() == "TEST") { if (phase() == "TEST") {
NOT_IMPLEMENTED; NOT_IMPLEMENTED;
} else if (phase() == "TRAIN") { } else if (phase() == "TRAIN") {
CHECK(this->use_scale_) << "\nCuDNN only supports the scaled dropout.";
// Initialize the dropout states // Initialize the dropout states
if (!states_initialized_) { if (!states_initialized_) {
states_initialized_ = true; states_initialized_ = true;
...@@ -102,13 +94,12 @@ void CuDNNDropoutGradientOp<Context>::DoRunWithType() { ...@@ -102,13 +94,12 @@ void CuDNNDropoutGradientOp<Context>::DoRunWithType() {
LOG(FATAL) << "Missing dropout states with seed: " << rng_seed_; LOG(FATAL) << "Missing dropout states with seed: " << rng_seed_;
} }
} }
// Check the reserve space // Check the reserve space
size_t reserve_size; size_t reserve_size;
CUDNN_CHECK(cudnnDropoutGetReserveSpaceSize(input_desc_, &reserve_size)); CUDNN_CHECK(cudnnDropoutGetReserveSpaceSize(input_desc_, &reserve_size));
auto* X_mask = Buffer("X_mask"); auto* X_mask = Buffer("X_mask");
CHECK_EQ(X_mask->size(), reserve_size); CHECK_EQ(X_mask->size(), reserve_size);
// Compute the gradient using mask
CUDNN_CHECK(cudnnDropoutBackward( CUDNN_CHECK(cudnnDropoutBackward(
ctx()->cudnn_handle(), ctx()->cudnn_handle(),
dropout_desc_, dropout_desc_,
......
...@@ -21,6 +21,7 @@ from dragon.core.device.cuda import current_device ...@@ -21,6 +21,7 @@ from dragon.core.device.cuda import current_device
from dragon.core.device.cuda import enable_cudnn from dragon.core.device.cuda import enable_cudnn
from dragon.core.device.cuda import get_device_capability from dragon.core.device.cuda import get_device_capability
from dragon.core.device.cuda import is_available from dragon.core.device.cuda import is_available
from dragon.core.device.cuda import memory_allocated
from dragon.core.device.cuda import set_default_device from dragon.core.device.cuda import set_default_device
from dragon.core.device.cuda import set_device from dragon.core.device.cuda import set_device
from dragon.core.device.cuda import synchronize from dragon.core.device.cuda import synchronize
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""Define the options for autograph utilities.""" """Autograph options."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
...@@ -52,9 +52,9 @@ def set_optimization(level=1): ...@@ -52,9 +52,9 @@ def set_optimization(level=1):
* level = ``1``: Eliminate the unused outputs and operators. * level = ``1``: Eliminate the unused outputs and operators.
* level = ``2``: Apply inplace to the inputs if available. * level = ``2``: Apply the inplace to inputs if available.
* level = ``3``: Allocate shared buffer for the outputs. * level = ``3``: Allocate the shared buffer to outputs if available.
Parameters Parameters
---------- ----------
......
...@@ -16,6 +16,7 @@ from __future__ import print_function ...@@ -16,6 +16,7 @@ from __future__ import print_function
from dragon import backend from dragon import backend
from dragon.core.framework import config from dragon.core.framework import config
from dragon.core.framework import workspace
class Stream(backend.CudaStream): class Stream(backend.CudaStream):
...@@ -62,14 +63,14 @@ def current_device(): ...@@ -62,14 +63,14 @@ def current_device():
def enable_cudnn(enabled=True, benchmark=False): def enable_cudnn(enabled=True, benchmark=False):
"""Enable the CuDNN library. """Enable backend to use the cuDNN library.
Parameters Parameters
---------- ----------
enabled : bool, optional, default=True enabled : bool, optional, default=True
**True** to enable the CuDNN. Use cuDNN library or not.
benchmark : bool, optional, default=False benchmark : bool, optional, default=False
**True** to select algorithms according to benchmark. Select algorithms according to the benchmark or not.
""" """
return backend.cudaEnableDNN(enabled, benchmark) return backend.cudaEnableDNN(enabled, benchmark)
...@@ -107,6 +108,32 @@ def is_available(): ...@@ -107,6 +108,32 @@ def is_available():
return backend.cudaIsDriverSufficient() return backend.cudaIsDriverSufficient()
def memory_allocated(device_index=None):
"""Return the size of memory used by tensors in current workspace.
If ``device_index`` is **None**, the current device will be selected.
Parameters
----------
device_index : int, optional
The device index.
Returns
-------
int
The total number of allocated bytes.
See Also
--------
`dragon.Workspace.memory_allocated(...)`_
"""
if device_index is None:
device_index = current_device()
current_ws = workspace.get_workspace()
return current_ws.memory_allocated('cuda', device_index)
def set_default_device(device_index=0): def set_default_device(device_index=0):
"""Set the default device. """Set the default device.
......
...@@ -24,6 +24,7 @@ from dragon.core.framework import mapping ...@@ -24,6 +24,7 @@ from dragon.core.framework import mapping
from dragon.core.framework import proto_util from dragon.core.framework import proto_util
from dragon.core.framework import types from dragon.core.framework import types
from dragon.core.proto import dragon_pb2 from dragon.core.proto import dragon_pb2
from dragon.core.util import logging
from dragon.core.util import serialization from dragon.core.util import serialization
from dragon.core.util import tls from dragon.core.util import tls
...@@ -149,7 +150,8 @@ class Workspace(backend.Workspace): ...@@ -149,7 +150,8 @@ class Workspace(backend.Workspace):
""" """
cfg = config.config() cfg = config.config()
if cfg.graph_verbosity == 2: if cfg.graph_verbosity == 2:
print(graph_def) msg = '\n' + str(graph_def)[:-1]
logging.info('\ngraph {' + msg.replace('\n', '\n ') + '\n}\n')
return self.CreateGraph( return self.CreateGraph(
serialization.serialize_proto(graph_def), serialization.serialize_proto(graph_def),
cfg.graph_verbosity == 1) cfg.graph_verbosity == 1)
...@@ -259,6 +261,24 @@ class Workspace(backend.Workspace): ...@@ -259,6 +261,24 @@ class Workspace(backend.Workspace):
""" """
return self.HasTensor(_stringify_object(tensor)) return self.HasTensor(_stringify_object(tensor))
def memory_allocated(self, device_type='cpu', device_index=0):
"""Return the size of memory used by tensors on given device.
Parameters
----------
device_type : str, optional
The device type.
device_index : int, optional
The device index.
Returns
-------
int
The total number of allocated bytes.
"""
return self.MemoryAllocated(device_type, device_index)
def merge_from(self, other): def merge_from(self, other):
"""Merge resources from the other. """Merge resources from the other.
......
...@@ -16,7 +16,6 @@ from __future__ import print_function ...@@ -16,7 +16,6 @@ from __future__ import print_function
from dragon.core.eager import context from dragon.core.eager import context
from dragon.core.framework import ops from dragon.core.framework import ops
from dragon.core.framework import types
from dragon.core.ops import math_ops_lib from dragon.core.ops import math_ops_lib
from dragon.core.ops.utils import OpSchema from dragon.core.ops.utils import OpSchema
from dragon.core.ops.utils import parse_args from dragon.core.ops.utils import parse_args
...@@ -122,10 +121,8 @@ def affine(inputs, axis=1, num_axes=1, **kwargs): ...@@ -122,10 +121,8 @@ def affine(inputs, axis=1, num_axes=1, **kwargs):
op_lib = math_ops_lib.Affine op_lib = math_ops_lib.Affine
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib \ return op_lib \
.instantiate( .instantiate(axis=axis, num_axes=num_axes) \
axis=axis, .apply(inputs)
num_axes=num_axes,
).apply(inputs)
else: else:
return op_lib.blend(**args) return op_lib.blend(**args)
...@@ -568,10 +565,8 @@ def fully_connected(inputs, axis=1, transpose_w=True, **kwargs): ...@@ -568,10 +565,8 @@ def fully_connected(inputs, axis=1, transpose_w=True, **kwargs):
op_lib = math_ops_lib.FullyConnected op_lib = math_ops_lib.FullyConnected
if context.executing_eagerly(): if context.executing_eagerly():
return op_lib \ return op_lib \
.instantiate( .instantiate(axis=axis, transpose_w=transpose_w) \
axis=axis, .apply(inputs)
transpose_w=transpose_w,
).apply(inputs)
else: else:
args.pop('transpose_w') args.pop('transpose_w')
args['transW'] = transpose_w args['transW'] = transpose_w
......
...@@ -25,17 +25,18 @@ from dragon.core.testing.unittest.common_utils import TEST_CUDA ...@@ -25,17 +25,18 @@ from dragon.core.testing.unittest.common_utils import TEST_CUDA
class TestCUDA(unittest.TestCase): class TestCUDA(unittest.TestCase):
"""Test the cuda utilities.""" """Test the cuda utilities."""
@unittest.skipIf(not TEST_CUDA, 'CUDA unavailable')
def test_stream(self): def test_stream(self):
stream = dragon.cuda.Stream(device_index=0) stream = dragon.cuda.Stream(device_index=0)
self.assertGreater(stream.ptr, 0) self.assertGreater(stream.ptr, 0 if TEST_CUDA else -1)
stream.synchronize() stream.synchronize()
dragon.cuda.synchronize() dragon.cuda.synchronize()
@unittest.skipIf(not TEST_CUDA, 'CUDA unavailable') def test_cudnn(self):
dragon.cuda.enable_cudnn()
def test_device(self): def test_device(self):
major, minor = dragon.cuda.get_device_capability(0) major, minor = dragon.cuda.get_device_capability(0)
self.assertGreaterEqual(major, 1) self.assertGreaterEqual(major, 1 if TEST_CUDA else 0)
self.assertGreaterEqual(minor, 0) self.assertGreaterEqual(minor, 0)
dragon.cuda.set_device(0) dragon.cuda.set_device(0)
self.assertEqual(dragon.cuda.current_device(), 0) self.assertEqual(dragon.cuda.current_device(), 0)
......
...@@ -190,6 +190,12 @@ class TestWorkspace(unittest.TestCase): ...@@ -190,6 +190,12 @@ class TestWorkspace(unittest.TestCase):
pass pass
dragon.reset_workspace() dragon.reset_workspace()
def test_memory_allocated(self):
w = dragon.Workspace()
with w.as_default():
_ = w.memory_allocated()
_ = dragon.cuda.memory_allocated()
if __name__ == '__main__': if __name__ == '__main__':
run_tests() run_tests()
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!