Skip to content
Toggle navigation
P
Projects
G
Groups
S
Snippets
Help
SeetaResearch
/
Dragon
This project
Loading...
Sign in
Toggle navigation
Go to a project
Project
Repository
Issues
0
Merge Requests
0
Pipelines
Wiki
Snippets
Settings
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Commit 93943fc8
authored
Aug 08, 2017
by
Ting PAN
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix bugs of
https://github.com/neopenx/Dragon/issues/6
1 parent
2f5edb5c
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
19 additions
and
12 deletions
Dragon/modules/python/dragon.cc
Dragon/modules/python/dragon.h
Dragon/python/dragon/core/gradient_maker.py
Dragon/modules/python/dragon.cc
View file @
93943fc
...
@@ -58,7 +58,7 @@ PyObject* RegisteredOperatorsCC(PyObject* self, PyObject* args) {
...
@@ -58,7 +58,7 @@ PyObject* RegisteredOperatorsCC(PyObject* self, PyObject* args) {
PyObject
*
list
=
PyList_New
(
all_keys
.
size
());
PyObject
*
list
=
PyList_New
(
all_keys
.
size
());
int
idx
=
0
;
int
idx
=
0
;
for
(
const
string
&
name
:
all_keys
)
for
(
const
string
&
name
:
all_keys
)
CHECK_EQ
(
PyList_SetItem
(
list
,
idx
++
,
StdStringToPy
Bytes
(
name
)),
0
);
CHECK_EQ
(
PyList_SetItem
(
list
,
idx
++
,
StdStringToPy
Unicode
(
name
)),
0
);
return
list
;
return
list
;
}
}
...
@@ -68,7 +68,7 @@ PyObject* NoGradientOperatorsCC(PyObject* self, PyObject* args) {
...
@@ -68,7 +68,7 @@ PyObject* NoGradientOperatorsCC(PyObject* self, PyObject* args) {
PyObject
*
list
=
PyList_New
(
all_keys
.
size
());
PyObject
*
list
=
PyList_New
(
all_keys
.
size
());
int
idx
=
0
;
int
idx
=
0
;
for
(
const
string
&
name
:
all_keys
)
for
(
const
string
&
name
:
all_keys
)
CHECK_EQ
(
PyList_SetItem
(
list
,
idx
++
,
StdStringToPy
Bytes
(
name
)),
0
);
CHECK_EQ
(
PyList_SetItem
(
list
,
idx
++
,
StdStringToPy
Unicode
(
name
)),
0
);
return
list
;
return
list
;
}
}
...
@@ -106,7 +106,7 @@ PyObject* CreateGradientDefsCC(PyObject* self, PyObject* args) {
...
@@ -106,7 +106,7 @@ PyObject* CreateGradientDefsCC(PyObject* self, PyObject* args) {
PyObject
*
g_input_py
=
PyList_New
(
grad
.
g_inputs
.
size
());
PyObject
*
g_input_py
=
PyList_New
(
grad
.
g_inputs
.
size
());
for
(
int
i
=
0
;
i
<
grad
.
g_inputs
.
size
();
i
++
)
for
(
int
i
=
0
;
i
<
grad
.
g_inputs
.
size
();
i
++
)
CHECK_EQ
(
PyList_SetItem
(
g_input_py
,
i
,
StdStringToPy
Bytes
(
grad
.
g_inputs
[
i
])),
0
);
CHECK_EQ
(
PyList_SetItem
(
g_input_py
,
i
,
StdStringToPy
Unicode
(
grad
.
g_inputs
[
i
])),
0
);
PyObject
*
defaults_py
=
PyList_New
(
grad
.
defaults
.
size
());
PyObject
*
defaults_py
=
PyList_New
(
grad
.
defaults
.
size
());
for
(
int
i
=
0
;
i
<
grad
.
defaults
.
size
();
i
++
)
for
(
int
i
=
0
;
i
<
grad
.
defaults
.
size
();
i
++
)
...
@@ -149,14 +149,14 @@ PyObject* SwitchWorkspaceCC(PyObject* self, PyObject *args) {
...
@@ -149,14 +149,14 @@ PyObject* SwitchWorkspaceCC(PyObject* self, PyObject *args) {
}
}
PyObject
*
CurrentWorkspaceCC
(
PyObject
*
self
,
PyObject
*
args
)
{
PyObject
*
CurrentWorkspaceCC
(
PyObject
*
self
,
PyObject
*
args
)
{
return
StdStringToPy
Bytes
(
g_current_workspace
);
return
StdStringToPy
Unicode
(
g_current_workspace
);
}
}
PyObject
*
WorkspacesCC
(
PyObject
*
self
,
PyObject
*
args
)
{
PyObject
*
WorkspacesCC
(
PyObject
*
self
,
PyObject
*
args
)
{
PyObject
*
list
=
PyList_New
(
g_workspaces
.
size
());
PyObject
*
list
=
PyList_New
(
g_workspaces
.
size
());
int
i
=
0
;
int
i
=
0
;
for
(
auto
const
&
it
:
g_workspaces
)
for
(
auto
const
&
it
:
g_workspaces
)
CHECK_EQ
(
PyList_SetItem
(
list
,
i
++
,
StdStringToPy
Bytes
(
it
.
first
)),
0
);
CHECK_EQ
(
PyList_SetItem
(
list
,
i
++
,
StdStringToPy
Unicode
(
it
.
first
)),
0
);
return
list
;
return
list
;
}
}
...
@@ -176,14 +176,14 @@ PyObject* ResetWorkspaceCC(PyObject* self, PyObject* args) {
...
@@ -176,14 +176,14 @@ PyObject* ResetWorkspaceCC(PyObject* self, PyObject* args) {
}
}
PyObject
*
RootFolderCC
(
PyObject
*
self
,
PyObject
*
args
)
{
PyObject
*
RootFolderCC
(
PyObject
*
self
,
PyObject
*
args
)
{
return
StdStringToPy
Bytes
(
g_workspace
->
GetRootFolder
());
return
StdStringToPy
Unicode
(
g_workspace
->
GetRootFolder
());
}
}
PyObject
*
TensorsCC
(
PyObject
*
self
,
PyObject
*
args
)
{
PyObject
*
TensorsCC
(
PyObject
*
self
,
PyObject
*
args
)
{
vector
<
string
>
tensor_strings
=
g_workspace
->
GetTensors
();
vector
<
string
>
tensor_strings
=
g_workspace
->
GetTensors
();
PyObject
*
list
=
PyList_New
(
tensor_strings
.
size
());
PyObject
*
list
=
PyList_New
(
tensor_strings
.
size
());
for
(
int
i
=
0
;
i
<
tensor_strings
.
size
();
i
++
)
for
(
int
i
=
0
;
i
<
tensor_strings
.
size
();
i
++
)
CHECK_EQ
(
PyList_SetItem
(
list
,
i
,
StdStringToPy
Bytes
(
tensor_strings
[
i
])),
0
);
CHECK_EQ
(
PyList_SetItem
(
list
,
i
,
StdStringToPy
Unicode
(
tensor_strings
[
i
])),
0
);
return
list
;
return
list
;
}
}
...
@@ -224,7 +224,7 @@ PyObject* GetTensorNameCC(PyObject* self, PyObject* args) {
...
@@ -224,7 +224,7 @@ PyObject* GetTensorNameCC(PyObject* self, PyObject* args) {
char
*
cname
;
char
*
cname
;
if
(
!
PyArg_ParseTuple
(
args
,
"s"
,
&
cname
))
return
nullptr
;
if
(
!
PyArg_ParseTuple
(
args
,
"s"
,
&
cname
))
return
nullptr
;
string
query
=
g_workspace
->
GetTensorName
(
string
(
cname
));
string
query
=
g_workspace
->
GetTensorName
(
string
(
cname
));
return
StdStringToPy
Bytes
(
query
);
return
StdStringToPy
Unicode
(
query
);
}
}
PyObject
*
CreateGraphCC
(
PyObject
*
self
,
PyObject
*
args
)
{
PyObject
*
CreateGraphCC
(
PyObject
*
self
,
PyObject
*
args
)
{
...
@@ -263,7 +263,7 @@ PyObject* GraphsCC(PyObject* self, PyObject* args) {
...
@@ -263,7 +263,7 @@ PyObject* GraphsCC(PyObject* self, PyObject* args) {
vector
<
string
>
graph_string
=
g_workspace
->
GetGraphs
();
vector
<
string
>
graph_string
=
g_workspace
->
GetGraphs
();
PyObject
*
list
=
PyList_New
(
graph_string
.
size
());
PyObject
*
list
=
PyList_New
(
graph_string
.
size
());
for
(
int
i
=
0
;
i
<
graph_string
.
size
();
i
++
)
for
(
int
i
=
0
;
i
<
graph_string
.
size
();
i
++
)
CHECK_EQ
(
PyList_SetItem
(
list
,
i
,
StdStringToPy
Bytes
(
graph_string
[
i
])),
0
);
CHECK_EQ
(
PyList_SetItem
(
list
,
i
,
StdStringToPy
Unicode
(
graph_string
[
i
])),
0
);
return
list
;
return
list
;
}
}
...
...
Dragon/modules/python/dragon.h
View file @
93943fc
...
@@ -21,7 +21,6 @@
...
@@ -21,7 +21,6 @@
#ifdef WITH_PYTHON3
#ifdef WITH_PYTHON3
#define PyString_AsString PyUnicode_AsUTF8
#define PyString_AsString PyUnicode_AsUTF8
#define PyBytes_FromStringAndSize PyUnicode_FromStringAndSize
#endif
#endif
using
namespace
dragon
;
using
namespace
dragon
;
...
@@ -33,6 +32,15 @@ inline std::string PyBytesToStdString(PyObject* pystring) {
...
@@ -33,6 +32,15 @@ inline std::string PyBytesToStdString(PyObject* pystring) {
inline
PyObject
*
StdStringToPyBytes
(
const
std
::
string
&
str
)
{
inline
PyObject
*
StdStringToPyBytes
(
const
std
::
string
&
str
)
{
return
PyBytes_FromStringAndSize
(
str
.
c_str
(),
str
.
size
());
return
PyBytes_FromStringAndSize
(
str
.
c_str
(),
str
.
size
());
}
}
inline
PyObject
*
StdStringToPyUnicode
(
const
std
::
string
&
str
)
{
#ifdef WITH_PYTHON3
return
PyUnicode_FromStringAndSize
(
str
.
c_str
(),
str
.
size
());
#else
return
PyBytes_FromStringAndSize
(
str
.
c_str
(),
str
.
size
());
#endif
}
template
<
typename
T
>
template
<
typename
T
>
inline
void
MakeStringInternal
(
std
::
stringstream
&
ss
,
const
T
&
t
)
{
ss
<<
t
;
}
inline
void
MakeStringInternal
(
std
::
stringstream
&
ss
,
const
T
&
t
)
{
ss
<<
t
;
}
...
@@ -114,7 +122,7 @@ class StringFetcher : public TensorFetcherBase {
...
@@ -114,7 +122,7 @@ class StringFetcher : public TensorFetcherBase {
public
:
public
:
PyObject
*
Fetch
(
const
Tensor
&
tensor
)
override
{
PyObject
*
Fetch
(
const
Tensor
&
tensor
)
override
{
CHECK_GT
(
tensor
.
count
(),
0
);
CHECK_GT
(
tensor
.
count
(),
0
);
return
StdStringToPyBytes
(
*
tensor
.
data
<
string
,
CPUContext
>
());
return
StdStringToPyBytes
(
*
tensor
.
data
<
string
,
CPUContext
>
());
}
}
};
};
...
...
Dragon/python/dragon/core/gradient_maker.py
View file @
93943fc
...
@@ -18,7 +18,6 @@ class GraphGradientMaker(object):
...
@@ -18,7 +18,6 @@ class GraphGradientMaker(object):
""" parse ops from string """
""" parse ops from string """
g_ops
,
g_inputs
,
defaults
=
CreateGradientDefsCC
(
op_def
.
SerializeToString
(),
g_output
)
g_ops
,
g_inputs
,
defaults
=
CreateGradientDefsCC
(
op_def
.
SerializeToString
(),
g_output
)
for
idx
,
g_op
in
enumerate
(
g_ops
):
for
idx
,
g_op
in
enumerate
(
g_ops
):
if
sys
.
version_info
>=
(
3
,
0
):
g_op
=
g_op
.
encode
()
new_def
=
pb
.
OperatorDef
()
new_def
=
pb
.
OperatorDef
()
new_def
.
ParseFromString
(
g_op
)
new_def
.
ParseFromString
(
g_op
)
_
,
new_def
.
name
=
GetOperatorName
()
_
,
new_def
.
name
=
GetOperatorName
()
...
...
Write
Preview
Markdown
is supported
Attach a file
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to post a comment