如何构建置信度掩码(代码)

import torch
import torch.nn.functional as F

# 假设 preds 是模型输出的预测结果,尺寸为 [2, 256, 256]
# 假设 labels 是真实标签图像,尺寸为 [256, 256],值为 0 或 1

# 假设的模型输出和标签(随机生成,实际使用时替换为真实数据)
preds = torch.rand(2, 256, 256)
labels = torch.randint(0, 2, (256, 256))

# 应用 softmax 来获取概率分布
probs = F.softmax(preds, dim=0)

# 选择每个像素点概率最高的类别
confidences, predictions = torch.max(probs, dim=0)

# 置信度阈值
tau = 0.6

# 生成置信度掩码
confidence_mask = confidences >= tau
filtered_labels = labels[confidence_mask]

confidence_mask = confidence_mask.unsqueeze(0).repeat(2, 1, 1)# 使用置信度掩码过滤掉不满足置信度要求的预测
filtered_probs = probs[confidence_mask]

data_part1 = filtered_probs[:int(filtered_probs.size(0)/2)]
data_part2 = filtered_probs[int(filtered_probs.size(0)/2):]

# 然后将这两部分堆叠起来形成所需的形状
reshaped_data = torch.stack((data_part1, data_part2), dim=1)
# 计算 CE 损失,只考虑置信度高的像素点

print(reshaped_data)
print(filtered_labels)

print(reshaped_data.size())
print(filtered_labels.size())
loss = F.cross_entropy(reshaped_data, filtered_labels)

print(loss)

对于CE损失函数,要求的尺寸预测图像要比GT多一维,所以在代码中特别地构建了reshape_data ,这样就可以计算损失了。

代码写得不是很规范,欢迎指正

你可能感兴趣的:(python,深度学习,人工智能)