Commit d35d74a4 by BoxuanXu

finish the convert stream

1 parent 073bfb99
Showing with 141 additions and 24 deletions
#!/usr/bin/env python #!/usr/bin/env python
#_*_ coding: UTF-8 _*_ #_*_ coding: UTF-8 _*_
import pymysql.cursors
import logging
import subprocess
from libpywrap import *
from converter import Run_Converter
from flask import Flask,request
__website__ = "www.seetatech.com" __website__ = "www.seetatech.com"
__author__ = "seetatech" __author__ = "seetatech"
__editor__ = "xuboxuan" __editor__ = "xuboxuan"
__Date__ = "20170807" __Date__ = "20170807"
#Create an instance of the flask class
app = Flask(__name__)
#initlization fdfs_client
stbf_stcnf("/etc/fdfs/client.conf")
from converter import Run_Converter download_path="."
import logging
from flask import Flask,request
app = Flask(__name__)
#initlization the logging #initlization the logging
logging.basicConfig( logging.basicConfig(
...@@ -20,25 +29,130 @@ logging.basicConfig( ...@@ -20,25 +29,130 @@ logging.basicConfig(
format="[%(asctime)s] %(name)s:%(levelname)s: %(message)s" format="[%(asctime)s] %(name)s:%(levelname)s: %(message)s"
) )
db_atlas = pymysql.connect("192.168.1.15","defaultUser","magician","seetaAtlas")
curl_atlas = db_atlas.cursor()
curl_atlas_exe = db_atlas.cursor()
def get_path_from_db(modelid,seetanet_model):
curl_atlas.execute("select path from seetaAtlas.model where id='%s'" % modelid)
result_atlas = curl_atlas.fetchall()
if len(result_atlas) != 1:
logging('get wrong mxnet model path');
#return
else:
path_name = str(result_atlas[0])
path_list = path_name.split("::")
if len(path_list) != 2:
logging.info('get wrong parameters!')
else:
params_name_path = path_list[0]
params_name_path = params_name_path[2:len(params_name_path)]
#substr file name from path
params_name_index = params_name_path.rfind('/',0,len(path_list[0]))
params_name = params_name_path[params_name_index+1:len(path_list[0])]
graph_name_path = path_list[1]
graph_name_path = graph_name_path[0:len(graph_name_path)-3]
#substr file name from path
graph_name_index = graph_name_path.rfind('/',0,len(path_list[1]))
graph_name = graph_name_path[params_name_index+1:len(path_list[1])]
logging.info("The params's path is %s"% params_name_path)
logging.info("The model graph's path is %s"% graph_name_path)
#download file from fastdfs
if(stbf_down(params_name_path,download_path)):
logging.info("download params file: %s success" % params_name)
else:
return None,None
#download file from fastdfs
if(stbf_down(graph_name_path,download_path)):
logging.info("download params file: %s success" % params_name)
else:
return None,None
#Driver converter model by params and model graph
Run_Converter(params_name,graph_name,seetanet_model)
#remove params file and graph file
try:
cmd_str = "rm -rf " + params_name + " " + graph_name
print(cmd_str)
subprocess.check_call(cmd_str,shell=True)
except subprocess.CalledProcessError as err:
logging.info("shell command error!")
return None,None
return params_name,graph_name
def upload_filetoFastDFS(params_name, graph_name,seetanet_model):
#upload seetanet file to fastdfs
seetanet_model_id = stbf_up(seetanet_model)
if(seetanet_model_id):
logging.info("upload seetanet file success,seetanet model_id is %s" % seetanet_model_id)
else:
logging.info("upload seetanet file failed!")
return None
return seetanet_model_id
def update_pkg_info_db(sdk_fid,seetanet_model_id):
#seetanet_model_id = "group2/M00/00/04/wKgB61l-xXKARFT0AAANtlf6yRo13.json"
#update pkg info table
logging.info("begin insert model_id ,sdk_fid,mxnet_modelid into pkg_info table")
sql_cmd = "insert into seetaAtlas.pkg_info (modelid,sdk_fid,stmodel_fid) values(" + str(mxnet_modelid) +",\'" + sdk_fid + "\',\'" + seetanet_model_id + "\')"
curl_atlas_exe.execute(sql_cmd)
db_atlas.commit()
#get parameters and driver the transition function #get parameters and driver the transition function
@app.route('/',methods=['POST']) @app.route('/',methods=['POST'])
def Dirver_Convert(model_params,model_json): def Dirver_Convert():
logging.info("get begin conver"); #get parameter modelid from post stream
#get two parameters modelid=request.form['modelid']
model_params=request.form['model_params'] logging.info("We get modelid :%s from post stream,Start conversion:" % modelid)
model_json=request.form['model_json']
try:
print(model_params) seetanet_model = "model_" + str(modelid) + ".data"
print(model_json)
params_name,graph_name = get_path_from_db(modelid,seetanet_model)
if model_params=='' and model_json=='':
logging.info('get null params or json')
return if params_name is None or graph_name is None:
logging.info("get wrong params file")
else:
seetanet_model_id = upload_filetoFastDFS(params_name, graph_name,seetanet_model)
if seetanet_model_id is None:
logging.info("upload filed")
else: else:
logging.info("The model_params:%s"% model_params) sdk_fid = "group2/M00/00/04/wKgB61l-xXKARFT0AAANtlf6yRo13.json"
logging.info("The model_json:%s"% model_json) update_pkg_info_db(sdk_fid,seetanet_model_id)
Run_Converter(model_params,model_json)
logging.info("successfully,Finish!")
finally:
db_atlas.close()
return "finish"
if __name__ == '__main__': if __name__ == '__main__':
app.run('0.0.0.0') app.run(host='0.0.0.0')
...@@ -338,8 +338,8 @@ class Converter(object): ...@@ -338,8 +338,8 @@ class Converter(object):
fo.write(length.tobytes()) fo.write(length.tobytes())
fo.write(str.encode(param_name)) fo.write(str.encode(param_name))
def convert(self): def convert(self,seetanet_model):
fo = open('model.data', 'wb') fo = open(seetanet_model, 'wb')
# 先写入blob的名称,再写入layer的名称,由于这里将blob和layer近似等同起来,所有写入的都是layer的name。 # 先写入blob的名称,再写入layer的名称,由于这里将blob和layer近似等同起来,所有写入的都是layer的name。
# 唯一的区别是,blob中多一个label blob # 唯一的区别是,blob中多一个label blob
self.write_layer_names(fo, True) self.write_layer_names(fo, True)
...@@ -400,7 +400,7 @@ def load_checkpoint(params, network_struct): ...@@ -400,7 +400,7 @@ def load_checkpoint(params, network_struct):
#function created by xuboxuan@20170807 #function created by xuboxuan@20170807
#if __name__ == '__main__': #if __name__ == '__main__':
def Run_Converter(model_param,model_json): def Run_Converter(model_param,model_json,seetanet_model):
#parser = argparse.ArgumentParser(description='manual to this converter script') #parser = argparse.ArgumentParser(description='manual to this converter script')
#parser.add_argument('--model_param',type=str,default = None) #parser.add_argument('--model_param',type=str,default = None)
#parser.add_argument('--model_json',type=str,default = None) #parser.add_argument('--model_json',type=str,default = None)
...@@ -411,7 +411,10 @@ def Run_Converter(model_param,model_json): ...@@ -411,7 +411,10 @@ def Run_Converter(model_param,model_json):
graph = load_graph(model_json) graph = load_graph(model_json)
converter = Converter(graph, arg_params, aux_params) converter = Converter(graph, arg_params, aux_params)
logging.info('start to convert model parameters') logging.info('start to convert model parameters')
converter.convert() print(model_param)
print(model_json)
print(seetanet_model)
#converter.convert(seetanet_model)
logging.info('convert success!!!') logging.info('convert success!!!')
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!