[toc]
背景
自动化测试需要绕过极验验证码进行登录,经过方案可行性验证后,确认使用yolov5作为验证码滑块的检测工具;
但yolov5本身没有提供服务化,需要对相关推理能力自行打包封装成服务,本文记录了服务化的相关过程。
目标
- 低门槛:需低门槛形式提供服务
- 低维护成本:由于接入项目可能众多,对训练后的新模型适配需要降低维护成本
- 响应速度:由于是ui自动化脚本使用,因此需要尽可能快的响应速度(1s内)
实现过程
推理脚本改造
yolo本身自带了一个推理脚本 detect.py,供模型训练完毕后测试效果用,但该脚本并不全部符合使用需求,因此不能直接使用,主要原因是:
1、每次都会调用都会进行模型加载,耗时较久,如果直接用来服务化响应时间不可接受;
2、没有返回值,执行后产物为验证用图片,不能直接提供检测目标的定位信息。
基于以上原因,分析推理脚本代码后,作出如下修改:
1、将推理脚本结构从纯方法(method)改造为类(class),提供初始化方法,方便一次初始化后,后续无需再次初始化,节约了加载模型的时间;
2、在类(class)中直接提供返回检测目标坐标值的方法,方便给后续web服务调用。
其余代码不做更改(包括生成本地图片,主要用于调试验证),整体代码如下:
import re
import threading
# YOLOv5 by Ultralytics, GPL-3.0 license
"""
Run inference on images, videos, directories, streams, etc.
Usage:
$ python path/to/detect.py --weights yolov5s.pt --source 0 # webcam
img.jpg # image
vid.mp4 # video
path/ # directory
path/*.jpg # glob
'https://youtu.be/Zgi9g1ksQHc' # YouTube
'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP stream
"""
import argparse
import os
import sys
from pathlib import Path
import cv2
import torch
import torch.backends.cudnn as cudnn
FILE = Path(__file__).resolve()
ROOT = FILE.parents[0] # YOLOv5 root directory
if str(ROOT) not in sys.path:
sys.path.append(str(ROOT)) # add ROOT to PATH
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
from models.common import DetectMultiBackend
from utils.datasets import IMG_FORMATS, VID_FORMATS, LoadImages, LoadStreams
from utils.general import (LOGGER, check_file, check_img_size, check_imshow, check_requirements, colorstr,
increment_path, non_max_suppression, print_args, scale_coords, strip_optimizer, xyxy2xywh)
from utils.plots import Annotator, colors, save_one_box
from utils.torch_utils import select_device, time_sync
class yolo_detect():
_instance_lock = threading.Lock()
def __init__(self):
print("yolo server now start initing.....")
self.weights = ROOT / 'best.pt' # model.pt path(s)
#self.source = '', # file/dir/URL/glob, 0 for webcam
self.imgsz = (640, 640) # inference size (height, width)
self.conf_thres = 0.25 # confidence threshold
self.iou_thres = 0.45 # NMS IOU threshold
self.max_det = 1000 # maximum detections per image
self.device = 'cpu' # cuda device, i.e. 0 or 0,1,2,3 or cpu
self.view_img = False # show results
self.save_txt = False # save results to *.txt
self.save_conf = False # save confidences in --save-txt labels
self.save_crop = False # save cropped prediction boxes
self.nosave = False # do not save images/videos
self.classes = None # filter by class: --class 0, or --class 0 2 3
self.agnostic_nms = False # class-agnostic NMS
self.augment = False # augmented inference
self.visualize = False # visualize features
self.update = False # update all models
self.project = ROOT / 'runs/detect' # save results to project/name
self.name = 'exp' # save results to project/name
self.exist_ok = False # existing project/name ok, do not increment
self.line_thickness = 3 # bounding box thickness (pixels)
self.hide_labels = False # hide labels
self.hide_conf = False # hide confidences
self.half = False # use FP16 half-precision inference
self.dnn = False # use OpenCV DNN for ONNX inference
# Directories
self.save_dir = increment_path(Path(self.project) / self.name, exist_ok=self.exist_ok) # increment run
(self.save_dir / 'labels' if self.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True) # make dir
# Load model
device = select_device(self.device)
self.model = DetectMultiBackend(self.weights, device=device, dnn=self.dnn)
self.stride, self.names, self.pt, jit, onnx, engine = self.model.stride, self.model.names, self.model.pt, self.model.jit, self.model.onnx, self.model.engine
self.imgsz = check_img_size(self.imgsz, s=self.stride) # check image size
# Half
self.half &= (self.pt or jit or engine) and device.type != 'cpu' # half precision only supported by PyTorch on CUDA
if self.pt or jit:
self.model.model.half() if self.half else self.model.model.float()
def detect(self, source):
detect_result = {
'target_info_list': [],
}
source = str(source)
save_img = not self.nosave and not source.endswith('.txt') # save inference images
is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
is_url = source.lower().startswith(('rtsp://', 'rtmp://', 'http://', 'https://'))
webcam = source.isnumeric() or source.endswith('.txt') or (is_url and not is_file)
if is_url and is_file:
source = check_file(source) # download
# Dataloader
if webcam:
view_img = check_imshow()
cudnn.benchmark = True # set True to speed up constant image size inference
dataset = LoadStreams(source, img_size=self.imgsz, stride=self.stride, auto=self.pt)
bs = len(dataset) # batch_size
else:
dataset = LoadImages(source, img_size=self.imgsz, stride=self.stride, auto=self.pt)
bs = 1 # batch_size
vid_path, vid_writer = [None] * bs, [None] * bs
# Run inference
self.model.warmup(imgsz=(1, 3, *self.imgsz), half=self.half) # warmup
dt, seen = [0.0, 0.0, 0.0], 0
for path, im, im0s, vid_cap, s in dataset:
t1 = time_sync()
im = torch.from_numpy(im).to(self.device)
im = im.half() if self.half else im.float() # uint8 to fp16/32
im /= 255 # 0 - 255 to 0.0 - 1.0
if len(im.shape) == 3:
im = im[None] # expand for batch dim
t2 = time_sync()
dt[0] += t2 - t1
# Inference
self.visualize = increment_path(self.save_dir / Path(path).stem, mkdir=True) if self.visualize else False
pred = self.model(im, augment=self.augment, visualize=self.visualize)
t3 = time_sync()
dt[1] += t3 - t2
# NMS
pred = non_max_suppression(pred, self.conf_thres, self.iou_thres, self.classes, self.agnostic_nms, max_det=self.max_det)
dt[2] += time_sync() - t3
# Second-stage classifier (optional)
# pred = utils.general.apply_classifier(pred, classifier_model, im, im0s)
# Process predictions
for i, det in enumerate(pred): # per image
seen += 1
if webcam: # batch_size >= 1
p, im0, frame = path[i], im0s[i].copy(), dataset.count
s += f'{i}: '
else:
p, im0, frame = path, im0s.copy(), getattr(dataset, 'frame', 0)
p = Path(p) # to Path
save_path = str(self.save_dir / p.name) # im.jpg
txt_path = str(self.save_dir / 'labels' / p.stem) + (
'' if dataset.mode == 'image' else f'_{frame}') # im.txt
s += '%gx%g ' % im.shape[2:] # print string
gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
imc = im0.copy() if self.save_crop else im0 # for save_crop
annotator = Annotator(im0, line_width=self.line_thickness, example=str(self.names))
if len(det):
# Rescale boxes from img_size to im0 size
det[:, :4] = scale_coords(im.shape[2:], det[:, :4], im0.shape).round()
# Print results
for c in det[:, -1].unique():
n = (det[:, -1] == c).sum() # detections per class
s += f"{n} {self.names[int(c)]}{'s' * (n > 1)}, " # add to string
# Write results
for *xyxy, conf, cls in reversed(det):
if self.save_txt: # Write to file
xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
line = (cls, *xywh, conf) if self.save_conf else (cls, *xywh) # label format
with open(txt_path + '.txt', 'a') as f:
f.write(('%g ' * len(line)).rstrip() % line + '\n')
if save_img or self.save_crop or self.view_img: # Add bbox to image
c = int(cls) # integer class
label = None if self.hide_labels else (self.names[c] if self.hide_conf else f'{self.names[c]} {conf:.2f}')
infos = str(xyxy)+label[-3:]
print(infos)
detect_result['target_info_list'].append(re.findall("([0-9]+)", infos))
annotator.box_label(xyxy, label, color=colors(c, True))
if self.save_crop:
save_one_box(xyxy, imc, file=self.save_dir / 'crops' / self.names[c] / f'{p.stem}.jpg', BGR=True)
# Print time (inference-only)
LOGGER.info(f'{s}Done. ({t3 - t2:.3f}s)')
# Stream results
im0 = annotator.result()
if self.view_img:
cv2.imshow(str(p), im0)
cv2.waitKey(1) # 1 millisecond
# Save results (image with detections)
if save_img:
if dataset.mode == 'image':
cv2.imwrite(save_path, im0)
else: # 'video' or 'stream'
if vid_path[i] != save_path: # new video
vid_path[i] = save_path
if isinstance(vid_writer[i], cv2.VideoWriter):
vid_writer[i].release() # release previous video writer
if vid_cap: # video
fps = vid_cap.get(cv2.CAP_PROP_FPS)
w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
else: # stream
fps, w, h = 30, im0.shape[1], im0.shape[0]
save_path += '.mp4'
vid_writer[i] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
vid_writer[i].write(im0)
# Print results
t = tuple(x / seen * 1E3 for x in dt) # speeds per image
LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {(1, 3, *self.imgsz)}' % t)
if self.save_txt or save_img:
s = f"\n{len(list(self.save_dir.glob('labels/*.txt')))} labels saved to {self.save_dir / 'labels'}" if self.save_txt else ''
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}")
if self.update:
strip_optimizer(self.weights) # update model (to fix SourceChangeWarning)
return detect_result
def __new__(cls, *args, **kwargs) -> object:
"""
使用new方法实现单例模式
:param args:
:param kwargs:
:return:
"""
if not hasattr(yolo_detect, "_instance"):
with yolo_detect._instance_lock:
if not hasattr(yolo_detect, "_instance"):
yolo_detect._instance = object.__new__(cls)
return yolo_detect._instance
if __name__ == '__main__':
y = yolo_detect()
y.detect("1640749331.png")
后台服务搭建
基于以上修改后的脚本理论上已经可以直接本地使用了,但考虑到yolo基于深度学习,需要搭建一整套的环境后才能运行,如果全部本地使用的话,各项目复用成本较高;因此决定将该脚本以web服务的方式提供,考虑到脚本为py,因此web服务采用py栈的Django框架。
主要为以下几步:
1、使用pycharm新建一个django项目
2、settings文件注释 MIDDLEWARE 的'django.middleware.csrf.CsrfViewMiddleware',解决跨域问题
3、url文件内增加路由映射关系
from yolo import views as yolo_view
urlpatterns = [
path('admin/', admin.site.urls),
path(r'up_file', yolo_view.up_file.as_view()),
]
4、views文件内增加路由方法
class up_file(View):
def post(self, request):
try:
file = request.FILES.get('file','')
print(file)
current_path = os.path.dirname(__file__) # 当前路径
print(current_path)
file_path = os.path.join(current_path,'tempdata')
print(file_path)
if not os.path.exists(file_path): # 文件夹不存在则创建
os.mkdir(file_path)
save_file_path = os.path.join(file_path,file.name)
with open(save_file_path,'wb') as fp: # 写文件
for i in file.chunks():
fp.write(i)
_msg = yolo.detect(save_file_path)
return HttpResponse(json.dumps(_msg))
except Exception as e:
print(e)
return HttpResponse(json.dumps({'status': False, 'msg': u'错误:{}'.format(e)}))
5、启动服务
测试脚本
服务搭建完成后,写一个测试脚本进行相关测试,内容如下:
import requests, os, time
def post_pic(url:str, path:str) -> dict:
"""
把图片上传到服务器,并获取返回值
:param url: 服务器地址
:param path: 图片本地路径
:return: 服务器的返回值
"""
if not os.path.exists(path):
print("文件不存在!")
return {}
fp = open(path, 'rb')
result = requests.post(url, files={'file': fp},)
fp.close()
data = result.json()
return data
def parse_response(response:dict) -> list:
"""
把请求返回值进行处理,返回处理后的值
:param response: 第一次执行上传方法后的
:return:处理后的列表,如果不是正确识别的情况,则返回空列表
"""
if response:
# 0,提取出列表
taglist = response['target_info_list']
# 1,把置信度 >95 的提取出来
real_tag_list = []
for tag in taglist:
if int(tag[-1]) >= 95:
real_tag_list.append(tag)
if len(real_tag_list) == 2:
# 2,如果提取出的列表长度是2,代表识别没有问题,直接返回提取后的列表
return real_tag_list # 返回处理后的列表
else:
# 3,提取出来后的列表长度如果不是2,代表本次识别的有问题,需要丢弃;返回空列表
return []
else:
# 传入的参数不是字典类型,代表识别出了问题,直接返回空列表
return []
if __name__ == '__main__':
url = 'http://172.31.183.153:8000/up_file'
path = '333333.png'
t1 = time.time()
tag_list = parse_response(post_pic(url, path))
print(tag_list)
difftime =time.time() - t1
print(difftime)
效果
请求后,响应体如下:
{'target_info_list': [['120', '933', '295', '1082', '97'], ['631', '932', '800', '1078', '98']]}
target_info_list是识别结果的列表,每一个列表代表一个被识别滑块的信息,其中,前四位数字是位置信息,代表滑块左上角的x轴位置、y轴位置和右下角的x轴位置、y轴位置;第五位数字为置信度,代表程序认为是滑块的可能性。
根据两个滑块的x轴位置,即可计算出滑动距离。
实测上传一张300kb左右的图片,响应时间为600ms左右。