mmpose PyTorch模型转TensorRT

文章目录

  • mmpose PyTorch模型转TensorRT
    • 1. github开源代码
    • 2. PyTorch模型转ONNX模型
  • 3. ONNX模型转TensorRT模型
    • 3.1 概述
    • 3.2 编译
    • 3.3 运行
  • 4. 推理结果

mmpose PyTorch模型转TensorRT

1. github开源代码

yolov5 TensorRT推理的开源代码位置在https://github.com/linghu8812/tensorrt_inference/tree/master/mmpose,PyTorch转onnx的代码是原作者的代码:pytorch2onnx.py,原作者仓库见https://github.com/open-mmlab/mmpose。

2. PyTorch模型转ONNX模型

首先通过命令git clone [email protected]:open-mmlab/mmpose.git clone mmpose的代码,然后按照install.md的步骤配置mmpose运行环境。完成环境配置后,按照tutorials/5_export_model.md导出ONNX文件,opset最好选择11以上。

python3 tools/pytorch2onnx.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [--shape ${SHAPE}] \
    [--verify] [--show] [--output-file ${OUTPUT_FILE}] [--opset-version ${VERSION}]

样例:

python3 tools/pytorch2onnx.py configs/top_down/hrnet/coco/hrnet_w48_coco_256x192.py https://download.openmmlab.com/mmpose/top_down/hrnet/hrnet_w48_coco_256x192-b9e0b3ab_20200708.pth --output-file hrnet_w48_coco_256x192.onnx

PyTorch转ONNX模型的步骤,原作者已完成了开发,可以一键导出。

3. ONNX模型转TensorRT模型

3.1 概述

TensorRT模型即TensorRT的推理引擎,代码中通过C++实现。相关配置写在config.yaml文件中,如果存在engine_file的路径,则读取engine_file,否则从onnx_file生成engine_file,生成engine的代码从model.cpp类中继承。

config.yaml文件可以设置batch size,图像的size及模型的anchor等。

mmpose:
    onnx_file:      "../hrnet_w48_coco_256x192.onnx"
    engine_file:    "../hrnet_w48_coco_256x192.trt"
    BATCH_SIZE:     1
    INPUT_CHANNEL:  3
    IMAGE_WIDTH:    192
    IMAGE_HEIGHT:   256
    img_mean:       [0.485, 0.456, 0.406]
    img_std:        [0.229, 0.224, 0.225]
    num_key_points: 17
    skeleton:       [[15, 13], [13, 11], [16, 14], [14, 12], [11, 12], [5, 11],
                     [6, 12], [5, 6], [5, 7], [6, 8], [7, 9], [8, 10], [1, 2],
                     [0, 1], [0, 2], [1, 3], [2, 4], [3, 5], [4, 6]]
    point_thresh:   0.5

3.2 编译

通过以下命令对项目进行编译,生成mmpose_trt

mkdir build && cd build
cmake ..
make -j

3.3 运行

通过以下命令运行项目,得到推理结果

./mmpose_trt../config.yaml ../samples

4. 推理结果

推理结果如下图所示:
mmpose PyTorch模型转TensorRT_第1张图片

你可能感兴趣的:(TensorRT,Pytorch,深度学习,深度学习,pytorch,mmpose,人体关键点,TensorRT)