python语义分割标签的游程长度编码和解码

在SAM中,mask标签在计算过程会占用大量内存,将之用游程长度编码存储会节省内存,官方repo中实现的方式为:
编码
tensor为一次batch(一批prompt生成的分割图)

def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]:
    """
    Encodes masks to an uncompressed RLE, in the format expected by
    pycoco tools.
    """
    # Put in fortran order and flatten h,w
    b, h, w = tensor.shape
    tensor = tensor.permute(0, 2, 1).flatten(1)

    # Compute change indices
    diff = tensor[:, 1:] ^ tensor[:, :-1] # 查看当前位与下一位是否不同
    change_indices = diff.nonzero() # 确定不同之处的位置

    # Encode run length
    out = []
    for i in range(b):
        cur_idxs = change_indices[change_indices[:, 0] == i, 1]
        cur_idxs = torch.cat(
            [
                torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device),
                cur_idxs + 1,
                torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device),
            ]
        )
        btw_idxs = cur_idxs[1:] - cur_idxs[:-1] # 获取当前与下一不同之处的当前相同数值的数量
        counts = [] if tensor[i, 0] == 0 else [0]  # ????不理解
        counts.extend(btw_idxs.detach().cpu().tolist())
        out.append({"size": [h, w], "counts": counts})
    return out

解码

def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
    """Compute a binary mask from an uncompressed RLE."""
    h, w = rle["size"]
    mask = np.empty(h * w, dtype=bool)
    idx = 0
    parity = False
    for count in rle["counts"]:
        mask[idx : idx + count] = parity
        idx += count
        parity ^= True  # 交替取反,代表0或1(初始为0 false)
    mask = mask.reshape(w, h)
    return mask.transpose()  # Put in C order

你可能感兴趣的:(python,开发语言)