Python 使用 Detectron2 进行目标检测 (Detectron2, CenterNet2, Detic)

代码说明

代码主要是一个用来演示如何使用 Detectron2 进行目标检测的脚本。它可以从摄像头或视频文件中读取图像,并应用指定的配置文件进行目标检测。其中,Detectron2 结合了 CenterNet2 和 Detic 进行目标检测。
 


 
 

主要库介绍

Detectron2

Detectron2是由Facebook AI Research开发的一个用于目标检测和实例分割的开源库。它提供了一系列预训练模型和灵活的配置系统,使得用户可以方便地进行模型训练和推理。

主要功能:

  • 支持多种目标检测和分割任务,如Faster R-CNN、Mask R-CNN等。
  • 提供模块化设计,易于扩展和定制。
  • 高效的训练和推理性能,支持GPU加速。

 

CenterNet2

CenterNet2是一个基于单阶段检测器CenterNet的目标检测框架。它通过在CenterNet的基础上添加更多特性和优化,进一步提升了检测性能和准确率。

主要功能:

  • 单阶段检测,速度快。
  • 通过热力图进行目标中心点定位。
  • 支持多个检测任务,如物体检测、姿态估计等。

 

Detic

Detic是一个用于开放词汇表目标检测的库。它允许用户在推理时使用自定义的词汇表,而不需要重新训练模型。这使得它在处理新的目标类别时更加灵活和高效。

主要功能:

  • 支持开放词汇表目标检测。
  • 可以在推理时动态加载自定义词汇表。
  • 兼容Detectron2框架,易于集成和使用。

 
 

主要代码

# Copyright (c) Facebook, Inc. and its affiliates.
import argparse
import glob
import multiprocessing as mp
import numpy as np
import os
import tempfile
import time
import warnings
import cv2
import tqdm
import sys
import mss

# 导入Detectron2配置和工具
from detectron2.config import get_cfg
from detectron2.data.detection_utils import read_image
from detectron2.utils.logger import setup_logger

sys.path.insert(0, 'third_party/CenterNet2/')  # 添加CenterNet2库路径
from centernet.config import add_centernet_config  # 导入CenterNet2配置函数
from detic.config import add_detic_config  # 导入Detic配置函数

from detic.predictor import VisualizationDemo  # 导入Detic可视化演示类

# Fake a video capture object OpenCV style - half width, half height of first screen using MSS
# 模拟一个OpenCV风格的视频捕获对象 - 使用MSS库截取屏幕的一部分
class ScreenGrab:
    def __init__(self):
        self.sct = mss.mss()  # 初始化MSS库
        m0 = self.sct.monitors[0]  # 获取主屏幕信息
        self.monitor = {'top': 0, 'left': 0, 'width': m0['width'] / 2, 'height': m0['height'] / 2}  # 设置捕获屏幕的区域(宽度和高度减半)

    def read(self):
        img =  np.array(self.sct.grab(self.monitor))  # 捕获屏幕区域图像
        nf = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR)  # 转换图像颜色格式
        return (True, nf)

    def isOpened(self):
        return True  # 模拟摄像头打开状态
    def release(self):
        return True  # 模拟摄像头释放资源

# 常量定义
WINDOW_NAME = "Detic"

def setup_cfg(args):
    cfg = get_cfg()  # 获取Detectron2默认配置
    if args.cpu:
        cfg.MODEL.DEVICE = "cpu"  # 设置使用CPU
    add_centernet_config(cfg)  # 添加CenterNet2配置
    add_detic_config(cfg)  # 添加Detic配置
    cfg.merge_from_file(args.config_file)  # 从文件合并配置
    cfg.merge_from_list(args.opts)  # 从命令行参数合并配置
    # 设置模型的阈值
    cfg.MODEL.RETINANET.SCORE_THRESH_TEST = args.confidence_threshold
    cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = args.confidence_threshold
    cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = args.confidence_threshold
    cfg.MODEL.ROI_BOX_HEAD.ZEROSHOT_WEIGHT_PATH = 'rand'  # 后续加载
    if not args.pred_all_class:
        cfg.MODEL.ROI_HEADS.ONE_CLASS_PER_PROPOSAL = True  # 每个提案只检测一个类别
    cfg.freeze()  # 冻结配置
    return cfg

def get_parser():
    parser = argparse.ArgumentParser(description="Detectron2 demo for builtin configs")  # 创建参数解析器
    parser.add_argument(
        "--config-file",
        default="configs/quick_schedules/mask_rcnn_R_50_FPN_inference_acc_test.yaml",
        metavar="FILE",
        help="path to config file",  # 配置文件路径
    )
    parser.add_argument("--webcam", help="Take inputs from webcam.")  # 摄像头输入
    parser.add_argument("--cpu", action='store_true', help="Use CPU only.")  # 使用CPU
    parser.add_argument("--video-input", help="Path to video file.")  # 视频文件输入路径
    parser.add_argument(
        "--input",
        nargs="+",
        help="A list of space separated input images; "
        "or a single glob pattern such as 'directory/*.jpg'",  # 输入图像列表或模式
    )
    parser.add_argument(
        "--output",
        help="A file or directory to save output visualizations. "
        "If not given, will show output in an OpenCV window.",  # 输出结果保存路径
    )
    parser.add_argument(
        "--vocabulary",
        default="lvis",
        choices=['lvis', 'openimages', 'objects365', 'coco', 'custom'],
        help="Vocabulary type for Detic",  # 词汇表类型
    )
    parser.add_argument(
        "--custom_vocabulary",
        default="",
        help="Path to custom vocabulary file",  # 自定义词汇表路径
    )
    parser.add_argument("--pred_all_class", action='store_true')  # 预测所有类别
    parser.add_argument(
        "--confidence-threshold",
        type=float,
        default=0.5,
        help="Minimum score for instance predictions to be shown",  # 置信度阈值
    )
    parser.add_argument(
        "--opts",
        help="Modify config options using the command-line 'KEY VALUE' pairs",
        default=[],
        nargs=argparse.REMAINDER,
    )
    return parser

