Commit a1bbf95f by BoxuanXu

fix convert net's bug

1 parent c0737563
Showing with 11 additions and 7 deletions
...@@ -237,7 +237,9 @@ class Converter(object): ...@@ -237,7 +237,9 @@ class Converter(object):
pool.pool = hd.Holiday_PoolingParameter.MAX pool.pool = hd.Holiday_PoolingParameter.MAX
else: else:
raise NotImplementedError raise NotImplementedError
pad = parse_tuple(attr['pad']) kernel = parse_tuple(attr['kernel'])
pool.kernel_height = kernel[0]
pool.kernel_width = kernel[1]
# TODO:现在强制VALID=True,后续有mxnet使用其他pool padding的方式时再修改 # TODO:现在强制VALID=True,后续有mxnet使用其他pool padding的方式时再修改
try: try:
stride = parse_tuple(attr['stride']) stride = parse_tuple(attr['stride'])
...@@ -350,8 +352,10 @@ class Converter(object): ...@@ -350,8 +352,10 @@ class Converter(object):
# 唯一的区别是,blob中多一个label blob # 唯一的区别是,blob中多一个label blob
self.write_layer_names(fo, True) self.write_layer_names(fo, True)
self.write_layer_names(fo) self.write_layer_names(fo)
return_flag=None
try: try:
flag = 1; flag = 1;
for i in range(self.__graph.idx_count): for i in range(self.__graph.idx_count):
if i == 1: continue if i == 1: continue
layer = self.__graph.get_layer(self.__graph.get_name(i)) layer = self.__graph.get_layer(self.__graph.get_name(i))
...@@ -377,7 +381,6 @@ class Converter(object): ...@@ -377,7 +381,6 @@ class Converter(object):
length = np.array(len(ss), np.int64) length = np.array(len(ss), np.int64)
fo.write(length.tobytes()) fo.write(length.tobytes())
fo.write(ss) fo.write(ss)
#post return progress #post return progress
if(i > (flag * (self.__graph.idx_count / 8))): if(i > (flag * (self.__graph.idx_count / 8))):
GP.set_progress_var(flag * 10 + 10) GP.set_progress_var(flag * 10 + 10)
...@@ -389,12 +392,13 @@ class Converter(object): ...@@ -389,12 +392,13 @@ class Converter(object):
length = np.array(len(ss), np.int64) length = np.array(len(ss), np.int64)
fo.write(length.tobytes()) fo.write(length.tobytes())
fo.write(ss) fo.write(ss)
return True return_flag=True
except Exception, e: except Exception, e:
GP.set_err_msg_var(repr(e)) GP.set_err_msg_var(repr(e))
return None return_flag = None
finally: finally:
fo.close() fo.close()
return return_flag
def test(): def test():
...@@ -416,7 +420,6 @@ def load_checkpoint(params, network_struct): ...@@ -416,7 +420,6 @@ def load_checkpoint(params, network_struct):
aux_params[name] = v aux_params[name] = v
return (symbol, arg_params, aux_params) return (symbol, arg_params, aux_params)
except Exception, e: except Exception, e:
logging.info('model load failed!!!')
return (None, None, None) return (None, None, None)
#function created by xuboxuan@20170807 #function created by xuboxuan@20170807
...@@ -426,13 +429,14 @@ def Run_Converter(model_param,model_json,seetanet_model): ...@@ -426,13 +429,14 @@ def Run_Converter(model_param,model_json,seetanet_model):
#parser.add_argument('--model_param',type=str,default = None) #parser.add_argument('--model_param',type=str,default = None)
#parser.add_argument('--model_json',type=str,default = None) #parser.add_argument('--model_json',type=str,default = None)
#args = parser.parse_args() #args = parser.parse_args()
#model_param = "wKgB7Vm2iBaALVPdENbR4G1D6sc.params" #model_param = "wKgB6Fmo2w2ASRqpBky2APcM8zs.params"
#model_json = "wKgB6lm2iBaAcq2PAAANvOLMPRI82.json" #model_json = "wKgB6Vmo2w2AXrJbAAGjO2NrZLE75.json"
#seetanet_model = "model_test" #seetanet_model = "model_test"
try: try:
sym, arg_params, aux_params = \ sym, arg_params, aux_params = \
load_checkpoint(model_param, model_json) load_checkpoint(model_param, model_json)
if sym is None or arg_params is None or aux_params is None: if sym is None or arg_params is None or aux_params is None:
logging.info('load module failed!!!')
return None return None
graph = load_graph(model_json) graph = load_graph(model_json)
if graph is None: if graph is None:
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!