二分类任务最后的TN、TP、FP、FN及相关指标的计算

def calculate_confusion_matrix_for_binary_classes(preds, labels):
   
    # 确保预测和标签的形状相同
    assert preds.shape == labels.shape, "预测和标签的形状必须相同"

    # 类别 0 的 TP, TN, FP, FN
    TP_0 = torch.sum((preds == 0) & (labels == 0)).item()
    TN_0 = torch.sum((preds == 1) & (labels == 1)).item()
    FP_0 = torch.sum((preds == 0) & (labels == 1)).item()
    FN_0 = torch.sum((preds == 1) & (labels == 0)).item()

    # 类别 1 的 TP, TN, FP, FN
    TP_1 = torch.sum((preds == 1) & (labels == 1)).item()
    TN_1 = torch.sum((preds == 0) & (labels == 0)).item()
    FP_1 = torch.sum((preds == 1) & (labels == 0)).item()
    FN_1 = torch.sum((preds == 0) & (labels == 1)).item()

    return [TP_0, TN_0, FP_0, FN_0], [TP_1, TN_1, FP_1, FN_1]


tbar = tqdm(loader, ncols=100)
total_inter, total_union = 0, 0
total_correct, total_label = 0, 0
total_TP=0
total_TN=0
total_FP=0
total_FN=0
total_class1=[0,0,0,0]
total_class2=[0,0,0,0]    

for ............"""假设有一个循环 理解为不断取数据"""

"""经过一堆操作形成pres和labels"""
class1,class2=calculate_confusion_matrix_for_binary_classes(pres, labels)
TP=class1[0]+class2[0]
TN=class1[1]+class2[1]
FP=class1[2]+class2[2]
FN=class1[3]+class2[3]

total_class1[0]=total_class1[0]+class1[0]
total_class1[1]=total_class1[1]+class1[1]
total_class1[2]=total_class1[2]+class1[2]
total_class1[3]=total_class1[3]+class1[3]
total_class2[0]=total_class2[0]+class2[0]
total_class2[1]=total_class2[1]+class2[1]
total_class2[2]=total_class2[2]+class2[2]
total_class2[3]=total_class2[3]+class2[3]

        total_TP,total_TN,total_FP,total_FN=total_TP+TP,total_TN+TN,total_FP+FP,total_FN+FN
        Pre=1.0*total_TP/(total_FP+total_TP+np.spacing(1))
        IOU_nc=1.0*total_class1[0]/(np.spacing(1)+total_class1[0]+total_class1[2]+total_class1[3])
        IOU_c=1.0*total_class2[0]/(np.spacing(1)+total_class2[0]+total_class2[2]+total_class2[3])
        Recall=1.0*total_TP/(total_FN+total_TP+np.spacing(1))
        F1=(2.0*Recall*Pre)/(np.spacing(1)+Recall+Pre)
        OA=(1.0*total_TP+total_TN)/(total_TN+total_FP+total_FN+total_TP+np.spacing(1))
        PRE=((total_FP+total_TP)*(total_FN+total_TP))/((total_TN+total_FP+total_TP+total_FN)**2)+\
            (total_TN+total_FP)*(total_FN+total_TN)/((total_TN+total_FP+total_TP+total_FN)**2)
        Kappa=1.0*(OA-PRE)/(1-PRE)
        tbar.set_description('Pre:{:.4f},IoU(nc):{:.4f},IoU(c):{:.4f},Rec:{:.4},F1:{:.4f},Kap:{:.4f},OA:{:.4f}'.format(Pre, IOU_nc, IOU_c, Recall, F1, Kappa, OA))

上面的代码是在变化检测任务背景下(二分类)的一串计算多指标的代码,代码写的比较繁重,希望能提供帮助

你可能感兴趣的:(分类,数据挖掘,人工智能)