Commit 7e98dfd9 by Ting PAN
1 parent 5d518b6c
......@@ -132,15 +132,14 @@ bool SwitchWorkspaceInternal(const string& name, const bool create_if_missing) {
}
PyObject* SwitchWorkspaceCC(PyObject* self, PyObject *args) {
PyObject* name = nullptr;
char* cname;
PyObject* create_if_missing = nullptr;
if (!PyArg_ParseTuple(args, "S|O", &name, &create_if_missing)) {
if (!PyArg_ParseTuple(args, "s|O", &cname, &create_if_missing)) {
PyErr_SetString(PyExc_ValueError, "SwitchWorkspaceCC takes a workspace name and a optional "
"bool value that specific whether to create the workspace if missing.");
return nullptr;
}
bool success = SwitchWorkspaceInternal(PyBytesToStdString(name),
PyObject_IsTrue(create_if_missing));
bool success = SwitchWorkspaceInternal(string(cname), PyObject_IsTrue(create_if_missing));
if (!success) {
PyErr_SetString(PyExc_RuntimeError, "workspace of the given name is not exist and "
"is not allowed to create. (try alllow ?)");
......@@ -162,15 +161,16 @@ PyObject* WorkspacesCC(PyObject* self, PyObject* args) {
}
PyObject* ResetWorkspaceCC(PyObject* self, PyObject* args) {
PyObject* root_folder = nullptr;
if (!PyArg_ParseTuple(args, "|S", &root_folder)) {
char* cname;
if (!PyArg_ParseTuple(args, "|s", &cname)) {
PyErr_SetString(PyExc_ValueError, "ResetWorkspaceCC takes in either no args or a string "
"specifing the root folder of the workspace.");
"specifing the name of the workspace.");
return nullptr;
}
LOG(INFO) << "Reset the Workspace(" << g_current_workspace << ")";
if (root_folder == nullptr) g_workspaces[g_current_workspace].reset(new Workspace());
else g_workspaces[g_current_workspace].reset(new Workspace(PyBytesToStdString(root_folder)));
string workspace_name = string(cname);
if (workspace_name.empty()) g_workspaces[g_current_workspace].reset(new Workspace());
else g_workspaces[g_current_workspace].reset(new Workspace(workspace_name));
g_workspace = g_workspaces[g_current_workspace].get();
Py_RETURN_TRUE;
}
......@@ -343,12 +343,12 @@ PyObject* SnapshotCC(PyObject* self, PyObject* args) {
return nullptr;
}
switch (format) {
case 0: // cPickle
case 0: // cPickle
PyErr_SetString(PyExc_NotImplementedError, "format(0) depends on cPickle, should not be used in CC.");
break;
case 1: // caffemodel
case 1: // caffemodel
for (int i = 0; i < PyList_Size(names); i++)
tensors.push_back(g_workspace->GetTensor(PyBytesToStdString(PyList_GetItem(names, i))));
tensors.push_back(g_workspace->GetTensor(PyString_AsString(PyList_GetItem(names, i))));
SavaCaffeModel(string(cname), tensors);
break;
default: LOG(FATAL) << "Unknwon Restore Format, code: " << format;
......
......@@ -21,6 +21,7 @@
#ifdef WITH_PYTHON3
#define PyString_AsString PyUnicode_AsUTF8
#define PyBytes_FromStringAndSize PyUnicode_FromStringAndSize
#endif
using namespace dragon;
......
......@@ -69,9 +69,4 @@ def SetLoggingLevel(level):
'WARNING': logging.WARNING,
'ERROR': logging.ERROR,
'FATAL': logging.CRITICAL
}[level])
}[level])
\ No newline at end of file
......@@ -17,9 +17,8 @@ class GraphGradientMaker(object):
def CreateGradientForOp(cls, op_def, g_output):
""" parse ops from string """
g_ops, g_inputs, defaults = CreateGradientDefsCC(op_def.SerializeToString(), g_output)
if sys.version_info >= (3, 0):
g_inputs = [g_input.decode('ascii') for g_input in g_inputs]
for idx, g_op in enumerate(g_ops):
if sys.version_info >= (3, 0): g_op = g_op.encode()
new_def = pb.OperatorDef()
new_def.ParseFromString(g_op)
_, new_def.name = GetOperatorName()
......
......@@ -36,7 +36,7 @@ def GetTensorName():
class TensorScope(object):
SEPARATOR = '/'
def __init__(self, prefix):
assert isinstance(prefix, basestring), \
assert isinstance(prefix, type('str')), \
"TensorScope takes in a string as its argument."
self.prefix = prefix + TensorScope.SEPARATOR
......@@ -51,7 +51,7 @@ class TensorScope(object):
class PhaseScope(object):
def __init__(self, phase):
assert isinstance(phase, basestring), \
assert isinstance(phase, type('str')), \
"PhaseScope takes in a string as its argument."
self.phase = phase
......
......@@ -49,7 +49,7 @@ class LMDB(object):
def put(self, key, value):
self._buffer.append((wrapper_str(key), wrapper_str(value)))
self._buffer.append((wrapper_str(key), value))
self._cur_put += 1
if (self._cur_put >= self._max_commit): self._try_put()
......
......@@ -260,7 +260,7 @@ class BNLayer(Layer):
'eps': param.eps}
mean = Tensor(LayerParameter.name + '@param0').Constant()
var = Tensor(LayerParameter.name + '@param1').Constant()
scale = Tensor(LayerParameter.name + '@param2').Uniform(low=0.0, high=1.0)
scale = Tensor(LayerParameter.name + '@param2').Constant(value=1.0)
bias = Tensor(LayerParameter.name + '@param3').Constant(value=0.0)
self.norm_blobs = [{'data': mean, 'diff': None},
{'data': var, 'diff': None}]
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!