TensorRT部署(图像分类)之onnx生成(第一讲)

import torch
import torchvision
import netron


class Classifier(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = torchvision.models.resnet34(pretrained=False)
        self.backbone.load_state_dict(torch.load("resnet34.pth", map_location=None))
    def forward(self, x):
        feature = self.backbone(x)
        probability = torch.softmax(feature, dim=1)
        return probability
        
dummy = torch.zeros(1, 3, 224, 224)
model = Classifier().eval()
with torch.no_grad():
    model(dummy)
torch.onnx.export(
    model, dummy,
    "classifier.onnx",
    input_names=["image"], 
    output_names=["prob"], 
    dynamic_axes={"image": {0: "batch"}, "prob": {0: "batch"}},
)
netron.start("classifier.onnx")

注意事项:保证权重pth文件已经下载,下载pth文件一直失败(难受)torchvision.models加载其他模型亦是类似。强烈建议自己搭建几个模型,有助于理解模型细节,batch维度设置成动态的,有助于多线程推理。

1)netron.start("classifier.onnx")可视化onnx模型结构

TensorRT部署(图像分类)之onnx生成(第一讲)_第1张图片

 TensorRT部署(图像分类)之onnx生成(第一讲)_第2张图片

 

你可能感兴趣的:(深度学习,人工智能,pytorch)