Skip to content
Toggle navigation
P
Projects
G
Groups
S
Snippets
Help
BoxuanXu
/
FlaskDriverMXNet2SeetaNet
This project
Loading...
Sign in
Toggle navigation
Go to a project
Project
Repository
Issues
0
Merge Requests
0
Pipelines
Wiki
Snippets
Settings
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Commit a1bbf95f
authored
Sep 12, 2017
by
BoxuanXu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix convert net's bug
1 parent
c0737563
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
7 deletions
src/converter.py
src/converter.py
View file @
a1bbf95
...
@@ -237,7 +237,9 @@ class Converter(object):
...
@@ -237,7 +237,9 @@ class Converter(object):
pool
.
pool
=
hd
.
Holiday_PoolingParameter
.
MAX
pool
.
pool
=
hd
.
Holiday_PoolingParameter
.
MAX
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
pad
=
parse_tuple
(
attr
[
'pad'
])
kernel
=
parse_tuple
(
attr
[
'kernel'
])
pool
.
kernel_height
=
kernel
[
0
]
pool
.
kernel_width
=
kernel
[
1
]
# TODO:现在强制VALID=True,后续有mxnet使用其他pool padding的方式时再修改
# TODO:现在强制VALID=True,后续有mxnet使用其他pool padding的方式时再修改
try
:
try
:
stride
=
parse_tuple
(
attr
[
'stride'
])
stride
=
parse_tuple
(
attr
[
'stride'
])
...
@@ -350,8 +352,10 @@ class Converter(object):
...
@@ -350,8 +352,10 @@ class Converter(object):
# 唯一的区别是,blob中多一个label blob
# 唯一的区别是,blob中多一个label blob
self
.
write_layer_names
(
fo
,
True
)
self
.
write_layer_names
(
fo
,
True
)
self
.
write_layer_names
(
fo
)
self
.
write_layer_names
(
fo
)
return_flag
=
None
try
:
try
:
flag
=
1
;
flag
=
1
;
for
i
in
range
(
self
.
__graph
.
idx_count
):
for
i
in
range
(
self
.
__graph
.
idx_count
):
if
i
==
1
:
continue
if
i
==
1
:
continue
layer
=
self
.
__graph
.
get_layer
(
self
.
__graph
.
get_name
(
i
))
layer
=
self
.
__graph
.
get_layer
(
self
.
__graph
.
get_name
(
i
))
...
@@ -377,7 +381,6 @@ class Converter(object):
...
@@ -377,7 +381,6 @@ class Converter(object):
length
=
np
.
array
(
len
(
ss
),
np
.
int64
)
length
=
np
.
array
(
len
(
ss
),
np
.
int64
)
fo
.
write
(
length
.
tobytes
())
fo
.
write
(
length
.
tobytes
())
fo
.
write
(
ss
)
fo
.
write
(
ss
)
#post return progress
#post return progress
if
(
i
>
(
flag
*
(
self
.
__graph
.
idx_count
/
8
))):
if
(
i
>
(
flag
*
(
self
.
__graph
.
idx_count
/
8
))):
GP
.
set_progress_var
(
flag
*
10
+
10
)
GP
.
set_progress_var
(
flag
*
10
+
10
)
...
@@ -389,12 +392,13 @@ class Converter(object):
...
@@ -389,12 +392,13 @@ class Converter(object):
length
=
np
.
array
(
len
(
ss
),
np
.
int64
)
length
=
np
.
array
(
len
(
ss
),
np
.
int64
)
fo
.
write
(
length
.
tobytes
())
fo
.
write
(
length
.
tobytes
())
fo
.
write
(
ss
)
fo
.
write
(
ss
)
return
True
return
_flag
=
True
except
Exception
,
e
:
except
Exception
,
e
:
GP
.
set_err_msg_var
(
repr
(
e
))
GP
.
set_err_msg_var
(
repr
(
e
))
return
None
return
_flag
=
None
finally
:
finally
:
fo
.
close
()
fo
.
close
()
return
return_flag
def
test
():
def
test
():
...
@@ -416,7 +420,6 @@ def load_checkpoint(params, network_struct):
...
@@ -416,7 +420,6 @@ def load_checkpoint(params, network_struct):
aux_params
[
name
]
=
v
aux_params
[
name
]
=
v
return
(
symbol
,
arg_params
,
aux_params
)
return
(
symbol
,
arg_params
,
aux_params
)
except
Exception
,
e
:
except
Exception
,
e
:
logging
.
info
(
'model load failed!!!'
)
return
(
None
,
None
,
None
)
return
(
None
,
None
,
None
)
#function created by xuboxuan@20170807
#function created by xuboxuan@20170807
...
@@ -426,13 +429,14 @@ def Run_Converter(model_param,model_json,seetanet_model):
...
@@ -426,13 +429,14 @@ def Run_Converter(model_param,model_json,seetanet_model):
#parser.add_argument('--model_param',type=str,default = None)
#parser.add_argument('--model_param',type=str,default = None)
#parser.add_argument('--model_json',type=str,default = None)
#parser.add_argument('--model_json',type=str,default = None)
#args = parser.parse_args()
#args = parser.parse_args()
#model_param = "wKgB
7Vm2iBaALVPdENbR4G1D6sc
.params"
#model_param = "wKgB
6Fmo2w2ASRqpBky2APcM8zs
.params"
#model_json = "wKgB6
lm2iBaAcq2PAAANvOLMPRI82
.json"
#model_json = "wKgB6
Vmo2w2AXrJbAAGjO2NrZLE75
.json"
#seetanet_model = "model_test"
#seetanet_model = "model_test"
try
:
try
:
sym
,
arg_params
,
aux_params
=
\
sym
,
arg_params
,
aux_params
=
\
load_checkpoint
(
model_param
,
model_json
)
load_checkpoint
(
model_param
,
model_json
)
if
sym
is
None
or
arg_params
is
None
or
aux_params
is
None
:
if
sym
is
None
or
arg_params
is
None
or
aux_params
is
None
:
logging
.
info
(
'load module failed!!!'
)
return
None
return
None
graph
=
load_graph
(
model_json
)
graph
=
load_graph
(
model_json
)
if
graph
is
None
:
if
graph
is
None
:
...
...
Write
Preview
Markdown
is supported
Attach a file
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to post a comment