pth模型转onnx模型记录

记录一次pth转onnx

  • pth转onnx:

nanotrack中pth转onnx(原项目)pth模型转onnx模型记录_第1张图片
因为使用nn.Sequential()定义网络,只接受单输入单输出,所以要将原模型进行解构:模板T的特征提取网络、图像X的特征提取网络、模型head部分。这三部分均要进行保存、转换过程。(转换过程中设置好输入参数,方便进行下一步rknn的转换)

pth转onnx:

import argparse
import os
import torch
import sys 

sys.path.append(os.getcwd()) 

from nanotrack.core.config import cfg

from nanotrack.utils.model_load import load_pretrain
from nanotrack.models.model_builder import ModelBuilder

os.environ["CUDA_VISIBLE_DEVICES"] = "0"  

parser = argparse.ArgumentParser(description='lighttrack')

parser.add_argument('--config', type=str, default='./models/config/config.yaml',help='config file')

parser.add_argument('--snapshot', default='./models/snapshot/checkpoint_e26.pth', type=str,  help='snapshot models to eval')

args = parser.parse_args()

def main():

    cfg.merge_from_file(args.config)

    model = ModelBuilder() 

    device = torch.device('cuda' if torch.cuda.is_available()  else 'cpu')

    model = ModelBuilder() 

    model = load_pretrain(model, args.snapshot) 
    
    model.eval().to(device)  

    backbone_net = model.backbone 

    head_net = model.ban_head 
    
    # backbone input-xf 
    backbone_x = torch.randn([1, 3, 255, 255], device=device) 
    export_onnx_file_path= './nanotrack_backbone_x.onnx' 
    torch.onnx.export(backbone_net, backbone_x, export_onnx_file_path, input_names=['input'], output_names=['output'], verbose=True)   
    #backbone_t = torch.randn([1, 3, 127, 127], device=device) 
    #export_onnx_file_path= './nanotrack_backbone_t.onnx' 
    #torch.onnx.export(backbone_net, backbone_t, export_onnx_file_path, input_names=['input0'], output_names=['output0'], verbose=True)   

    # head  change forward  /media/dell/Data/NanoTrack/nanotrack/models/model_builder.py
    head_zf, head_xf = torch.randn([1, 48, 8, 8],device=device), torch.randn([1, 48, 16, 16],device=device)
    export_onnx_file_path= './models/onnx/nanotrack_head.onnx' 
    torch.onnx.export(head_net,(head_zf,head_xf), export_onnx_file_path, input_names=['input1','input2'], output_names=['output1','output2'],verbose=True) 
    
if __name__ == '__main__':
    main() 

你可能感兴趣的:(深度学习,python,人工智能)