SAM代码内容解析
导入依赖包
import torch
from torch import nn
from torch.nn import functional as F
from typing import Any, Dict, List, Tuple
from .image_encoder import ImageEncoderViT
from .mask_decoder import MaskDecoder
from .prompt_encoder import PromptEncoder
def postprocess_masks(
self,
masks: torch.Tensor,
input_size: Tuple[int, ...],
original_size: Tuple[int, ...],
) -> torch.Tensor:
"""
Remove padding and upscale masks to the original image size.
将填充和高清掩膜解码的过程去除,使掩膜恢复到原始图像的大小。
Arguments:
masks (torch.Tensor): 从掩码解码器得到的批处理掩码,格式为BxCxHxW
input_size (tuple(int, int)): 输入到模型的图像大小,格式为(H,W),用于去除填充
original_size (tuple(int, int)): 输入模型之前的原始图像大小,格式为(H,W)
Returns:
(torch.Tensor):BxCxHxW格式的批量掩码,其中(H,W)由original_size给出.
"""
masks = F.interpolate(
masks,
(self.image_encoder.img_size, self.image_encoder.img_size),
mode="bilinear",
align_corners=False,
)
masks = masks[..., : input_size[0], : input_size[1]]
masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
return masks
def preprocess(self, x: torch.Tensor) -> torch.Tensor:
"""Normalize pixel values and pad to a square input."""
x = (x - self.pixel_mean) / self.pixel_std
h, w = x.shape[-2:]
padh = self.image_encoder.img_size - h
padw = self.image_encoder.img_size - w
x = F.pad(x, (0, padw, 0, padh))
return x
class Sam(nn.Module):
mask_threshold: float = 0.0
image_format: str = "RGB"
def __init__(
self,
image_encoder: ImageEncoderViT,
prompt_encoder: PromptEncoder,
mask_decoder: MaskDecoder,
pixel_mean: List[float] = [123.675, 116.28, 103.53],
pixel_std: List[float] = [58.395, 57.12, 57.375],
) -> None:
"""
SAM predicts object masks from an image and input prompts.
"""
super().__init__()
self.image_encoder = image_encoder
self.prompt_encoder = prompt_encoder
self.mask_decoder = mask_decoder
self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
@property
def device(self) -> Any:
return self.pixel_mean.device
@torch.no_grad()
def forward(
self,
batched_input: List[Dict[str, Any]],
multimask_output: bool,
) -> List[Dict[str, torch.Tensor]]:
"""
Predicts masks end-to-end from provided images and prompts.
If prompts are not known in advance, using SamPredictor is
recommended over calling the model directly.
Arguments:
batched_input (list(dict)): 一个输入图像的列表,每个图像为一个字典,具有以下键。 如果不存在,可以省略提示键。.
'image': 3xHxW格式的torch张量形式的图像,已经转换为输入到模型中。
'original_size': (tuple(int,int))转换前图像的原始大小,格式为(H,W)
'point_coords':(torch.Tensor)此图像的批处理点提示,形状为BxNx2。已转换为模型的输入帧
'point_labels': (torch.Tensor)点提示的标签,形状为BxN的批量标签
'boxes': torch.Tensor)批量框输入,形状为Bx4。已转换到模型的输入box
'mask_inputs': 模型输入的批量掩码输入,形式为Bx1xHxW
'multimask_output (bool)': 模型是否应预测多个消除歧义的掩码,还是返回单个掩码
Returns:
'(list(dict))': 包含每个图像的字典列表,其中每个元素具有以下键
'masks': (torch.Tensor) 批量二进制掩码预测,形状为BxCxHxW,其中B是输入提示的数量,C由multimask_output确定,(H,W)是原始图像的大小
'iou_predictions': torch.Tensor)掩码质量的模型预测,形状为BxC.
'low_res_logits': (torch.Tensor) 低分辨率的logits,形状为BxCxHxW,其中H=W=256。可以将其作为掩码输入传递给后续的预测迭代
"""
input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0)
image_embeddings = self.image_encoder(input_images)
outputs = []
for image_record, curr_embedding in zip(batched_input, image_embeddings):
if "point_coords" in image_record:
points = (image_record["point_coords"], image_record["point_labels"])
else:
points = None
sparse_embeddings, dense_embeddings = self.prompt_encoder(
points=points,
boxes=image_record.get("boxes", None),
masks=image_record.get("mask_inputs", None),
)
low_res_masks, iou_predictions = self.mask_decoder(
image_embeddings=curr_embedding.unsqueeze(0),
image_pe=self.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=multimask_output,
)
masks = self.postprocess_masks(
low_res_masks,
input_size=image_record["image"].shape[-2:],
original_size=image_record["original_size"],
)
masks = masks > self.mask_threshold
outputs.append(
{
"masks": masks,
"iou_predictions": iou_predictions,
"low_res_logits": low_res_masks,
}
)
return outputs