TensorRT严格地依赖于硬件条件,所以与Cuda、Cudnn版本强相关。
Following are the four steps for this example application:
Convert the pretrained image segmentation PyTorch model into ONNX. ONNX是一个深度学习模型的标准,来帮助从不同的框架中进行转换,例如Caffe2、Chainer、PaddlePaddle、Pytorch、MXNet。
Import the ONNX model into TensorRT. 导入 ONNX 模型包括从磁盘上保存的文件中加载它,并将其从其原生框架或格式转换为 TensorRT 网络。
Apply optimizations and generate an engine. TensorRT可以基于输入模型、目标 GPU 平台和指定的其他配置参数构建优化的 TensorRT 引擎。
Perform inference on the GPU. 为TensorRT导入输入数据,进行模型推理,计算速度。
Takes a converted PyTorch trained model into the ONNX format as input and populates a network object in TensorRT.
model = model.cuda()
model.eval()
_, c, h, w = input_size
dummy_input = torch.randn(1, c, h, w, device='cuda')
torch.onnx.export(model, dummy_input, "model.onnx", verbose=False, input_names=["input"], output_names=["output"])
# torch.onnx.export(model, dummy_input, model_onnx_path, input_names=inputs, output_names=outputs, dynamic_axes=dynamic_axes)
# model表示要导出的模型
# dummy_inputs表示模型的输入,任何非Tensor参数都将硬编码到导出的模型中;任何Tensor参数都将成为导出的模型的输入,并按照他们在dummy_input中出现的顺序输入。
# input_names/output_names按顺序分配名称到图中的输入/输出节点,只需保持数量对应,可以任意命名
# dynamic_axes表示动态轴,可以动态改变相关维度的输入尺寸
print('loading ONNX model and check that the IR is well formed')
onnx_model = onnx.load('model.onnx')
onnx.checker.check_model(onnx_model)
# print('ONNX model graph...')
# model_graph = onnx.helper.printable_graph(onnx_model.graph)
print('Onnxruntime for inference...')
image = np.random.randn(1, 3, 1024, 1024).astype(np.float32)
sess = rt.InferenceSession('model.onnx')
input_name1 = sess.get_inputs()[0].name
print('Check the performance of Pytorch and ONNX')
output = sess.run(None, {input_name1: image})[0] # 1, C, 1024, 1024
image_tensor = torch.tensor(image)
output_tensor = model(image_tensor)
diff = torch.max(torch.abs(output_tensor - torch.tensor(output)))
print('different is :', diff)
Pytorch支持原生的Onnx格式转码:
import torch
from torchvision import models
import onnx
import onnxruntime as rt
net = models.resnet.resnet18(pretrained=True)
dummpy_input = torch.randn(1,3,224,224)
torch.onnx.export(net, dummpy_input, 'resnet18.onnx') # 直接函数调用
# Load the ONNX model 加载模型
model = onnx.load("resnet18.onnx")
# Check that the IR is well formed 验证模型合法性
onnx.checker.check_model(model)
# Print a human readable representation of the graph 打印模型,查看ONNX,共由3个字典组成,input/initializers/operators
# ONNX的可视化,Netron和visualDL
print(onnx.helper.printable_graph(model.graph))
# 支持ONNX的runtime将统一的ONNX格式的模型包运行起来,包括对ONNX模型进行解读,优化,运行。
data = np.array(np.random.randn(1,3,224,224))
sess = rt.InferenceSession('resnet18.onnx')
input_name = sess.get_inputs()[0].name
label_name = sess.get_outputs()[0].name
pred_onx = sess.run([label_name], {input_name:data.astype(np.float32)})[0]
print(pred_onx)
print(np.argmax(pred_onx))
使用TensorRT Backend For ONNX快速地将ONNX转换为trt模型,以及测试转换后的Trt模型有多快
TRT_LOGGER = trt.Logger(trt.Logger.INFO)
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
with trt.Builder(TRT_LOGGER) as builder, builder.create_network() as network, trt.OnnxParser(network, TRT_LOGGER) as parser:
builder.max_workspace_size = MAX_WORKSPACE_SIZE
builder.max_batch_size = MAX_BATCH_SIZE
with open(model_file, 'rb') as model:
parser.parse(model.read())
return builder.build_cuda_engine(network)