下面是一个测试用例,会逐一解读代码:
import os
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
print("PyTorch version:", torch.__version__)
print("CUDA is available:", torch.cuda.is_available())
!git clone https://github.com/SysCV/sam-hq.git
os.chdir('sam-hq')
!export PYTHONPATH=$(pwd)
from segment_anything import sam_model_registry, SamPredictor
os:提供与操作系统交互的函数。
numpy(导入为 np):一个用于数值计算的Python库。
torch:主要用于使用PyTorch,一个流行的深度学习框架的库。
matplotlib.pyplot(导入为 plt):用于绘制图表和可视化数据的库。
cv2:OpenCV库,用于计算机视觉任务,如图像处理和计算机视觉算法。
PyTorch版本可以通过torch.__version__
获得,而torch.cuda.is_available()则判断CUDA是否可用。
使用Git克隆了一个名为 “sam-hq” 的GitHub仓库。!git clone 表示执行命令行命令来克隆仓库。然后使用os.chdir()将当前工作目录更改为 “sam-hq”。
export 命令用于设置环境变量,$(pwd) 返回当前目录的路径。
从 “segment_anything” 模块中导入了 sam_model_registry 和 SamPredictor。这些模块可能是自定义的,位于 “sam-hq” 仓库中的 “segment_anything” 文件夹中。
!mkdir pretrained_checkpoint
!wget https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_l.pth
!mv sam_hq_vit_l.pth pretrained_checkpoint
使用命令行命令mkdir在当前工作目录下创建一个名为 “pretrained_checkpoint” 的目录。
使用命令行命令wget从指定的URL下载文件。在这里,它从 https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_l.pth 下载文件。
使用命令行命令mv将文件 “sam_hq_vit_l.pth” 移动到 “pretrained_checkpoint” 目录下。mv命令接受两个参数,第一个参数是要移动的文件名,第二个参数是目标目录的路径。
综合起来,这部分代码的作用是在当前工作目录下创建 “pretrained_checkpoint” 目录,并从指定URL下载文件 “sam_hq_vit_l.pth”,然后将该文件移动到 “pretrained_checkpoint” 目录下。
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30/255, 144/255, 255/255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
这行代码定义了一个名为 show_mask 的函数,它接受三个参数:
根据 random_color 参数的值选择颜色。如果 random_color 为 True,则生成一个随机颜色,否则使用默认颜色。随机颜色是一个包含三个随机数和一个固定值的数组,而默认颜色是一个预定义的颜色(蓝色)。
将遮罩数组变换成一个与之对应的遮罩图像,并使用颜色数组对遮罩图像进行着色。最后,使用 Matplotlib 的 imshow 函数在指定的轴对象上显示遮罩图像。
综合起来,这个函数的目的是将给定的遮罩数组转换为可视化的遮罩图像,并将其显示在指定的 Matplotlib 轴对象上。
def show_points(coords, labels, ax, marker_size=375):
pos_points = coords[labels==1]
neg_points = coords[labels==0]
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
这行代码定义了一个名为 show_points 的函数,它接受四个参数:
根据点的标签将点分为正样本和负样本。它使用布尔索引从 coords 和 labels 数组中选择正样本和负样本。
使用 Matplotlib 的 scatter 函数在指定的轴对象上绘制点。它分别绘制了正样本和负样本的点。正样本用绿色表示,负样本用红色表示。marker=‘*’ 指定了点的标记形状为星号,s=marker_size 指定了点的大小,edgecolor=‘white’ 和 linewidth=1.25 设置了点的边缘颜色和边缘宽度。
综合起来,这个函数的目的是根据给定的点坐标和标签在指定的 Matplotlib 轴对象上绘制正样本和负样本的点。
def show_box(box, ax):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
这行代码定义了一个名为 show_box 的函数,它接受两个参数:
从边界框数组中提取了左上角坐标 (x0, y0) 和宽度 w 、高度 h。
使用 Matplotlib 的 Rectangle 函数创建一个矩形补丁,并将其添加到指定的轴对象中。该矩形补丁的位置由左上角坐标 (x0, y0) 和宽度 w 、高度 h 确定。edgecolor=‘green’ 设置矩形的边缘颜色为绿色,facecolor=(0,0,0,0) 设置矩形的填充颜色为透明,lw=2 设置矩形的边缘宽度为2。
综合起来,这个函数的目的是在指定的 Matplotlib 轴对象上绘制边界框,根据给定的边界框信息,绘制一个绿色的矩形框。
def show_res(masks, scores, input_point, input_label, input_box, image):
for i, (mask, score) in enumerate(zip(masks, scores)):
plt.figure(figsize=(10,10))
plt.imshow(image)
show_mask(mask, plt.gca())
if input_box is not None:
box = input_box[i]
show_box(box, plt.gca())
if (input_point is not None) and (input_label is not None):
show_points(input_point, input_label, plt.gca())
print(f"Score: {score:.3f}")
plt.axis('off')
plt.show()
这行代码定义了一个名为 show_res 的函数,它接受六个参数:
使用循环迭代预测的遮罩数组和分数数组。对于每个遮罩和分数,它执行以下操作:
综合起来,这个函数的目的是在图像上显示预测的遮罩、输入的边界框、输入的点以及预测的分数。
def show_res_multi(masks, scores, input_point, input_label, input_box, image):
plt.figure(figsize=(10, 10))
plt.imshow(image)
for mask in masks:
show_mask(mask, plt.gca(), random_color=True)
for box in input_box:
show_box(box, plt.gca())
for score in scores:
print(f"Score: {score:.3f}")
plt.axis('off')
plt.show()
这行代码定义了一个名为 show_res_multi 的函数,它接受六个参数:
执行以下操作:
综合起来,这个函数的目的是在图像上显示多个预测的遮罩、输入的边界框和相应的分数。
sam_checkpoint = "pretrained_checkpoint/sam_hq_vit_l.pth"
model_type = "vit_l"
device = "cuda"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
predictor = SamPredictor(sam)
这段代码主要进行了以下操作:
"pretrained_checkpoint/sam_hq_vit_l.pth"。
定义了变量 model_type,指定了模型类型 “vit_l”。
定义了变量 device,指定了设备类型 “cuda”,即使用 GPU 运行。
使用 sam_model_registry 字典根据模型类型从中获取对应的模型类,并传入预训练模型的路径 sam_checkpoint 创建了一个 sam 模型实例。
将 sam 模型移动到指定的设备上,即 GPU,使用 to(device=device) 方法。
创建了一个 SamPredictor 实例,将 sam 模型作为参数传入,用于进行预测。
综合起来,这段代码加载了预训练的 SAM 模型,将其移动到 GPU 上,并创建了一个SamPredictor 实例,用于使用该模型进行预测。
image = cv2.imread('demo/input_imgs/example0.png')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
input_box = np.array([[4,13,1007,1023]])
input_point, input_label = None, None
predictor.set_image(image)
masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
box = input_box,
multimask_output=False,
hq_token_only= False,
)
show_res(masks,scores,input_point, input_label, input_box, image)
这段代码执行了以下操作:
multimask_output=False 表示只输出单个遮罩。
hq_token_only=False 表示不仅输出高质量遮罩。
返回的结果包括预测的遮罩 masks、分数 scores 和逻辑值 logits。
综合起来,这段代码加载了输入图像,并使用预测器 predictor 进行了预测,并将预测结果显示在图像上。
image = cv2.imread('demo/input_imgs/example1.png')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
input_box = np.array([[306, 132, 925, 893]])
input_point, input_label = None, None
predictor.set_image(image)
masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
box = input_box,
multimask_output=False,
hq_token_only= True,
)
show_res(masks,scores,input_point, input_label, input_box, image)
image = cv2.imread('demo/input_imgs/example2.png')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
input_point = np.array([[495,518],[217,140]])
input_label = np.ones(input_point.shape[0])
input_box = None
predictor.set_image(image)
masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
box = input_box,
multimask_output=False,
hq_token_only= True,
)
show_res(masks,scores,input_point, input_label, input_box, image)
image = cv2.imread('demo/input_imgs/example3.png')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
input_point = np.array([[221,482],[498,633],[750,379]])
input_label = np.ones(input_point.shape[0])
input_box = None
predictor.set_image(image)
masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
box = input_box,
multimask_output=False,
hq_token_only= False,
)
show_res(masks,scores,input_point, input_label, input_box, image)
image = cv2.imread('demo/input_imgs/example4.png')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
input_box = np.array([[64,76,940,919]])
input_point, input_label = None, None
predictor.set_image(image)
masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
box = input_box,
multimask_output=False,
hq_token_only= True,
)
show_res(masks,scores,input_point, input_label, input_box, image)
image = cv2.imread('demo/input_imgs/example5.png')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
input_point = np.array([[373,363], [452, 575]])
input_label = np.ones(input_point.shape[0])
input_box = None
predictor.set_image(image)
masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
box = input_box,
multimask_output=False,
hq_token_only= False,
)
show_res(masks,scores,input_point, input_label, input_box, image)
image = cv2.imread('demo/input_imgs/example6.png')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
input_box = np.array([[181, 196, 757, 495]])
input_point, input_label = None, None
predictor.set_image(image)
masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
box = input_box,
multimask_output=False,
hq_token_only= False,
)
show_res(masks,scores,input_point, input_label, input_box, image)
image = cv2.imread('demo/input_imgs/example7.png')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# multi box input
input_box = torch.tensor([[45,260,515,470], [310,228,424,296]],device=predictor.device)
transformed_box = predictor.transform.apply_boxes_torch(input_box, image.shape[:2])
input_point, input_label = None, None
predictor.set_image(image)
masks, scores, logits = predictor.predict_torch(
point_coords=input_point,
point_labels=input_label,
boxes=transformed_box,
multimask_output=False,
hq_token_only=False,
)
masks = masks.squeeze(1).cpu().numpy()
scores = scores.squeeze(1).cpu().numpy()
input_box = input_box.cpu().numpy()
show_res_multi(masks, scores, input_point, input_label, input_box, image)