【技术追踪】SAM(Segment Anything Model)代码解析与结构绘制之Prompt Encoder

  论文:Segment Anything
  代码:https://github.com/facebookresearch/segment-anything

  上一篇:【技术追踪】SAM(Segment Anything Model)代码解析与结构绘制之Image Encoder

  本篇示例依然采用上一篇的狗狗图像运行代码,预测部分代码如下:

input_point = np.array([[1300, 800]])   # 输入point的坐标
input_label = np.array([1])   # label=1表示前景, label=0表示背景
# 输入box的坐标,(700,400)为左上角坐标, (1900,1100)为右下角坐标
input_box = np.array([[700, 400, 1900, 1100]])   
# 调用预测函数
masks, scores, logits = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    box=input_box,
    multimask_output=True,
)

1. Mask预测过程

(1)predict函数

位置:【segment_anything/predictor.py --> SamPredictor类 -->predict函数】
作用: 使用给定的prompt,调用predict_torch,预测mask与iou

def predict(
    self,
    point_coords: Optional[np.ndarray] = None,
    point_labels: Optional[np.ndarray] = None,
    box: Optional[np.ndarray] = None,
    mask_input: Optional[np.ndarray] = None,
    multimask_output: bool = True,
    return_logits: bool = False,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    
    if not self.is_image_set:
        raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")

    # Transform input prompts
    coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None
    
    # 若prompt为point
    if point_coords is not None:
        assert (
            point_labels is not None
        ), "point_labels must be supplied if point_coords is supplied."
        # 原始point_coords:[x,y]给定的坐标点=(1300,800)
        # self.original_size原始图像大小=(1365,2048)
        # 由于图像缩放为1024, 给定坐标应随之变换, 变换后point_coords:[X,Y]=(650, 400.29)
        point_coords = self.transform.apply_coords(point_coords, self.original_size)  
        # 将变换后的坐标[650, 400.29]以及前景与背景的标签转化为tensor
        coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device)
        labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
        # 加一个维度使得coords_torch.size():[1,1,2], labels_torch.size():[1,1]
        coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
        
    # 若prompt为box
    if box is not None:
    	# 同样对box坐标进行变换, (700, 400, 1900, 1100)->(350, 200.1465, 950, 500.4029)
        box = self.transform.apply_boxes(box, self.original_size) 
        # 转换为tensor, box_torch.size():[1,4]
        box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device)  
        box_torch = box_torch[None, :]  # 加一个维度使得box_torch.size():[1,1,4]
    
    # 若prompt为mask
    if mask_input is not None:
        mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device)
        mask_input_torch = mask_input_torch[None, :, :, :]
	
	# masks.size():[1,3,1365,2048], iou_predictions.size():[1,3], low_res_masks.size():[1,3,256,256]
    masks, iou_predictions, low_res_masks = self.predict_torch(
        coords_torch,
        labels_torch,
        box_torch,
        mask_input_torch,
        multimask_output,
        return_logits=return_logits,
    )

    masks_np = masks[0].detach().cpu().numpy()
    iou_predictions_np = iou_predictions[0].detach().cpu().numpy()
    low_res_masks_np = low_res_masks[0].detach().cpu().numpy()
    return masks_np, iou_predictions_np, low_res_masks_np

   apply_coords函数: 对输入point进行坐标变换,将图像 [ H , W ] {[H, W]} [H,W]给定坐标位置 [ x , y ] {[x, y]} [x,y],映射到变换图像 [ H ∗ 1024 / W , 1024 ] {[H*1024/W, 1024]} [H1024/W,1024]上的位置 [ X , Y ] {[X, Y]} [X,Y]

  def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
        old_h, old_w = original_size   # [H, W]
        new_h, new_w = self.get_preprocess_shape(
            original_size[0], original_size[1], self.target_length
        )   # [H*1024/W, 1024]
        coords = deepcopy(coords).astype(float)   # 输入坐标[x, y]
        # 将给定坐标位置[x, y]映射到变换图像[H*1024/W, 1024]上的位置[X, Y]
        coords[..., 0] = coords[..., 0] * (new_w / old_w)
        coords[..., 1] = coords[..., 1] * (new_h / old_h)
        return coords

   apply_boxes函数: 调用 apply_coords函数进行box的坐标变换

def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
    boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
    return boxes.reshape(-1, 4)

