Commit 695e554e by BoxuanXu

add new file to get parameters by flask and driver the function to convert

0 parents
#!/usr/bin/env python
#_*_ coding: UTF-8 _*_
__website__ = "www.seetatech.com"
__author__ = "seetatech"
__editor__ = "xuboxuan"
__Date__ = "20170807"
from converter import Run_Converter
import logging
from flask import Flask,request
app = Flask(__name__)
#initlization the logging
logging.basicConfig(
level=logging.INFO,
format="[%(asctime)s] %(name)s:%(levelname)s: %(message)s"
)
#get parameters and driver the transition function
@app.route('/',methods=['POST'])
def Dirver_Convert(model_params,model_json):
logging.info("get begin conver");
#get two parameters
model_params=request.form['model_params']
model_json=request.form['model_json']
print(model_params)
print(model_json)
if model_params=='' and model_json=='':
logging.info('get null params or json')
return
else:
logging.info("The model_params:%s"% model_params)
logging.info("The model_json:%s"% model_json)
Run_Converter(model_params,model_json)
if __name__ == '__main__':
app.run('0.0.0.0')
DATA_NAME = 'data'
OUTPUT_LAYER = ''
#MODEL_PARAM = '/home/dev01/workshop/projects/MXNet2SeetaNet/model-0015.params'
#MODEL_JSON = '/home/dev01/workshop/projects/MXNet2SeetaNet/model-symbol.json'
LOAD_EPOCH = 15
INPUT_DATA_CHANNEL = 3
INPUT_DATA_HEIGHT = 248
INPUT_DATA_WIDTH = 248
protoc -I=./ --python_out=./ HolidayCNN_proto.proto
python Drive_Converter.py
This diff is collapsed. Click to expand it.
#_*_ coding: UTF-8 _*_
import json
import logging
import config as cfg
def fuzzy_query(name_list, key_words):
for name in name_list:
suffix = name.split('_')[-1]
if suffix == key_words:
return name
return None
class Layer(object):
def __init__(self):
self.name = ''
self.bottom_name = set()
self.bottom_idx = set()
self.top_name = set()
self.top_idx = set()
self.param = set() # param name
self.type = 'null'
self.attr = None
class Graph(object):
def __init__(self):
self.layer_idx = {} # name -> idx
self.idx_layer = {} # idx -> name
self.name_layer = {} # name -> layer
self.idx_count = 0
def add_name_layer(self, name, layer):
if name in self.name_layer:
logging.info('name %s: have been in name_layer' % name)
return
self.name_layer[name] = layer
def get_layer(self, name):
layer = self.name_layer.get(name, None)
if layer is None:
raise KeyError(name)
return layer
def get_idx(self, layer_name):
idx = self.layer_idx.get(layer_name, None)
if idx is None:
raise ValueError(layer_name)
return idx
def get_name(self, idx):
name = self.idx_layer.get(idx, None)
if name is None:
raise ValueError(idx)
return name
def get_root_layer(self):
pass
def get_all_layers(self):
return self.name_layer.values()
def __iter__(self):
return iter(self.name_layer.values())
def separate_layer_param(sym_name_list):
layer_output_name = []
param_name = []
for sym_name in sym_name_list:
if sym_name.split('_')[-1] == 'output':
idx = sym_name.rfind('_')
# layer_output_name.append(sym_name[:idx])
layer_output_name.append(sym_name)
else:
param_name.append(sym_name)
return layer_output_name, param_name
def set_top_name(graph):
for layer in graph:
for l in graph:
if layer.name in l.bottom_name:
layer.top_name.add(l.name)
# 最底层的数据层必须有两个top idx
if layer.name == cfg.DATA_NAME:
layer.top_name.add(cfg.DATA_NAME)
def set_idx(graph):
try:
from Queue import Queue
except:
from queue import Queue
name_set = set()
q = Queue()
q.put(graph.get_layer(cfg.DATA_NAME))
while q.qsize() > 0:
layer = q.get()
layer_name = layer.name
flag = False
for b in layer.bottom_name:
# 这里的处理是为了保证初始一个layer时,他的输入blob已经都被初始化
if b not in name_set:
q.put(layer)
flag = True
break
if flag:
continue
if layer_name in name_set:
continue
else:
name_set.add(layer_name)
graph.layer_idx[layer_name] = graph.idx_count
graph.idx_layer[graph.idx_count] = layer_name
if layer_name == cfg.DATA_NAME:
# idx=1是为了预留给label blob
graph.idx_count += 2
else:
graph.idx_count += 1
for n in layer.top_name:
q.put(graph.get_layer(n))
for layer in graph:
for b in layer.bottom_name:
layer.bottom_idx.add(graph.get_idx(b))
#for t in layer.top_name:
layer.top_idx.add(graph.layer_idx[layer.name])
if layer.name == cfg.DATA_NAME:
layer.top_idx.add(1)
def remove_layer(graph, type='Flatten'):
'''layer_name所指向的layer必须只能有一个输入,即len(bottom_name) == 1'''
remove_layer_list = []
for layer in graph:
if layer.type == type:
remove_layer_list.append(layer)
layer_name = layer.name
bottom_name = layer.bottom_name.pop()
if len(layer.bottom_name) > 0:
raise ValueError
for l in graph:
if layer_name in l.bottom_name:
l.bottom_name.remove(layer_name)
l.bottom_name.add(bottom_name)
for l in remove_layer_list:
logging.info('remove: {}'.format(l.name))
graph.name_layer.pop(l.name)
def mxnetbn_to_bn_scale(graph):
'''将mxnet的bn层分为SeetaNet(Caffe)中的bn和scale层'''
bn_layers = []
for layer in graph.get_all_layers():
if layer.type == 'BatchNorm':
bn_layers.append(layer)
for layer in bn_layers:
new_name = 'Scale-%s' % layer.name
scale_layer = Layer()
scale_layer.type = 'Scale'
scale_layer.name = new_name
graph.add_name_layer(new_name, scale_layer)
scale_layer.bottom_name.add(layer.name)
for l in graph.get_all_layers():
if layer.name in l.bottom_name and l.type != 'Scale':
l.bottom_name.remove(layer.name)
l.bottom_name.add(scale_layer.name)
gamma_name = fuzzy_query(layer.param, 'gamma')
beta_name = fuzzy_query(layer.param, 'beta')
layer.param.remove(gamma_name)
layer.param.remove(beta_name)
scale_layer.param.add(gamma_name)
scale_layer.param.add(beta_name)
def construct_graph(json_file):
graph = Graph()
with open(json_file) as f:
symbol = json.load(f)
node_num = len(symbol['nodes'])
nodes = symbol['nodes']
for i in range(node_num):
node = nodes[i]
if str(node['op']) == 'null' and str(node['name']) != cfg.DATA_NAME:
continue
layer_name = node['name']
layer = Layer()
layer.name = layer_name
layer.type = node['op']
# layer.top_name = '%s_output' % node['name']
layer.attr = node.get('attr', None)
graph.add_name_layer(layer.name, layer)
for input in node['inputs']:
input_node = nodes[input[0]]
if str(input_node['op']) != 'null' or (str(input_node['name']) == cfg.DATA_NAME):
layer.bottom_name.add(str(input_node['name']))
if str(input_node['op']) == 'null':
layer.param.add(str(input_node['name']))
if not str(input_node['name']).startswith(str(input_node['name'])):
raise NotImplementedError('shared param is not implemented')
if layer_name == cfg.OUTPUT_LAYER:
break
return graph
def load_graph(json_file):
logging.info('construct graph')
graph = construct_graph(json_file)
logging.info('remove Flatten layer')
# 由于SeetaNet(Caffe)没有Fatten层,因此将此层去掉
remove_layer(graph)
# copy_data_layer(graph)
logging.info('split bn layer to bn layer and scale layer')
mxnetbn_to_bn_scale(graph)
set_top_name(graph)
set_idx(graph)
logging.info('load graph over')
return graph
# if __name__ == '__main__':
# graph = load_graph(cfg.MODEL_JSON_FILE)
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!