Commit 93943fc8 by Ting PAN
1 parent 2f5edb5c
......@@ -58,7 +58,7 @@ PyObject* RegisteredOperatorsCC(PyObject* self, PyObject* args) {
PyObject* list = PyList_New(all_keys.size());
int idx = 0;
for (const string& name : all_keys)
CHECK_EQ(PyList_SetItem(list, idx++, StdStringToPyBytes(name)), 0);
CHECK_EQ(PyList_SetItem(list, idx++, StdStringToPyUnicode(name)), 0);
return list;
}
......@@ -68,7 +68,7 @@ PyObject* NoGradientOperatorsCC(PyObject* self, PyObject* args) {
PyObject* list = PyList_New(all_keys.size());
int idx = 0;
for (const string& name : all_keys)
CHECK_EQ(PyList_SetItem(list, idx++, StdStringToPyBytes(name)), 0);
CHECK_EQ(PyList_SetItem(list, idx++, StdStringToPyUnicode(name)), 0);
return list;
}
......@@ -106,7 +106,7 @@ PyObject* CreateGradientDefsCC(PyObject* self, PyObject* args) {
PyObject* g_input_py = PyList_New(grad.g_inputs.size());
for (int i = 0; i < grad.g_inputs.size(); i++)
CHECK_EQ(PyList_SetItem(g_input_py, i, StdStringToPyBytes(grad.g_inputs[i])), 0);
CHECK_EQ(PyList_SetItem(g_input_py, i, StdStringToPyUnicode(grad.g_inputs[i])), 0);
PyObject* defaults_py = PyList_New(grad.defaults.size());
for (int i = 0; i < grad.defaults.size(); i++)
......@@ -149,14 +149,14 @@ PyObject* SwitchWorkspaceCC(PyObject* self, PyObject *args) {
}
PyObject* CurrentWorkspaceCC(PyObject* self, PyObject* args) {
return StdStringToPyBytes(g_current_workspace);
return StdStringToPyUnicode(g_current_workspace);
}
PyObject* WorkspacesCC(PyObject* self, PyObject* args) {
PyObject* list = PyList_New(g_workspaces.size());
int i = 0;
for (auto const& it : g_workspaces)
CHECK_EQ(PyList_SetItem(list, i++, StdStringToPyBytes(it.first)), 0);
CHECK_EQ(PyList_SetItem(list, i++, StdStringToPyUnicode(it.first)), 0);
return list;
}
......@@ -176,14 +176,14 @@ PyObject* ResetWorkspaceCC(PyObject* self, PyObject* args) {
}
PyObject* RootFolderCC(PyObject* self, PyObject* args) {
return StdStringToPyBytes(g_workspace->GetRootFolder());
return StdStringToPyUnicode(g_workspace->GetRootFolder());
}
PyObject* TensorsCC(PyObject* self, PyObject* args) {
vector<string> tensor_strings = g_workspace->GetTensors();
PyObject* list = PyList_New(tensor_strings.size());
for (int i = 0; i < tensor_strings.size(); i++)
CHECK_EQ(PyList_SetItem(list, i, StdStringToPyBytes(tensor_strings[i])), 0);
CHECK_EQ(PyList_SetItem(list, i, StdStringToPyUnicode(tensor_strings[i])), 0);
return list;
}
......@@ -224,7 +224,7 @@ PyObject* GetTensorNameCC(PyObject* self, PyObject* args) {
char* cname;
if (!PyArg_ParseTuple(args, "s", &cname)) return nullptr;
string query = g_workspace->GetTensorName(string(cname));
return StdStringToPyBytes(query);
return StdStringToPyUnicode(query);
}
PyObject* CreateGraphCC(PyObject* self, PyObject* args) {
......@@ -263,7 +263,7 @@ PyObject* GraphsCC(PyObject* self, PyObject* args) {
vector<string> graph_string = g_workspace->GetGraphs();
PyObject* list = PyList_New(graph_string.size());
for (int i = 0; i < graph_string.size(); i++)
CHECK_EQ(PyList_SetItem(list, i, StdStringToPyBytes(graph_string[i])), 0);
CHECK_EQ(PyList_SetItem(list, i, StdStringToPyUnicode(graph_string[i])), 0);
return list;
}
......
......@@ -21,7 +21,6 @@
#ifdef WITH_PYTHON3
#define PyString_AsString PyUnicode_AsUTF8
#define PyBytes_FromStringAndSize PyUnicode_FromStringAndSize
#endif
using namespace dragon;
......@@ -33,6 +32,15 @@ inline std::string PyBytesToStdString(PyObject* pystring) {
inline PyObject* StdStringToPyBytes(const std::string& str) {
return PyBytes_FromStringAndSize(str.c_str(), str.size());
}
inline PyObject* StdStringToPyUnicode(const std::string& str) {
#ifdef WITH_PYTHON3
return PyUnicode_FromStringAndSize(str.c_str(), str.size());
#else
return PyBytes_FromStringAndSize(str.c_str(), str.size());
#endif
}
template <typename T>
inline void MakeStringInternal(std::stringstream& ss, const T& t) { ss << t; }
......@@ -114,7 +122,7 @@ class StringFetcher : public TensorFetcherBase {
public:
PyObject* Fetch(const Tensor& tensor) override {
CHECK_GT(tensor.count(), 0);
return StdStringToPyBytes(*tensor.data<string,CPUContext>());
return StdStringToPyBytes(*tensor.data<string, CPUContext>());
}
};
......
......@@ -18,7 +18,6 @@ class GraphGradientMaker(object):
""" parse ops from string """
g_ops, g_inputs, defaults = CreateGradientDefsCC(op_def.SerializeToString(), g_output)
for idx, g_op in enumerate(g_ops):
if sys.version_info >= (3, 0): g_op = g_op.encode()
new_def = pb.OperatorDef()
new_def.ParseFromString(g_op)
_, new_def.name = GetOperatorName()
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!