(2)predict_torch函数

位置:【segment_anything/predictor.py --> SamPredictor类 -->predict_torch函数】
作用: 调用prompt_encoder实现prompt嵌入编码,调用mask_decoder实现mask预测

def predict_torch(
    self,
    point_coords: Optional[torch.Tensor],
    point_labels: Optional[torch.Tensor],
    boxes: Optional[torch.Tensor] = None,
    mask_input: Optional[torch.Tensor] = None,
    multimask_output: bool = True,
    return_logits: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

    if not self.is_image_set:
        raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")

    if point_coords is not None:
        points = (point_coords, point_labels)
    else:
        points = None

    # Embed prompts
    sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
        points=points,
        boxes=boxes,
        masks=mask_input,
    )  # sparse_embeddings.size():[1,2,256], dense_embeddings.size():[1,256,64,64]

    # Predict masks
    low_res_masks, iou_predictions = self.model.mask_decoder(
        image_embeddings=self.features,
        image_pe=self.model.prompt_encoder.get_dense_pe(),
        sparse_prompt_embeddings=sparse_embeddings,
        dense_prompt_embeddings=dense_embeddings,
        multimask_output=multimask_output,
    )

    # Upscale the masks to the original image resolution
    masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size)

    if not return_logits:
        masks = masks > self.model.mask_threshold

    return masks, iou_predictions, low_res_masks

2. Prompt Encoder代码解析

(1)PromptEncoder类

位置:【segment_anything/modeling/prompt_encoder.py -->PromptEncoder类】
作用: 实现prompt输入嵌入编码

  先看PromptEncoder的 _ _ i n i t _ _ {\_\_init\_\_} __init__ 初始化函数和 f o r w a r d {forward} forward 函数:

class PromptEncoder(nn.Module):
    def __init__(
        self,
        embed_dim: int,
        image_embedding_size: Tuple[int, int],
        input_image_size: Tuple[int, int],
        mask_in_chans: int,
        activation: Type[nn.Module] = nn.GELU,
    ) -> None:
        
        super().__init__()
        self.embed_dim = embed_dim  # 嵌入维度256
        self.input_image_size = input_image_size  # 输入图像大小[1024, 1024]
        
        # 图像嵌入大小[64, 64] image_encoder编码器输出为[1,256,64,64]
        self.image_embedding_size = image_embedding_size  
        self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)  # embed_dim // 2 = 128

        self.num_point_embeddings: int = 4  # pos/neg point + 2 box corners 有4个点
        # 4个点的嵌入向量 point_embeddings为4个Embedding(1, 256)
        point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)]
        self.point_embeddings = nn.ModuleList(point_embeddings)  # 4个点的嵌入向量添加到网络
        self.not_a_point_embed = nn.Embedding(1, embed_dim)  # 不是点的嵌入向量

        self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1])  # mask输入尺寸(256, 256)
        self.mask_downscaling = nn.Sequential(
            nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),  # 四倍下采样
            LayerNorm2d(mask_in_chans // 4),
            activation(),
            nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
            LayerNorm2d(mask_in_chans),
            activation(),
            nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),  # 最后通道也是256
        )
        self.no_mask_embed = nn.Embedding(1, embed_dim)  # 没有mask时的嵌入向量
        
    def forward(
        self,
        points: Optional[Tuple[torch.Tensor, torch.Tensor]],
        boxes: Optional[torch.Tensor],
        masks: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        
        bs = self._get_batch_size(points, boxes, masks)  # batch size = 1
        sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())  # 空tensor
        
        # ------------sparse_embeddings-----------
        if points is not None:
            coords, labels = points  # coords=(650, 400.29), labels=1表示前景
            # 坐标点[X, Y]嵌入, point_embeddings.size():[1, 2, 256]
            point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))  # 没有输入框的时候pad=True
            # sparse_embeddings.size():[1, 2, 256]
            sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
        if boxes is not None:
            box_embeddings = self._embed_boxes(boxes)
            sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
        # ------------sparse_embeddings-----------

        # ------------dense_embeddings------------
        if masks is not None:
            dense_embeddings = self._embed_masks(masks)  # 有mask采用mask嵌入向量
        else:
        	# 没有mask输入时采用 nn.Embedding 预定义嵌入向量
            # [1,256]->[1,256,1,1]->[1, 256, 64, 64]
            dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
                bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
            )  # dense_embeddings.size():[1, 256, 64, 64]
        # ------------dense_embeddings------------

        return sparse_embeddings, dense_embeddings

  传送门:torch.nn.Embedding函数用法图解

   f o r w a r d {forward} forward 的过程中主要完成了sparse_embeddings(由point和box嵌入向量组成)和dense_embeddings(由mask嵌入向量组成)两种向量嵌入。

  ① _embed_points函数:输入的坐标点 [ x , y ] {[x, y]} [x,y]= ( 1300 , 800 ) {(1300, 800)} (1300,800) 经过映射变换后为 [ X , Y ] {[X, Y]} [X,Y]= ( 650 , 400.29 ) {(650, 400.29)} (650,400.29) ( 650 , 400.29 ) {(650, 400.29)} (650,400.29) s e l f . _ e m b e d _ p o i n t s {self.\_embed\_points} self._embed_points 函数完成嵌入:

