Commit a1bbf95f by BoxuanXu

fix convert net's bug

1 parent c0737563
Showing with 12 additions and 8 deletions
......@@ -237,7 +237,9 @@ class Converter(object):
pool.pool = hd.Holiday_PoolingParameter.MAX
else:
raise NotImplementedError
pad = parse_tuple(attr['pad'])
kernel = parse_tuple(attr['kernel'])
pool.kernel_height = kernel[0]
pool.kernel_width = kernel[1]
# TODO:现在强制VALID=True,后续有mxnet使用其他pool padding的方式时再修改
try:
stride = parse_tuple(attr['stride'])
......@@ -350,8 +352,10 @@ class Converter(object):
# 唯一的区别是,blob中多一个label blob
self.write_layer_names(fo, True)
self.write_layer_names(fo)
return_flag=None
try:
flag = 1;
for i in range(self.__graph.idx_count):
if i == 1: continue
layer = self.__graph.get_layer(self.__graph.get_name(i))
......@@ -377,7 +381,6 @@ class Converter(object):
length = np.array(len(ss), np.int64)
fo.write(length.tobytes())
fo.write(ss)
#post return progress
if(i > (flag * (self.__graph.idx_count / 8))):
GP.set_progress_var(flag * 10 + 10)
......@@ -389,12 +392,13 @@ class Converter(object):
length = np.array(len(ss), np.int64)
fo.write(length.tobytes())
fo.write(ss)
return True
return_flag=True
except Exception, e:
GP.set_err_msg_var(repr(e))
return None
return_flag = None
finally:
fo.close()
return return_flag
def test():
......@@ -416,7 +420,6 @@ def load_checkpoint(params, network_struct):
aux_params[name] = v
return (symbol, arg_params, aux_params)
except Exception, e:
logging.info('model load failed!!!')
return (None, None, None)
#function created by xuboxuan@20170807
......@@ -426,14 +429,15 @@ def Run_Converter(model_param,model_json,seetanet_model):
#parser.add_argument('--model_param',type=str,default = None)
#parser.add_argument('--model_json',type=str,default = None)
#args = parser.parse_args()
#model_param = "wKgB7Vm2iBaALVPdENbR4G1D6sc.params"
#model_json = "wKgB6lm2iBaAcq2PAAANvOLMPRI82.json"
#model_param = "wKgB6Fmo2w2ASRqpBky2APcM8zs.params"
#model_json = "wKgB6Vmo2w2AXrJbAAGjO2NrZLE75.json"
#seetanet_model = "model_test"
try:
sym, arg_params, aux_params = \
load_checkpoint(model_param, model_json)
if sym is None or arg_params is None or aux_params is None:
return None
logging.info('load module failed!!!')
return None
graph = load_graph(model_json)
if graph is None:
logging.info('load graph failed!!!')
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!