https://blog.csdn.net/m0_46246301/article/details/130392008?spm=1001.2014.3001.5501
上一期介绍了SAM的基本使用,包括了安装、推理(point2mask、bbox2mask、point_bbox2mask)的介绍。
本期的内容将分为以下几个方面:
- SAM自动化生成mask
- 压缩保存mask
- 超像素分割算法改进SAM(目前效果不佳,但可能以后能做出来)
自动化生成mask的代码在SAM github中的automatic_mask_generator_example.ipynb文件已经给出,在此我们作了一定的简化。
首先,准备导入依赖库和必要的可视化函数。
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
import os
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30/255, 144/255, 255/255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
def show_points(coords, labels, ax, marker_size=375):
pos_points = coords[labels==1]
neg_points = coords[labels==0]
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
def show_box(box, ax):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
def show_anns(anns):
if len(anns) == 0:
return
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
ax = plt.gca()
ax.set_autoscale_on(False)
polygons = []
color = []
for ann in sorted_anns:
m = ann['segmentation']
img = np.ones((m.shape[0], m.shape[1], 3))
color_mask = np.random.random((1, 3)).tolist()[0]
for i in range(3):
img[:,:,i] = color_mask[i]
ax.imshow(np.dstack((img, m*0.35)))
接着,初始化SAM模型,确定使用的模型参数文件和cuda,构建自动生成相应mask的模型对象。
import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamPredictor
sam_checkpoint = "../sam_vit_h_4b8939.pth"
model_type = "vit_h"
device = "cuda"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
# predictor = SamPredictor(sam)
mask_generator = SamAutomaticMaskGenerator(sam)
然后,只需要读取图像进行推理即可,并将其可视化出来。
image = cv2.imread('images/dog.jpg')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
masks = mask_generator.generate(image)
plt.figure(figsize=(20,20))
plt.imshow(image)
show_anns(masks)
plt.axis('off')
plt.show()
但是,它的推理结果中,大多数像素都被预测为多个Label了,我们可以用如下代码进行测试,可以看出,平均每个像素对应了接近2个label,即包含在两个mask中。
# 测试-是否一个像素对应了多个label
sum = 0
for i,mask_info in enumerate(masks):
mask = mask_info['segmentation'].astype('int')
sum += np.count_nonzero(mask==1)
print(sum)
print(image.shape[0] * image.shape[1])
# 759084
# 427200
为了节省空间,我们需要将所有mask尽可能合并到一个mask内,实现代码如下:
### 1. 合并所有mask到一个all_mask内
all_mask = np.zeros_like(masks[0]['segmentation']) #初始化
# 遍历每个mask,依次以i+1作为label存入all_mask
for i,mask_info in enumerate(masks):
mask = mask_info['segmentation'].astype('int')
all_mask = np.where(mask==1, i+1, all_mask)
SAM在分割边缘的时候很容易效果不佳,为了增强边缘分割的效果,可以考虑引入超像素分割算法。
对每个超像素块,用出现次数最多的SAM的label赋值。
def updating(masks, number_slic, labels_mask):
# INPUT
# masks: SAM输出的mask信息
# number_slic: SLIC超像素块的个数
# labels_mask: SLIC的mask,每个超像素块对应一个SLIC-label,在该函数中替换为SAM-label
### 1. 合并所有mask到一个all_mask内
all_mask = np.zeros_like(masks[0]['segmentation']) #初始化
# 遍历每个mask,依次以i+1作为label存入all_mask
for i,mask_info in enumerate(masks):
mask = mask_info['segmentation'].astype('int')
all_mask = np.where(mask==1, i+1, all_mask)
### 2. 遍历所有的超像素块,统计出现最多的SAM-label,作为该块的label
for i in range(number_slic):
inds = np.argwhere(labels_mask==i) # 获取当前块的全部索引
cur_dict = {}
for x,y in inds: # 遍历该超像素块的每个像素
if all_mask[x,y] not in cur_dict.keys():
cur_dict[all_mask[x,y]] = 0
else:
cur_dict[all_mask[x,y]] += 1
if len(cur_dict)==0:
final_label = 0 # 该超像素块不含任何SAM-label
else: # 选择出现最多的SAM-label
final_label = [k for k, v in cur_dict.items() if v == max(cur_dict.values())][0]
for x,y in inds: # 替换
labels_mask[x,y] = final_label
# 3. 对SAM输出的mask进行更新
for i,mask_info in enumerate(masks):
new_mask = np.zeros_like(mask_info['segmentation'])
new_mask = np.where(labels_mask==i+1, 1, new_mask)
masks[i]['segmentation'] = new_mask.astype('bool')
return masks
def ROI_SLIC_superpixel_seg2(img_path):
# 读取图片
img = cv2.imread(img_path)
img_copy = img.copy()
# roi = img[y1:y2, x1:x2]
# 对整个图片进行超像素分割
img_copy = cv2.cvtColor(img_copy, cv2.COLOR_BGR2RGB)
slic = cv2.ximgproc.createSuperpixelSLIC(img_copy, algorithm=cv2.ximgproc.SLIC, region_size=20, ruler=30.0)
slic.iterate(10) # 迭代10次,生成超像素
labels = slic.getLabels() # 获取超像素标签
number_slic = slic.getNumberOfSuperpixels() # 获取超像素数目
mask_slic = slic.getLabelContourMask() #获取Mask,超像素边缘Mask==1
mask_inv_slic = cv2.bitwise_not(mask_slic)
img_copy = cv2.bitwise_and(img_copy,img_copy,mask = mask_inv_slic) #在原图上绘制超像素边界
# 显示结果
plot_func(img_copy)
return mask_slic, labels, number_slic
def plot_func(img):
plt.figure(figsize=(10, 10))
plt.imshow(img)
plt.axis('off')
plt.show()
利用SLIC超像素分割对SAM进行更新:
# SLIC
mask_slics, labels_mask, number_slic = ROI_SLIC_superpixel_seg2('images/dog.jpg')
# 对SAM更新
masks = updating(masks, number_slic, labels_mask)
# 绘图
plt.figure(figsize=(20,20))
plt.imshow(image)
show_anns(masks)
plt.axis('off')
plt.show()