本笔记本展示了如何使用SAM 2在视频中进行交互式分割。它将涵盖以下内容:
如果使用Jupyter在本地运行,请首先根据安装说明在您的环境中安装 segment-anything-2。
# 下载SAM2项目
!git clone https://github.com/facebookresearch/segment-anything-2.git
# 安装SAM2
%cd segment-anything-2
!pip install -e .
# 为了避免后续出现从SAM2导入_C.so失败,需要安装扩建
%cd segment-anything-2
!python setup.py build_ext --inplace
%cd checkpoints
# 若提示未获得权限请用该指令下载 !chmod +x /content/segment-anything-2-old/checkpoints/download_ckpts.sh
!./download_ckpts.sh
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
# 使用 bfloat16 数据类型计算
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
# 检查 GPU 设备架构是否在第八代以上
if torch.cuda.get_device_properties(0).major >= 8:
# 为Ampere GPU打开tfloat32支持 (参见https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# %cd /content/segment-anything-2/checkpoints
from sam2.build_sam import build_sam2_video_predictor
# 使用的 SAM2 模型的路径(在checkpoints目录下共有四个模型,这里是用large,如果在本地运行可以换小型模型)
sam2_checkpoint = "../checkpoints/sam2_hiera_large.pt"
# 模型配置文件
model_cfg = "sam2_hiera_l.yaml"
# 使用指定模型和配置文件构建视频预测器
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
def show_mask(mask, ax, obj_id=None, random_color=False):
# 对掩膜mask使用随机生成颜色或预定义的颜色
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) # 生成带透明度的随机颜色
else:
cmap = plt.get_cmap("tab10") # 使用预定义的色图
cmap_idx = 0 if obj_id is None else obj_id # 如果 obj_id 为 None,使用第一个颜色,否则使用 obj_id 指定的颜色
color = np.array([*cmap(cmap_idx)[:3], 0.6]) # 从色图获取颜色,并设置透明度为0.6
h, w = mask.shape[-2:] # 获取掩模的高度和宽度
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) # 通过掩模与颜色相乘创建图像
ax.imshow(mask_image) # 在提供的轴上显示彩色掩模图像
def show_points(coords, labels, ax, marker_size=200):
# 根据标签将坐标分为正点和负点
pos_points = coords[labels == 1] # 正点
neg_points = coords[labels == 0] # 负点
# 用绿色星号标记绘制正点
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
# 用红色星号标记绘制负点
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
我们假设视频存储为一系列 JPEG 帧,文件名格式为
对于您的自定义视频,您可以使用 ffmpeg (https://ffmpeg.org/) 提取 JPEG 帧,命令如下:
ffmpeg -i
其中 -q:v 参数生成高质量的 JPEG 帧,-start_number 0 参数告诉 ffmpeg 从 00000.jpg 开始编号 JPEG 文件。
若使用项目示例视频,直接从下述代码段运行即可
# %cd /content/segment-anything-2
%cd ../
# 'video_dir' 是包含 JPEG 帧的目录,文件名格式为 '.jpg'
video_dir = "./notebooks/videos/bedroom"
# 扫描此目录中所有 JPEG 帧的文件名
frame_names = [
p for p in os.listdir(video_dir) # 列出目录中的所有文件
if os.path.splitext(p)[-1] in [".jpg", ".jpeg",