代码地址:https://github.com/xxcheng0708/Pytorch_Retinaface_Accelerate
本文介绍的方法是提升pytorch版本RetinaFace代码在数据预处理阶段的速度,使用纯pytorch框架进行模型推理,并不涉及模型的onnx、tensorrt部署等方法。本文介绍的方法适用于从磁盘加载分辨率相同的一批图像使用RetinaFace进行人脸检测,能够带来30%的性能提升。关于pytorch_retinaface使用tensorrt部署请参考https://github.com/wang-xinyu/tensorrtx/tree/master/retinaface。
先上优化前后处理性能的结论:
优化前 |
优化后 |
提升效果 |
|||||
分辨率 |
fps |
总耗时(s) |
平均耗时(ms) |
fps |
总耗时(s) |
平均耗时(ms) |
|
1920 x 1080 |
5.92 |
134 |
168 |
8.84 |
90 |
113 |
32.7% |
1280 x 720 |
13.08 |
256 |
76 |
19.49 |
172 |
51 |
32.8% |
模型推理耗时主要来自于三个方面:
1、数据预处理:数据预处理阶段通常包括数据的读取、格式转化、归一化、维度扩充等。
2、模型预测:模型的forward过程
3、后处理:数据后处理如目标检测算法中的NMS等操作。
https://github.com/biubug6/Pytorch_Retinaface/blob/master/detect.py开源代码中,在数据预处理阶段,主要包括以下几个步骤:
1、使用opencv读取图片数据
2、将读取到的图片数据类型从uint8转换为float32
3、图像数据归一化,各通道减去一个数值,这里是主要耗时部分
4、图像矩阵轴对换,转换成[C, H, W]的形式,然后转换为tensor,并进行维度扩充到[1, C, H, W]的形式
5、将tensor放到GPU上,进行模型推理预测
这一系列操作都是在CPU上进行的,处理速度就会比较慢。
image_path = "./curve/test.jpg"
img_raw = cv2.imread(image_path, cv2.IMREAD_COLOR)
img = np.float32(img_raw)
img -= (104, 117, 123)
img = img.transpose(2, 0, 1)
img = torch.from_numpy(img).unsqueeze(0)
img = img.to(device)
在之前的torchvision.transforms GPU加速,提升预测阶段数据预处理速度内容中,我们介绍过torchvision 0.8.0版本之后提供了read_image函数,可以将图片直接读取为tensor然后放到GPU上做数据预处理操作。在torchvision 0.8.0版本之后,torchvision.io.read_image可以直接将图片读取为[C, H, W]形状的tensor,然后就可以将归一化、维度扩充等操作都放在GPU上进行,速度自然就会提升。经过转化后的数据预处理主要包括以下几个步骤:
1、使用read_image读取图片数据为[C, H, W]的tensor,并放到GPU上
2、使用torchvision.transform.Normalize转换算子在GPU上进行数据归一化
3、扩充数据维度为[1, C, H, W],然后进行模型推理
# read_image读取的是RGB通道顺序,RetinaFace输入的是BGR通道顺序,所以使用[[2, 1, 0], :, :]转换通道顺序
img = read_image(image_path, torchvision.io.ImageReadMode.RGB).to(device)[[2, 1, 0], :, :].float()
img = torchvision.transforms.Normalize(mean=[104.0, 117.0, 123.0], std=[1.0, 1.0, 1.0])
img = img.unsqueeze(0)
完整推理检测代码demo.py如下:
# coding:utf-8
import os
import cv2
import torch
from torch import nn
import torch.backends.cudnn as cudnn
import numpy as np
from data import cfg_mnet, cfg_re50
from layers.functions.prior_box import PriorBox
from utils.nms.py_cpu_nms import py_cpu_nms
from models.retinaface import RetinaFace
from utils.box_utils import decode, decode_landm
from imutils.video import FPS
import torchvision
from torchvision.io import read_image
# from utils.timer import print_execute_info
class RetinaFaceDetector(object):
def __init__(self, trained_model, network, use_cpu=False, confidence_threshold=0.02, top_k=5000,
nms_threshold=0.4, keep_top_k=750, vis_thres=0.6, im_height=720, im_width=1280):
super(RetinaFaceDetector, self).__init__()
self.trained_model = trained_model
self.network = network
self.use_cpu = use_cpu
self.confidence_threshold = confidence_threshold
self.top_k = top_k
self.nms_threshold = nms_threshold
self.keep_top_k = keep_top_k
self.vis_thres = vis_thres
self.im_height = im_height
self.im_width = im_width
self.device = torch.device("cpu" if self.use_cpu else "cuda")
self.norm = torchvision.transforms.Normalize(mean=[104.0, 117.0, 123.0], std=[1.0, 1.0, 1.0])
torch.set_grad_enabled(False)
self.cfg = None
if self.network == "mobile0.25":
self.cfg = cfg_mnet
elif self.network == "resnet50":
self.cfg = cfg_re50
self.net = RetinaFace(cfg=self.cfg, phase="test")
self.load_model(self.trained_model, self.use_cpu)
self.net.eval()
print(self.net)
cudnn.benchmark = True
self.net = self.net.to(self.device)
self.resize = 1
self.scale = torch.Tensor([self.im_width, self.im_height, self.im_width, self.im_height])
self.scale = self.scale.to(self.device)
self.scale1 = torch.Tensor([
self.im_width, self.im_height,
self.im_width, self.im_height,
self.im_width, self.im_height,
self.im_width, self.im_height,
self.im_width, self.im_height
])
self.scale1 = self.scale1.to(self.device)
self.priorbox = PriorBox(self.cfg, image_size=(self.im_height, self.im_width))
self.priors = self.priorbox.forward()
self.priors = self.priors.to(self.device)
self.prior_data = self.priors.data
def check_keys(self, model, pretrained_state_dict):
ckpt_keys = set(pretrained_state_dict.keys())
model_keys = set(model.state_dict().keys())
used_pretrained_keys = model_keys & ckpt_keys
unused_pretrained_keys = ckpt_keys - model_keys
missing_keys = model_keys - ckpt_keys
print('Missing keys:{}'.format(len(missing_keys)))
print('Unused checkpoint keys:{}'.format(len(unused_pretrained_keys)))
print('Used keys:{}'.format(len(used_pretrained_keys)))
assert len(used_pretrained_keys) > 0, 'load NONE from pretrained checkpoint'
return True
def remove_prefix(self, state_dict, prefix):
''' Old style model is stored with all names of parameters sharing common prefix 'module.' '''
print('remove prefix \'{}\''.format(prefix))
f = lambda x: x.split(prefix, 1)[-1] if x.startswith(prefix) else x
return {f(key): value for key, value in state_dict.items()}
def load_model(self, pretrained_path, load_to_cpu):
print('Loading pretrained model from {}'.format(pretrained_path))
if load_to_cpu:
pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage)
else:
pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage.cuda(self.device))
if "state_dict" in pretrained_dict.keys():
pretrained_dict = self.remove_prefix(pretrained_dict['state_dict'], 'module.')
else:
pretrained_dict = self.remove_prefix(pretrained_dict, 'module.')
self.check_keys(self.net, pretrained_dict)
self.net.load_state_dict(pretrained_dict, strict=False)
# @print_execute_info
def detect(self, img):
_, im_height, im_width = img.shape
if im_height != self.im_height or im_width != self.im_width:
self.im_height = im_height
self.im_width = im_width
self.scale = torch.Tensor([self.im_width, self.im_height, self.im_width, self.im_height])
self.scale = self.scale.to(self.device)
self.scale1 = torch.Tensor([
self.im_width, self.im_height,
self.im_width, self.im_height,
self.im_width, self.im_height,
self.im_width, self.im_height,
self.im_width, self.im_height
])
self.scale1 = self.scale1.to(self.device)
self.priorbox = PriorBox(self.cfg, image_size=(self.im_height, self.im_width))
self.priors = self.priorbox.forward()
self.priors = self.priors.to(self.device)
self.prior_data = self.priors.data
img = img.to(self.device)[[2, 1, 0], :, :].float()
img = self.norm(img)
img = img.unsqueeze(0)
loc, conf, landms = self.net(img)
boxes = decode(loc.data.squeeze(0), self.prior_data, self.cfg['variance'])
boxes = boxes * self.scale / self.resize
boxes = boxes.cpu().numpy()
scores = conf.squeeze(0).data.cpu().numpy()[:, 1]
landms = decode_landm(landms.data.squeeze(0), self.prior_data, self.cfg['variance'])
landms = landms * self.scale1 / self.resize
landms = landms.cpu().numpy()
# ignore low scores
inds = np.where(scores > self.confidence_threshold)[0]
boxes = boxes[inds]
landms = landms[inds]
scores = scores[inds]
# keep top-K before NMS
order = scores.argsort()[::-1][:self..top_k]
boxes = boxes[order]
landms = landms[order]
scores = scores[order]
# do NMS
dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)
keep = py_cpu_nms(dets, self..nms_threshold)
# keep = nms(dets, args.nms_threshold,force_cpu=args.cpu)
dets = dets[keep, :]
landms = landms[keep]
# keep top-K faster NMS
dets = dets[:self.keep_top_k, :]
landms = landms[:self.keep_top_k, :]
dets = np.concatenate((dets, landms), axis=1)
return dets
if __name__ == '__main__':
import shutil
detector = RetinaFaceDetector(trained_model="./weights/Resnet50_Final.pth", network="resnet50",
im_height=720, im_width=1280)
fps = FPS()
fps.start()
data_path = "./images"
output_path = "./outputs"
if os.path.exists(output_path) is False:
shutil.rmtree(output_path)
os.makedirs(output_path)
for image_name in os.listdir(data_path):
image_path = os.path.join(data_path, image_name)
img = read_image(image_path, mode=torchvision.io.ImageReadMode.RGB)
results = detector.detect(img)
fps.update()
# save results
if False:
img_raw = cv2.imread(image_path)
for b in results:
if b[4] < detector.vis_thres:
continue
text = "{:.4f}".format(b[4])
b = list(map(int, b))
cv2.rectangle(img_raw, (b[0], b[1]), (b[2], b[3]), (0, 0, 255), 2)
cx = b[0]
cy = b[1] + 12
cv2.putText(img_raw, text, (cx, cy), cv2.FONT_HERSHEY_DUPLEX, 0.5, (255, 255, 255))
# landms
cv2.circle(img_raw, (b[5], b[6]), 1, (0, 0, 255), 4)
cv2.circle(img_raw, (b[7], b[8]), 1, (0, 255, 255), 4)
cv2.circle(img_raw, (b[9], b[10]), 1, (255, 0, 255), 4)
cv2.circle(img_raw, (b[11], b[12]), 1, (0, 255, 0), 4)
cv2.circle(img_raw, (b[13], b[14]), 1, (255, 0, 0), 4)
# save image
cv2.imwrite(os.path.join(output_path, image_name), img_raw)
fps.stop()
print("duration time: {} s, fps: {}".format(fps.elapsed(), fps.fps()))