【图像分割】【深度学习】SAM官方Pytorch代码-各模块的功能解析

【图像分割】【深度学习】SAM官方Pytorch代码-各功能模块解析

Segment Anything:建立了迄今为止最大的分割数据集,在1100万张图像上有超过1亿个掩码,模型的设计和训练是灵活的,其重要的特点是Zero-shot(零样本迁移性)转移到新的图像分布和任务,一个图像分割新的任务、模型和数据集。SAM由三个部分组成:一个强大的图像编码器(Image encoder)计算图像嵌入,一个提示编码器(Prompt encoder)嵌入提示,然后将两个信息源组合在一个轻量级掩码解码器(Mask decoder)中来预测分割掩码。本博客将大致讲解SAM各模块的功能。

文章目录

  • 【图像分割】【深度学习】SAM官方Pytorch代码-各功能模块解析
  • 前言
  • 模型加载
  • SamPredictor类
    • __init__
    • reset_image
    • set_image
    • set_torch_image
    • predict
    • predict_torch
    • get_image_embedding
    • device
  • ResizeLongestSide类
    • __init__
    • apply_image
    • apply_coords
    • apply_boxes
    • get_preprocess_shape
  • 总结


前言

在详细解析SAM代码之前,首要任务是成功运行SAM代码【win10下参考教程】,后续学习才有意义。本博客将大致讲解各个子模块的功能代码,暂时不会详细讲解神经网络的代码部分。


模型加载

博主以【SAM官方代码示例】为例,源码提供了3种不同大小的模型。

# 选择合适的模型以及加载对应权重
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

sam_model_registry函数在segment_anything/build_sam.py文件内定义
SAM的3种模型通过字典形式保存。

sam_model_registry = {
    "default": build_sam_vit_h,
    "vit_h": build_sam_vit_h,
    "vit_l": build_sam_vit_l,
    "vit_b": build_sam_vit_b,
}

sam_model_registry中的3种模型结构是一致的,部分参数不同导致模型的大小有别。

def build_sam_vit_h(checkpoint=None):
    return _build_sam(
        encoder_embed_dim=1280,
        encoder_depth=32,
        encoder_num_heads=16,
        encoder_global_attn_indexes=[7, 15, 23, 31],
        checkpoint=checkpoint,
    )

def build_sam_vit_l(checkpoint=None):
    return _build_sam(
        encoder_embed_dim=1024,
        encoder_depth=24,
        encoder_num_heads=16,
        encoder_global_attn_indexes=[5, 11, 17, 23],
        checkpoint=checkpoint,
    )

def build_sam_vit_b(checkpoint=None):
    return _build_sam(
        encoder_embed_dim=768,
        encoder_depth=12,
        encoder_num_heads=12,
        encoder_global_attn_indexes=[2, 5, 8, 11],
        checkpoint=checkpoint,
    )

最后是_build_sam方法,完成了sam模型的初始化以及权重的加载,这里可以注意到sam模型由三个神经网络模块组成:ImageEncoderViT(Image encoder)、PromptEncoder和MaskDecoder。具体的参数的作用和意义在后续的神经网络的具体的学习中讲解。

def _build_sam(
    encoder_embed_dim,
    encoder_depth,
    encoder_num_heads,
    encoder_global_attn_indexes,
    checkpoint=None,
):
    prompt_embed_dim = 256
    image_size = 1024
    vit_patch_size = 16
    image_embedding_size = image_size // vit_patch_size
    sam = Sam(
        image_encoder=ImageEncoderViT(
            depth=encoder_depth,
            embed_dim=encoder_embed_dim,
            img_size=image_size,
            mlp_ratio=4,
            norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
            num_heads=encoder_num_heads,
            patch_size=vit_patch_size,
            qkv_bias=True,
            use_rel_pos=True,
            global_attn_indexes=encoder_global_attn_indexes,
            window_size=14,
            out_chans=prompt_embed_dim,
        ),
        prompt_encoder=PromptEncoder(
            embed_dim=prompt_embed_dim,
            image_embedding_size=(image_embedding_size, image_embedding_size),
            input_image_size=(image_size, image_size),
            mask_in_chans=16,
        ),
        mask_decoder=MaskDecoder(
            num_multimask_outputs=3,
            transformer=TwoWayTransformer(
                depth=2,
                embedding_dim=prompt_embed_dim,
                mlp_dim=2048,
                num_heads=8,
            ),
            transformer_dim=prompt_embed_dim,
            iou_head_depth=3,
            iou_head_hidden_dim=256,
        ),
        pixel_mean=[123.675, 116.28, 103.53],
        pixel_std=[58.395, 57.12, 57.375],
    )
    sam.eval()
    if checkpoint is not None:
        with open(checkpoint, "rb") as f:
            state_dict = torch.load(f)
        sam.load_state_dict(state_dict)
    return sam

