pth转onnx,onnx转tflite,亲测有效

1. pth2onnx

        pth模型转onnx模型采用torch框架现有转换函数:torch.onnx.export(),函数使用过程中注意要设置输入输出name,以及将batch维度设置为动态,以便后续onnx转tflite。

        具体转换代码如下:

def pth2onnx(pth_path, onnx_path):
    //加载模型
    model = MyModel()  //实例化modle对象
    model.load_state_dict(torch.load(pth_path))

    //模型转换
    dummy_input = torch.randn(1, 3, 256, 256).type(torch.FloatTensor)  #.to(self.device_num)
    dummy_input = dummy_input.to(next(model.parameters()).device)
    input_names = ["inputs"]
    output_names = ["outputs"]
    dynamic_axes = {"inputs":{0: "batch_size"}}  
    torch.onnx.export(model, 
                    dummy_input, 
                    onnx_path,            //设置onnx模型输出路径,例如:c:/xxx.onnx
                    export_params=True,
                    verbose=False,
                    input_names=input_names,
                    output_names=output_names,
                    dynamic_axes=dynamic_axes
                    )
    print("-----pth to onnx trans successed.")

2. onnx2tflite

        网上其他转换方案,虽然可以成功转换,但是存在一个问题,由于pth模型默认的数据存储是[b, c, h, w],转换之后的tflite也是[b, c, h, w],但是tflite的常规数据格式是[b, h, w, c],这个问题曾经让我排查了1天。

        最后借助开源项目完成转换,亲测有效,先上连接,再次感谢这位大佬。

        转换代码如下:

import sys
sys.path.append("./onnx2tflite_MPolaris")
from converter import onnx_converter


def onnx2tflite(onnx_path, tflite_path):
    onnx_converter(
        onnx_model_path = onnx_path,
        need_simplify = False,
        output_path = "./result/",
        target_formats = ['tflite'], # or ['keras'], ['keras', 'tflite']
        weight_quant = False,
        int8_model = False,
        int8_mean = None,
        int8_std = None,
        image_root = None
        )

你可能感兴趣的:(Pytorch,深度学习,pytorch,人工智能)