Commit 9f92a31d by BoxuanXu

add multithread code

1 parent 23709c5b
...@@ -5,6 +5,7 @@ import pymysql.cursors ...@@ -5,6 +5,7 @@ import pymysql.cursors
import logging import logging
import subprocess import subprocess
from libpywrap import * from libpywrap import *
import requests import requests
...@@ -18,7 +19,7 @@ from jenkinsapi.jenkins import Jenkins ...@@ -18,7 +19,7 @@ from jenkinsapi.jenkins import Jenkins
from mysql_glock import Mysql_lock from mysql_glock import Mysql_lock
import GProgress as GP from GProgress import GProgress_Var
import Queue import Queue
__website__ = "www.seetatech.com" __website__ = "www.seetatech.com"
...@@ -62,7 +63,7 @@ curl_atlas_exe = db_atlas.cursor() ...@@ -62,7 +63,7 @@ curl_atlas_exe = db_atlas.cursor()
db_l = Mysql_lock() db_l = Mysql_lock()
lock_name = 'db_queue' lock_name = 'db_queue'
def get_path_from_db(modelid,seetanet_model): def get_path_from_db(modelid,seetanet_model,GP):
#curl_atlas.execute("select path from seetaAtlas.model where id='%s'" % modelid) #curl_atlas.execute("select path from seetaAtlas.model where id='%s'" % modelid)
ret = db_l.lock(lock_name,10) ret = db_l.lock(lock_name,10)
if ret != True: if ret != True:
...@@ -130,7 +131,7 @@ def get_path_from_db(modelid,seetanet_model): ...@@ -130,7 +131,7 @@ def get_path_from_db(modelid,seetanet_model):
GP.Post_return() GP.Post_return()
#Driver converter model by params and model graph #Driver converter model by params and model graph
result = Run_Converter(params_name,graph_name,seetanet_model) result = Run_Converter(params_name,graph_name,seetanet_model,GP)
if result is "true": if result is "true":
return None,None return None,None
...@@ -149,6 +150,8 @@ def upload_filetoFastDFS(params_name, graph_name,seetanet_model): ...@@ -149,6 +150,8 @@ def upload_filetoFastDFS(params_name, graph_name,seetanet_model):
return stmodel_fid return stmodel_fid
def get_info_from_queue(arg): def get_info_from_queue(arg):
GP = GProgress_Var()
while 1: while 1:
if not Info_Queue.empty(): if not Info_Queue.empty():
Info = Info_Queue.get() Info = Info_Queue.get()
...@@ -171,9 +174,9 @@ def get_info_from_queue(arg): ...@@ -171,9 +174,9 @@ def get_info_from_queue(arg):
#GP.set_progress_var(10) #GP.set_progress_var(10)
#GP.Post_return() #GP.Post_return()
seetanet_model = "model_" + str(modelid) + ".data" seetanet_model = "model_" + str(modelid) + "_" + str(threading.currentThread().ident) + ".data"
params_name,graph_name = get_path_from_db(modelid,seetanet_model) params_name,graph_name = get_path_from_db(modelid,seetanet_model,GP)
if params_name is None or graph_name is None: if params_name is None or graph_name is None:
...@@ -240,7 +243,7 @@ def get_info_from_queue(arg): ...@@ -240,7 +243,7 @@ def get_info_from_queue(arg):
GP.Post_return() GP.Post_return()
t = threading.Thread(target=get_info_from_queue,args=(1,)) #t = threading.Thread(target=get_info_from_queue,args=(1,))
#get parameters and driver the transition function #get parameters and driver the transition function
@app.route('/convert',methods=['POST']) @app.route('/convert',methods=['POST'])
...@@ -261,10 +264,12 @@ def Dirver_Convert(): ...@@ -261,10 +264,12 @@ def Dirver_Convert():
Post_Info = { "modelid": modelid, "output_layer" : output_layer,"pool_id" : pool_id} Post_Info = { "modelid": modelid, "output_layer" : output_layer,"pool_id" : pool_id}
Info_Queue.put(Post_Info) Info_Queue.put(Post_Info)
'''
if t.is_alive(): if t.is_alive():
t.stop() t.stop()
t = threading.Thread(target=get_info_from_queue,args=(1,)) t = threading.Thread(target=get_info_from_queue,args=(1,))
t.start() t.start()
'''
finally: finally:
return return_flag return return_flag
...@@ -272,6 +277,9 @@ def Dirver_Convert(): ...@@ -272,6 +277,9 @@ def Dirver_Convert():
if __name__ == '__main__': if __name__ == '__main__':
try: try:
for i in range(0,5):
logging.info("Begin Thread %d" % i)
t = threading.Thread(target=get_info_from_queue,args=(i,))
t.start(); t.start();
app.run(host='0.0.0.0') app.run(host='0.0.0.0')
finally: finally:
......
...@@ -9,28 +9,30 @@ __Date__ = "20170812" ...@@ -9,28 +9,30 @@ __Date__ = "20170812"
import requests import requests
#post_url = "http://192.168l.1.170:1234/API" #post_url = "http://192.168.1.244:5000/result"
post_url = "http://192.168.1.244:5000/result" post_url = "http://192.168.1.127:1234/API"
class GProgress_Var: class GProgress_Var:
posttype = None; def __init__(self):
progress = None; self.posttype = "";
pool_id = None; self.progress = "";
err_msg = None; self.pool_id = "";
self.err_msg = "";
def set_post_type_var(posttype): def set_post_type_var(self,posttype):
GProgress_Var.posttype = posttype self.posttype = posttype
def set_progress_var(progress): def set_progress_var(self,progress):
GProgress_Var.progress = progress self.progress = progress
def set_pool_id_var(pool_id): def set_pool_id_var(self,pool_id):
GProgress_Var.pool_id = pool_id self.pool_id = pool_id
def set_err_msg_var(err_msg): def set_err_msg_var(self,err_msg):
GProgress_Var.err_msg = err_msg self.err_msg = err_msg
def Post_return(): def Post_return(self):
post_return = { "posttype": GProgress_Var.posttype, "progress" : GProgress_Var.progress, "pool_id": GProgress_Var.pool_id, "err_msg":GProgress_Var.err_msg } post_return = { "posttype": self.posttype, "progress" : self.progress, "pool_id": self.pool_id, "err_msg": self.err_msg }
print(post_return)
requests.post(post_url, data=post_return) requests.post(post_url, data=post_return)
...@@ -5,7 +5,7 @@ from mxnet_graph import load_graph ...@@ -5,7 +5,7 @@ from mxnet_graph import load_graph
import numpy as np import numpy as np
import logging import logging
import config as cfg import config as cfg
import GProgress as GP import GProgress
logging.basicConfig( logging.basicConfig(
level=logging.INFO, level=logging.INFO,
format="[%(asctime)s] %(name)s:%(levelname)s: %(message)s" format="[%(asctime)s] %(name)s:%(levelname)s: %(message)s"
...@@ -344,7 +344,7 @@ class Converter(object): ...@@ -344,7 +344,7 @@ class Converter(object):
fo.write(length.tobytes()) fo.write(length.tobytes())
fo.write(str.encode(param_name)) fo.write(str.encode(param_name))
def convert(self,seetanet_model): def convert(self,seetanet_model,GP):
fo = open(seetanet_model, 'wb') fo = open(seetanet_model, 'wb')
# 先写入blob的名称,再写入layer的名称,由于这里将blob和layer近似等同起来,所有写入的都是layer的name。 # 先写入blob的名称,再写入layer的名称,由于这里将blob和layer近似等同起来,所有写入的都是layer的name。
# 唯一的区别是,blob中多一个label blob # 唯一的区别是,blob中多一个label blob
...@@ -415,7 +415,7 @@ def load_checkpoint(params, network_struct): ...@@ -415,7 +415,7 @@ def load_checkpoint(params, network_struct):
#function created by xuboxuan@20170807 #function created by xuboxuan@20170807
#if __name__ == '__main__': #if __name__ == '__main__':
def Run_Converter(model_param,model_json,seetanet_model): def Run_Converter(model_param,model_json,seetanet_model,GP):
#parser = argparse.ArgumentParser(description='manual to this converter script') #parser = argparse.ArgumentParser(description='manual to this converter script')
#parser.add_argument('--model_param',type=str,default = None) #parser.add_argument('--model_param',type=str,default = None)
#parser.add_argument('--model_json',type=str,default = None) #parser.add_argument('--model_json',type=str,default = None)
...@@ -429,7 +429,7 @@ def Run_Converter(model_param,model_json,seetanet_model): ...@@ -429,7 +429,7 @@ def Run_Converter(model_param,model_json,seetanet_model):
graph = load_graph(model_json) graph = load_graph(model_json)
converter = Converter(graph, arg_params, aux_params) converter = Converter(graph, arg_params, aux_params)
logging.info('start to convert model parameters') logging.info('start to convert model parameters')
converter.convert(seetanet_model) converter.convert(seetanet_model,GP)
logging.info('convert success!!!') logging.info('convert success!!!')
return true return true
except Exception, e: except Exception, e:
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!