跨模态检索绘制p-r曲线

跨模态检索中绘制p-r曲线、top-k曲线

读研时候的遗留代码,做一个分享

目录

  • 跨模态检索中绘制p-r曲线、top-k曲线
    • 1.首先是计算p、r的函数
      • 1.1 哈希方法
      • 1.2 实值
    • 2.然后是调用方法的两个示例
      • 2.1 哈希调用
      • 2.2 实值调用
      • 2.3 存储为npy文件
      • 2.4 npy转csv
    • 3.图像的绘制

1.首先是计算p、r的函数

跨模态检索中实值、哈希两种不同的方法对于p、r的计算有些许不同

1.1 哈希方法

    import torch
    import numpy as np
    def pr_curve(qB, rB, query_label, retrieval_label):
        num_query = qB.shape[0]
        num_bit = qB.shape[1]  
        P = torch.zeros(num_query, num_bit + 1)  
        R = torch.zeros(num_query, num_bit + 1)
        # 枚举 query sample  
        for i in range(num_query):
            gnd = (query_label[i].unsqueeze(0).mm(retrieval_label.t()) > 0).float().squeeze()
            # 整个被检索数据库中的相关样本数
            tsum = torch.sum(gnd)
            if tsum == 0:
                continue
            hamm = calc_hamming_dist(qB[i, :], rB)
            tmp = (hamm <= torch.arange(0, num_bit + 1).reshape(-1, 1).float().to(hamm.device)).float()
            total = tmp.sum(dim=-1)
            total = total + (total == 0).float() * 0.1
            t = gnd * tmp
            count = t.sum(dim=-1)
            p = count / total
            r = count / tsum
            P[i] = p
            R[i] = r
        mask = (P > 0).float().sum(dim=0)
        mask = mask + (mask == 0).float() * 0.1
        P = P.sum(dim=0) / mask
        R = R.sum(dim=0) / mask
        return P, R

1.2 实值

    import numpy as np
    import scipy.spatial
    import torch
    def pr_curve(qB, rB, label):
        num_query = qB.shape[0]
        topK = rB.shape[0]
        # topK =50
        P, R = [], []
        dist = scipy.spatial.distance.cdist(qB, rB, 'cosine')
        Rank = np.argsort(dist)
        Gnd = (label.mm(label.transpose(0, 1)) > 0).type(torch.float32)
        for k in range(1, topK + 1):  # 枚举 top-K 之 K  
            p = np.zeros(num_query)    
            r = np.zeros(num_query)  
            for it in range(num_query):
                gnd = Gnd[it]
                gnd_all = gnd.sum()  # 整个被检索数据库中的相关样本数
                if gnd_all == 0:
                    continue
                asc_id = Rank[it][:k]
                gnd = gnd[asc_id]
                gnd_r = gnd.sum()  # top-K 中的相关样本数
                p[it] = gnd_r / k  
                r[it] = gnd_r / gnd_all  
    
            P.append(np.mean(p))
            R.append(np.mean(r))
        S = np.arange(topK)
        S = S.tolist()
        return P, R, S

2.然后是调用方法的两个示例

两者差不多,需要注意的是输入的数据格式可能需要转换

2.1 哈希调用

    p_i2t, r_i2t = pr_curve(qBX, rBY, query_labels, db_labels)
    p_t2i, r_t2i = pr_curve(qBY, rBX, query_labels, db_labels)

2.2 实值调用

    i2tp,i2tr,i2t_S = pr_curve(view1_feature, view2_feature, torch.tensor(input_data_par['label_test']))
    t2ip,t2ir,t2i_S = pr_curve(view2_feature, view1_feature, torch.tensor(input_data_par['label_test']))

2.3 存储为npy文件

也可以考虑存储其他数据类型,当时写的时候没找到方法,又不擅长处理npy文件,多用了一步数据转换

    import os
    import numpy as np
    np.save(os.path.join(path, 'P_i2t'), p_i2t.numpy())
    np.save(os.path.join(path, 'R_i2t'), r_i2t.numpy())
    np.save(os.path.join(path, 'P_t2i'), p_t2i.numpy())
    np.save(os.path.join(path, 'R_t2i'), r_t2i.numpy())

2.4 npy转csv

    import numpy as np
    import pandas as pd
    # path处填入npy文件具体路径
    npfile = np.load(path)
    # 将npy文件的数据格式转化为csv格式
    np_to_csv = pd.DataFrame(data=npfile)
    # 存入具体目录下的np_to_csv.csv 文件
    np_to_csv.to_csv(path\P_i2t.csv')

3.图像的绘制

    import matplotlib.pyplot as plt
    import pandas as pd
    a_i2t_p = pd.read_csv ('path\P_i2t.csv',usecols=[1])
    a_i2t_r = pd.read_csv ('path\r_i2t.csv',usecols=[1])
    b_i2t_p = pd.read_csv ('path\P_i2t.csv',usecols=[1])
    b_i2t_r = pd.read_csv ('path\r_i2t.csv',usecols=[1])
    # 画 P-R 曲线
    fig = plt.figure(figsize=(5, 5))
    plt.grid(linestyle = "--") #设置背景网格线为虚线
    ax = plt.gca()
    ax.spines['top'].set_visible(False) #去掉上边框
    ax.spines['right'].set_visible(False) #去掉右边框
    #markevery为间隔点、marker为点的形式、linestyle为线的形式,都可选
    plt.plot(a_i2t_r, a_i2t_p, color="r",label="a",linewidth=1.5,linestyle="-", marker='o', markevery=270)
    plt.plot(b_i2t_r, b_i2t_p,color="lightgreen",label="b",linewidth=1.5, linestyle="--", marker='*', markevery=270)
    plt.grid(True)
    plt.xlim(0, 1)#x轴范围,可调整
    plt.ylim(0, 1)#y轴范围,可调整
    plt.xlabel('recall')
    plt.ylabel('precision')
    # plt.title("Image2Text",fontsize=12,fontweight='bold') #默认字体大小为12
    plt.legend(loc=0, numpoints=1)
    leg = plt.gca().get_legend()
    ltext = leg.get_texts()
    plt.setp(ltext, fontsize=10,fontweight='bold') #设置图例字体的大小和粗细
    plt.savefig(r'path\Text2Image.png')
    plt.show()

你可能感兴趣的:(#,跨模态检索,python,深度学习,人工智能)