Commit 93943fc8 by Ting PAN
1 parent 2f5edb5c
...@@ -58,7 +58,7 @@ PyObject* RegisteredOperatorsCC(PyObject* self, PyObject* args) { ...@@ -58,7 +58,7 @@ PyObject* RegisteredOperatorsCC(PyObject* self, PyObject* args) {
PyObject* list = PyList_New(all_keys.size()); PyObject* list = PyList_New(all_keys.size());
int idx = 0; int idx = 0;
for (const string& name : all_keys) for (const string& name : all_keys)
CHECK_EQ(PyList_SetItem(list, idx++, StdStringToPyBytes(name)), 0); CHECK_EQ(PyList_SetItem(list, idx++, StdStringToPyUnicode(name)), 0);
return list; return list;
} }
...@@ -68,7 +68,7 @@ PyObject* NoGradientOperatorsCC(PyObject* self, PyObject* args) { ...@@ -68,7 +68,7 @@ PyObject* NoGradientOperatorsCC(PyObject* self, PyObject* args) {
PyObject* list = PyList_New(all_keys.size()); PyObject* list = PyList_New(all_keys.size());
int idx = 0; int idx = 0;
for (const string& name : all_keys) for (const string& name : all_keys)
CHECK_EQ(PyList_SetItem(list, idx++, StdStringToPyBytes(name)), 0); CHECK_EQ(PyList_SetItem(list, idx++, StdStringToPyUnicode(name)), 0);
return list; return list;
} }
...@@ -106,7 +106,7 @@ PyObject* CreateGradientDefsCC(PyObject* self, PyObject* args) { ...@@ -106,7 +106,7 @@ PyObject* CreateGradientDefsCC(PyObject* self, PyObject* args) {
PyObject* g_input_py = PyList_New(grad.g_inputs.size()); PyObject* g_input_py = PyList_New(grad.g_inputs.size());
for (int i = 0; i < grad.g_inputs.size(); i++) for (int i = 0; i < grad.g_inputs.size(); i++)
CHECK_EQ(PyList_SetItem(g_input_py, i, StdStringToPyBytes(grad.g_inputs[i])), 0); CHECK_EQ(PyList_SetItem(g_input_py, i, StdStringToPyUnicode(grad.g_inputs[i])), 0);
PyObject* defaults_py = PyList_New(grad.defaults.size()); PyObject* defaults_py = PyList_New(grad.defaults.size());
for (int i = 0; i < grad.defaults.size(); i++) for (int i = 0; i < grad.defaults.size(); i++)
...@@ -149,14 +149,14 @@ PyObject* SwitchWorkspaceCC(PyObject* self, PyObject *args) { ...@@ -149,14 +149,14 @@ PyObject* SwitchWorkspaceCC(PyObject* self, PyObject *args) {
} }
PyObject* CurrentWorkspaceCC(PyObject* self, PyObject* args) { PyObject* CurrentWorkspaceCC(PyObject* self, PyObject* args) {
return StdStringToPyBytes(g_current_workspace); return StdStringToPyUnicode(g_current_workspace);
} }
PyObject* WorkspacesCC(PyObject* self, PyObject* args) { PyObject* WorkspacesCC(PyObject* self, PyObject* args) {
PyObject* list = PyList_New(g_workspaces.size()); PyObject* list = PyList_New(g_workspaces.size());
int i = 0; int i = 0;
for (auto const& it : g_workspaces) for (auto const& it : g_workspaces)
CHECK_EQ(PyList_SetItem(list, i++, StdStringToPyBytes(it.first)), 0); CHECK_EQ(PyList_SetItem(list, i++, StdStringToPyUnicode(it.first)), 0);
return list; return list;
} }
...@@ -176,14 +176,14 @@ PyObject* ResetWorkspaceCC(PyObject* self, PyObject* args) { ...@@ -176,14 +176,14 @@ PyObject* ResetWorkspaceCC(PyObject* self, PyObject* args) {
} }
PyObject* RootFolderCC(PyObject* self, PyObject* args) { PyObject* RootFolderCC(PyObject* self, PyObject* args) {
return StdStringToPyBytes(g_workspace->GetRootFolder()); return StdStringToPyUnicode(g_workspace->GetRootFolder());
} }
PyObject* TensorsCC(PyObject* self, PyObject* args) { PyObject* TensorsCC(PyObject* self, PyObject* args) {
vector<string> tensor_strings = g_workspace->GetTensors(); vector<string> tensor_strings = g_workspace->GetTensors();
PyObject* list = PyList_New(tensor_strings.size()); PyObject* list = PyList_New(tensor_strings.size());
for (int i = 0; i < tensor_strings.size(); i++) for (int i = 0; i < tensor_strings.size(); i++)
CHECK_EQ(PyList_SetItem(list, i, StdStringToPyBytes(tensor_strings[i])), 0); CHECK_EQ(PyList_SetItem(list, i, StdStringToPyUnicode(tensor_strings[i])), 0);
return list; return list;
} }
...@@ -224,7 +224,7 @@ PyObject* GetTensorNameCC(PyObject* self, PyObject* args) { ...@@ -224,7 +224,7 @@ PyObject* GetTensorNameCC(PyObject* self, PyObject* args) {
char* cname; char* cname;
if (!PyArg_ParseTuple(args, "s", &cname)) return nullptr; if (!PyArg_ParseTuple(args, "s", &cname)) return nullptr;
string query = g_workspace->GetTensorName(string(cname)); string query = g_workspace->GetTensorName(string(cname));
return StdStringToPyBytes(query); return StdStringToPyUnicode(query);
} }
PyObject* CreateGraphCC(PyObject* self, PyObject* args) { PyObject* CreateGraphCC(PyObject* self, PyObject* args) {
...@@ -263,7 +263,7 @@ PyObject* GraphsCC(PyObject* self, PyObject* args) { ...@@ -263,7 +263,7 @@ PyObject* GraphsCC(PyObject* self, PyObject* args) {
vector<string> graph_string = g_workspace->GetGraphs(); vector<string> graph_string = g_workspace->GetGraphs();
PyObject* list = PyList_New(graph_string.size()); PyObject* list = PyList_New(graph_string.size());
for (int i = 0; i < graph_string.size(); i++) for (int i = 0; i < graph_string.size(); i++)
CHECK_EQ(PyList_SetItem(list, i, StdStringToPyBytes(graph_string[i])), 0); CHECK_EQ(PyList_SetItem(list, i, StdStringToPyUnicode(graph_string[i])), 0);
return list; return list;
} }
......
...@@ -21,7 +21,6 @@ ...@@ -21,7 +21,6 @@
#ifdef WITH_PYTHON3 #ifdef WITH_PYTHON3
#define PyString_AsString PyUnicode_AsUTF8 #define PyString_AsString PyUnicode_AsUTF8
#define PyBytes_FromStringAndSize PyUnicode_FromStringAndSize
#endif #endif
using namespace dragon; using namespace dragon;
...@@ -33,6 +32,15 @@ inline std::string PyBytesToStdString(PyObject* pystring) { ...@@ -33,6 +32,15 @@ inline std::string PyBytesToStdString(PyObject* pystring) {
inline PyObject* StdStringToPyBytes(const std::string& str) { inline PyObject* StdStringToPyBytes(const std::string& str) {
return PyBytes_FromStringAndSize(str.c_str(), str.size()); return PyBytes_FromStringAndSize(str.c_str(), str.size());
} }
inline PyObject* StdStringToPyUnicode(const std::string& str) {
#ifdef WITH_PYTHON3
return PyUnicode_FromStringAndSize(str.c_str(), str.size());
#else
return PyBytes_FromStringAndSize(str.c_str(), str.size());
#endif
}
template <typename T> template <typename T>
inline void MakeStringInternal(std::stringstream& ss, const T& t) { ss << t; } inline void MakeStringInternal(std::stringstream& ss, const T& t) { ss << t; }
...@@ -114,7 +122,7 @@ class StringFetcher : public TensorFetcherBase { ...@@ -114,7 +122,7 @@ class StringFetcher : public TensorFetcherBase {
public: public:
PyObject* Fetch(const Tensor& tensor) override { PyObject* Fetch(const Tensor& tensor) override {
CHECK_GT(tensor.count(), 0); CHECK_GT(tensor.count(), 0);
return StdStringToPyBytes(*tensor.data<string,CPUContext>()); return StdStringToPyBytes(*tensor.data<string, CPUContext>());
} }
}; };
......
...@@ -18,7 +18,6 @@ class GraphGradientMaker(object): ...@@ -18,7 +18,6 @@ class GraphGradientMaker(object):
""" parse ops from string """ """ parse ops from string """
g_ops, g_inputs, defaults = CreateGradientDefsCC(op_def.SerializeToString(), g_output) g_ops, g_inputs, defaults = CreateGradientDefsCC(op_def.SerializeToString(), g_output)
for idx, g_op in enumerate(g_ops): for idx, g_op in enumerate(g_ops):
if sys.version_info >= (3, 0): g_op = g_op.encode()
new_def = pb.OperatorDef() new_def = pb.OperatorDef()
new_def.ParseFromString(g_op) new_def.ParseFromString(g_op)
_, new_def.name = GetOperatorName() _, new_def.name = GetOperatorName()
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!