pytorch模型转换为torch2trt模型

安装TensorRT

安装torch2trt

转换模型

torch转torch2trt

减少显存占用,建议在模型转换完成后,将模型保存,使用时直接加载转换后的模型。模型不使用.eval(),可能会报[TRT] [E] Could not implicitly convert NumPy data type: i64 to TensorRT.错误

model = torchvision.models.mobilenet_v2(pretrained=True).eval().cuda()
x = torch.ones((1, 3, 255, 255)).cuda()
model_trt = torch2trt(model, [x])
torch.save(net_trt.state_dict(), 'mobilenet_trt.pth')

torch2trt模型保存并加载

torch.save(net_trt.state_dict(), 'mobilenet_trt.pth')
model_trt = torch2trt.TRTModule()
model_trt.load_state_dict(torch.load('mobilenet_trt.pth'))

测试速度提升

model = mobilenet_v1(pretrain=True, model_path='mobilenet_v1.pth').eval().cuda()
x = torch.ones((1, 3, 255, 255)).cuda()
net_trt = TRTModule()
net_trt.load_state_dict(torch.load('mobilenet_trt.pth'))
torch.cuda.synchronize()
start = time.time()
result = model(x)
torch.cuda.synchronize()
end = time.time() 
print(end - start)
torch.cuda.synchronize()
start = time.time()
y = net_trt(x)
torch.cuda.synchronize()
end = time.time()
print(end - start)

你可能感兴趣的:(torch学习,pytorch)