Commit 04fdadb0 by Ting PAN

Add ClearWorkspace interface

1 parent 94863c22
......@@ -48,6 +48,14 @@ class Workspace {
return workspace_map_[ws->name()] = ws;
}
inline void ClearWorkspace() {
// clear the relationship of avatars
avatar_map_.clear();
// clear the buffers
ResetBuffer("Common", WORKSPACE_COMMON_BUFFER_SIZE);
ResetBuffer("Grad", WORKSPACE_GRAD_BUFFER_SIZE);
}
/******************** Tensor ********************/
inline string GetTensorName(const string& name) {
......@@ -159,8 +167,8 @@ class Workspace {
/******************** Buffer ********************/
void CreateBuffer(string category, int num) {
CHECK(!buffer_map_.count(category));
buffer_map_[category] = stack<string>();
if (!buffer_map_.count(category))
buffer_map_[category] = stack<string>();
for (int i = 1; i <= num; i++) {
string name = "/share/buffer/" + category + "_" + dragon_cast<string, int>(i);
buffer_map_[category].push(name);
......@@ -179,6 +187,15 @@ class Workspace {
return nullptr;
}
void ResetBuffer(string category, int num) {
while (!buffer_map_[category].empty()) {
string name = buffer_map_[category].top();
buffer_map_[category].pop();
tensor_map_[name]->Reset();
}
CreateBuffer(category, num);
}
void ReleaseBuffer(Tensor* tensor,
string category = "Common",
bool enforce = false) {
......
......@@ -162,12 +162,12 @@ PyObject* WorkspacesCC(PyObject* self, PyObject* args) {
PyObject* ResetWorkspaceCC(PyObject* self, PyObject* args) {
char* cname;
if (!PyArg_ParseTuple(args, "|s", &cname)) {
PyErr_SetString(PyExc_ValueError, "You can only provide a optional name for the new workspace.");
if (!PyArg_ParseTuple(args, "s", &cname)) {
PyErr_SetString(PyExc_ValueError, "You should provide a name to locate the workspace.");
return nullptr;
}
string target_workspace = g_current_workspace;
if (cname) target_workspace = string(cname);
if (!string(cname).empty()) target_workspace = string(cname);
CHECK(g_workspaces.count(target_workspace))
<< "\nWorkspace(" << target_workspace << ") does not exist, can not be reset.";
LOG(INFO) << "Reset the Workspace(" << target_workspace << ")";
......@@ -176,6 +176,21 @@ PyObject* ResetWorkspaceCC(PyObject* self, PyObject* args) {
Py_RETURN_TRUE;
}
PyObject* ClearWorkspaceCC(PyObject* self, PyObject* args) {
char* cname;
if (!PyArg_ParseTuple(args, "s", &cname)) {
PyErr_SetString(PyExc_ValueError, "You should provide a name to locate the workspace.");
return nullptr;
}
string target_workspace = g_current_workspace;
if (!string(cname).empty()) target_workspace = string(cname);
CHECK(g_workspaces.count(target_workspace))
<< "\nWorkspace(" << target_workspace << ") does not exist, can not be reset.";
LOG(INFO) << "Clear the Workspace(" << target_workspace << ")";
g_workspaces[target_workspace]->ClearWorkspace();
Py_RETURN_TRUE;
}
PyObject* TensorsCC(PyObject* self, PyObject* args) {
vector<string> tensor_strings = g_workspace->GetTensors();
PyObject* list = PyList_New(tensor_strings.size());
......@@ -375,6 +390,7 @@ PyMethodDef* GetAllMethods() {
PYFUNC(CurrentWorkspaceCC),
PYFUNC(WorkspacesCC),
PYFUNC(ResetWorkspaceCC),
PYFUNC(ClearWorkspaceCC),
PYFUNC(TensorsCC),
PYFUNC(HasTensorCC),
PYFUNC(GetTensorNameCC),
......
......@@ -795,7 +795,7 @@ class Tensor(object):
# transpose
output = Tensor.CreateOperator(inputs=self, nout=1,
op_type='Transpose', perms=perms, **kwargs)
op_type='Transpose', perms=perms, **kwargs)
if self.shape is not None:
if len(self.shape) != len(perms):
raise ValueError('The ndim of inputs is {}, but perms provide {}'. \
......
......@@ -27,6 +27,8 @@ CURRENT_GRAPH_IDX = 0
__all__ = [
'SwitchWorkspace',
'ResetWorkspace',
'ClearWorkspace',
'CreateGraph',
'RunGraph',
'HasTensor',
......@@ -60,12 +62,12 @@ def _stringify_proto(obj):
else: raise TypeError('object can not be serialized as a string')
def SwitchWorkspace(workspace, create_if_missing=True):
def SwitchWorkspace(workspace_name, create_if_missing=True):
"""Switch to the specific workspace.
Parameters
----------
workspace : str
workspace_name : str
The name of the specific workspace.
create_if_missing : boolean
Whether to create the specific workspace if it does not exist.
......@@ -79,7 +81,57 @@ def SwitchWorkspace(workspace, create_if_missing=True):
The wrapper of ``SwitchWorkspaceCC``.
"""
SwitchWorkspaceCC(workspace, create_if_missing)
if workspace_name == '':
raise ValueError('The workspace name should not be empty.')
SwitchWorkspaceCC(workspace_name, create_if_missing)
def ResetWorkspace(workspace_name=''):
"""Reset the specific workspace.
Remove all resources of given workspace.
If workspace name is empty, the current workspace will be modified.
Parameters
----------
workspace_name : str
The name of the specific workspace.
Returns
-------
None
References
----------
The wrapper of ``ResetWorkspaceCC``.
"""
ResetWorkspaceCC(workspace_name)
def ClearWorkspace(workspace_name=''):
"""Clear the specific workspace.
You may need to clear the workspace when sharing grads.
If workspace name is empty, the current workspace will be modified.
Parameters
----------
workspace_name : str
The name of the specific workspace.
Returns
-------
None
References
----------
The wrapper of ``ClearWorkspaceCC``.
"""
ClearWorkspaceCC(workspace_name)
def CreateGraph(meta_graph):
......@@ -498,7 +550,10 @@ def Restore(filepath, format='default'):
from dragon.config import logger
assert os.path.exists(filepath), 'model of path({}) does not exist.'.format(filepath)
if format == 'default':
content = cPickle.load(open(filepath, 'rb'))
try:
content = cPickle.load(open(filepath, 'rb'))
except UnicodeDecodeError:
content = cPickle.load(open(filepath, 'rb'), encoding='iso-8859-1')
logger.info('Restore From Model@: ' + filepath)
logger.info('Model Format: cPickle')
for key, ndarray in content.items():
......
......@@ -37,6 +37,8 @@ List Brief
`Snapshot`_ Snapshot tensors into a binary file.
`Restore`_ Restore tensors from a binary file.
`SwitchWorkspace`_ Switch to the specific Workspace.
`ResetWorkspace`_ Reset the specific workspace.
`ClearWorkspace`_ Clear the specific workspace.
`LogMetaGraph`_ Log the meta graph.
`LogOptimizedGraph`_ Log the optimized graph.
`ExportMetaGraph`_ Export the meta graph into a file under specific folder.
......@@ -51,6 +53,8 @@ API Reference
:show-inheritance:
.. _SwitchWorkspace: #dragon.core.workspace.SwitchWorkspace
.. _ResetWorkspace: #dragon.core.workspace.ResetWorkspace
.. _ClearWorkspace: #dragon.core.workspace.ClearWorkspace
.. _CreateGraph: #dragon.core.workspace.CreateGraph
.. _HasTensor: #dragon.core.workspace.HasTensor
.. _GetTensorName: #dragon.core.workspace.GetTensorName
......
......@@ -164,7 +164,6 @@ List Brief
`Accuracy`_ Calculate the Top-K accuracy.
`StopGradient`_ Return the identity of input with truncated gradient flow.
`MovingAverage`_ Calculate the moving average.
`Proposal`_ Generate Regional Proposals, introduced by `[Ren et.al, 2015] <https://arxiv.org/abs/1506.01497>`_.
================= ======================================================================
Contrib
......
......@@ -36,7 +36,7 @@ find_packages('dragon')
find_modules()
setup(name = 'dragon',
version='0.2.1.7',
version='0.2.1.8',
description = 'Dragon: A Computation Graph Virtual Machine Based Deep Learning Framework',
url='https://github.com/neopenx/Dragon',
author='Ting Pan',
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!