MMPose预训练模型预测-Python API

导入依赖的包

import cv2
import numpy as np
from PIL import Image

import matplotlib.pyplot as plt
%matplotlib inline

import torch

import mmcv
from mmcv import imread
import mmengine
from mmengine.registry import init_default_scope

from mmpose.apis import inference_topdown
from mmpose.apis import init_model as init_pose_estimator
from mmpose.evaluation.functional import nms
from mmpose.registry import VISUALIZERS
from mmpose.structures import merge_data_samples

from mmdet.apis import inference_detector, init_detector

设定使用的设备

# 有 GPU 就用 GPU,没有就用 CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('device', device)

指定要处理的图像

img_path = 'data/test/multi-person.jpeg'

构造目标检测模型(载入一个即可)

1. Faster R CNN

# Faster R CNN
detector = init_detector(
    'demo/mmdetection_cfg/faster_rcnn_r50_fpn_coco.py',
    'https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth',
    device=device
)

2. RTMPose-Tiny

# RTMPose-Tiny
# https://github.com/open-mmlab/mmpose/tree/dev-1.x/projects/rtmpose
detector = init_detector(
    'projects/rtmpose/rtmdet/person/rtmdet_m_640-8xb32_coco-person.py',
    'https://download.openmmlab.com/mmpose/v1/projects/rtmpose/rtmdet_m_8xb32-100e_coco-obj365-person-235e8209.pth',
    device=device
)

构造人体姿态估计模型

pose_estimator = init_pose_estimator(
    'configs/body_2d_keypoint/topdown_heatmap/coco/td-hm_hrnet-w32_8xb64-210e_coco-256x192.py',
    'https://download.openmmlab.com/mmpose/top_down/hrnet/hrnet_w32_coco_256x192-c78dce93_20200708.pth',
    device=device,
    cfg_options={'model': {'test_cfg': {'output_heatmaps': True}}}
)

预测目标检测

init_default_scope(detector.cfg.get('default_scope', 'mmdet'))
# 获取目标检测预测结果
detect_result = inference_detector(detector, img_path)
detect_result.keys()
# 预测类别
detect_result.pred_instances.labels
# 置信度
detect_result.pred_instances.scores
# 框坐标:左上角X坐标、左上角Y坐标、右下角X坐标、右下角Y坐标
# detect_result.pred_instances.bboxes

置信度阈值过滤,获得最终目标检测预测结果

# 置信度阈值
CONF_THRES = 0.5
pred_instance = detect_result.pred_instances.cpu().numpy()
bboxes = np.concatenate((pred_instance.bboxes, pred_instance.scores[:, None]), axis=1)
bboxes = bboxes[np.logical_and(pred_instance.labels == 0, pred_instance.scores > CONF_THRES)]
bboxes = bboxes[nms(bboxes, 0.3)][:, :4]
bboxes

预测-关键点

# 获取每个 bbox 的关键点预测结果
pose_results = inference_topdown(pose_estimator, img_path, bboxes)

len(pose_results)
# 把多个bbox的pose结果打包到一起
data_samples = merge_data_samples(pose_results)

data_samples.keys()

预测结果-关键点坐标

# 每个人 17个关键点 坐标
data_samples.pred_instances.keypoints.shape

return(框的数量,关键点,xy2个坐标)

# 索引为 0 的人,每个关键点的坐标
data_samples.pred_instances.keypoints[0,:,:]

查看序号为0的人的坐标

预测结果-关键点热力图

# 每一类关键点的预测热力图
data_samples.pred_fields.heatmaps.shape

idx_point = 13
heatmap = data_samples.pred_fields.heatmaps[idx_point,:,:]

heatmap.shape

# 索引为 idx 的关键点,在全图上的预测热力图
plt.imshow(heatmap)
plt.show()

MMPose官方可视化工具visulizer

# 半径
pose_estimator.cfg.visualizer.radius = 10
# 线宽
pose_estimator.cfg.visualizer.line_width = 8
visualizer = VISUALIZERS.build(pose_estimator.cfg.visualizer)
# 元数据
visualizer.set_dataset_meta(pose_estimator.dataset_meta)

img = mmcv.imread(img_path)
img = mmcv.imconvert(img, 'bgr', 'rgb')

img_output = visualizer.add_datasample(
            'result',
            img,
            data_sample=data_samples,
            draw_gt=False,
            draw_heatmap=True,
            draw_bbox=True,
            show_kpt_idx=True,
            show=False,
            wait_time=0,
            out_file='outputs/B2.jpg'
)

img_output.shape

plt.figure(figsize=(10,10))
plt.imshow(img_output)
plt.show()

你可能感兴趣的:(python,开发语言)