【YOLOv3 预测】YOLOv3图像输入->处理->输出全过程

文章目录

  • 1 数据输入输出代码详解
  • 2 yolo.detect_image详解
  • 3 感谢链接

1 数据输入输出代码详解

数据输入方式主要包括三种:

  • 读取本地一张图片(代码中predict状态)
  • 读取本地文件夹中所有图片(代码中dir_predict状态)
  • 视频输入状态(代码中video状态),视频输入又分为摄像头输入和本地视频输入两种,当video_path=0时表示检测摄像头。

将这三种输入方式归总到一个predict.py文件中,从全局角度看图像输入、处理、输出的全过程,其代码如下:

# -----------------------------------------------------------------------#
#   predict.py将单张图片预测、摄像头检测和目录遍历检测等功能
#   整合到了一个py文件中,通过指定mode进行模式的修改。
# -----------------------------------------------------------------------#
import time
import cv2
import numpy as np
from PIL import Image

from yolo import YOLO  # 从yolo.py文件中导入YOLO类,详见下一小节

if __name__ == "__main__":
    yolo = YOLO()  # 在这儿进行类的实例化,详见下一小节
    # ----------------------------------------------------------------------------------------------------------#
    #   mode用于指定测试的模式:
    #   'predict'表示单张图片预测,如果想对预测过程进行修改,如保存图片,截取对象等,可以先看下方详细的注释
    #   'video'表示视频检测,可调用摄像头或者视频进行检测,详情查看下方注释。
    #   'dir_predict'表示遍历文件夹进行检测并保存。默认遍历img文件夹,保存img_out文件夹,详情查看下方注释。
    # ----------------------------------------------------------------------------------------------------------#
    mode = "predict"
    # ----------------------------------------------------------------------------------------------------------#
    #   video_path用于指定视频的路径,当video_path=0时表示检测摄像头
    #   想要检测视频,则设置如video_path = "xxx.mp4"即可,代表读取出根目录下的xxx.mp4文件。
    #   video_save_path表示视频保存的路径,当video_save_path=""时表示不保存
    #   想要保存视频,则设置如video_save_path = "yyy.mp4"即可,代表保存为根目录下的yyy.mp4文件。
    #   video_fps用于保存的视频的fps
    #   video_path、video_save_path和video_fps仅在mode='video'时有效
    #   保存视频时需要ctrl+c退出或者运行到最后一帧才会完成完整的保存步骤。
    # ----------------------------------------------------------------------------------------------------------#
    video_path = 0      # video_path用于指定视频的路径,当video_path=0时表示检测摄像头
    video_save_path = ""
    video_fps = 25.0
    # -------------------------------------------------------------------------#
    #   test_interval用于指定测量fps的时候,图片检测的次数
    #   理论上test_interval越大,fps越准确。
    # -------------------------------------------------------------------------#
    test_interval = 100
    # -------------------------------------------------------------------------#
    #   dir_origin_path指定了用于检测的图片的文件夹路径
    #   dir_save_path指定了检测完图片的保存路径
    #   dir_origin_path和dir_save_path仅在mode='dir_predict'时有效
    # -------------------------------------------------------------------------#
    dir_origin_path = "img/"
    dir_save_path = "img_out/"

    if mode == "predict":
        '''
        1、如果想要进行检测完的图片的保存,利用r_image.save("img.jpg")即可保存,直接在predict.py里进行修改即可。 
        2、如果想要获得预测框的坐标,可以进入yolo.detect_image函数,在绘图部分读取top,left,bottom,right这四个值。
        3、如果想要利用预测框截取下目标,可以进入yolo.detect_image函数,在绘图部分利用获取到的top,left,bottom,right这四个值
        在原图上利用矩阵的方式进行截取。
        4、如果想要在预测图上写额外的字,比如检测到的特定目标的数量,可以进入yolo.detect_image函数,在绘图部分对predicted_class进行判断,
        比如判断if predicted_class == 'car': 即可判断当前目标是否为车,然后记录数量即可。利用draw.text即可写字。
        '''
        while True:
            img = input('Input image filename:')
            try:
                image = Image.open(img)
            except:
                print('Open Error! Try again!')
                continue
            else:
                r_image = yolo.detect_image(image)
                r_image.show()

    elif mode == "video":
        capture = cv2.VideoCapture(video_path)      # video_path用于指定视频的路径,当video_path=0时表示检测摄像头
        if video_save_path != "":
            fourcc = cv2.VideoWriter_fourcc(*'XVID')
            size = (int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)), int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)))
            out = cv2.VideoWriter(video_save_path, fourcc, video_fps, size)

        ref, frame = capture.read()
        if not ref:
            raise ValueError("未能正确读取摄像头(视频),请注意是否正确安装摄像头(是否正确填写视频路径)。")

        fps = 0.0
        while (True):
            t1 = time.time()
            # 读取某一帧
            ref, frame = capture.read()
            if not ref:
                break
            # 格式转变,BGRtoRGB
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            # 转变成Image
            frame = Image.fromarray(np.uint8(frame))
            # 进行检测
            frame = np.array(yolo.detect_image(frame))
            # RGBtoBGR满足opencv显示格式
            frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)

            fps = (fps + (1. / (time.time() - t1))) / 2
            print("fps= %.2f" % (fps))
            frame = cv2.putText(frame, "fps= %.2f" % (fps), (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)

            cv2.imshow("video", frame)
            c = cv2.waitKey(1) & 0xff
            if video_save_path != "":
                out.write(frame)

            if c == 27:
                capture.release()
                break

        print("Video Detection Done!")
        capture.release()
        if video_save_path != "":
            print("Save processed video to the path :" + video_save_path)
            out.release()
        cv2.destroyAllWindows()

    elif mode == "dir_predict":
        import os
        from tqdm import tqdm

        img_names = os.listdir(dir_origin_path)
        for img_name in tqdm(img_names):    # tqdm用来在长循环中添加一个进度提示信息
            if img_name.lower().endswith(
                    ('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')):
                image_path = os.path.join(dir_origin_path, img_name)
                image = Image.open(image_path)
                r_image = yolo.detect_image(image)
                if not os.path.exists(dir_save_path):
                    os.makedirs(dir_save_path)
                r_image.save(os.path.join(dir_save_path, img_name))

    else:
        raise AssertionError("Please specify the correct mode: 'predict', 'video', or 'dir_predict'.")

2 yolo.detect_image详解

包括模型初始化、模型推理得到预测结果,解码,非极大值抑制(确实是在解码后面),将结果可视化。
代码中包含各部分介绍及重要函数的参考链接。

import colorsys     # 用来给画框设置不同的颜色
import numpy as np
import torch
import torch.nn as nn
from PIL import ImageDraw, ImageFont
# yolov3网络结构,详解见参考链接https://blog.csdn.net/weixin_45377629/article/details/124080087
from nets.yolo import YoloBody      
#---------------------------------------------------------#
#   cvtColor将图像转换成RGB图像,防止灰度图在预测时报错。
#		代码仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
#   get_anchors获得先验框
#   get_classes获得类别
#   preprocess_input归一化
#   resize_image对输入图像进行resize
#---------------------------------------------------------#
from utils.utils import (cvtColor, get_anchors, get_classes, preprocess_input, resize_image)
# 对网络预测结构进行解码,详解见参考链接https://blog.csdn.net/weixin_45377629/article/details/124144913
from utils.utils_bbox import DecodeBox


class YOLO(object):
    _defaults = {
        # --------------------------------------------------------------------------#
        #   使用自己训练好的模型进行预测一定要修改model_path和classes_path!
        #   model_path指向logs文件夹下的权值文件,classes_path指向model_data下的txt
        #
        #   训练好后logs文件夹下存在多个权值文件,选择验证集损失较低的即可。
        #   验证集损失较低不代表mAP较高,仅代表该权值在验证集上泛化性能较好。
        #   如果出现shape不匹配,同时要注意训练时的model_path和classes_path参数的修改
        # --------------------------------------------------------------------------#
        "model_path": 'model_data/yolo_weights.pth',
        "classes_path": 'model_data/coco_classes.txt',
        # ---------------------------------------------------------------------#
        #   anchors_path代表先验框对应的txt文件,一般不修改。
        #   anchors_mask用于帮助代码找到对应的先验框,一般不修改。
        # ---------------------------------------------------------------------#
        "anchors_path": 'model_data/yolo_anchors.txt',
        "anchors_mask": [[6, 7, 8], [3, 4, 5], [0, 1, 2]],
        # ---------------------------------------------------------------------#
        #   输入图片的大小,必须为32的倍数。
        # ---------------------------------------------------------------------#
        "input_shape": [416, 416],
        # ---------------------------------------------------------------------#
        #   只有得分大于置信度的预测框会被保留下来
        # ---------------------------------------------------------------------#
        "confidence": 0.5,
        # ---------------------------------------------------------------------#
        #   非极大抑制所用到的nms_iou大小
        # ---------------------------------------------------------------------#
        "nms_iou": 0.3,
        # ---------------------------------------------------------------------#
        #   该变量用于控制是否使用letterbox_image对输入图像进行不失真的resize,
        #   在多次测试后,发现关闭letterbox_image直接resize的效果更好
        # ---------------------------------------------------------------------#
        "letterbox_image": False,
        # -------------------------------#
        #   是否使用Cuda
        #   没有GPU可以设置成False
        # -------------------------------#
        "cuda": True,
    }

    # ---------------------------------------------------#
    #   初始化YOLO
    # ---------------------------------------------------#
    def __init__(self, **kwargs):
        self.__dict__.update(self._defaults)
        for name, value in kwargs.items():
            setattr(self, name, value)

        # ---------------------------------------------------#
        #   获得种类和先验框的数量
        #	DecodeBox详解见:https://blog.csdn.net/weixin_45377629/article/details/124144913
        # ---------------------------------------------------#
        self.class_names, self.num_classes = get_classes(self.classes_path)
        self.anchors, self.num_anchors = get_anchors(self.anchors_path)
        self.bbox_util = DecodeBox(self.anchors, self.num_classes, (self.input_shape[0], self.input_shape[1]),
                                   self.anchors_mask)

        # ---------------------------------------------------#
        #   画框设置不同的颜色
        # ---------------------------------------------------#
        hsv_tuples = [(x / self.num_classes, 1., 1.) for x in range(self.num_classes)]
        self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
        self.colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), self.colors))
        self.generate()  # 生成 yolov3 模型并加载权重

    # ---------------------------------------------------#
    #   生成模型,模型初始化,放到device上,加载模型参数,改成eval()模式
    # ---------------------------------------------------#
    def generate(self):
        # ---------------------------------------------------#
        #   建立yolov3模型,载入yolov3模型的权重
        # ---------------------------------------------------#
        self.net = YoloBody(self.anchors_mask, self.num_classes)
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.net.load_state_dict(torch.load(self.model_path, map_location=device))
        self.net = self.net.eval()
        print('{} model, anchors, and classes loaded.'.format(self.model_path))

        if self.cuda:
            self.net = nn.DataParallel(self.net)
            self.net = self.net.cuda()

    # ---------------------------------------------------#
    #   检测图片
    # ---------------------------------------------------#
    def detect_image(self, image):
        image_shape = np.array(np.shape(image)[0:2])
        # ---------------------------------------------------------#
        #   在这里将图像转换成RGB图像,防止灰度图在预测时报错。
        #   代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
        # ---------------------------------------------------------#
        image = cvtColor(image)
        # ---------------------------------------------------------#
        #   给图像增加灰条,实现不失真的resize
        #   也可以直接resize进行识别
        # ---------------------------------------------------------#
        image_data = resize_image(image, (self.input_shape[1], self.input_shape[0]), self.letterbox_image)
        # ---------------------------------------------------------#
        #   添加上batch_size维度,缺少这个维度,网络没法预测,preprocess_input表示数据归一化 image /= 255.0
        # ---------------------------------------------------------#
        image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)

        with torch.no_grad():
            images = torch.from_numpy(image_data)  # 从numpy形式转成torch需要的形式
            if self.cuda:
                images = images.cuda()
            # ---------------------------------------------------------#
            #   将图像输入网络当中进行预测!    
            #	然后进行解码,可参考链接https://blog.csdn.net/weixin_45377629/article/details/124144913
            # ---------------------------------------------------------#
            outputs = self.net(images)
            outputs = self.bbox_util.decode_box(outputs)
            # ---------------------------------------------------------#
            #   将预测框进行堆叠,然后进行非极大抑制   功能:筛选出一定区域内,属于同一种类得分最大的方框
            #   此处可参考链接:https://blog.csdn.net/weixin_45377629/article/details/124202975
            # ---------------------------------------------------------#
            results = self.bbox_util.non_max_suppression(torch.cat(outputs, 1), self.num_classes, self.input_shape,
                                                         image_shape, self.letterbox_image, conf_thres=self.confidence,
                                                         nms_thres=self.nms_iou)

            if results[0] is None:
                return image

            top_label = np.array(results[0][:, 6], dtype='int32')  	# 预测框的种类
            top_conf = results[0][:, 4] * results[0][:, 5]  		# 预测框的置信度
            top_boxes = results[0][:, :4]  							# 预测框的坐标
		
		# ------------------------------#
		#   打印每个类别的数量
		# ------------------------------#
        print("top_label:", top_label)
        classes_num = np.zeros([self.num_classes])
        for i in range(self.num_classes):
        	num = np.sum(top_label == i):
        	if num > 0:
        		print(self.class_names[i], " : ", num)
        	classes_num[i] = num
        print("classes_num: ", classes_num)
        # ---------------------------------------------------------#
        #   设置字体与边框厚度
        # ---------------------------------------------------------#
        font = ImageFont.truetype(font='model_data/simhei.ttf',
                                  size=np.floor(3e-2 * image.size[1] + 0.5).astype('int32'))
        thickness = int(max((image.size[0] + image.size[1]) // np.mean(self.input_shape), 1))

        # ---------------------------------------------------------#
        #   图像绘制
        # ---------------------------------------------------------#
        for i, c in list(enumerate(top_label)):
            predicted_class = self.class_names[int(c)]
            box = top_boxes[i]
            score = top_conf[i]

            top, left, bottom, right = box

            top = max(0, np.floor(top).astype('int32'))
            left = max(0, np.floor(left).astype('int32'))
            bottom = min(image.size[1], np.floor(bottom).astype('int32'))
            right = min(image.size[0], np.floor(right).astype('int32'))

            label = '{} {:.2f}'.format(predicted_class, score)
            draw = ImageDraw.Draw(image)
            label_size = draw.textsize(label, font)
            label = label.encode('utf-8')
            print(label, top, left, bottom, right)

            if top - label_size[1] >= 0:
                text_origin = np.array([left, top - label_size[1]])
            else:
                text_origin = np.array([left, top + 1])

            for i in range(thickness):
                draw.rectangle([left + i, top + i, right - i, bottom - i], outline=self.colors[c])
            draw.rectangle([tuple(text_origin), tuple(text_origin + label_size)], fill=self.colors[c])
            draw.text(text_origin, str(label, 'UTF-8'), fill=(0, 0, 0), font=font)
            del draw

        return image

结果演示:
【YOLOv3 预测】YOLOv3图像输入->处理->输出全过程_第1张图片

3 感谢链接

https://www.bilibili.com/video/BV1Hp4y1y788?p=7
https://blog.csdn.net/weixin_44791964/article/details/105310627

你可能感兴趣的:(目标检测系列,深度学习,python,pytorch)