Commit 413dbad0 by Ting PAN

Add the explicit clear method for workspace

Summary:
This commit adds the clear method to free resources manually
if the workspace instance is referenced circularly.
1 parent 02ad90d5
...@@ -127,7 +127,7 @@ class BatchNorm(Layer): ...@@ -127,7 +127,7 @@ class BatchNorm(Layer):
} }
self.add_blob(value=0, no_grad=True) # running_mean self.add_blob(value=0, no_grad=True) # running_mean
self.add_blob(value=1, no_grad=True) # running_var self.add_blob(value=1, no_grad=True) # running_var
self.add_blob(value=1, no_grad=True) # running_num_batches self.add_blob(value=1, no_grad=True) # running_num
self.add_blob(value=1, no_grad=True) # fixed_gamma self.add_blob(value=1, no_grad=True) # fixed_gamma
self.add_blob(value=0, no_grad=True) # fixed_beta self.add_blob(value=0, no_grad=True) # fixed_beta
self._blobs[2]['data'].set_value([1.]) self._blobs[2]['data'].set_value([1.])
...@@ -140,6 +140,17 @@ class BatchNorm(Layer): ...@@ -140,6 +140,17 @@ class BatchNorm(Layer):
self._bias = scale_layer._blobs[1]['data'] self._bias = scale_layer._blobs[1]['data']
scale_layer.__call__ = lambda *args, **kwargs: None scale_layer.__call__ = lambda *args, **kwargs: None
def from_proto(self, proto):
super(BatchNorm, self).from_proto(proto)
current_ws = workspace.get_workspace()
running_num = float(current_ws.fetch_tensor(self._blobs[2]['data']))
if running_num != 1:
running_mean = current_ws.fetch_tensor(self._blobs[0]['data'])
running_var = current_ws.fetch_tensor(self._blobs[1]['data'])
current_ws.feed_tensor(self._blobs[0]['data'], running_mean / running_num)
current_ws.feed_tensor(self._blobs[1]['data'], running_var / running_num)
current_ws.feed_tensor(self._blobs[2]['data'], [1], dtype='float32')
def __call__(self, bottom): def __call__(self, bottom):
inputs = [bottom, self._weight, self._bias] + \ inputs = [bottom, self._weight, self._bias] + \
[blob['data'] for blob in self._blobs[:2]] [blob['data'] for blob in self._blobs[:2]]
......
...@@ -54,14 +54,14 @@ class Data(Layer): ...@@ -54,14 +54,14 @@ class Data(Layer):
data_param { data_param {
source: "/data/train" source: "/data/train"
batch_size: 128 batch_size: 128
prefetch: 4
}
image_data_param {
shuffle: true shuffle: true
num_chunks: 0
prefetch: 5
} }
transform_param { transform_param {
mirror: true mirror: true
random_crop_size: 224 crop_size: 224
augment_color: true
mean_value: 104.00698793 mean_value: 104.00698793
mean_value: 116.66876762 mean_value: 116.66876762
mean_value: 122.67891434 mean_value: 122.67891434
...@@ -79,7 +79,6 @@ class Data(Layer): ...@@ -79,7 +79,6 @@ class Data(Layer):
batch_size: 64 batch_size: 64
} }
transform_param { transform_param {
resize: 256
crop_size: 224 crop_size: 224
mean_value: 104.00698793 mean_value: 104.00698793
mean_value: 116.66876762 mean_value: 116.66876762
...@@ -92,20 +91,16 @@ class Data(Layer): ...@@ -92,20 +91,16 @@ class Data(Layer):
def __init__(self, layer_param): def __init__(self, layer_param):
super(Data, self).__init__(layer_param) super(Data, self).__init__(layer_param)
param = layer_param.data_param data_param = layer_param.data_param
image_data_param = layer_param.image_data_param
transform_param = layer_param.transform_param transform_param = layer_param.transform_param
self.data_args = { self.data_args = {
'source': param.source, 'source': data_param.source,
'prefetch': param.prefetch, 'batch_size': data_param.batch_size,
'shuffle': param.shuffle, 'prefetch': data_param.prefetch,
'num_chunks': param.num_chunks, 'shuffle': image_data_param.shuffle,
'batch_size': param.batch_size,
'phase': {0: 'TRAIN', 1: 'TEST'}[int(layer_param.phase)], 'phase': {0: 'TRAIN', 1: 'TEST'}[int(layer_param.phase)],
'resize': transform_param.resize,
'padding': transform_param.padding,
'crop_size': transform_param.crop_size, 'crop_size': transform_param.crop_size,
'random_crop_size': transform_param.random_crop_size,
'augment_color': transform_param.augment_color,
'mirror': transform_param.mirror, 'mirror': transform_param.mirror,
} }
self.norm_args = { self.norm_args = {
......
...@@ -38,7 +38,6 @@ message Datum { ...@@ -38,7 +38,6 @@ message Datum {
repeated float float_data = 6; repeated float float_data = 6;
// If true data contains an encoded image that need to be decoded // If true data contains an encoded image that need to be decoded
optional bool encoded = 7 [default = false]; optional bool encoded = 7 [default = false];
repeated int32 labels = 8;
} }
message FillerParameter { message FillerParameter {
...@@ -438,14 +437,6 @@ message TransformationParameter { ...@@ -438,14 +437,6 @@ message TransformationParameter {
optional bool force_color = 6 [default = false]; optional bool force_color = 6 [default = false];
// Force the decoded image to have 1 color channels. // Force the decoded image to have 1 color channels.
optional bool force_gray = 7 [default = false]; optional bool force_gray = 7 [default = false];
// Distort the color?
optional bool augment_color = 9 [default = false];
// Target size.
optional uint32 resize = 10 [default = 0];
// Padding size.
optional uint32 padding = 11 [default = 0];
// Crop size during scale jittering
optional uint32 random_crop_size = 12 [default = 0];
} }
// Message that stores parameters shared by loss layers // Message that stores parameters shared by loss layers
...@@ -676,11 +667,7 @@ message DataParameter { ...@@ -676,11 +667,7 @@ message DataParameter {
optional bool force_encoded_color = 9 [default = false]; optional bool force_encoded_color = 9 [default = false];
// Prefetch queue (Number of batches to prefetch to host memory, increase if // Prefetch queue (Number of batches to prefetch to host memory, increase if
// data access bandwidth varies). // data access bandwidth varies).
optional uint32 prefetch = 10 [default = 5]; optional uint32 prefetch = 10 [default = 4];
// Whether to shuffle the data.
optional bool shuffle = 11 [default = false];
// The number of chunks to shuffle.
optional int32 num_chunks = 12 [default = 2048];
} }
message DropoutParameter { message DropoutParameter {
......
...@@ -14,6 +14,10 @@ as_default ...@@ -14,6 +14,10 @@ as_default
########## ##########
.. automethod:: dragon.Workspace.as_default .. automethod:: dragon.Workspace.as_default
clear
#####
.. automethod:: dragon.Workspace.clear
feed_tensor feed_tensor
########### ###########
.. automethod:: dragon.Workspace.feed_tensor .. automethod:: dragon.Workspace.feed_tensor
......
...@@ -29,6 +29,23 @@ void Workspace::MergeFrom(Workspace* other) { ...@@ -29,6 +29,23 @@ void Workspace::MergeFrom(Workspace* other) {
} }
} }
void Workspace::Clear() {
// Following resources usually take large memory blob.
// It's necessary to clear them manually if workspace referenced
// by the frontend GC circularly.
graph_map_.clear();
operator_map_.clear();
for (const auto& it : tensor_map_) {
// The tensor pointer may be referenced by the frontend.
// Reset memory only to avoid the dangling pointer.
it.second->Reset();
}
// Reinitialize the tensor flags
GetTensor("/share/flag/recomputing")
->Reshape({})
->mutable_data<bool, CPUContext>()[0] = false;
}
Tensor* Workspace::TryGetTensor(const string& name, bool external) const { Tensor* Workspace::TryGetTensor(const string& name, bool external) const {
// Check the alias firstly // Check the alias firstly
const auto& alias_it = alias_map_.find(name); const auto& alias_it = alias_map_.find(name);
......
...@@ -25,6 +25,9 @@ class Workspace { ...@@ -25,6 +25,9 @@ class Workspace {
/*! \brief Merge resources from other */ /*! \brief Merge resources from other */
DRAGON_API void MergeFrom(Workspace*); DRAGON_API void MergeFrom(Workspace*);
/*! \brief Clear the cached resources */
DRAGON_API void Clear();
/* \brief Return an unique name */ /* \brief Return an unique name */
DRAGON_API string UniqueName( DRAGON_API string UniqueName(
const string& name, const string& name,
......
...@@ -47,6 +47,9 @@ PYBIND11_MODULE(libdragon_python, m) { ...@@ -47,6 +47,9 @@ PYBIND11_MODULE(libdragon_python, m) {
/*! \brief Merge resources from another workspace */ /*! \brief Merge resources from another workspace */
.def("MergeFrom", &Workspace::MergeFrom) .def("MergeFrom", &Workspace::MergeFrom)
/*! \brief Clear the cached resources */
.def("Clear", &Workspace::Clear)
/*! \brief Return an unique name */ /*! \brief Return an unique name */
.def("UniqueName", &Workspace::UniqueName) .def("UniqueName", &Workspace::UniqueName)
...@@ -73,7 +76,7 @@ PYBIND11_MODULE(libdragon_python, m) { ...@@ -73,7 +76,7 @@ PYBIND11_MODULE(libdragon_python, m) {
} }
return self->CreateTensor(name); return self->CreateTensor(name);
}, },
py::return_value_policy::reference_internal) py::return_value_policy::reference)
/*! \brief Return the tensor */ /*! \brief Return the tensor */
.def( .def(
...@@ -81,7 +84,7 @@ PYBIND11_MODULE(libdragon_python, m) { ...@@ -81,7 +84,7 @@ PYBIND11_MODULE(libdragon_python, m) {
[](Workspace* self, const string& name) { [](Workspace* self, const string& name) {
return self->TryGetTensor(name); return self->TryGetTensor(name);
}, },
py::return_value_policy::reference_internal) py::return_value_policy::reference)
/* \brief Register an alias for the name */ /* \brief Register an alias for the name */
.def( .def(
......
...@@ -157,7 +157,7 @@ void RegisterModule(py::module& m) { ...@@ -157,7 +157,7 @@ void RegisterModule(py::module& m) {
self->raw_mutable_data<CPUContext>(); self->raw_mutable_data<CPUContext>();
return self; return self;
}, },
py::return_value_policy::reference_internal) py::return_value_policy::reference)
/*! \brief Construct from an external pointer */ /*! \brief Construct from an external pointer */
.def( .def(
...@@ -192,7 +192,7 @@ void RegisterModule(py::module& m) { ...@@ -192,7 +192,7 @@ void RegisterModule(py::module& m) {
if (self->ExternalDeleter) self->ExternalDeleter(); if (self->ExternalDeleter) self->ExternalDeleter();
return self; return self;
}, },
py::return_value_policy::reference_internal) py::return_value_policy::reference)
/*! \brief Construct from a numpy array */ /*! \brief Construct from a numpy array */
.def( .def(
...@@ -219,7 +219,7 @@ void RegisterModule(py::module& m) { ...@@ -219,7 +219,7 @@ void RegisterModule(py::module& m) {
self->ExternalDeleter = [array]() -> void { Py_XDECREF(array); }; self->ExternalDeleter = [array]() -> void { Py_XDECREF(array); };
return self; return self;
}, },
py::return_value_policy::reference_internal) py::return_value_policy::reference)
/*! \brief Construct from a dlpack tensor */ /*! \brief Construct from a dlpack tensor */
.def( .def(
...@@ -227,7 +227,7 @@ void RegisterModule(py::module& m) { ...@@ -227,7 +227,7 @@ void RegisterModule(py::module& m) {
[](Tensor* self, py::object object) { [](Tensor* self, py::object object) {
return DLPackWrapper(self).From(object); return DLPackWrapper(self).From(object);
}, },
py::return_value_policy::reference_internal) py::return_value_policy::reference)
/*! \brief Switch memory to the cpu context */ /*! \brief Switch memory to the cpu context */
.def( .def(
......
...@@ -120,6 +120,20 @@ class Workspace(backend.Workspace): ...@@ -120,6 +120,20 @@ class Workspace(backend.Workspace):
""" """
return _GLOBAL_DEFAULT_WORKSPACE_STACK.get_controller(self) return _GLOBAL_DEFAULT_WORKSPACE_STACK.get_controller(self)
def clear(self):
"""Clear the cached tensors, operators and graphs.
Call this method before deleting to free cached resources certainly:
```python
my_workspace = dragon.Workspace()
my_workspace.clear()
del my_workspace
```
"""
self.Clear()
def create_graph(self, graph_def): def create_graph(self, graph_def):
"""Create the graph. """Create the graph.
...@@ -425,7 +439,7 @@ def reset_workspace(): ...@@ -425,7 +439,7 @@ def reset_workspace():
"""Reset the current default workspace.""" """Reset the current default workspace."""
if not _GLOBAL_DEFAULT_WORKSPACE_STACK.is_cleared(): if not _GLOBAL_DEFAULT_WORKSPACE_STACK.is_cleared():
raise AssertionError( raise AssertionError(
"Do not use reset_default() to clear " "Do not use reset_workspace() to clear "
"nested workspaces.\nIf you need a cleared workspace, " "nested workspaces.\nIf you need a cleared workspace, "
"exit the nesting and create a new workspace.") "exit the nesting and create a new workspace.")
_GLOBAL_DEFAULT_WORKSPACE_STACK.reset() _GLOBAL_DEFAULT_WORKSPACE_STACK.reset()
...@@ -457,6 +471,8 @@ class _DefaultWorkspaceStack(tls.Stack): ...@@ -457,6 +471,8 @@ class _DefaultWorkspaceStack(tls.Stack):
def reset(self): def reset(self):
super(_DefaultWorkspaceStack, self).reset() super(_DefaultWorkspaceStack, self).reset()
if self._global_default_workspace is not None:
self._global_default_workspace.clear()
self._global_default_workspace = None self._global_default_workspace = None
@contextlib.contextmanager @contextlib.contextmanager
......
...@@ -102,9 +102,7 @@ class DataIterator(object): ...@@ -102,9 +102,7 @@ class DataIterator(object):
The optional running phase. The optional running phase.
batch_size : int, optional, default=128 batch_size : int, optional, default=128
The size of a mini-batch. The size of a mini-batch.
partition : bool, optional, default=False prefetch : int, optional, default=4
Whether to partition batch for parallelism.
prefetch : int, optional, default=5
The prefetch count. The prefetch count.
num_transformers : int, optional, default=-1 num_transformers : int, optional, default=-1
The number of transformers to process image. The number of transformers to process image.
...@@ -122,12 +120,10 @@ class DataIterator(object): ...@@ -122,12 +120,10 @@ class DataIterator(object):
rank = distributed.get_rank(process_group) rank = distributed.get_rank(process_group)
# Configuration. # Configuration.
self._prefetch = kwargs.get('prefetch', 5) self._prefetch = kwargs.get('prefetch', 4)
self._num_readers = kwargs.get('num_readers', 1) self._num_readers = kwargs.get('num_readers', 1)
self._num_transformers = kwargs.get('num_transformers', -1) self._num_transformers = kwargs.get('num_transformers', -1)
self._batch_size = kwargs.get('batch_size', 128) self._batch_size = kwargs.get('batch_size', 128)
if kwargs.get('partition', False):
self._batch_size //= group_size
self.daemon = True self.daemon = True
# Io-Aware Policy. # Io-Aware Policy.
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!