模型转换pytorch->onnx->saved_model->tflite

目前的很多模型使用pytorch框架进行训练,移植到android时,需要将模型进行转换:ncnn, tflite等(如Ultra-Fast-Lane-Detection,PFLD网络),下面就针对pytorch结构的Ultra-Fast-Lane-Detection网络进行转换步骤说明。

环境配置

以下为本人的环境配置:
torch 1.6.0
tensorflow-gpu 2.3.1
onnx 1.8.0
onnx-tf 1.7.0

代码转换

// An highlighted block
import torch
from model.model import parsingNet
from utils.common import merge_config
from torchsummary import summary


cls_num_per_lane = 18

model = parsingNet(pretrained = False, backbone='18',cls_dim = (201,cls_num_per_lane,4),
                    use_aux=False)#.cuda() # we dont need auxiliary segmentation in testing

state_dict = torch.load('culane_18.pth', map_location='cpu')['model']
compatible_state_dict = {}
for k, v in state_dict.items():
    if 'module.' in k:
        compatible_state_dict[k[7:]] = v
    else:
        compatible_state_dict[k] = v

model.load_state_dict(compatible_state_dict, strict=False)  #load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

summary(model, input_size=(3,288,800))
model = model.to(device)
img = torch.zeros(1, 3, 288, 800) 
y = model(img)

# ONNX export
try:
    import onnx

    print('\nStarting ONNX export with onnx %s...' % onnx.__version__)
    f = 'culane_18.onnx' # filename
    torch.onnx.export(model, img, f, verbose=False, opset_version=12, input_names=['input_0'],
                      output_names=['output_0/Reshape'] )

    # Checks
    onnx_model = onnx.load('culane_18.onnx')  # load onnx model
    onnx.checker.check_model(onnx_model)  # check onnx model
    print(onnx.helper.printable_graph(onnx_model.graph))  # print a human readable model
    print('ONNX export success, saved as %s' % f)
except Exception as e:
    print('ONNX or pb export failure: %s' % e)

try:

    onnx_model = onnx.load("culane_18.onnx")  # load onnx model
    from onnx_tf.backend import prepare
    tf_exp = prepare(onnx_model)  # prepare tf representation
    tf_exp.export_graph("culane_18_saved_model")  # export the model
    print('saved_model export success, saved as culane_18_saved_model' )
except Exception as e:
    print('saved_model export failure: %s' % e)

try:
    import tensorflow as tf
    converter = tf.lite.TFLiteConverter.from_saved_model('culane_18_saved_model')
    converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS,tf.lite.OpsSet.SELECT_TF_OPS]
    tflite_model = converter.convert()
    with open('culane_18.tflite','wb') as g:
        g.write(tflite_model)
    print('tflite export success, saved as culane_18.tflite' )
except Exception as e:
    print('tflite export failure: %s' % e)

其他模型的转换 只需要将对应的model替换成对应的模型即可。

你可能感兴趣的:(模型转换)