def _embed_points(
    self,
    points: torch.Tensor,  # [[[650, 400.29]]]
    labels: torch.Tensor,  # [[1]]
    pad: bool,  # false
) -> torch.Tensor:
    
    points = points + 0.5  # Shift to center of pixel 移到像素中心=(650.5, 400.79)
    
    # 当没有box输入时, pad=ture
    if pad:
        padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)  # size():[1,1,2]
        padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)  # 是负数,size():[1,1]
        points = torch.cat([points, padding_point], dim=1)  # [1, 2, 2]
        labels = torch.cat([labels, padding_label], dim=1)  # [1, 2]
	
	# self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) = PositionEmbeddingRandom(128)
    point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)  # 点嵌入[1,2,256]
    # -------------------------------------------------------------------------------------
    # self.point_embeddings中预设四个点的可学习嵌入向量,分别为前景点,背景点,box的左上角和右下角坐标点
    # -------------------------------------------------------------------------------------
    # 当labels=-1, 输入点是非标记点, 设为非标记点, 加上非标记点权重
    point_embedding[labels == -1] = 0.0
    point_embedding[labels == -1] += self.not_a_point_embed.weight
    # 当labels=0, 输入点是背景点, 加上背景点权重
    point_embedding[labels == 0] += self.point_embeddings[0].weight
    # 当labels=1, 输入点是目标点, 加上目标点权重
    point_embedding[labels == 1] += self.point_embeddings[1].weight
    return point_embedding

  ② _embed_boxes函数:box的左上角与右下角点 ( 700 , 400 , 1900 , 1100 ) {(700, 400, 1900, 1100)} (700,400,1900,1100)经过映射变换后为 ( 350 , 200.1465 , 950 , 500.4029 ) {(350, 200.1465, 950, 500.4029)} (350,200.1465,950,500.4029),由 s e l f . _ e m b e d _ b o x e s {self.\_embed\_boxes} self._embed_boxes 函数完成嵌入:

def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
    
    # (350, 200.1465, 950, 500.4029)->(350.5000, 200.6465, 950.5000, 550.9030)
    boxes = boxes + 0.5  # Shift to center of pixel  size()=[1,1,4]
    coords = boxes.reshape(-1, 2, 2)  # [1,1,4]->[1,2,2]
    corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)  # [1,2,256]
    # 目标框起始点的和末位点分别加上权重
    corner_embedding[:, 0, :] += self.point_embeddings[2].weight  # 左上角点
    corner_embedding[:, 1, :] += self.point_embeddings[3].weight  # 右下角点
    return corner_embedding

  ③_embed_masks函数:若有mask输入,由 s e l f . _ e m b e d _ m a s k s {self.\_embed\_masks} self._embed_masks 函数完成嵌入:

def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
   
    mask_embedding = self.mask_downscaling(masks)
    return mask_embedding

  self.mask_downscaling结构:

(mask_downscaling): Sequential(
    (0): Conv2d(1, 4, kernel_size=(2, 2), stride=(2, 2))
    (1): LayerNorm2d()
    (2): GELU(approximate='none')
    (3): Conv2d(4, 16, kernel_size=(2, 2), stride=(2, 2))
    (4): LayerNorm2d()
    (5): GELU(approximate='none')
    (6): Conv2d(16, 256, kernel_size=(1, 1), stride=(1, 1))
  )

  结束了么,家人们!是不是在疑惑,还有最后一步了(ง •_•)ง,在 _embed_points函数_embed_boxes函数 中均调用了随机位置嵌入PositionEmbeddingRandom类,以进行point的位置编码。可以理解为,每一个point的向量嵌入都由point的位置编码和可学习nn.Embedding预设权重相加组成。

