关于pth转onnx以及检测onnx模型是否正确

目录

一、pth转onnx

二、检测onnx模型


一、pth转onnx

load_state_dict:指从一个字典对象中加载神经网络的参数。

state_dict:可用于保存模型参数、超参数以及优化器的状态信息。只有可学习参数的层,如卷积层、线性层等才有state_dict。

torch.save():用来加载torch.save()保存的模型文件。

model = MyModel() # MyModel为自己用的模型

保存整个模型:torch.save(model, 'best_weight.pth')

只保存训练好的权重:torch.save(model.state_dict(),'best_weight.pth')

torch.load() :用于加载模型。

pth若只包含权重参数:

注意:若直接 (model.state_dict(), '*.pth'),则会得出一个错误的pth。应先加载出pth权重文件,再加载神经网络的参数,即 model .load_state_dict(torch.load('*.pth'))

model = MyModel()

model .load_state_dict(torch.load('*.pth'))

model.eval() # 不启用BatchNormalization和Dropout层 

import torch
from net import MyModel# 自己的模型

model_path = './best_weight.pth' # 这里的 pth 只包含了权重参数

model = MyModel()

model.load_state_dict(torch.load(model_path))

model.eval()

# 在机器学习模型开发和测试中,通常需要创建一个测试数据集用来评估模型的性能和准确性。
dummy_input = torch.randn(1, 3, 640, 640) # 虚拟输入,模拟输入数据的格式和形状。

torch.onnx.export(model, dummy_input, 'best_weight.onnx', verbose=True, input_names=['input'],
                  output_names=['output'])

print('Successful!')

二、检测onnx模型

import os, sys

sys.path.append(os.getcwd())
import onnxruntime
import onnx
import cv2
import torch
import numpy as np
import torchvision.transforms as transforms


def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()


img = cv2.imread("./img/1.jpg")

img = cv2.resize(img, (640, 640), interpolation=cv2.INTER_CUBIC)

to_tensor = transforms.ToTensor()
img = to_tensor(img)
img = img.unsqueeze_(0)

onnx_model_path = 'best_weight.onnx'
rnet_session = onnxruntime.InferenceSession(onnx_model_path)

# compute ONNX Runtime output prediction
inputs = {rnet_session.get_inputs()[0].name: to_numpy(img)}
outs = rnet_session.run(None, inputs) # 推理

print(outs)

你可能感兴趣的:(功能类,深度学习,人工智能)