Commit 9f92a31d by BoxuanXu

add multithread code

1 parent 23709c5b
......@@ -5,6 +5,7 @@ import pymysql.cursors
import logging
import subprocess
from libpywrap import *
import requests
......@@ -18,7 +19,7 @@ from jenkinsapi.jenkins import Jenkins
from mysql_glock import Mysql_lock
import GProgress as GP
from GProgress import GProgress_Var
import Queue
__website__ = "www.seetatech.com"
......@@ -62,7 +63,7 @@ curl_atlas_exe = db_atlas.cursor()
db_l = Mysql_lock()
lock_name = 'db_queue'
def get_path_from_db(modelid,seetanet_model):
def get_path_from_db(modelid,seetanet_model,GP):
#curl_atlas.execute("select path from seetaAtlas.model where id='%s'" % modelid)
ret = db_l.lock(lock_name,10)
if ret != True:
......@@ -130,7 +131,7 @@ def get_path_from_db(modelid,seetanet_model):
GP.Post_return()
#Driver converter model by params and model graph
result = Run_Converter(params_name,graph_name,seetanet_model)
result = Run_Converter(params_name,graph_name,seetanet_model,GP)
if result is "true":
return None,None
......@@ -149,6 +150,8 @@ def upload_filetoFastDFS(params_name, graph_name,seetanet_model):
return stmodel_fid
def get_info_from_queue(arg):
GP = GProgress_Var()
while 1:
if not Info_Queue.empty():
Info = Info_Queue.get()
......@@ -171,9 +174,9 @@ def get_info_from_queue(arg):
#GP.set_progress_var(10)
#GP.Post_return()
seetanet_model = "model_" + str(modelid) + ".data"
seetanet_model = "model_" + str(modelid) + "_" + str(threading.currentThread().ident) + ".data"
params_name,graph_name = get_path_from_db(modelid,seetanet_model)
params_name,graph_name = get_path_from_db(modelid,seetanet_model,GP)
if params_name is None or graph_name is None:
......@@ -240,7 +243,7 @@ def get_info_from_queue(arg):
GP.Post_return()
t = threading.Thread(target=get_info_from_queue,args=(1,))
#t = threading.Thread(target=get_info_from_queue,args=(1,))
#get parameters and driver the transition function
@app.route('/convert',methods=['POST'])
......@@ -261,10 +264,12 @@ def Dirver_Convert():
Post_Info = { "modelid": modelid, "output_layer" : output_layer,"pool_id" : pool_id}
Info_Queue.put(Post_Info)
'''
if t.is_alive():
t.stop()
t = threading.Thread(target=get_info_from_queue,args=(1,))
t.start()
'''
finally:
return return_flag
......@@ -272,7 +277,10 @@ def Dirver_Convert():
if __name__ == '__main__':
try:
t.start();
for i in range(0,5):
logging.info("Begin Thread %d" % i)
t = threading.Thread(target=get_info_from_queue,args=(i,))
t.start();
app.run(host='0.0.0.0')
finally:
curl_atlas_exe.close()
......
......@@ -9,28 +9,30 @@ __Date__ = "20170812"
import requests
#post_url = "http://192.168l.1.170:1234/API"
post_url = "http://192.168.1.244:5000/result"
#post_url = "http://192.168.1.244:5000/result"
post_url = "http://192.168.1.127:1234/API"
class GProgress_Var:
posttype = None;
progress = None;
pool_id = None;
err_msg = None;
def __init__(self):
self.posttype = "";
self.progress = "";
self.pool_id = "";
self.err_msg = "";
def set_post_type_var(posttype):
GProgress_Var.posttype = posttype
def set_post_type_var(self,posttype):
self.posttype = posttype
def set_progress_var(progress):
GProgress_Var.progress = progress
def set_progress_var(self,progress):
self.progress = progress
def set_pool_id_var(pool_id):
GProgress_Var.pool_id = pool_id
def set_pool_id_var(self,pool_id):
self.pool_id = pool_id
def set_err_msg_var(err_msg):
GProgress_Var.err_msg = err_msg
def set_err_msg_var(self,err_msg):
self.err_msg = err_msg
def Post_return():
post_return = { "posttype": GProgress_Var.posttype, "progress" : GProgress_Var.progress, "pool_id": GProgress_Var.pool_id, "err_msg":GProgress_Var.err_msg }
requests.post(post_url, data=post_return)
def Post_return(self):
post_return = { "posttype": self.posttype, "progress" : self.progress, "pool_id": self.pool_id, "err_msg": self.err_msg }
print(post_return)
requests.post(post_url, data=post_return)
......@@ -5,7 +5,7 @@ from mxnet_graph import load_graph
import numpy as np
import logging
import config as cfg
import GProgress as GP
import GProgress
logging.basicConfig(
level=logging.INFO,
format="[%(asctime)s] %(name)s:%(levelname)s: %(message)s"
......@@ -344,7 +344,7 @@ class Converter(object):
fo.write(length.tobytes())
fo.write(str.encode(param_name))
def convert(self,seetanet_model):
def convert(self,seetanet_model,GP):
fo = open(seetanet_model, 'wb')
# 先写入blob的名称,再写入layer的名称,由于这里将blob和layer近似等同起来,所有写入的都是layer的name。
# 唯一的区别是,blob中多一个label blob
......@@ -415,7 +415,7 @@ def load_checkpoint(params, network_struct):
#function created by xuboxuan@20170807
#if __name__ == '__main__':
def Run_Converter(model_param,model_json,seetanet_model):
def Run_Converter(model_param,model_json,seetanet_model,GP):
#parser = argparse.ArgumentParser(description='manual to this converter script')
#parser.add_argument('--model_param',type=str,default = None)
#parser.add_argument('--model_json',type=str,default = None)
......@@ -429,7 +429,7 @@ def Run_Converter(model_param,model_json,seetanet_model):
graph = load_graph(model_json)
converter = Converter(graph, arg_params, aux_params)
logging.info('start to convert model parameters')
converter.convert(seetanet_model)
converter.convert(seetanet_model,GP)
logging.info('convert success!!!')
return true
except Exception, e:
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!