论文中SAM的结构示意图:
【图像分割】【深度学习】SAM官方Pytorch代码-各模块的功能解析_第1张图片

SamPredictor类

sam模型被封装在SamPredictor类的对象中,方便使用。

predictor = SamPredictor(sam)
predictor.set_image(image)

image_encoder操作在set_image时就已经执行了,而不是在predic时

SamPredictor类在segment_anything/predictor.py文件:

init

初始化了mask预测模型sam,以及数据处理工具对象,重置了图片相关数据信息(ResizeLongestSide)。

    def __init__(
        self,
        sam_model: Sam,
    ) -> None:
        super().__init__()
        # sam mask预测模型
        self.model = sam_model
        # 用于数据预处理
        self.transform = ResizeLongestSide(sam_model.image_encoder.img_size)
        # 图片相关数据信息
        self.reset_image()

reset_image

self.is_image_set与 self.features息息相关,self.features保存图片经过Image encoder后的特征数据,self.is_image_set是一个信号信息,用来表示self.features是否已经保存了特征数据,在刚初始化时,self.features是none,self.is_image_set便是false。

def reset_image(self) -> None:
    # 图像设置flag
    self.is_image_set = False
    # 图像编码特征
    self.features = None
    self.orig_h = None
    self.orig_w = None
    self.input_h = None
    self.input_w = None

set_image

首先确认输入是否是RGB或BGR三通道图像,将BGR图像统一为RGB,而后并对图像尺寸(apply_image)和channel顺序作出调整满足神经网络的输入要求。

def set_image(
    self,
    image: np.ndarray,
    image_format: str = "RGB",
) -> None:
    # 图像不是['RGB', 'BGR']格式则报错
    assert image_format in [
        "RGB",
        "BGR",
    ], f"image_format must be in ['RGB', 'BGR'], is {image_format}."
    # H,W,C
    if image_format != self.model.image_format:
        image = image[..., ::-1]            # H,W,C中 C通道的逆序RGB-->BGR

    # Transform the image to the form expected by the model 改变图像尺寸
    input_image = self.transform.apply_image(image)
    # torch 浅拷贝 转tensor
    input_image_torch = torch.as_tensor(input_image, device=self.device)
    # permute H,W,C-->C,H,W
    # contiguous 连续内存
    # [None, :, :, :] C,H,W -->1,C,H,W
    input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
    self.set_torch_image(input_image_torch, image.shape[:2])

set_torch_image

用padding填补缩放后的图片,在H和W满足神经网络需要的标准尺寸,而后通过image_encoder模型获得图像特征数据并保存在self.features中,同时self.is_image_set设为true。

注意image_encoder过程不是在predict_torch时与Prompt encoder过程和Mask decoder过程一同执行的,而是在set_image时就已经执行了。
【图像分割】【深度学习】SAM官方Pytorch代码-各模块的功能解析_第2张图片

    def set_torch_image(
        self,
        transformed_image: torch.Tensor,
        original_image_size: Tuple[int, ...],
    ) -> None:
        # 满足输入是四个维度且为B,C,H,W
        assert (
            len(transformed_image.shape) == 4
            and transformed_image.shape[1] == 3
            and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size
        ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}."

        self.reset_image()
        # 原始图像的尺寸
        self.original_size = original_image_size
        # torch图像的尺寸
        self.input_size = tuple(transformed_image.shape[-2:])
        # torch图像进行padding
        input_image = self.model.preprocess(transformed_image)
        # image_encoder网络模块对图像进行编码
        self.features = self.model.image_encoder(input_image)
        # 图像设置flag
        self.is_image_set = True

