Commit 413dbad0 by Ting PAN

Add the explicit clear method for workspace

Summary:
This commit adds the clear method to free resources manually
if the workspace instance is referenced circularly.
1 parent 02ad90d5
......@@ -127,7 +127,7 @@ class BatchNorm(Layer):
}
self.add_blob(value=0, no_grad=True) # running_mean
self.add_blob(value=1, no_grad=True) # running_var
self.add_blob(value=1, no_grad=True) # running_num_batches
self.add_blob(value=1, no_grad=True) # running_num
self.add_blob(value=1, no_grad=True) # fixed_gamma
self.add_blob(value=0, no_grad=True) # fixed_beta
self._blobs[2]['data'].set_value([1.])
......@@ -140,6 +140,17 @@ class BatchNorm(Layer):
self._bias = scale_layer._blobs[1]['data']
scale_layer.__call__ = lambda *args, **kwargs: None
def from_proto(self, proto):
super(BatchNorm, self).from_proto(proto)
current_ws = workspace.get_workspace()
running_num = float(current_ws.fetch_tensor(self._blobs[2]['data']))
if running_num != 1:
running_mean = current_ws.fetch_tensor(self._blobs[0]['data'])
running_var = current_ws.fetch_tensor(self._blobs[1]['data'])
current_ws.feed_tensor(self._blobs[0]['data'], running_mean / running_num)
current_ws.feed_tensor(self._blobs[1]['data'], running_var / running_num)
current_ws.feed_tensor(self._blobs[2]['data'], [1], dtype='float32')
def __call__(self, bottom):
inputs = [bottom, self._weight, self._bias] + \
[blob['data'] for blob in self._blobs[:2]]
......
......@@ -54,14 +54,14 @@ class Data(Layer):
data_param {
source: "/data/train"
batch_size: 128
prefetch: 4
}
image_data_param {
shuffle: true
num_chunks: 0
prefetch: 5
}
transform_param {
mirror: true
random_crop_size: 224
augment_color: true
crop_size: 224
mean_value: 104.00698793
mean_value: 116.66876762
mean_value: 122.67891434
......@@ -79,7 +79,6 @@ class Data(Layer):
batch_size: 64
}
transform_param {
resize: 256
crop_size: 224
mean_value: 104.00698793
mean_value: 116.66876762
......@@ -92,20 +91,16 @@ class Data(Layer):
def __init__(self, layer_param):
super(Data, self).__init__(layer_param)
param = layer_param.data_param
data_param = layer_param.data_param
image_data_param = layer_param.image_data_param
transform_param = layer_param.transform_param
self.data_args = {
'source': param.source,
'prefetch': param.prefetch,
'shuffle': param.shuffle,
'num_chunks': param.num_chunks,
'batch_size': param.batch_size,
'source': data_param.source,
'batch_size': data_param.batch_size,
'prefetch': data_param.prefetch,
'shuffle': image_data_param.shuffle,
'phase': {0: 'TRAIN', 1: 'TEST'}[int(layer_param.phase)],
'resize': transform_param.resize,
'padding': transform_param.padding,
'crop_size': transform_param.crop_size,
'random_crop_size': transform_param.random_crop_size,
'augment_color': transform_param.augment_color,
'mirror': transform_param.mirror,
}
self.norm_args = {
......
......@@ -38,7 +38,6 @@ message Datum {
repeated float float_data = 6;
// If true data contains an encoded image that need to be decoded
optional bool encoded = 7 [default = false];
repeated int32 labels = 8;
}
message FillerParameter {
......@@ -438,14 +437,6 @@ message TransformationParameter {
optional bool force_color = 6 [default = false];
// Force the decoded image to have 1 color channels.
optional bool force_gray = 7 [default = false];
// Distort the color?
optional bool augment_color = 9 [default = false];
// Target size.
optional uint32 resize = 10 [default = 0];
// Padding size.
optional uint32 padding = 11 [default = 0];
// Crop size during scale jittering
optional uint32 random_crop_size = 12 [default = 0];
}
// Message that stores parameters shared by loss layers
......@@ -676,11 +667,7 @@ message DataParameter {
optional bool force_encoded_color = 9 [default = false];
// Prefetch queue (Number of batches to prefetch to host memory, increase if
// data access bandwidth varies).
optional uint32 prefetch = 10 [default = 5];
// Whether to shuffle the data.
optional bool shuffle = 11 [default = false];
// The number of chunks to shuffle.
optional int32 num_chunks = 12 [default = 2048];
optional uint32 prefetch = 10 [default = 4];
}
message DropoutParameter {
......
......@@ -14,6 +14,10 @@ as_default
##########
.. automethod:: dragon.Workspace.as_default
clear
#####
.. automethod:: dragon.Workspace.clear
feed_tensor
###########
.. automethod:: dragon.Workspace.feed_tensor
......
......@@ -29,6 +29,23 @@ void Workspace::MergeFrom(Workspace* other) {
}
}
void Workspace::Clear() {
// Following resources usually take large memory blob.
// It's necessary to clear them manually if workspace referenced
// by the frontend GC circularly.
graph_map_.clear();
operator_map_.clear();
for (const auto& it : tensor_map_) {
// The tensor pointer may be referenced by the frontend.
// Reset memory only to avoid the dangling pointer.
it.second->Reset();
}
// Reinitialize the tensor flags
GetTensor("/share/flag/recomputing")
->Reshape({})
->mutable_data<bool, CPUContext>()[0] = false;
}
Tensor* Workspace::TryGetTensor(const string& name, bool external) const {
// Check the alias firstly
const auto& alias_it = alias_map_.find(name);
......
......@@ -25,6 +25,9 @@ class Workspace {
/*! \brief Merge resources from other */
DRAGON_API void MergeFrom(Workspace*);
/*! \brief Clear the cached resources */
DRAGON_API void Clear();
/* \brief Return an unique name */
DRAGON_API string UniqueName(
const string& name,
......
......@@ -47,6 +47,9 @@ PYBIND11_MODULE(libdragon_python, m) {
/*! \brief Merge resources from another workspace */
.def("MergeFrom", &Workspace::MergeFrom)
/*! \brief Clear the cached resources */
.def("Clear", &Workspace::Clear)
/*! \brief Return an unique name */
.def("UniqueName", &Workspace::UniqueName)
......@@ -73,7 +76,7 @@ PYBIND11_MODULE(libdragon_python, m) {
}
return self->CreateTensor(name);
},
py::return_value_policy::reference_internal)
py::return_value_policy::reference)
/*! \brief Return the tensor */
.def(
......@@ -81,7 +84,7 @@ PYBIND11_MODULE(libdragon_python, m) {
[](Workspace* self, const string& name) {
return self->TryGetTensor(name);
},
py::return_value_policy::reference_internal)
py::return_value_policy::reference)
/* \brief Register an alias for the name */
.def(
......
......@@ -157,7 +157,7 @@ void RegisterModule(py::module& m) {
self->raw_mutable_data<CPUContext>();
return self;
},
py::return_value_policy::reference_internal)
py::return_value_policy::reference)
/*! \brief Construct from an external pointer */
.def(
......@@ -192,7 +192,7 @@ void RegisterModule(py::module& m) {
if (self->ExternalDeleter) self->ExternalDeleter();
return self;
},
py::return_value_policy::reference_internal)
py::return_value_policy::reference)
/*! \brief Construct from a numpy array */
.def(
......@@ -219,7 +219,7 @@ void RegisterModule(py::module& m) {
self->ExternalDeleter = [array]() -> void { Py_XDECREF(array); };
return self;
},
py::return_value_policy::reference_internal)
py::return_value_policy::reference)
/*! \brief Construct from a dlpack tensor */
.def(
......@@ -227,7 +227,7 @@ void RegisterModule(py::module& m) {
[](Tensor* self, py::object object) {
return DLPackWrapper(self).From(object);
},
py::return_value_policy::reference_internal)
py::return_value_policy::reference)
/*! \brief Switch memory to the cpu context */
.def(
......
......@@ -120,6 +120,20 @@ class Workspace(backend.Workspace):
"""
return _GLOBAL_DEFAULT_WORKSPACE_STACK.get_controller(self)
def clear(self):
"""Clear the cached tensors, operators and graphs.
Call this method before deleting to free cached resources certainly:
```python
my_workspace = dragon.Workspace()
my_workspace.clear()
del my_workspace
```
"""
self.Clear()
def create_graph(self, graph_def):
"""Create the graph.
......@@ -425,7 +439,7 @@ def reset_workspace():
"""Reset the current default workspace."""
if not _GLOBAL_DEFAULT_WORKSPACE_STACK.is_cleared():
raise AssertionError(
"Do not use reset_default() to clear "
"Do not use reset_workspace() to clear "
"nested workspaces.\nIf you need a cleared workspace, "
"exit the nesting and create a new workspace.")
_GLOBAL_DEFAULT_WORKSPACE_STACK.reset()
......@@ -457,6 +471,8 @@ class _DefaultWorkspaceStack(tls.Stack):
def reset(self):
super(_DefaultWorkspaceStack, self).reset()
if self._global_default_workspace is not None:
self._global_default_workspace.clear()
self._global_default_workspace = None
@contextlib.contextmanager
......
......@@ -102,9 +102,7 @@ class DataIterator(object):
The optional running phase.
batch_size : int, optional, default=128
The size of a mini-batch.
partition : bool, optional, default=False
Whether to partition batch for parallelism.
prefetch : int, optional, default=5
prefetch : int, optional, default=4
The prefetch count.
num_transformers : int, optional, default=-1
The number of transformers to process image.
......@@ -122,12 +120,10 @@ class DataIterator(object):
rank = distributed.get_rank(process_group)
# Configuration.
self._prefetch = kwargs.get('prefetch', 5)
self._prefetch = kwargs.get('prefetch', 4)
self._num_readers = kwargs.get('num_readers', 1)
self._num_transformers = kwargs.get('num_transformers', -1)
self._batch_size = kwargs.get('batch_size', 128)
if kwargs.get('partition', False):
self._batch_size //= group_size
self.daemon = True
# Io-Aware Policy.
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!