TensorRT(三)TensorRT实现batch批处理(Python版 )

原文地址:

ZhouJianGuo|个人博客https://www.zhoujianguo.ltd/#/fore/article?id=145

一、 前言

在我的上一篇文章中,已经讲述了TensorRT如何使用,并且是支持对单个数据的输入以及处理并且输出结果,然而在实际应用中,我们往往是需要将多个输入数据构造为一个batch,一次性喂入深度学习模型并进行预测,并获得对应的结果。

同样的,TensorRT也是支持batch批处理输入数据的,并且也能够使检测速度有一定的提升。

本文将会简单地介绍如何实现batch批处理预测,至于具体细节,请参考我的上一篇文章

二、导出ONNX

我们在导出模型的时候,需要将模型输入的batch参数声明为动态参数,例如,hrnet的输入数据维度为(1, 3 , 384, 288),第一个维度为batch的大小,第二个维度为RGB三个色彩通道,第三个维度为图片的高,第四个维度为图片的宽,因此这里我们在导出onnx模型时,需要将第一个维度的参数声明为动态参数,具体代码如下:

# 定义输入名称,list结构,可能有多个输入
input_names = ['input']
# 定义输出名称,list结构,可能有多个输出
output_names = ['output']
# 声明动态维度,这里我们把input的第0维度赋名为batch_size
dynamic_axes = {
            'input': {0: 'batch_size'}
        }
 # 构造输入,用以onnx验证
input = torch.randn(2, 3, 384, 288, requires_grad=True)
torch.onnx.export(model, input, output_path,
                          export_params=True,
                          opset_version=10,
                          do_constant_folding=True,
                          input_names=input_names,
                          output_names=output_names,
                          dynamic_axes=dynamic_axes)

导出完成后,我们可以得到一个onnx文件。我们可以将该onnx文件拖到在线网站https://lutzroeder.github.io/netron/查看onnx的结构,如图所示

TensorRT(三)TensorRT实现batch批处理(Python版 )_第1张图片

三、模型构造阶段

这一步我们同样使用tensorrt自带的trtexec.exe实现利用onnx模型构造trt的模型

注意,需要在这一阶段确定你想要tensorrt的batch大小

首先将上一步导出的onnx文件拉到trtexec.exe所在的目录下,并且在cmd控制台中运行以下命令,其中需要--shapes参数以确定动态参数具体的值,乘号为字母x

trtexec --fp16 --shapes=input:32x3x384x288 --onnx=xxxx.onnx --saveEngine=xxxx.trt

四、执行阶段

在执行阶段,唯一需要注意的是TensorRT的输入必须是严格固定的batch_size大小,即每次输入到trt模型时,输入必须是batch_size大小严格等于构造阶段的输入,因此对于小于batch_size大小的数据需要0填充处理 

你可能感兴趣的:(TensorRT,batch,深度学习)