def test_opencv_video_format(codec, file_ext):
    with tempfile.TemporaryDirectory(prefix="video_format_test") as dir:  # 创建临时目录
        filename = os.path.join(dir, "test_file" + file_ext)
        writer = cv2.VideoWriter(
            filename=filename,
            fourcc=cv2.VideoWriter_fourcc(*codec),
            fps=float(30),
            frameSize=(10, 10),
            isColor=True,
        )
        [writer.write(np.zeros((10, 10, 3), np.uint8)) for _ in range(30)]  # 写入空白帧
        writer.release()
        if os.path.isfile(filename):
            return True  # 如果文件存在,返回True
        return False  # 否则返回False

if __name__ == "__main__":
    mp.set_start_method("spawn", force=True)  # 设置多进程启动方式
    args = get_parser().parse_args()  # 解析命令行参数
    setup_logger(name="fvcore")  # 设置日志记录器
    logger = setup_logger()  # 获取日志记录器
    logger.info("Arguments: " + str(args))  # 记录参数信息

    cfg = setup_cfg(args)  # 设置配置

    demo = VisualizationDemo(cfg, args)  # 创建可视化演示对象

    if args.input:  # 如果有输入图像
        if len(args.input) == 1:
            args.input = glob.glob(os.path.expanduser(args.input[0]))  # 获取输入图像路径
            assert args.input, "The input path(s) was not found"
        for path in tqdm.tqdm(args.input, disable=not args.output):  # 遍历输入图像
            img = read_image(path, format="BGR")  # 读取图像
            start_time = time.time()
            predictions, visualized_output = demo.run_on_image(img)  # 运行检测
            logger.info(
                "{}: {} in {:.2f}s".format(
                    path,
                    "detected {} instances".format(len(predictions["instances"]))
                    if "instances" in predictions
                    else "finished",
                    time.time() - start_time,
                )
            )

            if args.output:  # 如果有输出路径
                if os.path.isdir(args.output):
                    assert os.path.isdir(args.output), args.output
                    out_filename = os.path.join(args.output, os.path.basename(path))  # 构建输出文件名
                else:
                    assert len(args.input) == 1, "Please specify a directory with args.output"
                    out_filename = args.output
                visualized_output.save(out_filename)  # 保存可视化结果
            else:
                cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL)  # 创建窗口
                cv2.imshow(WINDOW_NAME, visualized_output.get_image()[:, :, ::-1])  # 显示图像
                if cv2.waitKey(0) == 27:
                    break  # 按Esc键退出
    elif args.webcam:  # 如果从摄像头读取
        assert args.input is None, "Cannot have both --input and --webcam!"
        assert args.output is None, "output not yet supported with --webcam!"
        if args.webcam == "screen":
            cam = ScreenGrab()  # 使用屏幕捕获
        else:
            cam = cv2.VideoCapture(int(args.webcam))  # 打开摄像头
        for vis in tqdm.tqdm(demo.run_on_video(cam)):  # 遍历视频帧
            cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL)  # 创建窗口
            cv2.imshow(WINDOW_NAME, vis)  # 显示视频帧
            if cv2.waitKey(1) == 27:
                break  # 按Esc键退出
        cam.release()
        cv2.destroyAllWindows()  # 释放资源
    elif args.video_input:  # 如果从视频文件读取
        video = cv2.VideoCapture(args.video_input)  # 打开视频文件
        width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))  # 获取视频宽度
        height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))  # 获取视频高度
        frames_per_second = video.get(cv2.CAP_PROP_FPS)  # 获取视频帧率
        num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))  # 获取视频帧数
        basename = os.path.basename(args.video_input)  # 获取视频文件名
        codec, file_ext = (
            ("x264", ".mkv") if test_opencv_video_format("x264", ".mkv") else ("mp4v", ".mp4")
        )
        if codec == ".mp4v":
            warnings.warn("x264 codec not available, switching to mp4v")  # 切换到mp4v编码
        if args.output:
            if os.path.isdir(args.output):
                output_fname = os.path.join(args.output, basename)  # 输出文件路径
                output_fname = os.path.splitext(output_fname)[0] + file_ext
            else:
                output_fname = args.output
            assert not os.path.isfile(output_fname), output_fname
            output_file = cv2.VideoWriter(
                filename=output_fname,
                fourcc=cv2.VideoWriter_fourcc(*codec),
                fps=float(frames_per_second),
                frameSize=(width, height),
                isColor=True,
            )
        assert os.path.isfile(args.video_input)
        for vis_frame in tqdm.tqdm(demo.run_on_video(video), total=num_frames):
            if args.output:
                output_file.write(vis_frame)  # 写入视频文件
            else:
                cv2.namedWindow(basename, cv2.WINDOW_NORMAL)  # 创建窗口
                cv2.imshow(basename, vis_frame)  # 显示输出
                if cv2.waitKey(1) == 27:
                    break  # 按ESC退出
        video.release()  # 释放视频文件
        if args.output:
            output_file.release()  # 释放输出文件
        else:
            cv2.destroyAllWindows()  # 销毁所有窗口

 
 
 

你可能感兴趣的:(Python,AI,Ubuntu,python,目标检测,开发语言)