pytorch标签onehot编码_pytorch 01 关于分割任务中 onehot 编码转换的问题

在分割任务中,我们拿到的label通常是由数字类别组成的,但是在应用某些损失函数时,我们需要把label转换成 one—hot编码的形式。

例如:原始label维度 224*224*1(由数字0-2组成) ,为一个三类别的分割任务,在onehot编码后维度为 224*224*3,(可以看成3张224*224*1的切片)。

代码:

一:当维度为 N 1 *

one-hot后 N C *

def make_one_hot(input, num_classes):

"""Convert class index tensor to one hot encoding tensor.

Args:

input: A tensor of shape [N, 1, *]

num_classes: An int of number of class

Returns:

A tensor of shape [N, num_classes, *]

"""

shape = np.array(input.shape)

shape[1] = num_classes

shape = tuple(shape)

result = torch.zeros(shape)

result = result.scatter_(1, torch.LongTensor(input), 1)

return result

二:当维度为 1 *

one_hot后 N *

def make_one_hot(input, num_classes):

"""Convert class index tensor to one hot encoding tensor.

Args:

input: A tensor of shape [N, 1, *]

num_classes: An int of number of class

Returns:

A tensor of shape [N, num_classes, *]

"""

shape = np.array(input.shape)

shape[0] = num_classes

shape = tuple(shape)

result = torch.zeros(shape)

result = result.scatter_(0, torch.LongTensor(input), 1)

return result

* 代表图像大小 例如 224 x 224

本文地址:https://blog.csdn.net/wwwww_bw/article/details/107643179

如您对本文有疑问或者有任何想说的,请点击进行留言回复,万千网友为您解惑!

你可能感兴趣的:(pytorch标签onehot编码_pytorch 01 关于分割任务中 onehot 编码转换的问题)