读研时候的遗留代码,做一个分享
跨模态检索中实值、哈希两种不同的方法对于p、r的计算有些许不同
import torch
import numpy as np
def pr_curve(qB, rB, query_label, retrieval_label):
num_query = qB.shape[0]
num_bit = qB.shape[1]
P = torch.zeros(num_query, num_bit + 1)
R = torch.zeros(num_query, num_bit + 1)
# 枚举 query sample
for i in range(num_query):
gnd = (query_label[i].unsqueeze(0).mm(retrieval_label.t()) > 0).float().squeeze()
# 整个被检索数据库中的相关样本数
tsum = torch.sum(gnd)
if tsum == 0:
continue
hamm = calc_hamming_dist(qB[i, :], rB)
tmp = (hamm <= torch.arange(0, num_bit + 1).reshape(-1, 1).float().to(hamm.device)).float()
total = tmp.sum(dim=-1)
total = total + (total == 0).float() * 0.1
t = gnd * tmp
count = t.sum(dim=-1)
p = count / total
r = count / tsum
P[i] = p
R[i] = r
mask = (P > 0).float().sum(dim=0)
mask = mask + (mask == 0).float() * 0.1
P = P.sum(dim=0) / mask
R = R.sum(dim=0) / mask
return P, R
import numpy as np
import scipy.spatial
import torch
def pr_curve(qB, rB, label):
num_query = qB.shape[0]
topK = rB.shape[0]
# topK =50
P, R = [], []
dist = scipy.spatial.distance.cdist(qB, rB, 'cosine')
Rank = np.argsort(dist)
Gnd = (label.mm(label.transpose(0, 1)) > 0).type(torch.float32)
for k in range(1, topK + 1): # 枚举 top-K 之 K
p = np.zeros(num_query)
r = np.zeros(num_query)
for it in range(num_query):
gnd = Gnd[it]
gnd_all = gnd.sum() # 整个被检索数据库中的相关样本数
if gnd_all == 0:
continue
asc_id = Rank[it][:k]
gnd = gnd[asc_id]
gnd_r = gnd.sum() # top-K 中的相关样本数
p[it] = gnd_r / k
r[it] = gnd_r / gnd_all
P.append(np.mean(p))
R.append(np.mean(r))
S = np.arange(topK)
S = S.tolist()
return P, R, S
两者差不多,需要注意的是输入的数据格式可能需要转换
p_i2t, r_i2t = pr_curve(qBX, rBY, query_labels, db_labels)
p_t2i, r_t2i = pr_curve(qBY, rBX, query_labels, db_labels)
i2tp,i2tr,i2t_S = pr_curve(view1_feature, view2_feature, torch.tensor(input_data_par['label_test']))
t2ip,t2ir,t2i_S = pr_curve(view2_feature, view1_feature, torch.tensor(input_data_par['label_test']))
也可以考虑存储其他数据类型,当时写的时候没找到方法,又不擅长处理npy文件,多用了一步数据转换
import os
import numpy as np
np.save(os.path.join(path, 'P_i2t'), p_i2t.numpy())
np.save(os.path.join(path, 'R_i2t'), r_i2t.numpy())
np.save(os.path.join(path, 'P_t2i'), p_t2i.numpy())
np.save(os.path.join(path, 'R_t2i'), r_t2i.numpy())
import numpy as np
import pandas as pd
# path处填入npy文件具体路径
npfile = np.load(path)
# 将npy文件的数据格式转化为csv格式
np_to_csv = pd.DataFrame(data=npfile)
# 存入具体目录下的np_to_csv.csv 文件
np_to_csv.to_csv(path\P_i2t.csv')
import matplotlib.pyplot as plt
import pandas as pd
a_i2t_p = pd.read_csv ('path\P_i2t.csv',usecols=[1])
a_i2t_r = pd.read_csv ('path\r_i2t.csv',usecols=[1])
b_i2t_p = pd.read_csv ('path\P_i2t.csv',usecols=[1])
b_i2t_r = pd.read_csv ('path\r_i2t.csv',usecols=[1])
# 画 P-R 曲线
fig = plt.figure(figsize=(5, 5))
plt.grid(linestyle = "--") #设置背景网格线为虚线
ax = plt.gca()
ax.spines['top'].set_visible(False) #去掉上边框
ax.spines['right'].set_visible(False) #去掉右边框
#markevery为间隔点、marker为点的形式、linestyle为线的形式,都可选
plt.plot(a_i2t_r, a_i2t_p, color="r",label="a",linewidth=1.5,linestyle="-", marker='o', markevery=270)
plt.plot(b_i2t_r, b_i2t_p,color="lightgreen",label="b",linewidth=1.5, linestyle="--", marker='*', markevery=270)
plt.grid(True)
plt.xlim(0, 1)#x轴范围,可调整
plt.ylim(0, 1)#y轴范围,可调整
plt.xlabel('recall')
plt.ylabel('precision')
# plt.title("Image2Text",fontsize=12,fontweight='bold') #默认字体大小为12
plt.legend(loc=0, numpoints=1)
leg = plt.gca().get_legend()
ltext = leg.get_texts()
plt.setp(ltext, fontsize=10,fontweight='bold') #设置图例字体的大小和粗细
plt.savefig(r'path\Text2Image.png')
plt.show()