Skip to content
Toggle navigation
P
Projects
G
Groups
S
Snippets
Help
BoxuanXu
/
FlaskDriverMXNet2SeetaNet
This project
Loading...
Sign in
Toggle navigation
Go to a project
Project
Repository
Issues
0
Merge Requests
0
Pipelines
Wiki
Snippets
Settings
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Commit d35d74a4
authored
Aug 07, 2017
by
BoxuanXu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
finish the convert stream
1 parent
073bfb99
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
141 additions
and
24 deletions
Drive_Converter.py
converter.py
Drive_Converter.py
View file @
d35d74a
#!/usr/bin/env python
#_*_ coding: UTF-8 _*_
import
pymysql.cursors
import
logging
import
subprocess
from
libpywrap
import
*
from
converter
import
Run_Converter
from
flask
import
Flask
,
request
__website__
=
"www.seetatech.com"
__author__
=
"seetatech"
__editor__
=
"xuboxuan"
__Date__
=
"20170807"
#Create an instance of the flask class
app
=
Flask
(
__name__
)
#initlization fdfs_client
stbf_stcnf
(
"/etc/fdfs/client.conf"
)
from
converter
import
Run_Converter
import
logging
from
flask
import
Flask
,
request
app
=
Flask
(
__name__
)
download_path
=
"."
#initlization the logging
logging
.
basicConfig
(
...
...
@@ -20,25 +29,130 @@ logging.basicConfig(
format
=
"[
%(asctime)
s]
%(name)
s:
%(levelname)
s:
%(message)
s"
)
db_atlas
=
pymysql
.
connect
(
"192.168.1.15"
,
"defaultUser"
,
"magician"
,
"seetaAtlas"
)
curl_atlas
=
db_atlas
.
cursor
()
curl_atlas_exe
=
db_atlas
.
cursor
()
def
get_path_from_db
(
modelid
,
seetanet_model
):
curl_atlas
.
execute
(
"select path from seetaAtlas.model where id='
%
s'"
%
modelid
)
result_atlas
=
curl_atlas
.
fetchall
()
if
len
(
result_atlas
)
!=
1
:
logging
(
'get wrong mxnet model path'
);
#return
else
:
path_name
=
str
(
result_atlas
[
0
])
path_list
=
path_name
.
split
(
"::"
)
if
len
(
path_list
)
!=
2
:
logging
.
info
(
'get wrong parameters!'
)
else
:
params_name_path
=
path_list
[
0
]
params_name_path
=
params_name_path
[
2
:
len
(
params_name_path
)]
#substr file name from path
params_name_index
=
params_name_path
.
rfind
(
'/'
,
0
,
len
(
path_list
[
0
]))
params_name
=
params_name_path
[
params_name_index
+
1
:
len
(
path_list
[
0
])]
graph_name_path
=
path_list
[
1
]
graph_name_path
=
graph_name_path
[
0
:
len
(
graph_name_path
)
-
3
]
#substr file name from path
graph_name_index
=
graph_name_path
.
rfind
(
'/'
,
0
,
len
(
path_list
[
1
]))
graph_name
=
graph_name_path
[
params_name_index
+
1
:
len
(
path_list
[
1
])]
logging
.
info
(
"The params's path is
%
s"
%
params_name_path
)
logging
.
info
(
"The model graph's path is
%
s"
%
graph_name_path
)
#download file from fastdfs
if
(
stbf_down
(
params_name_path
,
download_path
)):
logging
.
info
(
"download params file:
%
s success"
%
params_name
)
else
:
return
None
,
None
#download file from fastdfs
if
(
stbf_down
(
graph_name_path
,
download_path
)):
logging
.
info
(
"download params file:
%
s success"
%
params_name
)
else
:
return
None
,
None
#Driver converter model by params and model graph
Run_Converter
(
params_name
,
graph_name
,
seetanet_model
)
#remove params file and graph file
try
:
cmd_str
=
"rm -rf "
+
params_name
+
" "
+
graph_name
print
(
cmd_str
)
subprocess
.
check_call
(
cmd_str
,
shell
=
True
)
except
subprocess
.
CalledProcessError
as
err
:
logging
.
info
(
"shell command error!"
)
return
None
,
None
return
params_name
,
graph_name
def
upload_filetoFastDFS
(
params_name
,
graph_name
,
seetanet_model
):
#upload seetanet file to fastdfs
seetanet_model_id
=
stbf_up
(
seetanet_model
)
if
(
seetanet_model_id
):
logging
.
info
(
"upload seetanet file success,seetanet model_id is
%
s"
%
seetanet_model_id
)
else
:
logging
.
info
(
"upload seetanet file failed!"
)
return
None
return
seetanet_model_id
def
update_pkg_info_db
(
sdk_fid
,
seetanet_model_id
):
#seetanet_model_id = "group2/M00/00/04/wKgB61l-xXKARFT0AAANtlf6yRo13.json"
#update pkg info table
logging
.
info
(
"begin insert model_id ,sdk_fid,mxnet_modelid into pkg_info table"
)
sql_cmd
=
"insert into seetaAtlas.pkg_info (modelid,sdk_fid,stmodel_fid) values("
+
str
(
mxnet_modelid
)
+
",
\'
"
+
sdk_fid
+
"
\'
,
\'
"
+
seetanet_model_id
+
"
\'
)"
curl_atlas_exe
.
execute
(
sql_cmd
)
db_atlas
.
commit
()
#get parameters and driver the transition function
@app.route
(
'/'
,
methods
=
[
'POST'
])
def
Dirver_Convert
(
model_params
,
model_json
):
logging
.
info
(
"get begin conver"
);
#get two parameters
model_params
=
request
.
form
[
'model_params'
]
model_json
=
request
.
form
[
'model_json'
]
print
(
model_params
)
print
(
model_json
)
if
model_params
==
''
and
model_json
==
''
:
logging
.
info
(
'get null params or json'
)
return
def
Dirver_Convert
():
#get parameter modelid from post stream
modelid
=
request
.
form
[
'modelid'
]
logging
.
info
(
"We get modelid :
%
s from post stream,Start conversion:"
%
modelid
)
try
:
seetanet_model
=
"model_"
+
str
(
modelid
)
+
".data"
params_name
,
graph_name
=
get_path_from_db
(
modelid
,
seetanet_model
)
if
params_name
is
None
or
graph_name
is
None
:
logging
.
info
(
"get wrong params file"
)
else
:
seetanet_model_id
=
upload_filetoFastDFS
(
params_name
,
graph_name
,
seetanet_model
)
if
seetanet_model_id
is
None
:
logging
.
info
(
"upload filed"
)
else
:
logging
.
info
(
"The model_params:
%
s"
%
model_params
)
logging
.
info
(
"The model_json:
%
s"
%
model_json
)
Run_Converter
(
model_params
,
model_json
)
sdk_fid
=
"group2/M00/00/04/wKgB61l-xXKARFT0AAANtlf6yRo13.json"
update_pkg_info_db
(
sdk_fid
,
seetanet_model_id
)
logging
.
info
(
"successfully,Finish!"
)
finally
:
db_atlas
.
close
()
return
"finish"
if
__name__
==
'__main__'
:
app
.
run
(
'0.0.0.0'
)
app
.
run
(
host
=
'0.0.0.0'
)
converter.py
View file @
d35d74a
...
...
@@ -338,8 +338,8 @@ class Converter(object):
fo
.
write
(
length
.
tobytes
())
fo
.
write
(
str
.
encode
(
param_name
))
def
convert
(
self
):
fo
=
open
(
'model.data'
,
'wb'
)
def
convert
(
self
,
seetanet_model
):
fo
=
open
(
seetanet_model
,
'wb'
)
# 先写入blob的名称,再写入layer的名称,由于这里将blob和layer近似等同起来,所有写入的都是layer的name。
# 唯一的区别是,blob中多一个label blob
self
.
write_layer_names
(
fo
,
True
)
...
...
@@ -400,7 +400,7 @@ def load_checkpoint(params, network_struct):
#function created by xuboxuan@20170807
#if __name__ == '__main__':
def
Run_Converter
(
model_param
,
model_json
):
def
Run_Converter
(
model_param
,
model_json
,
seetanet_model
):
#parser = argparse.ArgumentParser(description='manual to this converter script')
#parser.add_argument('--model_param',type=str,default = None)
#parser.add_argument('--model_json',type=str,default = None)
...
...
@@ -411,7 +411,10 @@ def Run_Converter(model_param,model_json):
graph
=
load_graph
(
model_json
)
converter
=
Converter
(
graph
,
arg_params
,
aux_params
)
logging
.
info
(
'start to convert model parameters'
)
converter
.
convert
()
print
(
model_param
)
print
(
model_json
)
print
(
seetanet_model
)
#converter.convert(seetanet_model)
logging
.
info
(
'convert success!!!'
)
Write
Preview
Markdown
is supported
Attach a file
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to post a comment