# 可以在https://github.com/ultralytics/yolov5/blob/master/export.py里面找到
def export_engine(model, im, file, half, dynamic, simplify, workspace=4, verbose=False, prefix=colorstr('TensorRT:')):
# YOLOv5 TensorRT export https://developer.nvidia.com/tensorrt
assert im.device.type != 'cpu', 'export running on CPU but must be on GPU, i.e. `python export.py --device 0`'
try:
import tensorrt as trt
except Exception:
if platform.system() == 'Linux':
check_requirements('nvidia-tensorrt', cmds='-U --index-url https://pypi.ngc.nvidia.com')
import tensorrt as trt
if trt.__version__[0] == '7': # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012
grid = model.model[-1].anchor_grid
model.model[-1].anchor_grid = [a[..., :1, :1, :] for a in grid]
export_onnx(model, im, file, 12, dynamic, simplify) # opset 12
model.model[-1].anchor_grid = grid
else: # TensorRT >= 8
check_version(trt.__version__, '8.0.0', hard=True) # require tensorrt>=8.0.0
export_onnx(model, im, file, 12, dynamic, simplify) # opset 12
onnx = file.with_suffix('.onnx')
LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...')
assert onnx.exists(), f'failed to export ONNX file: {onnx}'
f = file.with_suffix('.engine') # TensorRT engine file
logger = trt.Logger(trt.Logger.INFO)
if verbose:
logger.min_severity = trt.Logger.Severity.VERBOSE
builder = trt.Builder(logger)
config = builder.create_builder_config()
config.max_workspace_size = workspace * 1 << 30
# config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace << 30) # fix TRT 8.4 deprecation notice
flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
network = builder.create_network(flag)
parser = trt.OnnxParser(network, logger)
if not parser.parse_from_file(str(onnx)):
raise RuntimeError(f'failed to load ONNX file: {onnx}')
inputs = [network.get_input(i) for i in range(network.num_inputs)]
outputs = [network.get_output(i) for i in range(network.num_outputs)]
for inp in inputs:
LOGGER.info(f'{prefix} input "{inp.name}" with shape{inp.shape} {inp.dtype}')
for out in outputs:
LOGGER.info(f'{prefix} output "{out.name}" with shape{out.shape} {out.dtype}')
if dynamic:
if im.shape[0] <= 1:
LOGGER.warning(f'{prefix} WARNING ⚠️ --dynamic model requires maximum --batch-size argument')
profile = builder.create_optimization_profile()
for inp in inputs:
profile.set_shape(inp.name, (1, *im.shape[1:]), (max(1, im.shape[0] // 2), *im.shape[1:]), im.shape)
config.add_optimization_profile(profile)
LOGGER.info(f'{prefix} building FP{16 if builder.platform_has_fast_fp16 and half else 32} engine as {f}')
if builder.platform_has_fast_fp16 and half:
config.set_flag(trt.BuilderFlag.FP16)
with builder.build_engine(network, config) as engine, open(f, 'wb') as t:
t.write(engine.serialize())
return f, None
assert im.device.type != 'cpu', 'export running on CPU but must be on GPU, i.e. `python export.py --device 0`'
try:
import tensorrt as trt
except Exception:
if platform.system() == 'Linux':
check_requirements('nvidia-tensorrt', cmds='-U --index-url https://pypi.ngc.nvidia.com')
import tensorrt as trt
tensorrt
库,如果失败并且系统是 Linux,通过 check_requirements
函数安装 nvidia-tensorrt
。tensorrt
库。if trt.__version__[0] == '7':
grid = model.model[-1].anchor_grid
model.model[-1].anchor_grid = [a[..., :1, :1, :] for a in grid]
export_onnx(model, im, file, 12, dynamic, simplify) # opset 12
model.model[-1].anchor_grid = grid
else:
check_version(trt.__version__, '8.0.0', hard=True) # require tensorrt>=8.0.0
export_onnx(model, im, file, 12, dynamic, simplify) # opset 12
onnx = file.with_suffix('.onnx')
export_onnx(model, im, file, 12, dynamic, simplify) # opset 12
export_onnx
函数将 YOLOv5 模型导出为 ONNX 格式。LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...')
assert onnx.exists(), f'failed to export ONNX file: {onnx}'
f = file.with_suffix('.engine') # TensorRT 引擎文件
logger = trt.Logger(trt.Logger.INFO)
builder = trt.Builder(logger)
config = builder.create_builder_config()
config.max_workspace_size = workspace * 1 << 30
config.max_workspace_size = workspace * 1 << 30
这行代码设置了 TensorRT 构建配置对象 config 的最大工作空间大小max_workspace_size:
1 << 30
表示将二进制数 1
左移 30 位。在计算机中,左移操作相当于乘以 2 的指定次方。因此,1 << 30
相当于 2 的 30 次方,即 2^30。
将 workspace
乘以 2^30 就是将其转换为字节。这是因为在计算机存储中,通常使用字节为基本单位。
在这里,workspace * 1 << 30
计算出的值将工作空间大小设置为 workspace GB。你可以根据系统的内存情况和模型的复杂性调整此值,以确保在构建 TensorRT 引擎时有足够的内存可用。
flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
network = builder.create_network(flag)
parser = trt.OnnxParser(network, logger)
if not parser.parse_from_file(str(onnx)):
raise RuntimeError(f'failed to load ONNX file: {onnx}')
flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
这里创建了一个标志 flag
,使用位运算左移的方式将 1
移动到 trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH
这个标志所表示的位置上。这个标志表示在创建网络时使用显式批处理。
inputs = [network.get_input(i) for i in range(network.num_inputs)]
outputs = [network.get_output(i) for i in range(network.num_outputs)]
for inp in inputs:
LOGGER.info(f'{prefix} input "{inp.name}" with shape{inp.shape} {inp.dtype}')
for out in outputs:
LOGGER.info(f'{prefix} output "{out.name}" with shape{out.shape} {out.dtype}')
if dynamic:
if im.shape[0] <= 1:
LOGGER.warning(f'{prefix} WARNING ⚠️ --dynamic model requires maximum --batch-size argument')
profile = builder.create_optimization_profile()
for inp in inputs:
profile.set_shape(inp.name, (1, *im.shape[1:]), (max(1, im.shape[0] // 2), *im.shape[1:]), im.shape)
config.add_optimization_profile(profile)
profile.set_shape(inp.name, (1, *im.shape[1:]), (max(1, im.shape[0] // 2), *im.shape[1:]), im.shape)
用于设置 TensorRT 动态优化配置文件的输入形状。让我们逐步解释这行代码:
profile
是 TensorRT 中的优化配置文件(trt.OptimizationProfile
)。inp.name
是当前输入张量的名称。(1, *im.shape[1:])
设置了最小的输入形状,其中批次大小(batch size)为 1,其余维度与 im
的形状相同。(max(1, im.shape[0] // 2), *im.shape[1:])
设置了最大的输入形状,其中批次大小(batch size)为 im.shape[0] // 2
,其余维度与 im
的形状相同。im.shape
是当前输入张量的形状。这行代码的目的是为动态 TensorRT 模型创建一个优化配置文件,并设置输入形状的范围,以便在运行时适应不同批次大小的输入。这对于处理动态批次大小的模型非常有用,允许模型在训练和推理中适应不同大小的输入数据。
LOGGER.info(f'{prefix} building FP{16 if builder.platform_has_fast_fp16 and half else 32} engine as {f}')
if builder.platform_has_fast_fp16 and half:
config.set_flag(trt.BuilderFlag.FP16)
with builder.build_engine(network, config) as engine, open(f, 'wb') as t:
t.write(engine.serialize())
return f, None