yolov5 torch转tensorrt详解【推荐】

转化函数

# 可以在https://github.com/ultralytics/yolov5/blob/master/export.py里面找到
def export_engine(model, im, file, half, dynamic, simplify, workspace=4, verbose=False, prefix=colorstr('TensorRT:')):
    # YOLOv5 TensorRT export https://developer.nvidia.com/tensorrt
    assert im.device.type != 'cpu', 'export running on CPU but must be on GPU, i.e. `python export.py --device 0`'
    try:
        import tensorrt as trt
    except Exception:
        if platform.system() == 'Linux':
            check_requirements('nvidia-tensorrt', cmds='-U --index-url https://pypi.ngc.nvidia.com')
        import tensorrt as trt

    if trt.__version__[0] == '7':  # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012
        grid = model.model[-1].anchor_grid
        model.model[-1].anchor_grid = [a[..., :1, :1, :] for a in grid]
        export_onnx(model, im, file, 12, dynamic, simplify)  # opset 12
        model.model[-1].anchor_grid = grid
    else:  # TensorRT >= 8
        check_version(trt.__version__, '8.0.0', hard=True)  # require tensorrt>=8.0.0
        export_onnx(model, im, file, 12, dynamic, simplify)  # opset 12
    onnx = file.with_suffix('.onnx')

    LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...')
    assert onnx.exists(), f'failed to export ONNX file: {onnx}'
    f = file.with_suffix('.engine')  # TensorRT engine file
    logger = trt.Logger(trt.Logger.INFO)
    if verbose:
        logger.min_severity = trt.Logger.Severity.VERBOSE

    builder = trt.Builder(logger)
    config = builder.create_builder_config()
    config.max_workspace_size = workspace * 1 << 30
    # config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace << 30)  # fix TRT 8.4 deprecation notice

    flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
    network = builder.create_network(flag)
    parser = trt.OnnxParser(network, logger)
    if not parser.parse_from_file(str(onnx)):
        raise RuntimeError(f'failed to load ONNX file: {onnx}')

    inputs = [network.get_input(i) for i in range(network.num_inputs)]
    outputs = [network.get_output(i) for i in range(network.num_outputs)]
    for inp in inputs:
        LOGGER.info(f'{prefix} input "{inp.name}" with shape{inp.shape} {inp.dtype}')
    for out in outputs:
        LOGGER.info(f'{prefix} output "{out.name}" with shape{out.shape} {out.dtype}')

    if dynamic:
        if im.shape[0] <= 1:
            LOGGER.warning(f'{prefix} WARNING ⚠️ --dynamic model requires maximum --batch-size argument')
        profile = builder.create_optimization_profile()
        for inp in inputs:
            profile.set_shape(inp.name, (1, *im.shape[1:]), (max(1, im.shape[0] // 2), *im.shape[1:]), im.shape)
        config.add_optimization_profile(profile)

    LOGGER.info(f'{prefix} building FP{16 if builder.platform_has_fast_fp16 and half else 32} engine as {f}')
    if builder.platform_has_fast_fp16 and half:
        config.set_flag(trt.BuilderFlag.FP16)
    with builder.build_engine(network, config) as engine, open(f, 'wb') as t:
        t.write(engine.serialize())
    return f, None

步骤 1: 导入库和检查 GPU 可用性

assert im.device.type != 'cpu', 'export running on CPU but must be on GPU, i.e. `python export.py --device 0`'
try:
    import tensorrt as trt
except Exception:
    if platform.system() == 'Linux':
        check_requirements('nvidia-tensorrt', cmds='-U --index-url https://pypi.ngc.nvidia.com')
    import tensorrt as trt
  • 确保模型在 GPU 上运行,如果在 CPU 上运行,抛出异常。
  • 尝试导入 tensorrt 库,如果失败并且系统是 Linux,通过 check_requirements 函数安装 nvidia-tensorrt
  • 再次尝试导入 tensorrt 库。

步骤 2: 处理 TensorRT 版本 7 的兼容性

if trt.__version__[0] == '7':
    grid = model.model[-1].anchor_grid
    model.model[-1].anchor_grid = [a[..., :1, :1, :] for a in grid]
    export_onnx(model, im, file, 12, dynamic, simplify)  # opset 12
    model.model[-1].anchor_grid = grid
else:
    check_version(trt.__version__, '8.0.0', hard=True)  # require tensorrt>=8.0.0
    export_onnx(model, im, file, 12, dynamic, simplify)  # opset 12
  • 如果 TensorRT 版本是 7,调整 YOLOv5 模型的锚点网格,导出 ONNX 文件,然后恢复原始的锚点网格。
  • 如果 TensorRT 版本大于等于 8,检查 TensorRT 版本是否满足要求(至少 8.0.0),然后导出 ONNX 文件。

步骤 3: 将模型导出为 ONNX 格式

onnx = file.with_suffix('.onnx')
export_onnx(model, im, file, 12, dynamic, simplify)  # opset 12
  • 指定 ONNX 文件的路径,并调用 export_onnx 函数将 YOLOv5 模型导出为 ONNX 格式。

步骤 4: 初始化 TensorRT 组件

LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...')
assert onnx.exists(), f'failed to export ONNX file: {onnx}'
f = file.with_suffix('.engine')  # TensorRT 引擎文件
logger = trt.Logger(trt.Logger.INFO)
  • 记录 TensorRT 版本信息。
  • 确保 ONNX 文件存在。
  • 指定 TensorRT 引擎文件的路径。
  • 初始化 TensorRT 的日志记录器。

步骤 5: 创建 TensorRT 构建器和配置

builder = trt.Builder(logger)
config = builder.create_builder_config()
config.max_workspace_size = workspace * 1 << 30
  • 创建 TensorRT 构建器。
  • 创建构建器配置对象。
  • 配置最大工作空间大小。
补充说明:
config.max_workspace_size = workspace * 1 << 30

这行代码设置了 TensorRT 构建配置对象 config 的最大工作空间大小max_workspace_size:

  • 1 << 30 表示将二进制数 1 左移 30 位。在计算机中,左移操作相当于乘以 2 的指定次方。因此,1 << 30 相当于 2 的 30 次方,即 2^30。

  • workspace 乘以 2^30 就是将其转换为字节。这是因为在计算机存储中,通常使用字节为基本单位。

在这里,workspace * 1 << 30 计算出的值将工作空间大小设置为 workspace GB。你可以根据系统的内存情况和模型的复杂性调整此值,以确保在构建 TensorRT 引擎时有足够的内存可用。

步骤 6: 创建 TensorRT 网络和 ONNX 解析器

flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
network = builder.create_network(flag)
parser = trt.OnnxParser(network, logger)
if not parser.parse_from_file(str(onnx)):
    raise RuntimeError(f'failed to load ONNX file: {onnx}')
  • 创建 TensorRT 网络,启用显式批处理。
  • 使用 ONNX 解析器解析 ONNX 文件,构建 TensorRT 网络。
补充说明:
flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))

这里创建了一个标志 flag,使用位运算左移的方式将 1 移动到 trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH 这个标志所表示的位置上。这个标志表示在创建网络时使用显式批处理。

步骤 7: 显示输入和输出信息

inputs = [network.get_input(i) for i in range(network.num_inputs)]
outputs = [network.get_output(i) for i in range(network.num_outputs)]
for inp in inputs:
    LOGGER.info(f'{prefix} input "{inp.name}" with shape{inp.shape} {inp.dtype}')
for out in outputs:
    LOGGER.info(f'{prefix} output "{out.name}" with shape{out.shape} {out.dtype}')
  • 获取 TensorRT 网络的输入和输出信息。
  • 打印输入和输出的名称、形状和数据类型。

步骤 8: 处理动态 TensorRT 优化

if dynamic:
    if im.shape[0] <= 1:
        LOGGER.warning(f'{prefix} WARNING ⚠️ --dynamic model requires maximum --batch-size argument')
    profile = builder.create_optimization_profile()
    for inp in inputs:
        profile.set_shape(inp.name, (1, *im.shape[1:]), (max(1, im.shape[0] // 2), *im.shape[1:]), im.shape)
    config.add_optimization_profile(profile)
  • 如果启用动态优化,创建优化配置文件。
  • 设置输入的形状,以便在不同批次大小下进行优化。
补充说明:
profile.set_shape(inp.name, (1, *im.shape[1:]), (max(1, im.shape[0] // 2), *im.shape[1:]), im.shape)

用于设置 TensorRT 动态优化配置文件的输入形状。让我们逐步解释这行代码:

  • profile 是 TensorRT 中的优化配置文件(trt.OptimizationProfile)。
  • inp.name 是当前输入张量的名称。
  • (1, *im.shape[1:]) 设置了最小的输入形状,其中批次大小(batch size)为 1,其余维度与 im 的形状相同。
  • (max(1, im.shape[0] // 2), *im.shape[1:]) 设置了最大的输入形状,其中批次大小(batch size)为 im.shape[0] // 2,其余维度与 im 的形状相同。
  • im.shape 是当前输入张量的形状。

这行代码的目的是为动态 TensorRT 模型创建一个优化配置文件,并设置输入形状的范围,以便在运行时适应不同批次大小的输入。这对于处理动态批次大小的模型非常有用,允许模型在训练和推理中适应不同大小的输入数据。

步骤 9: 构建 TensorRT 引擎

LOGGER.info(f'{prefix} building FP{16 if builder.platform_has_fast_fp16 and half else 32} engine as {f}')
if builder.platform_has_fast_fp16 and half:
    config.set_flag(trt.BuilderFlag.FP16)
with builder.build_engine(network, config) as engine, open(f, 'wb') as t:
    t.write(engine.serialize())
  • 记录正在构建的 TensorRT 引擎的精度信息(FP16 或 FP32)。
  • 如果支持 FP16 且指定使用 FP16,则设置相应标志。
  • 使用构建器、配置和网络构建 TensorRT 引擎。
  • 将引擎序列化并写入指定的文件。

步骤 10: 返回引擎文件路径

return f, None
  • 最终,函数返回 TensorRT 引擎文件的路径。

你可能感兴趣的:(YOLO)