【Pytorch】多标签分类,分类指标计算,按阈值,或者概率最大的前top个标签【sigmoid + BCELoss】

from: https://github.com/Sun-DongYang/Pytorch/blob/master/multiLabel/multiLabel.py

多标签计算准确率的方式:按阈值,或者概率最大的前top个标签

# 计算准确率——方式1
# 设定一个阈值,当预测的概率值大于这个阈值,则认为这幅图像中含有这类标签
def calculate_acuracy_mode_one(model_pred, labels):
    # 注意这里的model_pred是经过sigmoid处理的,sigmoid处理后可以视为预测是这一类的概率
    # 预测结果,大于这个阈值则视为预测正确
    accuracy_th = 0.5
    pred_result = model_pred > accuracy_th
    pred_result = pred_result.float()
    pred_one_num = torch.sum(pred_result)
    if pred_one_num == 0:
        return 0, 0
    target_one_num = torch.sum(labels)
    true_predict_num = torch.sum(pred_result * labels)
    # 模型预测的结果中有多少个是正确的
    precision = true_predict_num / pred_one_num
    # 模型预测正确的结果中,占所有真实标签的数量
    recall = true_predict_num / target_one_num

    return precision.item(), recall.item()

# 计算准确率——方式2
# 取预测概率最大的前top个标签,作为模型的预测结果
def calculate_acuracy_mode_two(model_pred, labels):
    # 取前top个预测结果作为模型的预测结果
    precision = 0
    recall = 0
    top = 5
    # 对预测结果进行按概率值进行降序排列,取概率最大的top个结果作为模型的预测结果
    pred_label_locate = torch.argsort(model_pred, descending=True)[:, 0:top]
    for i in range(model_pred.shape[0]):
        temp_label = torch.zeros(1, model_pred.shape[1])
        temp_label[0,pred_label_locate[i]] = 1
        target_one_num = torch.sum(labels[i])
        true_predict_num = torch.sum(temp_label * labels[i])
        # 对每一幅图像进行预测准确率的计算
        precision += true_predict_num / top
        # 对每一幅图像进行预测查全率的计算
        recall += true_predict_num / target_one_num
    return precision, recall

你可能感兴趣的:(【Pytorch学习】,pytorch,分类,人工智能)