Segment Anything(2)

Segment Anything(2)

https://blog.csdn.net/m0_46246301/article/details/130392008?spm=1001.2014.3001.5501

上一期介绍了SAM的基本使用,包括了安装、推理(point2mask、bbox2mask、point_bbox2mask)的介绍。

本期的内容将分为以下几个方面:

  1. SAM自动化生成mask
  2. 压缩保存mask
  3. 超像素分割算法改进SAM(目前效果不佳,但可能以后能做出来)

文章目录

  • Segment Anything(2)
    • 1 SAM自动化生成mask
    • 2 压缩保存mask
    • 3 超像素分割算法改进SAM

1 SAM自动化生成mask

自动化生成mask的代码在SAM github中的automatic_mask_generator_example.ipynb文件已经给出,在此我们作了一定的简化。

首先,准备导入依赖库和必要的可视化函数。

import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
import os
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)
    
def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   
    
def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))    

def show_anns(anns):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)
    polygons = []
    color = []
    for ann in sorted_anns:
        m = ann['segmentation']
        img = np.ones((m.shape[0], m.shape[1], 3))
        color_mask = np.random.random((1, 3)).tolist()[0]
        for i in range(3):
            img[:,:,i] = color_mask[i]
        ax.imshow(np.dstack((img, m*0.35)))

接着,初始化SAM模型,确定使用的模型参数文件和cuda,构建自动生成相应mask的模型对象。

import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamPredictor

sam_checkpoint = "../sam_vit_h_4b8939.pth"
model_type = "vit_h"

device = "cuda"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
# predictor = SamPredictor(sam)
mask_generator = SamAutomaticMaskGenerator(sam)

然后,只需要读取图像进行推理即可,并将其可视化出来。

image = cv2.imread('images/dog.jpg')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
masks = mask_generator.generate(image)
plt.figure(figsize=(20,20))
plt.imshow(image)
show_anns(masks)
plt.axis('off')
plt.show() 

但是,它的推理结果中,大多数像素都被预测为多个Label了,我们可以用如下代码进行测试,可以看出,平均每个像素对应了接近2个label,即包含在两个mask中。

# 测试-是否一个像素对应了多个label
sum = 0
for i,mask_info in enumerate(masks):
    mask = mask_info['segmentation'].astype('int')
    sum += np.count_nonzero(mask==1)
print(sum)
print(image.shape[0] * image.shape[1])

# 759084
# 427200

2 压缩保存mask

为了节省空间,我们需要将所有mask尽可能合并到一个mask内,实现代码如下:

### 1. 合并所有mask到一个all_mask内
all_mask = np.zeros_like(masks[0]['segmentation']) #初始化
# 遍历每个mask,依次以i+1作为label存入all_mask
for i,mask_info in enumerate(masks):
    mask = mask_info['segmentation'].astype('int')
    all_mask = np.where(mask==1, i+1, all_mask)

3 超像素分割算法改进SAM

SAM在分割边缘的时候很容易效果不佳,为了增强边缘分割的效果,可以考虑引入超像素分割算法。

对每个超像素块,用出现次数最多的SAM的label赋值。

  1. 上节已介绍:将SAM推理出的所有mask的label,合并至一个all_mask内。
  2. 投票:利用SAM推理出的label-mask(all_mask),对每个超像素块(labels_mask)投票赋值。
def updating(masks, number_slic, labels_mask):
    # INPUT
    # masks: SAM输出的mask信息
    # number_slic: SLIC超像素块的个数
    # labels_mask: SLIC的mask,每个超像素块对应一个SLIC-label,在该函数中替换为SAM-label
    
    ### 1. 合并所有mask到一个all_mask内
    all_mask = np.zeros_like(masks[0]['segmentation']) #初始化
    # 遍历每个mask,依次以i+1作为label存入all_mask
    for i,mask_info in enumerate(masks):
        mask = mask_info['segmentation'].astype('int')
        all_mask = np.where(mask==1, i+1, all_mask)
    
    ### 2. 遍历所有的超像素块,统计出现最多的SAM-label,作为该块的label
    for i in range(number_slic):
        inds = np.argwhere(labels_mask==i)  # 获取当前块的全部索引
        cur_dict = {}
        for x,y in inds:  # 遍历该超像素块的每个像素
            if all_mask[x,y] not in cur_dict.keys():
                cur_dict[all_mask[x,y]] = 0
            else:
                cur_dict[all_mask[x,y]] += 1
        if len(cur_dict)==0:  
            final_label = 0  # 该超像素块不含任何SAM-label
        else:    # 选择出现最多的SAM-label
            final_label = [k for k, v in cur_dict.items() if v == max(cur_dict.values())][0]
        for x,y in inds:  # 替换
            labels_mask[x,y] = final_label
    
    # 3. 对SAM输出的mask进行更新
    for i,mask_info in enumerate(masks):
        new_mask = np.zeros_like(mask_info['segmentation'])
        new_mask = np.where(labels_mask==i+1, 1, new_mask)
        masks[i]['segmentation'] = new_mask.astype('bool')
    return masks

def ROI_SLIC_superpixel_seg2(img_path):
    # 读取图片
    img = cv2.imread(img_path)
    img_copy = img.copy()
#     roi = img[y1:y2, x1:x2]

    # 对整个图片进行超像素分割
    img_copy = cv2.cvtColor(img_copy, cv2.COLOR_BGR2RGB)
    slic = cv2.ximgproc.createSuperpixelSLIC(img_copy, algorithm=cv2.ximgproc.SLIC, region_size=20, ruler=30.0)
    slic.iterate(10)  # 迭代10次,生成超像素
    labels = slic.getLabels()  # 获取超像素标签
    number_slic = slic.getNumberOfSuperpixels()  # 获取超像素数目
    mask_slic = slic.getLabelContourMask()              #获取Mask,超像素边缘Mask==1
    mask_inv_slic = cv2.bitwise_not(mask_slic)
    img_copy = cv2.bitwise_and(img_copy,img_copy,mask =  mask_inv_slic)  #在原图上绘制超像素边界

    # 显示结果
    plot_func(img_copy)
    
    return mask_slic, labels, number_slic

def plot_func(img):
    plt.figure(figsize=(10, 10))
    plt.imshow(img)
    plt.axis('off')
    plt.show()

利用SLIC超像素分割对SAM进行更新:

# SLIC
mask_slics, labels_mask, number_slic = ROI_SLIC_superpixel_seg2('images/dog.jpg')
# 对SAM更新
masks = updating(masks, number_slic, labels_mask)
# 绘图
plt.figure(figsize=(20,20))
plt.imshow(image)
show_anns(masks)
plt.axis('off')
plt.show() 

你可能感兴趣的:(计算机视觉,数据可视化,机器学习,计算机视觉,pytorch,人工智能)