Commit 04fdadb0 by Ting PAN

Add ClearWorkspace interface

1 parent 94863c22
...@@ -48,6 +48,14 @@ class Workspace { ...@@ -48,6 +48,14 @@ class Workspace {
return workspace_map_[ws->name()] = ws; return workspace_map_[ws->name()] = ws;
} }
inline void ClearWorkspace() {
// clear the relationship of avatars
avatar_map_.clear();
// clear the buffers
ResetBuffer("Common", WORKSPACE_COMMON_BUFFER_SIZE);
ResetBuffer("Grad", WORKSPACE_GRAD_BUFFER_SIZE);
}
/******************** Tensor ********************/ /******************** Tensor ********************/
inline string GetTensorName(const string& name) { inline string GetTensorName(const string& name) {
...@@ -159,8 +167,8 @@ class Workspace { ...@@ -159,8 +167,8 @@ class Workspace {
/******************** Buffer ********************/ /******************** Buffer ********************/
void CreateBuffer(string category, int num) { void CreateBuffer(string category, int num) {
CHECK(!buffer_map_.count(category)); if (!buffer_map_.count(category))
buffer_map_[category] = stack<string>(); buffer_map_[category] = stack<string>();
for (int i = 1; i <= num; i++) { for (int i = 1; i <= num; i++) {
string name = "/share/buffer/" + category + "_" + dragon_cast<string, int>(i); string name = "/share/buffer/" + category + "_" + dragon_cast<string, int>(i);
buffer_map_[category].push(name); buffer_map_[category].push(name);
...@@ -179,6 +187,15 @@ class Workspace { ...@@ -179,6 +187,15 @@ class Workspace {
return nullptr; return nullptr;
} }
void ResetBuffer(string category, int num) {
while (!buffer_map_[category].empty()) {
string name = buffer_map_[category].top();
buffer_map_[category].pop();
tensor_map_[name]->Reset();
}
CreateBuffer(category, num);
}
void ReleaseBuffer(Tensor* tensor, void ReleaseBuffer(Tensor* tensor,
string category = "Common", string category = "Common",
bool enforce = false) { bool enforce = false) {
......
...@@ -162,12 +162,12 @@ PyObject* WorkspacesCC(PyObject* self, PyObject* args) { ...@@ -162,12 +162,12 @@ PyObject* WorkspacesCC(PyObject* self, PyObject* args) {
PyObject* ResetWorkspaceCC(PyObject* self, PyObject* args) { PyObject* ResetWorkspaceCC(PyObject* self, PyObject* args) {
char* cname; char* cname;
if (!PyArg_ParseTuple(args, "|s", &cname)) { if (!PyArg_ParseTuple(args, "s", &cname)) {
PyErr_SetString(PyExc_ValueError, "You can only provide a optional name for the new workspace."); PyErr_SetString(PyExc_ValueError, "You should provide a name to locate the workspace.");
return nullptr; return nullptr;
} }
string target_workspace = g_current_workspace; string target_workspace = g_current_workspace;
if (cname) target_workspace = string(cname); if (!string(cname).empty()) target_workspace = string(cname);
CHECK(g_workspaces.count(target_workspace)) CHECK(g_workspaces.count(target_workspace))
<< "\nWorkspace(" << target_workspace << ") does not exist, can not be reset."; << "\nWorkspace(" << target_workspace << ") does not exist, can not be reset.";
LOG(INFO) << "Reset the Workspace(" << target_workspace << ")"; LOG(INFO) << "Reset the Workspace(" << target_workspace << ")";
...@@ -176,6 +176,21 @@ PyObject* ResetWorkspaceCC(PyObject* self, PyObject* args) { ...@@ -176,6 +176,21 @@ PyObject* ResetWorkspaceCC(PyObject* self, PyObject* args) {
Py_RETURN_TRUE; Py_RETURN_TRUE;
} }
PyObject* ClearWorkspaceCC(PyObject* self, PyObject* args) {
char* cname;
if (!PyArg_ParseTuple(args, "s", &cname)) {
PyErr_SetString(PyExc_ValueError, "You should provide a name to locate the workspace.");
return nullptr;
}
string target_workspace = g_current_workspace;
if (!string(cname).empty()) target_workspace = string(cname);
CHECK(g_workspaces.count(target_workspace))
<< "\nWorkspace(" << target_workspace << ") does not exist, can not be reset.";
LOG(INFO) << "Clear the Workspace(" << target_workspace << ")";
g_workspaces[target_workspace]->ClearWorkspace();
Py_RETURN_TRUE;
}
PyObject* TensorsCC(PyObject* self, PyObject* args) { PyObject* TensorsCC(PyObject* self, PyObject* args) {
vector<string> tensor_strings = g_workspace->GetTensors(); vector<string> tensor_strings = g_workspace->GetTensors();
PyObject* list = PyList_New(tensor_strings.size()); PyObject* list = PyList_New(tensor_strings.size());
...@@ -375,6 +390,7 @@ PyMethodDef* GetAllMethods() { ...@@ -375,6 +390,7 @@ PyMethodDef* GetAllMethods() {
PYFUNC(CurrentWorkspaceCC), PYFUNC(CurrentWorkspaceCC),
PYFUNC(WorkspacesCC), PYFUNC(WorkspacesCC),
PYFUNC(ResetWorkspaceCC), PYFUNC(ResetWorkspaceCC),
PYFUNC(ClearWorkspaceCC),
PYFUNC(TensorsCC), PYFUNC(TensorsCC),
PYFUNC(HasTensorCC), PYFUNC(HasTensorCC),
PYFUNC(GetTensorNameCC), PYFUNC(GetTensorNameCC),
......
...@@ -795,7 +795,7 @@ class Tensor(object): ...@@ -795,7 +795,7 @@ class Tensor(object):
# transpose # transpose
output = Tensor.CreateOperator(inputs=self, nout=1, output = Tensor.CreateOperator(inputs=self, nout=1,
op_type='Transpose', perms=perms, **kwargs) op_type='Transpose', perms=perms, **kwargs)
if self.shape is not None: if self.shape is not None:
if len(self.shape) != len(perms): if len(self.shape) != len(perms):
raise ValueError('The ndim of inputs is {}, but perms provide {}'. \ raise ValueError('The ndim of inputs is {}, but perms provide {}'. \
......
...@@ -27,6 +27,8 @@ CURRENT_GRAPH_IDX = 0 ...@@ -27,6 +27,8 @@ CURRENT_GRAPH_IDX = 0
__all__ = [ __all__ = [
'SwitchWorkspace', 'SwitchWorkspace',
'ResetWorkspace',
'ClearWorkspace',
'CreateGraph', 'CreateGraph',
'RunGraph', 'RunGraph',
'HasTensor', 'HasTensor',
...@@ -60,12 +62,12 @@ def _stringify_proto(obj): ...@@ -60,12 +62,12 @@ def _stringify_proto(obj):
else: raise TypeError('object can not be serialized as a string') else: raise TypeError('object can not be serialized as a string')
def SwitchWorkspace(workspace, create_if_missing=True): def SwitchWorkspace(workspace_name, create_if_missing=True):
"""Switch to the specific workspace. """Switch to the specific workspace.
Parameters Parameters
---------- ----------
workspace : str workspace_name : str
The name of the specific workspace. The name of the specific workspace.
create_if_missing : boolean create_if_missing : boolean
Whether to create the specific workspace if it does not exist. Whether to create the specific workspace if it does not exist.
...@@ -79,7 +81,57 @@ def SwitchWorkspace(workspace, create_if_missing=True): ...@@ -79,7 +81,57 @@ def SwitchWorkspace(workspace, create_if_missing=True):
The wrapper of ``SwitchWorkspaceCC``. The wrapper of ``SwitchWorkspaceCC``.
""" """
SwitchWorkspaceCC(workspace, create_if_missing) if workspace_name == '':
raise ValueError('The workspace name should not be empty.')
SwitchWorkspaceCC(workspace_name, create_if_missing)
def ResetWorkspace(workspace_name=''):
"""Reset the specific workspace.
Remove all resources of given workspace.
If workspace name is empty, the current workspace will be modified.
Parameters
----------
workspace_name : str
The name of the specific workspace.
Returns
-------
None
References
----------
The wrapper of ``ResetWorkspaceCC``.
"""
ResetWorkspaceCC(workspace_name)
def ClearWorkspace(workspace_name=''):
"""Clear the specific workspace.
You may need to clear the workspace when sharing grads.
If workspace name is empty, the current workspace will be modified.
Parameters
----------
workspace_name : str
The name of the specific workspace.
Returns
-------
None
References
----------
The wrapper of ``ClearWorkspaceCC``.
"""
ClearWorkspaceCC(workspace_name)
def CreateGraph(meta_graph): def CreateGraph(meta_graph):
...@@ -498,7 +550,10 @@ def Restore(filepath, format='default'): ...@@ -498,7 +550,10 @@ def Restore(filepath, format='default'):
from dragon.config import logger from dragon.config import logger
assert os.path.exists(filepath), 'model of path({}) does not exist.'.format(filepath) assert os.path.exists(filepath), 'model of path({}) does not exist.'.format(filepath)
if format == 'default': if format == 'default':
content = cPickle.load(open(filepath, 'rb')) try:
content = cPickle.load(open(filepath, 'rb'))
except UnicodeDecodeError:
content = cPickle.load(open(filepath, 'rb'), encoding='iso-8859-1')
logger.info('Restore From Model@: ' + filepath) logger.info('Restore From Model@: ' + filepath)
logger.info('Model Format: cPickle') logger.info('Model Format: cPickle')
for key, ndarray in content.items(): for key, ndarray in content.items():
......
...@@ -37,6 +37,8 @@ List Brief ...@@ -37,6 +37,8 @@ List Brief
`Snapshot`_ Snapshot tensors into a binary file. `Snapshot`_ Snapshot tensors into a binary file.
`Restore`_ Restore tensors from a binary file. `Restore`_ Restore tensors from a binary file.
`SwitchWorkspace`_ Switch to the specific Workspace. `SwitchWorkspace`_ Switch to the specific Workspace.
`ResetWorkspace`_ Reset the specific workspace.
`ClearWorkspace`_ Clear the specific workspace.
`LogMetaGraph`_ Log the meta graph. `LogMetaGraph`_ Log the meta graph.
`LogOptimizedGraph`_ Log the optimized graph. `LogOptimizedGraph`_ Log the optimized graph.
`ExportMetaGraph`_ Export the meta graph into a file under specific folder. `ExportMetaGraph`_ Export the meta graph into a file under specific folder.
...@@ -51,6 +53,8 @@ API Reference ...@@ -51,6 +53,8 @@ API Reference
:show-inheritance: :show-inheritance:
.. _SwitchWorkspace: #dragon.core.workspace.SwitchWorkspace .. _SwitchWorkspace: #dragon.core.workspace.SwitchWorkspace
.. _ResetWorkspace: #dragon.core.workspace.ResetWorkspace
.. _ClearWorkspace: #dragon.core.workspace.ClearWorkspace
.. _CreateGraph: #dragon.core.workspace.CreateGraph .. _CreateGraph: #dragon.core.workspace.CreateGraph
.. _HasTensor: #dragon.core.workspace.HasTensor .. _HasTensor: #dragon.core.workspace.HasTensor
.. _GetTensorName: #dragon.core.workspace.GetTensorName .. _GetTensorName: #dragon.core.workspace.GetTensorName
......
...@@ -164,7 +164,6 @@ List Brief ...@@ -164,7 +164,6 @@ List Brief
`Accuracy`_ Calculate the Top-K accuracy. `Accuracy`_ Calculate the Top-K accuracy.
`StopGradient`_ Return the identity of input with truncated gradient flow. `StopGradient`_ Return the identity of input with truncated gradient flow.
`MovingAverage`_ Calculate the moving average. `MovingAverage`_ Calculate the moving average.
`Proposal`_ Generate Regional Proposals, introduced by `[Ren et.al, 2015] <https://arxiv.org/abs/1506.01497>`_.
================= ====================================================================== ================= ======================================================================
Contrib Contrib
......
...@@ -36,7 +36,7 @@ find_packages('dragon') ...@@ -36,7 +36,7 @@ find_packages('dragon')
find_modules() find_modules()
setup(name = 'dragon', setup(name = 'dragon',
version='0.2.1.7', version='0.2.1.8',
description = 'Dragon: A Computation Graph Virtual Machine Based Deep Learning Framework', description = 'Dragon: A Computation Graph Virtual Machine Based Deep Learning Framework',
url='https://github.com/neopenx/Dragon', url='https://github.com/neopenx/Dragon',
author='Ting Pan', author='Ting Pan',
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!