使用python利用segment anything进行图像全景分割

参考自github:segment-anything/automatic_mask_generator_example.ipynb at main · facebookresearch/segment-anything · GitHub官网:https://segment-anything.com/

github源文件网址:GitHub - facebookresearch/segment-anything: The repository provides code for running inference with the SegmentAnything Model (SAM), links for downloading the trained model checkpoints, and example notebooks that show how to use the model.安装配置步骤

代码需要python>=3.8, 以及pytorch>=1.7torchvision>=0.8。请按照此处的说明安装 PyTorch 和 TorchVision 依赖项。强烈建议安装支持 CUDA 的 PyTorch 和 TorchVision。

安装段任何东西:

pip install git+https://github.com/facebookresearch/segment-anything.git

或者在本地克隆存储库并安装

git clone [email protected]:facebookresearch/segment-anything.git
cd segment-anything; pip install -e .

以下可选依赖项对于掩码后处理、以 COCO 格式保存掩码、示例笔记本以及以 ONNX 格式导出模型是必需的。jupyter还需要运行示例笔记本。

pip install opencv-python pycocotools matplotlib onnxruntime onnx

首先下载一个模型checkpoint。然后只需几行就可以使用该模型从给定的提示中获取掩码:

 
from segment_anything import SamPredictor, sam_model_registry
sam = sam_model_registry[""](checkpoint="")
predictor = SamPredictor(sam)
predictor.set_image()
masks, _, _ = predictor.predict()

此外,可以从命令行为图像生成遮罩:

python scripts/amg.py --checkpoint  --model-type  --input  --output 

模型权重点

该模型的三种模型版本具有不同的骨干尺寸。这些模型可以通过运行来实例化

from segment_anything import sam_model_registry
sam = sam_model_registry[""](checkpoint="")

单击下面的链接下载相应模型类型的权重,分别是h代表了2.6G,l代表1.2G,b代表375M,注意如果显存不够是无法跑大权重的。

查看自己显存的代码:

grep -i --color memory /var/log/Xorg.0.log
  • defaultvit_h:ViT-H SAM 型号。
  • vit_l: ViT-L SAM 模型。
  • vit_b: ViT-B SAM 型号。

单张图片分割

使用python代码对单张图片进行全景切割:

import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2

import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

# 改变sam_checkpoint,model_type,device为你想要的模型
sam_checkpoint = "sam_vit_l_0b3195.pth"
model_type = "vit_l"
device = "cuda"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

mask_generator = SamAutomaticMaskGenerator(sam)

def show_anns(anns):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)

    img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
    img[:,:,3] = 0
    for ann in sorted_anns:
        m = ann['segmentation']
        color_mask = np.concatenate([np.random.random(3), [0.35]])
        img[m] = color_mask

    ax.imshow(img)

# 在image处加入要处理的图片路径
image = cv2.imread('/*.jpg')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

masks = mask_generator.generate(image)

plt.figure(figsize=(20,20))
plt.imshow(image)
show_anns(masks)
plt.axis('off')
plt.show() 

批量图片处理

批量处理一个文件夹中的图片:

# 输入和输出的路径
input_dir = "路径"
output_dir = "路径"

# 循环读取文件并保存
for image_path in glob.glob(os.path.join(input_dir, "*.jpg")):
    # 读取图片
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    # 生成mask
    masks = mask_generator.generate(image)

    # 保存带有
    output_path = os.path.join(output_dir, os.path.basename(image_path))
    save_anns(masks, image, output_path)

你可能感兴趣的:(深度学习,pytorch,python)