这里可以暂时不考虑image_encoder模型的代码细节。

predict

predict对输入到模型中进行预测的数据(标记点apply_coords和标记框apply_boxes)进行一个预处理,并接受和处理模型返回的预测结果。

def predict(
    self,
    # 标记点的坐标
    point_coords: Optional[np.ndarray] = None,
    # 标记点的标签
    point_labels: Optional[np.ndarray] = None,
    # 标记框的坐标
    box: Optional[np.ndarray] = None,
    # 输入的mask
    mask_input: Optional[np.ndarray] = None,
    # 输出多个mask供选择
    multimask_output: bool = True,
    # ture 返回掩码logits, false返回阈值处理的二进制掩码。
    return_logits: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    # 假设没有设置图像,报错
    if not self.is_image_set:
        raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")

    # Transform input prompts 
    # 输入提示转换为torch
    coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None

    if point_coords is not None:
        # 标记点坐标对应的标记点标签不能为空
        assert (
            point_labels is not None
        ), "point_labels must be supplied if point_coords is supplied."
        # 图像改变了原始尺寸,所以对应的点位置也会发生改变
        point_coords = self.transform.apply_coords(point_coords, self.original_size)
        # 标记点坐标和标记点标签 np-->tensor
        coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device)
        labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
        # 增加维度
        # coords_torch:N,2-->1,N,2
        # labels_torch: N-->1,N
        coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
    if box is not None:
        # 图像改变了原始尺寸,所以对应的框坐标位置也会发生改变
        box = self.transform.apply_boxes(box, self.original_size)
        # 标记框坐标 np-->tensor
        box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device)
        # 增加维度 N,4-->1,N,4
        box_torch = box_torch[None, :]
    if mask_input is not None:
        # mask np-->tensor
        mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device)
        # 增加维度 1,H,W-->B,1,H,W
        mask_input_torch = mask_input_torch[None, :, :, :]
    # 输入数据预处理完毕,可以输入到网络中 
    masks, iou_predictions, low_res_masks = self.predict_torch(
        coords_torch,
        labels_torch,
        box_torch,
        mask_input_torch,
        multimask_output,
        return_logits=return_logits,
    )
    # 因为batchsize为1,压缩维度
    # mask
    masks = masks[0].detach().cpu().numpy()
    # score
    iou_predictions = iou_predictions[0].detach().cpu().numpy()
    low_res_masks = low_res_masks[0].detach().cpu().numpy()
    return masks, iou_predictions, low_res_masks

源码在segment_anything/modeling/sam.py内

    def postprocess_masks(
        self,
        masks: torch.Tensor,
        input_size: Tuple[int, ...],
        original_size: Tuple[int, ...],
    ) -> torch.Tensor:
        # mask上采样到与输入到模型中的图片尺寸一致
        masks = F.interpolate(
            masks,
            (self.image_encoder.img_size, self.image_encoder.img_size),
            mode="bilinear",
            align_corners=False,
        )
        masks = masks[..., : input_size[0], : input_size[1]]
        # mask resize 到与未做处理的原始图片尺寸一致
        masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
        return masks

predict_torch

输入数据经过预处理后输入到模型中预测结果。

Prompt encoder过程和Mask decoder过程是在predict_torch时执行的。
【图像分割】【深度学习】SAM官方Pytorch代码-各模块的功能解析_第3张图片

def predict_torch(
    self,
    point_coords: Optional[torch.Tensor],
    point_labels: Optional[torch.Tensor],
    boxes: Optional[torch.Tensor] = None,
    mask_input: Optional[torch.Tensor] = None,
    multimask_output: bool = True,
    return_logits: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    # 假设没有设置图像,报错
    if not self.is_image_set:
        raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
    # 绑定标记点和标记点标签
    if point_coords is not None:
        points = (point_coords, point_labels)
    else:
        points = None

    # ----- EPrompt encoder -----
    sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
        points=points,
        boxes=boxes,
        masks=mask_input,
    )
    # ----- Prompt encoder -----

    # ----- Mask decoder -----
    low_res_masks, iou_predictions = self.model.mask_decoder(
        image_embeddings=self.features,
        image_pe=self.model.prompt_encoder.get_dense_pe(),
        sparse_prompt_embeddings=sparse_embeddings,
        dense_prompt_embeddings=dense_embeddings,
        multimask_output=multimask_output,
    )
    #  ----- Mask decoder -----

    # 上采样mask掩膜到原始图片尺寸
    # Upscale the masks to the original image resolution
    masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size)

    if not return_logits:
        masks = masks > self.model.mask_threshold
    return masks, iou_predictions, low_res_masks