(2)PositionEmbeddingRandom类

位置:【segment_anything/modeling/prompt_encoder.py -->PositionEmbeddingRandom类】
作用: 调用forward_with_coords将point归一化到[0,1],调用_pe_encoding完成位置编码

class PositionEmbeddingRandom(nn.Module):
    
    def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
        super().__init__()
        if scale is None or scale <= 0.0:
            scale = 1.0
        self.register_buffer(
            "positional_encoding_gaussian_matrix",
            scale * torch.randn((2, num_pos_feats)),  # 生成随机数, 满足标准正态分布
        )

    def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
        """Positionally encode points that are normalized to [0,1]."""
        # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
        # coords: [X/1024, Y/1024]=(0.6353, 0.3914)
        # 映射至[-1,1],适应三角函数. coords=(0.2705, -0.2172) size():[1,1,2]
        coords = 2 * coords - 1   
        # self.positional_encoding_gaussian_matrix是随机生成的: [2, 128]
        coords = coords @ self.positional_encoding_gaussian_matrix  # 矩阵乘法[1, 1, 128] / [64, 64, 128]
        coords = 2 * np.pi * coords  # 2*Π*R [1, 1, 128]
        # outputs d_1 x ... x d_n x C shape
        return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)  # [1, 1, 256] / [64, 64, 256]

    def forward(self, size: Tuple[int, int]) -> torch.Tensor:
        """Generate positional encoding for a grid of the specified size."""
        h, w = size  # 64, 64
        device: Any = self.positional_encoding_gaussian_matrix.device
        grid = torch.ones((h, w), device=device, dtype=torch.float32)  # [64, 64]的全1矩阵
        y_embed = grid.cumsum(dim=0) - 0.5  # [64, 64] 列逐累加
        x_embed = grid.cumsum(dim=1) - 0.5  # [64, 64] 行逐累加
        y_embed = y_embed / h
        x_embed = x_embed / w
        # torch.stack([x_embed, y_embed], dim=-1)->size(): [64, 64, 2]
        pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))  # [64, 64, 256]
        return pe.permute(2, 0, 1)  # C x H x W [256, 64, 64]

    def forward_with_coords(
        self, coords_input: torch.Tensor, image_size: Tuple[int, int]
    ) -> torch.Tensor:
        """Positionally encode points that are not normalized to [0,1]."""
        coords = coords_input.clone()  # [X+0.5, Y+0.5]=(650.5, 400.79)
        coords[:, :, 0] = coords[:, :, 0] / image_size[1]
        coords[:, :, 1] = coords[:, :, 1] / image_size[0]
        # 除以1024,归一化到[0,1]->[X/1024, Y/1024]=(0.6353, 0.3914)
        return self._pe_encoding(coords.to(torch.float))  # B x N x C

  奇怪的是,PositionEmbeddingRandom类自身的forward似乎并没有用上,也不知道干啥滴哩~

3. Prompt Encoder结构绘制

(1)结构打印

PromptEncoder(
  (pe_layer): PositionEmbeddingRandom()
  (point_embeddings): ModuleList(
    (0-3): 4 x Embedding(1, 256)
  )
  (not_a_point_embed): Embedding(1, 256)
  (mask_downscaling): Sequential(
    (0): Conv2d(1, 4, kernel_size=(2, 2), stride=(2, 2))
    (1): LayerNorm2d()
    (2): GELU(approximate='none')
    (3): Conv2d(4, 16, kernel_size=(2, 2), stride=(2, 2))
    (4): LayerNorm2d()
    (5): GELU(approximate='none')
    (6): Conv2d(16, 256, kernel_size=(1, 1), stride=(1, 1))
  )
  (no_mask_embed): Embedding(1, 256)
)

(2)结构绘制

【技术追踪】SAM(Segment Anything Model)代码解析与结构绘制之Prompt Encoder_第1张图片

你可能感兴趣的:(prompt,深度学习,人工智能,SAM,大模型)