Skip to content
Toggle navigation
P
Projects
G
Groups
S
Snippets
Help
BoxuanXu
/
FlaskDriverMXNet2SeetaNet
This project
Loading...
Sign in
Toggle navigation
Go to a project
Project
Repository
Issues
0
Merge Requests
0
Pipelines
Wiki
Snippets
Settings
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Commit 9f92a31d
authored
Sep 04, 2017
by
BoxuanXu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add multithread code
1 parent
23709c5b
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
38 additions
and
28 deletions
src/Drive_Converter.py
src/GProgress.py
src/converter.py
src/Drive_Converter.py
View file @
9f92a31
...
@@ -5,6 +5,7 @@ import pymysql.cursors
...
@@ -5,6 +5,7 @@ import pymysql.cursors
import
logging
import
logging
import
subprocess
import
subprocess
from
libpywrap
import
*
from
libpywrap
import
*
import
requests
import
requests
...
@@ -18,7 +19,7 @@ from jenkinsapi.jenkins import Jenkins
...
@@ -18,7 +19,7 @@ from jenkinsapi.jenkins import Jenkins
from
mysql_glock
import
Mysql_lock
from
mysql_glock
import
Mysql_lock
import
GProgress
as
GP
from
GProgress
import
GProgress_Var
import
Queue
import
Queue
__website__
=
"www.seetatech.com"
__website__
=
"www.seetatech.com"
...
@@ -62,7 +63,7 @@ curl_atlas_exe = db_atlas.cursor()
...
@@ -62,7 +63,7 @@ curl_atlas_exe = db_atlas.cursor()
db_l
=
Mysql_lock
()
db_l
=
Mysql_lock
()
lock_name
=
'db_queue'
lock_name
=
'db_queue'
def
get_path_from_db
(
modelid
,
seetanet_model
):
def
get_path_from_db
(
modelid
,
seetanet_model
,
GP
):
#curl_atlas.execute("select path from seetaAtlas.model where id='%s'" % modelid)
#curl_atlas.execute("select path from seetaAtlas.model where id='%s'" % modelid)
ret
=
db_l
.
lock
(
lock_name
,
10
)
ret
=
db_l
.
lock
(
lock_name
,
10
)
if
ret
!=
True
:
if
ret
!=
True
:
...
@@ -130,7 +131,7 @@ def get_path_from_db(modelid,seetanet_model):
...
@@ -130,7 +131,7 @@ def get_path_from_db(modelid,seetanet_model):
GP
.
Post_return
()
GP
.
Post_return
()
#Driver converter model by params and model graph
#Driver converter model by params and model graph
result
=
Run_Converter
(
params_name
,
graph_name
,
seetanet_model
)
result
=
Run_Converter
(
params_name
,
graph_name
,
seetanet_model
,
GP
)
if
result
is
"true"
:
if
result
is
"true"
:
return
None
,
None
return
None
,
None
...
@@ -149,6 +150,8 @@ def upload_filetoFastDFS(params_name, graph_name,seetanet_model):
...
@@ -149,6 +150,8 @@ def upload_filetoFastDFS(params_name, graph_name,seetanet_model):
return
stmodel_fid
return
stmodel_fid
def
get_info_from_queue
(
arg
):
def
get_info_from_queue
(
arg
):
GP
=
GProgress_Var
()
while
1
:
while
1
:
if
not
Info_Queue
.
empty
():
if
not
Info_Queue
.
empty
():
Info
=
Info_Queue
.
get
()
Info
=
Info_Queue
.
get
()
...
@@ -171,9 +174,9 @@ def get_info_from_queue(arg):
...
@@ -171,9 +174,9 @@ def get_info_from_queue(arg):
#GP.set_progress_var(10)
#GP.set_progress_var(10)
#GP.Post_return()
#GP.Post_return()
seetanet_model
=
"model_"
+
str
(
modelid
)
+
".data"
seetanet_model
=
"model_"
+
str
(
modelid
)
+
"
_"
+
str
(
threading
.
currentThread
()
.
ident
)
+
"
.data"
params_name
,
graph_name
=
get_path_from_db
(
modelid
,
seetanet_model
)
params_name
,
graph_name
=
get_path_from_db
(
modelid
,
seetanet_model
,
GP
)
if
params_name
is
None
or
graph_name
is
None
:
if
params_name
is
None
or
graph_name
is
None
:
...
@@ -240,7 +243,7 @@ def get_info_from_queue(arg):
...
@@ -240,7 +243,7 @@ def get_info_from_queue(arg):
GP
.
Post_return
()
GP
.
Post_return
()
t
=
threading
.
Thread
(
target
=
get_info_from_queue
,
args
=
(
1
,))
#
t = threading.Thread(target=get_info_from_queue,args=(1,))
#get parameters and driver the transition function
#get parameters and driver the transition function
@app.route
(
'/convert'
,
methods
=
[
'POST'
])
@app.route
(
'/convert'
,
methods
=
[
'POST'
])
...
@@ -261,10 +264,12 @@ def Dirver_Convert():
...
@@ -261,10 +264,12 @@ def Dirver_Convert():
Post_Info
=
{
"modelid"
:
modelid
,
"output_layer"
:
output_layer
,
"pool_id"
:
pool_id
}
Post_Info
=
{
"modelid"
:
modelid
,
"output_layer"
:
output_layer
,
"pool_id"
:
pool_id
}
Info_Queue
.
put
(
Post_Info
)
Info_Queue
.
put
(
Post_Info
)
'''
if t.is_alive():
if t.is_alive():
t.stop()
t.stop()
t = threading.Thread(target=get_info_from_queue,args=(1,))
t = threading.Thread(target=get_info_from_queue,args=(1,))
t.start()
t.start()
'''
finally
:
finally
:
return
return_flag
return
return_flag
...
@@ -272,7 +277,10 @@ def Dirver_Convert():
...
@@ -272,7 +277,10 @@ def Dirver_Convert():
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
try
:
try
:
t
.
start
();
for
i
in
range
(
0
,
5
):
logging
.
info
(
"Begin Thread
%
d"
%
i
)
t
=
threading
.
Thread
(
target
=
get_info_from_queue
,
args
=
(
i
,))
t
.
start
();
app
.
run
(
host
=
'0.0.0.0'
)
app
.
run
(
host
=
'0.0.0.0'
)
finally
:
finally
:
curl_atlas_exe
.
close
()
curl_atlas_exe
.
close
()
...
...
src/GProgress.py
View file @
9f92a31
...
@@ -9,28 +9,30 @@ __Date__ = "20170812"
...
@@ -9,28 +9,30 @@ __Date__ = "20170812"
import
requests
import
requests
#post_url = "http://192.168
l.1.170:1234/API
"
#post_url = "http://192.168
.1.244:5000/result
"
post_url
=
"http://192.168.1.
244:5000/result
"
post_url
=
"http://192.168.1.
127:1234/API
"
class
GProgress_Var
:
class
GProgress_Var
:
posttype
=
None
;
def
__init__
(
self
):
progress
=
None
;
self
.
posttype
=
""
;
pool_id
=
None
;
self
.
progress
=
""
;
err_msg
=
None
;
self
.
pool_id
=
""
;
self
.
err_msg
=
""
;
def
set_post_type_var
(
posttype
):
def
set_post_type_var
(
self
,
posttype
):
GProgress_Var
.
posttype
=
posttype
self
.
posttype
=
posttype
def
set_progress_var
(
progress
):
def
set_progress_var
(
self
,
progress
):
GProgress_Var
.
progress
=
progress
self
.
progress
=
progress
def
set_pool_id_var
(
pool_id
):
def
set_pool_id_var
(
self
,
pool_id
):
GProgress_Var
.
pool_id
=
pool_id
self
.
pool_id
=
pool_id
def
set_err_msg_var
(
err_msg
):
def
set_err_msg_var
(
self
,
err_msg
):
GProgress_Var
.
err_msg
=
err_msg
self
.
err_msg
=
err_msg
def
Post_return
():
def
Post_return
(
self
):
post_return
=
{
"posttype"
:
GProgress_Var
.
posttype
,
"progress"
:
GProgress_Var
.
progress
,
"pool_id"
:
GProgress_Var
.
pool_id
,
"err_msg"
:
GProgress_Var
.
err_msg
}
post_return
=
{
"posttype"
:
self
.
posttype
,
"progress"
:
self
.
progress
,
"pool_id"
:
self
.
pool_id
,
"err_msg"
:
self
.
err_msg
}
requests
.
post
(
post_url
,
data
=
post_return
)
print
(
post_return
)
requests
.
post
(
post_url
,
data
=
post_return
)
src/converter.py
View file @
9f92a31
...
@@ -5,7 +5,7 @@ from mxnet_graph import load_graph
...
@@ -5,7 +5,7 @@ from mxnet_graph import load_graph
import
numpy
as
np
import
numpy
as
np
import
logging
import
logging
import
config
as
cfg
import
config
as
cfg
import
GProgress
as
GP
import
GProgress
logging
.
basicConfig
(
logging
.
basicConfig
(
level
=
logging
.
INFO
,
level
=
logging
.
INFO
,
format
=
"[
%(asctime)
s]
%(name)
s:
%(levelname)
s:
%(message)
s"
format
=
"[
%(asctime)
s]
%(name)
s:
%(levelname)
s:
%(message)
s"
...
@@ -344,7 +344,7 @@ class Converter(object):
...
@@ -344,7 +344,7 @@ class Converter(object):
fo
.
write
(
length
.
tobytes
())
fo
.
write
(
length
.
tobytes
())
fo
.
write
(
str
.
encode
(
param_name
))
fo
.
write
(
str
.
encode
(
param_name
))
def
convert
(
self
,
seetanet_model
):
def
convert
(
self
,
seetanet_model
,
GP
):
fo
=
open
(
seetanet_model
,
'wb'
)
fo
=
open
(
seetanet_model
,
'wb'
)
# 先写入blob的名称,再写入layer的名称,由于这里将blob和layer近似等同起来,所有写入的都是layer的name。
# 先写入blob的名称,再写入layer的名称,由于这里将blob和layer近似等同起来,所有写入的都是layer的name。
# 唯一的区别是,blob中多一个label blob
# 唯一的区别是,blob中多一个label blob
...
@@ -415,7 +415,7 @@ def load_checkpoint(params, network_struct):
...
@@ -415,7 +415,7 @@ def load_checkpoint(params, network_struct):
#function created by xuboxuan@20170807
#function created by xuboxuan@20170807
#if __name__ == '__main__':
#if __name__ == '__main__':
def
Run_Converter
(
model_param
,
model_json
,
seetanet_model
):
def
Run_Converter
(
model_param
,
model_json
,
seetanet_model
,
GP
):
#parser = argparse.ArgumentParser(description='manual to this converter script')
#parser = argparse.ArgumentParser(description='manual to this converter script')
#parser.add_argument('--model_param',type=str,default = None)
#parser.add_argument('--model_param',type=str,default = None)
#parser.add_argument('--model_json',type=str,default = None)
#parser.add_argument('--model_json',type=str,default = None)
...
@@ -429,7 +429,7 @@ def Run_Converter(model_param,model_json,seetanet_model):
...
@@ -429,7 +429,7 @@ def Run_Converter(model_param,model_json,seetanet_model):
graph
=
load_graph
(
model_json
)
graph
=
load_graph
(
model_json
)
converter
=
Converter
(
graph
,
arg_params
,
aux_params
)
converter
=
Converter
(
graph
,
arg_params
,
aux_params
)
logging
.
info
(
'start to convert model parameters'
)
logging
.
info
(
'start to convert model parameters'
)
converter
.
convert
(
seetanet_model
)
converter
.
convert
(
seetanet_model
,
GP
)
logging
.
info
(
'convert success!!!'
)
logging
.
info
(
'convert success!!!'
)
return
true
return
true
except
Exception
,
e
:
except
Exception
,
e
:
...
...
Write
Preview
Markdown
is supported
Attach a file
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to post a comment