今天放松一下,随便看看这个YOLOV5 的识别部分的代码是怎么做的,先前的话我们自己手动实现了一个非常简易的分类框架,HuClassfiy(已经上传Gitee,方便各位访问),那么这里的话想要使用YOLOV5做点好玩的,也必须要对整个的代码流程进行梳理。原理就不用说了,老复杂了,所以先从简单的来探索。
我们原来的实现这个detect的代码非常简单,后面会贴出,我注释后的detect代码
import argparse
from PIL import Image
from utils.DataSet.MyDataSet import MyDataSet
from utils.DataSet.TransformAtions import TransFormAtions
"""
这里不想写那么多东西,就是简单地去做一个测试就ok了。
其实做法就是在那个train里面的训练
"""
import argparse
import torch
from torch.utils.data import DataLoader
from models.LeNet import LeNet
from data.ModelConfig import *
import outProcess
def detect():
ways = opt.valid_imgs
transformations = TransFormAtions()
net = LeNet(classes=Classes)
state_dict_load = torch.load(opt.path_state_dict)
net.load_state_dict(state_dict_load)
if(ways):
test_data = MyDataSet(data_dir=opt.valid_dir, transform=transformations.valid_transform)
valid_loader = DataLoader(dataset=test_data, batch_size=1)
net.eval()
with torch.no_grad():
for i, data in enumerate(valid_loader):
# forward
inputs, labels = data
outputs = net(inputs)
_, predicted = torch.max(outputs.data, 1)
# 输出处理器
outProcess.Function(predicted.numpy()[0])
else:
#指定的是单张图片,少给我来奇奇怪怪的输入,这个版本容错很差滴!!!
path_img = opt.valid_dir
if(".jpg" not in path_img):
raise Exception("小爷打不开这图片")
image = Image.open(path_img)
image = transformations.valid_transform(image)
image = torch.reshape(image, (1, 3, 32, 32))
net.eval()
with torch.no_grad():
out = net(image)
outProcess.Function(out.argmax(1).item())
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# False表示识别单张图片,True表示多张图片,此时指定路径即可。
parser.add_argument('--valid_imgs',type=bool,default=False)
parser.add_argument('--valid_dir', type=str, default=r'F:\projects\PythonProject\MyClassfication\mydata\train\100\1.jpg')
parser.add_argument('--path_state_dict', type=str, default='runs/train/epx2/weights/best.pth')
opt = parser.parse_args()
detect()
在YOLO V 5 里面也不复杂,也就比我多了100多行代码。
我们这边大致的流程就三个。
然后每一个环节都可以有很多细节优化啥的,由于俺们那个是很简陋的,所以没有哈。
好了,我们开始正式进入这个YOLOV5的实际环节。
我们先来这看到这个环节,这里一共是做了两件事情嘛,读取超参数,加载模型权重文件,加载驱动
这里可以注意到这个函数
这个的话不用想的那么复杂,就是这个玩意
目的就是返回一个 可以正常使用的驱动,要是我写的话,我压根不会管那么多,不行就玩命报错,然后输出日志文件。
然后第二步是加载数据,这个说实话,没什么好说的,分两个,一个是读取网络摄像头,一个是读取一张图片,或者视频,本地摄像头。这些逻辑处理细节不一样,但是结果都是一样的。
就是把数据给我封装的dataset里面,然后读取。
这里我们说说那个预测的格式。
我这里还是拿上次的一张图片做演示
这里有两个目标框,所以拿到的数据是这样的
我们发现pred 是一个长度为1,里面有两个list的玩意
之后我们发现最后一个直接是0
这个的话,是这样的
之后就是拿到东西之后处理。在yolo里面默认是实现了一个自己绘图的玩意。
当然有时候,我们不仅仅要这玩意,我们想要实现AI压枪的话还需要那啥。
import argparse
import time
from pathlib import Path
import cv2
import torch
import torch.backends.cudnn as cudnn
from numpy import random
from models.experimental import attempt_load
from utils.datasets import LoadStreams, LoadImages
from utils.general import check_img_size, check_requirements, check_imshow, non_max_suppression, apply_classifier, \
scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path
from utils.plots import plot_one_box
from utils.torch_utils import select_device, load_classifier, time_synchronized
def detect(save_img=False):
# 读取初始化参数
source, weights, view_img, save_txt, imgsz = opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_size
save_img = not opt.nosave and not source.endswith('.txt') # save inference images
webcam = source.isnumeric() or source.endswith('.txt') or source.lower().startswith(
('rtsp://', 'rtmp://', 'http://', 'https://'))
# Directories
save_dir = Path(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok)) # increment run
(save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir
# Initialize
set_logging()
device = select_device(opt.device)
half = device.type != 'cpu' # half precision only supported on CUDA
# Load model
#加载模型,这一块,weights是我们传入的参数,是我们权重文件地址
#注意到这个model就是我们Huclassfiy的net
model = attempt_load(weights, map_location=device) # load FP32 model
stride = int(model.stride.max()) # model stride,维度的变换步长,这个和YOLO的网络结构有关,先忽略
#imgsz是我们图片资源,对图片尺寸进行检查
imgsz = check_img_size(imgsz, s=stride) # check img_size
#Pytorch 模型加速,这个需要GPU加速,需要先加载模型权重的!!!
if half:
model.half() # to FP16
# Second-stage classifier
classify = False
if classify:
modelc = load_classifier(name='resnet101', n=2) # initialize
modelc.load_state_dict(torch.load('weights/resnet101.pt', map_location=device)['model']).to(device).eval()
# Set Dataloader
vid_path, vid_writer = None, None
#如果是网络摄像头的数据这样处理
if webcam:
view_img = check_imshow()
cudnn.benchmark = True # set True to speed up constant image size inference
dataset = LoadStreams(source, img_size=imgsz, stride=stride)
else:
#这部分是加载dataset 和我们那个也是类似的,只不过对于单张图片,我们直接转化为了一个tensor在HuClassFiy
dataset = LoadImages(source, img_size=imgsz, stride=stride)
# Get names and colors
names = model.module.names if hasattr(model, 'module') else model.names
colors = [[random.randint(0, 255) for _ in range(3)] for _ in names]
# Run inference
if device.type != 'cpu':
model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters()))) # run once
t0 = time.time()
for path, img, im0s, vid_cap in dataset:
#这里就和我们的那个进入验证是类似的了
#path 是你的图片路径
# img 自然是image转化为了tensor
#im0s 是做了一个转化img0 = cv2.imread(path) # BGR
#vid_cap 就是说这玩意是不是一个视频,我们读入图片当然不是所以是None
img = torch.from_numpy(img).to(device)
img = img.half() if half else img.float() # uint8 to fp16/32
img /= 255.0 # 0 - 255 to 0.0 - 1.0 #归一化
if img.ndimension() == 3:
img = img.unsqueeze(0)
# Inference
t1 = time_synchronized()
pred = model(img, augment=opt.augment)[0]
#这个是预测的结果,但是按照那个网络的工作原理,还需要进行NMS非极大值抑制筛选目标框框
# Apply NMS
pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms)
t2 = time_synchronized()
print("预测结果是",pred)
#按照我们在YOLV1论文里面的推出,应该是有5个参数
#x,y,w,h,k可信度,但是这里要显示所以还有一个对应的条件概率
#所以应该有6个参数,但是对应参数k,我们的概率计算是需要k的,结合参数opt.iou_thres
#所以此时那个参数k应该是iou,之后对应的概率,这里最后通过debug我发现那个完整的参数是这样的
#左上角,右下角,然后可信度,然后所属类别,注意那里显示的是按照屏幕100%来的,我的笔记本是125%
#得到的坐标是需要除以1.25的
# Apply Classifier
if classify:
pred = apply_classifier(pred, modelc, img, im0s)
# Process detections
#这部分,就是我们的后置处理了。说实话,应该把这玩意拆开的,这个部分是给Opencv画图用的
for i, det in enumerate(pred): # detections per image
if webcam: # batch_size >= 1
p, s, im0, frame = path[i], '%g: ' % i, im0s[i].copy(), dataset.count
else:
p, s, im0, frame = path, '', im0s, getattr(dataset, 'frame', 0)
p = Path(p) # to Path
save_path = str(save_dir / p.name) # img.jpg
txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # img.txt
s += '%gx%g ' % img.shape[2:] # print string
gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
if len(det):
# Rescale boxes from img_size to im0 size
det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()
# Print results
for c in det[:, -1].unique():
n = (det[:, -1] == c).sum() # detections per class
s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string
# Write results
for *xyxy, conf, cls in reversed(det):
if save_txt: # Write to file
xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
line = (cls, *xywh, conf) if opt.save_conf else (cls, *xywh) # label format
with open(txt_path + '.txt', 'a') as f:
f.write(('%g ' * len(line)).rstrip() % line + '\n')
if save_img or view_img: # Add bbox to image
label = f'{names[int(cls)]} {conf:.2f}'
# plot_one_box(xyxy, im0, label=label, color=colors[int(cls)], line_thickness=3)
im0 = plot_one_box(xyxy, im0, label=label, color=colors[int(cls)], line_thickness=3)
# Print time (inference + NMS)
print(f'{s}Done. ({t2 - t1:.3f}s)')
# Stream results
if view_img:
cv2.imshow(str(p), im0)
cv2.waitKey(1) # 1 millisecond
# Save results (image with detections)
if save_img:
if dataset.mode == 'image':
cv2.imwrite(save_path, im0)
else: # 'video' or 'stream'
if vid_path != save_path: # new video
vid_path = save_path
if isinstance(vid_writer, cv2.VideoWriter):
vid_writer.release() # release previous video writer
if vid_cap: # video
fps = vid_cap.get(cv2.CAP_PROP_FPS)
w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
else: # stream
fps, w, h = 30, im0.shape[1], im0.shape[0]
save_path += '.mp4'
vid_writer = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
vid_writer.write(im0)
if save_txt or save_img:
s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
print(f"Results saved to {save_dir}{s}")
print(f'Done. ({time.time() - t0:.3f}s)')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--weights', nargs='+', type=str, default='runs/train/exp2/weights/best.pt', help='model.pt path(s)')
# http://admin:[email protected]:8081
parser.add_argument('--source', type=str, default=r'F:\projects\PythonProject\yolov5-5.0\mydata\images\003.jpg', help='source') # file/folder, 0 for webcam
parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)')
parser.add_argument('--conf-thres', type=float, default=0.25, help='object confidence threshold')
parser.add_argument('--iou-thres', type=float, default=0.45, help='IOU threshold for NMS')
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--view-img', action='store_true', help='display results')
parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')
parser.add_argument('--nosave', action='store_true', help='do not save images/videos')
parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 0 2 3')
parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
parser.add_argument('--augment', action='store_true', help='augmented inference')
parser.add_argument('--update', action='store_true', help='update all models')
parser.add_argument('--project', default='runs/detect', help='save results to project/name')
parser.add_argument('--name', default='exp', help='save results to project/name')
parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
opt = parser.parse_args()
print(opt)
check_requirements(exclude=('pycocotools', 'thop'))
with torch.no_grad():
if opt.update: # update all models (to fix SourceChangeWarning)
for opt.weights in ['yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt']:
detect()
strip_optimizer(opt.weights)
else:
detect()
那么接下来我们要做的就是提取detect,把这个玩意套在我们自己的项目里面。为了后面便于使用这个yolo,我决定后面对这个玩意进行工程化规范,便于直接进行二次使用,开发。毕竟核心的话其实就和HuClassfiy一样,就那几个块。还是那句话,yolo的难点不在工程上,在原理实现上面…