MMClassificatio 框架下 Pytorch模型转TensorRT

模型的加载

import torchvision.models as models
resnet34 = models.resnet34()
resnet34.load_state_dict(torch.load('latest.pth')['model'])

要解决的疑问

  • load_state_dict torch.load作用
    网络结构有了 这部分是在加载参数
  • dummy input作用
    给网络一个输入
  • 如果dynamic_axes 后面输入可以更改指定的维度
  • binding inputname outputname作用
    binding 每个engine有且只有两个binding,对应输入输出
    name可以理解为指针,在转onnx时候就指定根据这个指针拿到输入输出的内容
dummy_input=torch.randn(BATCH_SIZE, 3, 224, 224)
import torch.onnx
torch.onnx.export(resnet34, dummy_input, "rp_rec.onnx", verbose=False)

注意

torchvision和mmcls的Resnet模型不一样

resnet34 = models.resnet34()
resnet34.load_state_dict(torch.load('latest.pth')['model'])

模型必须和参数对应起来
不能用torchvision的模型加载mmcls的参数

Pytorch转TensorRT方法总结

采用mmclassification框架,根据网络推理时的输入指定网络输入dummy_input,看推理代码,如果网络允许某个维度有变化,那么可以设定dynamic_axes(某个维度定死了,就不要dynamic_axes),采用verify参数,对比模型的输出是否一致
步骤:在服务器上完成trt到onnx转换(configs等等不好往板卡放)
然后将deployment复制到板卡上,执行转trt代码

你可能感兴趣的:(工程项目,pytorch,深度学习,神经网络)