【Pytorch】分割的mask 独热编码转换 scatter 参数理解

scatter_(input, dim, index, src)将src中数据根据index中的索引按照dim的方向填进input中。

 

用于将数据转换为 one hot 独热编码时,代码如下

def to_one_hot(mask, n_class):
    """
    Transform a mask to one hot
    change a mask to n * h* w   n is the class
    Args:
        mask:
        n_class: number of class for segmentation
    Returns:
        y_one_hot: one hot mask
    """
    y_one_hot = torch.zeros((n_class, mask.shape[1], mask.shape[2]))
    y_one_hot = y_one_hot.scatter(0, mask, 1).long()
    return y_one_hot

mask 的尺寸为 1 * h *w,假设其矩阵内容大致如下:

    1    2    3    4    1
    2    0    2    0    0
    1    1    4    0    0
    3    3    0    0    0
    2    4    0    0    0

n_class 为分割图片的像素种类个数,即分割的类别个数,假定是5 (包括背景类别在内)。

首先创建一个 n * w * h的tensor张量,然后调用scatter函数,作用是在第0维度上。对于mask数据,其内容就是将要填充给tensor张量数据的index索引。所以当遇到2 时候,将会在第1维上,填充src的数据(即1),同位置在其余维度为0 。由此转换为独热编码。

你可能感兴趣的:(【Pytorch】分割的mask 独热编码转换 scatter 参数理解)