个人主页:十二月的猫-CSDN博客
系列专栏: 《PyTorch科研加速指南:即插即用式模块开发》-CSDN博客十二月的寒冬阻挡不了春天的脚步,十二点的黑夜遮蔽不住黎明的曙光
目录
1. 前言
2. 标签转独热编码函数
2.1 完整函数
2.2 函数功能解释
3. 实战示例
4. 总结
正在更新中
项目运行环境:
def label2onehot(logits, labels):
"""
将标签转换为 one-hot 编码形式。
参数:
logits (torch.Tensor 或 np.ndarray): 模型的输出 logits,形状通常为 (batch_size, num_classes)。
labels (list 或 torch.Tensor): 对应的标签,形状为 (batch_size,)。
返回:
np.ndarray: 转换后的 one-hot 编码,形状与 logits 相同。
"""
# 创建一个与 logits 形状相同的全零张量
label_onehot = torch.zeros_like(torch.tensor(logits))
# 使用 scatter_ 函数将 labels 转换为 one-hot 编码
label_onehot.scatter_(1, torch.tensor(labels).long().view(-1, 1), 1)
# 将 one-hot 编码的张量转换为 numpy 数组并返回
return label_onehot.numpy()
1. label_onehot = torch.zeros_like(torch.tensor(logits)):
logits = [[0.1, 0.2, 0.7], [0.9, 0.05, 0.05]] # 形状: (2, 3)
label_onehot = torch.zeros_like(torch.tensor(logits))
# 输出: tensor([[0., 0., 0.],
# [0., 0., 0.]])
模型输出的结果都是二维的:每一行是一个记录;每一列是对一个label的可能性评估
2. torch.tensor(labels).long().view(-1, 1):
labels = [2, 0] # 形状: (2,)
index_tensor = torch.tensor(labels).long().view(-1, 1)
# 输出: tensor([[2],
# [0]])
3. label_onehot.scatter_(1, index_tensor, 1):
label_onehot = torch.zeros(2, 3) # 形状: (2, 3)
index_tensor = torch.tensor([[2], [0]]) # 形状: (2, 1)
label_onehot.scatter_(1, index_tensor, 1)
# 输出: tensor([[0., 0., 1.],
# [1., 0., 0.]])
看起来二维的Tensor在列上是隔开的,但正如线代的矩阵,Tensor在列和行上都是相连的。
4. return label_onehot.numpy():
label_onehot = torch.tensor([[0., 0., 1.], [1., 0., 0.]])
onehot_numpy = label_onehot.numpy()
# 输出: array([[0., 0., 1.],
# [1., 0., 0.]], dtype=float32)
import torch
# 定义函数
def label2onehot(logits, labels):
label_onehot = torch.zeros_like(torch.tensor(logits))
label_onehot.scatter_(1, torch.tensor(labels).long().view(-1, 1), 1)
return label_onehot.numpy()
# 示例数据
logits = [[0.1, 0.2, 0.7], [0.9, 0.05, 0.05]] # 形状: (2, 3)
labels = [2, 0] # 形状: (2,)
# 调用函数
onehot_labels = label2onehot(logits, labels)
print("One-hot 编码结果:")
print(onehot_labels)
运行结果:
总结:
【如果想学习更多深度学习文章,可以订阅一下热门专栏】
如果想要学习更多pyTorch/python编程的知识,大家可以点个关注并订阅,持续学习、天天进步你的点赞就是我更新的动力,如果觉得对你有帮助,辛苦友友点个赞,收个藏呀~~~