【pytorch】torch2trt

1,介绍
torch2trt 是一个易于使用的PyTorch到TensorRT转换器

2,安装

sudo apt-get install libprotobuf* protobuf-compiler ninja-build
git clone https://github.com/NVIDIA-AI-IOT/torch2trt
cd torch2trt
python setup.py install --plugins # 虚拟环境

3,应用

# -*- coding: utf-8 -*-
import torchvision
import torch
from torch2trt import torch2trt

data = torch.randn((1, 3, 224, 224)).cuda().half()
model = torchvision.models.resnet18(pretrained=True).cuda().half().eval()
output = model(data)

# pytorch -> tensorrt
model_trt = torch2trt(model, [data], fp16_mode=True)
output_trt = model_trt(data)

# compare
print('max error: %f' % float(torch.max(torch.abs(output - output_trt))))
print("mse :%f" % float((output - output_trt)**2))

# save tensorrt model
torch.save(model_trt.state_dict(), "resnet18_trt.pth")

# load tensorrt model
from torch2trt import TRTModule
model_trt = TRTModule()
model_trt.load_state_dict(torch.load('resnet18_trt.pth'))

你可能感兴趣的:(环境搭建,pytorch)