我们训练好深度学习模型后,这时其仍然需在特定的深度学习框架下运行,往往不能进行高性能推理。
NVIDIA提供了一套高效推理的框架——TensorRT,可将已训练好的模型转为TensorRT引擎格式,然后进行高效推理。
对于Pytorch用户而言,该技术路线为:pytorch model-->onnx file-->TensorRT engine。
因此,我们需要做的只有三步:
关于TensorRT的介绍网上资料较多,这里就不再赘述。下面将结合这三个步骤对整个过程进行简单介绍 。
详细的代码文件我已整理到GitHub(https://github.com/qq995431104/Pytorch2TensorRT.git),欢迎大家参考并给个star~~
目录
1、Pytorch to ONNX
2、ONNX to TensorRT
3、推理
这一步比较简单,只要你的模型中所有OP均被ONNX支持,即可利用Pytorch中的ONN库进行转换。参考如下代码:
import torch
def get_model():
""" Define your own model and return it
:return: Your own model
"""
pass
def get_onnx(model, onnx_save_path, example_tensor):
example_tensor = example_tensor.cuda()
_ = torch.onnx.export(model, # model being run
example_tensor, # model input (or a tuple for multiple inputs)
onnx_save_path,
verbose=False, # store the trained parameter weights inside the model file
training=False,
do_constant_folding=True,
input_names=['input'],
output_names=['output']
)
if __name__ == '__main__':
model = get_model()
onnx_save_path = "onnx/resnet50_2.onnx"
example_tensor = torch.randn(1, 3, 288, 512, device='cuda')
# 导出模型
get_onnx(model, onnx_save_path, example_tensor)
需要提供的有:加载好的Pytorch模型、一个输入样例。其中,模型需要按照自己的方式导入并加载模型,输入样例的格式为BCHW,B为batch_size,CHW为通道、高、宽,CHW的值需要与你自己的模型相匹配,否则后面转换成功后输出结果也不对。
如果出现“RuntimeError: ONNX export failed: Couldn't export Python operator XXXX”错误提示,说明你的模型中有ONNX不支持的OP,可以尝试升级Pytorch版本,或者编写自定义op,这部分尚未进行研究,后续有了进展会更新上来。
现在已经有了ONNX文件了,需要利用TensorRT提供的OnnxParser解析该文件,同理:Caffe对应的有CaffPaser、TensorFlow的UFF格式对应的有UffParser。
先使用TensorRT创建一个builder,然后创建一个network,然后利用对应的Parser将ONNX文件加载进去;
接着,对builder指定一些参数设置,如:max_batch_size、max_workspace_size;
如需转为特定格式,如fp16或int8,需指定相应参数:fp16_mode或int8_mode设为True;
对于Int8格式,需要:
myCalibrator.py
.示例代码如下:
def ONNX2TRT(args, calib=None):
''' convert onnx to tensorrt engine, use mode of ['fp32', 'fp16', 'int8']
:return: trt engine
'''
assert args.mode.lower() in ['fp32', 'fp16', 'int8'], "mode should be in ['fp32', 'fp16', 'int8']"
G_LOGGER = trt.Logger(trt.Logger.WARNING)
with trt.Builder(G_LOGGER) as builder, builder.create_network() as network, \
trt.OnnxParser(network, G_LOGGER) as parser:
builder.max_batch_size = args.batch_size
builder.max_workspace_size = 1 << 30
if args.mode.lower() == 'int8':
assert (builder.platform_has_fast_int8 == True), "not support int8"
builder.int8_mode = True
builder.int8_calibrator = calib
elif args.mode.lower() == 'fp16':
assert (builder.platform_has_fast_fp16 == True), "not support fp16"
builder.fp16_mode = True
print('Loading ONNX file from path {}...'.format(args.onnx_file_path))
with open(args.onnx_file_path, 'rb') as model:
print('Beginning ONNX file parsing')
parser.parse(model.read())
print('Completed parsing of ONNX file')
print('Building an engine from file {}; this may take a while...'.format(args.onnx_file_path))
engine = builder.build_cuda_engine(network)
print("Created engine success! ")
# 保存计划文件
print('Saving TRT engine file to path {}...'.format(args.engine_file_path))
with open(args.engine_file_path, "wb") as f:
f.write(engine.serialize())
print('Engine file has already saved to {}!'.format(args.engine_file_path))
return engine
推理过程就完全独立于我们原先模型所依赖的框架了。
基本过程如下:
根据引擎文件反序列化为TensorRT引擎的示例代码如下:
def loadEngine2TensorRT(filepath):
G_LOGGER = trt.Logger(trt.Logger.WARNING)
# 反序列化引擎
with open(filepath, "rb") as f, trt.Runtime(G_LOGGER) as runtime:
engine = runtime.deserialize_cuda_engine(f.read())
return engine
推理过程示例如下:
# 通过engine文件创建引擎
engine = loadEngine2TensorRT('path_to_engine_file')
# 准备输入输出数据
img = Image.open('XXX.jpg')
img = D.transform(img).unsqueeze(0)
img = img.numpy()
output = np.empty((1, 2), dtype=np.float32)
#创建上下文
context = engine.create_execution_context()
# 分配内存
d_input = cuda.mem_alloc(1 * img.size * img.dtype.itemsize)
d_output = cuda.mem_alloc(1 * output.size * output.dtype.itemsize)
bindings = [int(d_input), int(d_output)]
# pycuda操作缓冲区
stream = cuda.Stream()
# 将输入数据放入device
cuda.memcpy_htod_async(d_input, img, stream)
# 执行模型
context.execute_async(batch_size=1, bindings, stream.handle, None)
# 将预测结果从从缓冲区取出
cuda.memcpy_dtoh_async(output, d_output, stream)
# 线程同步
stream.synchronize()
print(output)
*更多详细内容,请参阅GitHub仓库:https://github.com/qq995431104/Pytorch2TensorRT.git