pytorch模型转tflite【以EfficientNet-BTS为例】

步骤

使用pytorch转tflite需要经过:pytorch -> onnx -> tensorflow -> tflite

配置环境

# ONNX-TensorFlow:  1.8.0   [pip install onnx-tf==1.8.0]
# ONNX:             1.8.0   [pip install onnx==1.8.0]
## TensorFlow:      2.4.0   [pip install tensorflow==2.4.0]
# tf-nightly:       2.9.0-dev20220223   [pip install tf-nightly]
# PyTorch:          1.8.0   [pip install torch==1.8.0 ]

环境配置上的一些问题:

  • 使用Tensorflow 2.4.0 会在onnx导出pb文件时报错,参考链接。应当使用tf-nightly。issue中推荐使用tf-nightly 2.4.0,测试发现使用最新版本2.9.0也可以解决问题。
  • 使用Pytorch 1.7.0 时会出现Cat等冗余op维度不匹配的问题。导出的onnx模型无法正确inference。使用Pytorch1.8可以规避这个问题
  • onnx与tf的版本对应可以参考链接。

Pytorch转onnx

os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

import onnx
from onnx_tf.backend import prepare
import tensorflow as tf
from onnxsim import simplify
import onnxruntime as ort
import numpy as np


if __name__ == '__main__':
    model = Model()
    # Converting model to ONNX
    for _ in model.modules():
        _.training = False

    test_arr = np.random.randn(1, 3, 480, 640).astype(np.float32)
    sample_input = torch.tensor(test_arr)
    # sample_input = torch.randn(1, 3, 480, 640)
    input_nodes = ['input']
    output_nodes = ['output']

    model(sample_input)

    torch.onnx.export(model, sample_input, "model.onnx", export_params=True, input_names=input_nodes,
                      output_names=output_nodes, opset_version=11)
  • 此处注意opset_version=11,如果设置opset_version=10 / 9 会出现一些op不支持的问题,例如upsample_bilinear。
  • 模型输入大小应当与原始模型输入大小一致,如果想动态适应,可以修改export中dynamic_axis参数
  • Gpu应当设置为不可用,使得全部导出过程在CPU上运行。

onnx模型测试

    model = onnx.load("model.onnx")
    ort_session = ort.InferenceSession('model.onnx')
    onnx_outputs = ort_session.run(None, {'input': test_arr})
    print('Export ONNX!')
  • 如果可以正常通过,证明onnx可以正确导出。
  • 测试时可以和原模型输出对照一下,观察是否存在误差。

onnx模型简化

    onnx_model = onnx.load("model.onnx")
    model_simp, check = simplify(onnx_model)
    assert check, "Simplified ONNX model could not be validated"
  • 模型简化使用的是onnx-simplify工具
  • 模型简化可以去除一些在模型转化过程中产生的冗余Op,例如Concat / SUB 

onnx转tensorflow

    output = prepare(model_simp)
    output.export_graph("tf_model/")
    print('Export tf_model!')
  • onnx转Tensorflow过程中可能会遇到一些Op无法转化的问题,例如interpolate函数,align_corners应当设置为True,然后重新导出onnx。参考链接

tensorflow转tflite

    converter = tf.lite.TFLiteConverter.from_saved_model("tf_model")
    tflite_model = converter.convert()
    open("model.tflite", "wb").write(tflite_model)
    print('Export tf lite model!')
  • 转换时候可能会存在一些问题。安装tf-nightly可以解决。

Onnx和Tflite模型可以通过Netron工具可视化查看。

你可能感兴趣的:(pytorch,深度学习,tensorflow,模型部署)