Segment Anything Model批量检测图像

最近看到了Segment Anything Model,发现不需要配置太多的东西就能跑起来
介绍说明的网址Segment Anything Model

Segment Anything Model批量检测图像_第1张图片

可以从github下载代码code,提供了三个模型
Segment Anything Model批量检测图像_第2张图片

  • vit_b的大小是358m
  • vit_l的大小是1.2g,
  • vit_h的大小是2.4g

我已经下载了源代码和最小的模型,有需要的话:百度云链接

接下来是加载最小的那个模型,检测自己的图像

Segment Anything Model批量检测图像_第3张图片
需要运行的是amg.py这个文件。我按照官方给的提示对这个文件进行了一点修改,没有修改之前不会在图像上显示检测的结果,这个修改后的文件也放在百度云链接中了
Segment Anything Model批量检测图像_第4张图片

  • 在pycharm的右上角哪里点击编辑配置
    在形参那里需要添加几个命令行参数
--input 检测图片的路径 --output 检测结果的输出路径 --model-type vit_b --checkpoint 权重文件的存放路径

举个例子

--input /home/image/test.png --output /home/image --model-type vit_b --checkpoint home/image/sam_vit_b_01ec64.pth

在这里插入图片描述

然后直接运行amg.py即可,下面是我运行的结果
Segment Anything Model批量检测图像_第5张图片

Segment Anything Model批量检测图像_第6张图片

        # for mask
        # base = os.path.basename(t)
        # base = os.path.splitext(base)[0]
        # save_base = os.path.join(args.output, base)
        # if output_mode == "binary_mask":
        #     os.makedirs(save_base, exist_ok=False)
        #     write_masks_to_folder(masks, save_base)
        # else:
        #     save_file = save_base + ".json"
        #     with open(save_file, "w") as f:
        #         json.dump(masks, f)

不注释掉的话检测得到的掩码二值图像都会保存下来

下面的代码是用于处理一批图像,并保存下来

def main(args: argparse.Namespace) -> None:
    print("Loading model...")
    sam = sam_model_registry[args.model_type](checkpoint=args.checkpoint)
    _ = sam.to(device=args.device)
    # output_mode = "coco_rle" if args.convert_to_rle else "binary_mask"
    output_mode = 'binary_mask'
    amg_kwargs = get_amg_kwargs(args)
    generator = SamAutomaticMaskGenerator(sam, output_mode=output_mode, **amg_kwargs)

    floader_path = "/home/jin/图片/dete"  # 这里为一批图像所在的文件夹
    file_path = os.listdir(floader_path)
    for im in file_path:
        args.input = os.path.join(floader_path, im)
        if not os.path.isdir(args.input):
            targets = [args.input]
        else:
            targets = [
                f for f in os.listdir(args.input) if not os.path.isdir(os.path.join(args.input, f))
            ]
            targets = [os.path.join(args.input, f) for f in targets]

        os.makedirs(args.output, exist_ok=True)

        for t in targets:
            print(f"Processing '{t}'...")
            image = cv2.imread(t)
            if image is None:
                print(f"Could not load '{t}' as an image, skipping...")
                continue
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            masks = generator.generate(image)


            plt.imshow(image)  # 这里自行选择加不加
            for mask in masks:
                show_mask(mask['segmentation'], plt.gca())
                # show_box(mask['bbox'],plt.gca())
            # plt.axis('off') 保存轴线或不保存
            plt.savefig('/home/jin/图片/segment-anything-main/scripts/'+im)  # 这里要替换为自己的路径
            plt.close()
    print("Done!")
  • 把main函数替换为这个,可以批量处理文件夹下的图像,需要改下plt.savefig(‘/home/jin/图片/segment-anything-main/scripts/’+im)这里的是存储路径,floader_path就是待检测的图像路径
  • 加上plt.imshow(image)这一句,效果如下,保存的是覆盖在原始图像的掩码结果
    Segment Anything Model批量检测图像_第7张图片
  • 如果不加上plt.imshow(image)就是单纯的保存掩码图像,效果如下

Segment Anything Model批量检测图像_第8张图片

你可能感兴趣的:(python,计算机视觉)