Commit 7e98dfd9 by Ting PAN
1 parent 5d518b6c
...@@ -132,15 +132,14 @@ bool SwitchWorkspaceInternal(const string& name, const bool create_if_missing) { ...@@ -132,15 +132,14 @@ bool SwitchWorkspaceInternal(const string& name, const bool create_if_missing) {
} }
PyObject* SwitchWorkspaceCC(PyObject* self, PyObject *args) { PyObject* SwitchWorkspaceCC(PyObject* self, PyObject *args) {
PyObject* name = nullptr; char* cname;
PyObject* create_if_missing = nullptr; PyObject* create_if_missing = nullptr;
if (!PyArg_ParseTuple(args, "S|O", &name, &create_if_missing)) { if (!PyArg_ParseTuple(args, "s|O", &cname, &create_if_missing)) {
PyErr_SetString(PyExc_ValueError, "SwitchWorkspaceCC takes a workspace name and a optional " PyErr_SetString(PyExc_ValueError, "SwitchWorkspaceCC takes a workspace name and a optional "
"bool value that specific whether to create the workspace if missing."); "bool value that specific whether to create the workspace if missing.");
return nullptr; return nullptr;
} }
bool success = SwitchWorkspaceInternal(PyBytesToStdString(name), bool success = SwitchWorkspaceInternal(string(cname), PyObject_IsTrue(create_if_missing));
PyObject_IsTrue(create_if_missing));
if (!success) { if (!success) {
PyErr_SetString(PyExc_RuntimeError, "workspace of the given name is not exist and " PyErr_SetString(PyExc_RuntimeError, "workspace of the given name is not exist and "
"is not allowed to create. (try alllow ?)"); "is not allowed to create. (try alllow ?)");
...@@ -162,15 +161,16 @@ PyObject* WorkspacesCC(PyObject* self, PyObject* args) { ...@@ -162,15 +161,16 @@ PyObject* WorkspacesCC(PyObject* self, PyObject* args) {
} }
PyObject* ResetWorkspaceCC(PyObject* self, PyObject* args) { PyObject* ResetWorkspaceCC(PyObject* self, PyObject* args) {
PyObject* root_folder = nullptr; char* cname;
if (!PyArg_ParseTuple(args, "|S", &root_folder)) { if (!PyArg_ParseTuple(args, "|s", &cname)) {
PyErr_SetString(PyExc_ValueError, "ResetWorkspaceCC takes in either no args or a string " PyErr_SetString(PyExc_ValueError, "ResetWorkspaceCC takes in either no args or a string "
"specifing the root folder of the workspace."); "specifing the name of the workspace.");
return nullptr; return nullptr;
} }
LOG(INFO) << "Reset the Workspace(" << g_current_workspace << ")"; LOG(INFO) << "Reset the Workspace(" << g_current_workspace << ")";
if (root_folder == nullptr) g_workspaces[g_current_workspace].reset(new Workspace()); string workspace_name = string(cname);
else g_workspaces[g_current_workspace].reset(new Workspace(PyBytesToStdString(root_folder))); if (workspace_name.empty()) g_workspaces[g_current_workspace].reset(new Workspace());
else g_workspaces[g_current_workspace].reset(new Workspace(workspace_name));
g_workspace = g_workspaces[g_current_workspace].get(); g_workspace = g_workspaces[g_current_workspace].get();
Py_RETURN_TRUE; Py_RETURN_TRUE;
} }
...@@ -343,12 +343,12 @@ PyObject* SnapshotCC(PyObject* self, PyObject* args) { ...@@ -343,12 +343,12 @@ PyObject* SnapshotCC(PyObject* self, PyObject* args) {
return nullptr; return nullptr;
} }
switch (format) { switch (format) {
case 0: // cPickle case 0: // cPickle
PyErr_SetString(PyExc_NotImplementedError, "format(0) depends on cPickle, should not be used in CC."); PyErr_SetString(PyExc_NotImplementedError, "format(0) depends on cPickle, should not be used in CC.");
break; break;
case 1: // caffemodel case 1: // caffemodel
for (int i = 0; i < PyList_Size(names); i++) for (int i = 0; i < PyList_Size(names); i++)
tensors.push_back(g_workspace->GetTensor(PyBytesToStdString(PyList_GetItem(names, i)))); tensors.push_back(g_workspace->GetTensor(PyString_AsString(PyList_GetItem(names, i))));
SavaCaffeModel(string(cname), tensors); SavaCaffeModel(string(cname), tensors);
break; break;
default: LOG(FATAL) << "Unknwon Restore Format, code: " << format; default: LOG(FATAL) << "Unknwon Restore Format, code: " << format;
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#ifdef WITH_PYTHON3 #ifdef WITH_PYTHON3
#define PyString_AsString PyUnicode_AsUTF8 #define PyString_AsString PyUnicode_AsUTF8
#define PyBytes_FromStringAndSize PyUnicode_FromStringAndSize
#endif #endif
using namespace dragon; using namespace dragon;
......
...@@ -69,9 +69,4 @@ def SetLoggingLevel(level): ...@@ -69,9 +69,4 @@ def SetLoggingLevel(level):
'WARNING': logging.WARNING, 'WARNING': logging.WARNING,
'ERROR': logging.ERROR, 'ERROR': logging.ERROR,
'FATAL': logging.CRITICAL 'FATAL': logging.CRITICAL
}[level]) }[level])
\ No newline at end of file
...@@ -17,9 +17,8 @@ class GraphGradientMaker(object): ...@@ -17,9 +17,8 @@ class GraphGradientMaker(object):
def CreateGradientForOp(cls, op_def, g_output): def CreateGradientForOp(cls, op_def, g_output):
""" parse ops from string """ """ parse ops from string """
g_ops, g_inputs, defaults = CreateGradientDefsCC(op_def.SerializeToString(), g_output) g_ops, g_inputs, defaults = CreateGradientDefsCC(op_def.SerializeToString(), g_output)
if sys.version_info >= (3, 0):
g_inputs = [g_input.decode('ascii') for g_input in g_inputs]
for idx, g_op in enumerate(g_ops): for idx, g_op in enumerate(g_ops):
if sys.version_info >= (3, 0): g_op = g_op.encode()
new_def = pb.OperatorDef() new_def = pb.OperatorDef()
new_def.ParseFromString(g_op) new_def.ParseFromString(g_op)
_, new_def.name = GetOperatorName() _, new_def.name = GetOperatorName()
......
...@@ -36,7 +36,7 @@ def GetTensorName(): ...@@ -36,7 +36,7 @@ def GetTensorName():
class TensorScope(object): class TensorScope(object):
SEPARATOR = '/' SEPARATOR = '/'
def __init__(self, prefix): def __init__(self, prefix):
assert isinstance(prefix, basestring), \ assert isinstance(prefix, type('str')), \
"TensorScope takes in a string as its argument." "TensorScope takes in a string as its argument."
self.prefix = prefix + TensorScope.SEPARATOR self.prefix = prefix + TensorScope.SEPARATOR
...@@ -51,7 +51,7 @@ class TensorScope(object): ...@@ -51,7 +51,7 @@ class TensorScope(object):
class PhaseScope(object): class PhaseScope(object):
def __init__(self, phase): def __init__(self, phase):
assert isinstance(phase, basestring), \ assert isinstance(phase, type('str')), \
"PhaseScope takes in a string as its argument." "PhaseScope takes in a string as its argument."
self.phase = phase self.phase = phase
......
...@@ -49,7 +49,7 @@ class LMDB(object): ...@@ -49,7 +49,7 @@ class LMDB(object):
def put(self, key, value): def put(self, key, value):
self._buffer.append((wrapper_str(key), wrapper_str(value))) self._buffer.append((wrapper_str(key), value))
self._cur_put += 1 self._cur_put += 1
if (self._cur_put >= self._max_commit): self._try_put() if (self._cur_put >= self._max_commit): self._try_put()
......
...@@ -260,7 +260,7 @@ class BNLayer(Layer): ...@@ -260,7 +260,7 @@ class BNLayer(Layer):
'eps': param.eps} 'eps': param.eps}
mean = Tensor(LayerParameter.name + '@param0').Constant() mean = Tensor(LayerParameter.name + '@param0').Constant()
var = Tensor(LayerParameter.name + '@param1').Constant() var = Tensor(LayerParameter.name + '@param1').Constant()
scale = Tensor(LayerParameter.name + '@param2').Uniform(low=0.0, high=1.0) scale = Tensor(LayerParameter.name + '@param2').Constant(value=1.0)
bias = Tensor(LayerParameter.name + '@param3').Constant(value=0.0) bias = Tensor(LayerParameter.name + '@param3').Constant(value=0.0)
self.norm_blobs = [{'data': mean, 'diff': None}, self.norm_blobs = [{'data': mean, 'diff': None},
{'data': var, 'diff': None}] {'data': var, 'diff': None}]
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!