这里可以暂时不考虑Prompt encoder和Mask decoder模型的代码细节。

get_image_embedding

获得图像image_encoder的特征。

    def get_image_embedding(self) -> torch.Tensor:
        if not self.is_image_set:
            raise RuntimeError(
                "An image must be set with .set_image(...) to generate an embedding."
            )
        assert self.features is not None, "Features must exist if an image has been set."
        return self.features

device

获得模型所使用的设备

def device(self) -> torch.device:
    return self.model.device

ResizeLongestSide类


ResizeLongestSide是专门用来处理图片、标记点和标记框的工具类。
ResizeLongestSide类在segment_anything/utils/transforms.py文件:

init

设置了所有输入到神经网络的标准图片尺寸

def __init__(self, target_length: int) -> None:
    self.target_length = target_length

apply_image


原图尺寸根据标准尺寸计算调整(get_preprocess_shape)得新尺寸。

def apply_image(self, image: np.ndarray) -> np.ndarray:
    target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
    # to_pil_image将numpy装变为PIL.Image,而后resize
    return np.array(resize(to_pil_image(image), target_size))

一个简单的示意图,通过计算获得与标准尺寸对应的缩放比例并缩放图片,后续通过padding补零操作(虚线部分),将所有图片的尺寸都变成标准尺寸。
【图像分割】【深度学习】SAM官方Pytorch代码-各模块的功能解析_第4张图片

不直接使用resize的目的是为了不破坏原图片中各个物体的比例关系。

apply_coords

图像改变了原始尺寸,对应的标记点坐标位置也要改变([get_preprocess_shape](#get_preprocess_shape))。

def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
    old_h, old_w = original_size
    # 图像改变了原始尺寸,所以对应的标记点坐标位置也会发生改变
    new_h, new_w = self.get_preprocess_shape(
        original_size[0], original_size[1], self.target_length
    )
    # 深拷贝coords
    coords = deepcopy(coords).astype(float)
    # 改变对应标记点坐标
    coords[..., 0] = coords[..., 0] * (new_w / old_w)
    coords[..., 1] = coords[..., 1] * (new_h / old_h)
    return coords

apply_boxes

图像改变了原始尺寸,对应的标记框坐标位置也要改变([get_preprocess_shape](#get_preprocess_shape))。

def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
    # 图像改变了原始尺寸,所以对应的框坐标位置也会发生改变
    # reshape: N,4-->N,2,2
    boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
    # reshape: N,2,2-->N,4
    return boxes.reshape(-1, 4)

get_preprocess_shape

    def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
        # H和W的长边(大值)作为基准,计算比例,缩放H W的大小
        scale = long_side_length * 1.0 / max(oldh, oldw)
        newh, neww = oldh * scale, oldw * scale
        # 四舍五入
        neww = int(neww + 0.5)
        newh = int(newh + 0.5)
        return (newh, neww)

总结

尽可能简单、详细的介绍SAM中各个子模块的功能代码,后续会讲解SAM中三个深度学习网络模块的代码。

强调一点,在预测过程中sam模型是被封装在SamPredictor类中,将sam的forward预测的流程分别拆解到SamPredictor类的不同方法中、分不同阶段进行。
sam中forward函数对Image encoder、Prompt encoder和Mask decoder三个操作是连续的,如下图所示:
【图像分割】【深度学习】SAM官方Pytorch代码-各模块的功能解析_第5张图片
源码暂未开源这部分,因此个人自觉forward只是训练过程中使用的,预测过程并未涉及,希望大家不要被搞晕,最后有大佬自己写train部分的代码话可以踢我一下。

你可能感兴趣的:(图像分割,深度学习,深度学习,pytorch,计算机视觉)