深度学习之格式转换笔记(一):模型文件pt转换成onnx格式详解

常见的模型文件包括后缀名为.pt,.pth,.pkl的模型文件,而这几种模型文件并非格式上有区别而是后缀不同而已,保存模型文件往往用的是torch.save(),后缀不同只是单纯因为每个人喜好不同而已。通常用的是pth和pt。
保存
orch.save(model.state_dict(), mymodel.pth)#只保存模型权重参数,不保存模型结构

调用
model = My_model(*args, **kwargs) #这里需要重新模型结构,
pthfile = r’绝对路径’
loaded_model = torch.load(pthfile, map_location=‘cpu’)
model.load_state_dict(loaded_model[‘model’])
model.eval() #不启用 BatchNormalization 和 Dropout,不改变权值

from nn.mobilenetv3 import mobilenetv3_large,mobilenetv3_large_full,mobilenetv3_small
import torch
from nn.models import DarknetWithShh
from hyp import hyp

def convert_onnx():
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model_path = 'weights/mbv3_large_75_light_final.pt' #这是我们要转换的模型
    backone = mobilenetv3_large(width_mult=0.75)#mobilenetv3_small()  mobilenetv3_small(width_mult=0.75)  mobilenetv3_large(width_mult=0.75)
    model = DarknetWithShh(backone, hyp,light_head=True).to(device)

    model.load_state_dict(torch.load(model_path, map_location=device)['model'])

    model.to(device)
    model.eval()
    dummy_input = torch.randn(1, 3, 32, 32).to(device)#输入大小   #data type nchw
    onnx_path = 'weights/mbv3_large_75_light_final.onnx'
    torch.onnx.export(model, dummy_input, onnx_path, input_names=['input'], output_names=['output'],opset_version=11)
    print('convert retinaface to onnx finish!!!')

if __name__ == "__main__" :
    convert_onnx()

转换结果

深度学习之格式转换笔记(一):模型文件pt转换成onnx格式详解_第1张图片
在这里插入图片描述

你可能感兴趣的:(深度学习,深度学习,python,pytorch)