定义:torch.nn.functional.one_hot(tensor, num_classes=- 1) → LongTensor
描述:Takes LongTensor with index values of shape () and returns a tensor of shape (, num_classes) that have zeros everywhere except where the index of last dimension matches the corresponding value of the input tensor, in which case it will be 1.
译文:
使用index值为shape()的LongTensor,并返回一个shape (, num_classes)的张量,除了最后一个维度的索引与输入张量的对应值相匹配的地方外,其他地方都是零,在这种情况下,它将是1。
说人话: 就是在你给定一个张量的时候,可以对你给的张量进行编码,这里分两种情况
import torch
from torch.nn import functional as F
x = torch.tensor([1, 2, 3, 8, 5])
# 定义一个张量输入,因为此时有 5 个数值,且最大值为8,
# 所以我们可以得到 y 的输出结果的形状应该为 shape=(5,9);5行9列
y = F.one_hot(x) # 只有一个参数张量x
print(f'x = {x}') # 输出 x
print(f'x_shape = {x.shape}') # 查看 x 的形状
print(f'y = {y}') # 输出 y
print(f'y_shape = {y.shape}') # 输出 y 的形状
我们可以看出来,所得的结果为 X 中每个张量里面的值为 Y 结果中的序号为 1 的地方;
比如: X 中第 4 个值表示为 8 的值,可以看到 Y 中第 4 行的 8 个(下标从 0 开始)
x = tensor([1, 2, 3, 8, 5])
x_shape = torch.Size([5])
y = tensor([[0, 1, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1, 0, 0, 0]])
y_shape = torch.Size([5, 9])
import torch
from torch.nn import functional as F
x = torch.tensor([1, 2, 3, 8, 5])
# 定义一个张量输入,因为此时有 5 个数值,且最大值为8,且 类别数为 12
# 所以我们可以得到 y 的输出结果的形状应该为 shape=(5,12);5行12列
y = F.one_hot(x, 12) # 一个参数张量x, 12 为类别数,其中 12 > max{x}
print(f'x = {x}') # 输出 x
print(f'x_shape = {x.shape}') # 查看 x 的形状
print(f'y = {y}') # 输出 y
print(f'y_shape = {y.shape}') # 输出 y 的形状
x = tensor([1, 2, 3, 8, 5])
x_shape = torch.Size([5])
y = tensor([[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0]])
y_shape = torch.Size([5, 12])