pytorch 转tensorflow注意

最近在作pytorch 模型转tensorflow,通过onnx 中间转换和容易,但是再转换时有一个注意事项,即如何处理batch

pytorch 模型转tensorflow: https://www.jianshu.com/p/3e5623696a8e

通过 onnx 手动修改batch 为动态值: model_onnx.graph.input[0].type.tensor_type.shape.dim[0].dim_param ='?'

这下就可以使用batch 了。

最后别忘了用transform_graph 压缩下模型大小: https://www.jianshu.com/p/d2637646cda1

self.model.load_state_dict(model_dict)

example = torch.ones(1,1,112,112).cuda() #限定好的tensor 输入大小

# traced_script_module = torch.jit.trace(self.model, example)

# traced_script_module.save('./lt_model.pt')

torch.onnx.export(self.model, example,'./model_simple.onnx',input_names=['input'],

output_names=['output'])

model_onnx = onnx.load('./model_simple.onnx')

model_onnx.graph.input[0].type.tensor_type.shape.dim[0].dim_param ='?'

tf_rep = prepare(model_onnx)

print(tf_rep.tensor_dict)

tf_rep.export_graph('./lt_model.pb')

你可能感兴趣的:(pytorch 转tensorflow注意)