Pytorch\Onnx\TensorRT

前言

TensorRT是NVIDIA推出的一款高效深度学习模型推理框架,其包括了深度学习推理优化器和运行时,能够让深度学习推理应用拥有低时延和高吞吐的优点。

Pytorch\Onnx\TensorRT_第1张图片

本质上来讲,就是通过采用对模型中的部分算子进行融合、对特定尺寸的算子选用更好的实现方法,以及使用混合精度等方式,最终加速整个网络的推理速度。
Pytorch\Onnx\TensorRT_第2张图片

在使用PyTorch训练得到网络模型后,我们希望在模型部署时通过TensorRT加速模型推理,那么可以先将PyTorch模型转为ONNX,然后再讲ONNX转为TensorRT的engine。

实现步骤

PyTorch模型转为ONNX

具体过程可参考 PyTorch模型转ONNX格式_TracelessLe的专栏-CSDN博客

ONNX转TensorRT的engine

  • trtexec 命令
  • https://github.com/NVIDIA/TensorRT/blob/master/samples/trtexec/README.md
trtexec --onnx=net_bs8_v1_simple.onnx --tacticSources=-cublasLt,+cublas --workspace=2048 --fp16 --saveEngine=net_bs8_v1.engine  --verbose

--onnx指定ONNX文件路径
--tacticSources指定使用的方法库
--workspace指定工作空间大小,单位是MB
--fp16 开启FP16模式
--saveEngine 指定生成的engine的保存路径
--verbose 打开verbose模式,更多打印信息。

方法二:基于Python API的engine生成

__author__ = 'TracelessLe'

import os
import tensorrt as trt

ONNX_SIM_MODEL_PATH = 'net_bs8_v1_simple.onnx'
TENSORRT_ENGINE_PATH_PY = 'net_bs8_v1_fp16_py.engine'

def build_engine(onnx_file_path, engine_file_path, flop=16):
    trt_logger = trt.Logger(trt.Logger.VERBOSE)  # trt.Logger.ERROR
    builder = trt.Builder(trt_logger)
    network = builder.create_network(
        1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
    )
    
    parser = trt.OnnxParser(network, trt_logger)
    # parse ONNX
    with open(onnx_file_path, 'rb') as model:
        if not parser.parse(model.read()):
            print('ERROR: Failed to parse the ONNX file.')
            for error in range(parser.num_errors):
                print(parser.get_error(error))
            return None
    print("Completed parsing ONNX file")
    builder.max_workspace_size = 2 << 30
    # default = 1 for fixed batch size
    builder.max_batch_size = 1
    # set mixed flop computation for the best performance
    if builder.platform_has_fast_fp16 and flop == 16:
        builder.fp16_mode = True

    if os.path.isfile(engine_file_path):
        try:
            os.remove(engine_file_path)
        except Exception:
            print("Cannot remove existing file: ",
                engine_file_path)

    print("Creating Tensorrt Engine")

    config = builder.create_builder_config()
    config.set_tactic_sources(1 << int(trt.TacticSource.CUBLAS))
    config.m

你可能感兴趣的:(python,pytorch,深度学习)