给定RGB视频或图片,目标是分割出图像中的指定目标掩码。我们需要复现两个Zero-shot的开源项目,分别为IDEA研究院的GroundingDINO和Facebook的SAM。首先使用目标检测方法GroundingDINO,输入想检测目标的文字提示,可以获得目标的anchor box。将上一步获得的box信息作为SAM的提示,分割出目标mask。具体效果如下(测试数据来自VolumeDeform数据集):
其中GroundingDINO根据white shirt
的文字输入计算的box信息为:"shirt_000500": "[194.23726, 2.378189, 524.09503, 441.5135]"
。项目实测下来单张图片的预测速度GroundingDINO要慢于SAM。GroundingDINO和SAM均会给出多个预测结果,当选择置信度最高的结果时两个模型也会存在预测不准确的情况。
GroundingDINO extends a closedset detector DINO by performing vision-language modality fusion at multiple phases, including a feature enhancer, a language-guided query selection module, and a cross-modality decoder. Such a deep fusion strategy effectively improves open-set object detection.
两个项目的复现很简单,按照github的readme配置相关环境并运行程序。当然也可以直接使用一站式项目Grounded Segment Anything等。当需要分割的图片较多时,可以修改GroundingDINO的demo.sh
和demo/inference_on_a_image.py
文件将检测结果保存至json文件。
demo/inference_on_a_image.py文件
# 修改plot_boxes_to_image函数输出box信息
image_with_box, mask, box_coor = plot_boxes_to_image(image_pil, pred_dict)
# obj为目标名称,i为当前图片的索引
obj = 'shirt'
data = {f'{obj}_{str(i).zfill(6)}': str(list(box_coor.cpu().detach().numpy()))}
with open("box.json", "r", encoding="utf-8") as f:
old_data = json.load(f)
old_data.update(data)
with open("box.json", "w", encoding="utf-8") as f:
json.dump(old_data, f, indent=4)
# f.write(json.dumps(old_data, indent=4, ensure_ascii=False))
f.close()
然后SAM再读取json文件获取box信息,将SAM的输入提示改为box。
测试代码
import os
import numpy as np
import matplotlib.pyplot as plt
import cv2
import glob
import json
coords = []
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
def show_points(coords, labels, ax, marker_size=375):
pos_points = coords[labels == 1]
neg_points = coords[labels == 0]
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white',
linewidth=1.25)
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white',
linewidth=1.25)
def show_box(box, ax):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))
def on_click(event):
global coords
if event.button == 1:
x, y = event.xdata, event.ydata
print(f"鼠标左键点击:x={x:.2f}, y={y:.2f}")
coords.append([x, y])
# if len(coords) == 2:
# fig.canvas.mpl_disconnect(cid)
elif event.button == 3:
print("鼠标右键点击")
def get_mask(image, mask_id=1, click_coords=False, choose_mask=False, box=None):
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# plt.figure(figsize=(10, 10))
# plt.imshow(image)
# plt.axis('on')
if click_coords:
global coords
fig, ax = plt.subplots() # 创建画布和子图对象
fig.set_size_inches(30, 20) # 设置宽度和高度,单位为英寸(inch)
ax.imshow(image)
cid = fig.canvas.mpl_connect('button_press_event', on_click)
plt.show()
else: # 如果使用 必须全局
coords = []
from segment_anything import SamPredictor, sam_model_registry
sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"
device = "cuda"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
predictor = SamPredictor(sam)
predictor.set_image(image)
input_point = np.array(coords)
input_label = np.array([1] * len(coords))
# plt.figure(figsize=(10, 10))
# plt.imshow(image)
# show_points(input_point, input_label, plt.gca())
# plt.axis('on')
# plt.show()
input_box = box
if len(coords) == 0:
input_point = None
input_label = None
masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
box=input_box[None, :],
multimask_output=True)
if choose_mask:
plt.figure(figsize=(60, 20))
plt.subplot(1, 3, 1)
plt.imshow(image)
show_mask(masks[0], plt.gca())
# show_points(input_point, input_label, plt.gca())
plt.title(f"Mask 0, Score: {scores[0]:.3f}", fontsize=18)
plt.subplot(1, 3, 2)
plt.imshow(image)
show_mask(masks[1], plt.gca())
# show_points(input_point, input_label, plt.gca())
plt.title(f"Mask 1, Score: {scores[1]:.3f}", fontsize=18)
plt.subplot(1, 3, 3)
plt.imshow(image)
show_mask(masks[2], plt.gca())
# show_points(input_point, input_label, plt.gca())
plt.title(f"Mask 2, Score: {scores[1]:.3f}", fontsize=18)
plt.show()
mask_id = int(input()) # 通过输入idx或者设置特定的idx输出
mask = masks[mask_id]
mask = np.tile(np.expand_dims(mask, axis=-1), 3)
mask_data = np.where(mask, 255, 0)
# mask_image = np.where(mask, image/255, 0.)
# plt.figure(figsize=(10, 10))
# plt.imshow(mask_image)
# plt.show()
if click_coords: coords.clear()
return mask_data
if __name__ == '__main__':
obj = 'shirt'
color_path = f'/Data/VolumeDeformData/{obj}/data/'
mask_path = f'/Data/VolumeDeformData/{obj}/mask/'
if not os.path.exists(mask_path):
os.makedirs(mask_path)
img_paths = []
for extension in ["jpg", "png", "jpeg"]:
img_paths += glob.glob(os.path.join(color_path, "*.{}".format(extension)))
json_path = 'GroundingDINO-main/box.json'
with open(json_path, "r", encoding="utf-8") as f:
data = json.load(f)
for i in range(len(img_paths) // 2):
img_name = f'frame-{str(i).zfill(6)}.color.png'
img = cv2.imread(color_path + img_name)
id = f'{obj}_{str(i).zfill(6)}'
box = np.array(list(map(float, data[id][1:-1].split(','))))
mask = get_mask(img, mask_id=2, click_coords=False, choose_mask=False, box=box)
cv2.imwrite(mask_path + str(i).zfill(6) + '.png', mask)
print(img_name)
f.close()