公式:
解释:不同的位数越多,汉明距越大
代码:
def calc_hamming_dist(B1, B2):
q = B2.shape[1] #哈希码位数
if len(B1.shape) < 2:
B1 = B1.unsqueeze(0)
distH = 0.5 * (q - B1.mm(B2.t())) #计算汉明码距离 公式4
return distH
输入:
B1: [1, -1, 1, 1]
B2:[1, -1, 1, -1], [-1, -1, 1, -1], [-1, -1, 1, -1], [1, 1, -1, -1], [-1, 1, -1, -1], [1, 1, -1, 1]]
输出:[1., 2., 2., 3., 4., 2.]
假设原始相关向量为:[1,1,0,0,1,0](1表示与它相关,0表示与它不相关),那么一共三个相关的。按照预测的哈希进行汉明码排序后,为[0,0,1,1,1,0]。则MAP=(1/3+2/4+3/5)/3=0.4778。再例,原始相关向量为:[1,1,0,0,1,0]。按照预测的哈希进行汉明码排序后,为[1,0,1,1,0]。则MAP=(1/1+2/3+3/4)/3=0.9167。
如果k为2,原始相关向量为[1,1,0,0,1,0],重排后为[1,0,1,1,0],则MAP=(1/1+2/3)/2=0.8333。
def calc_map_k(qB, rB, query_label, retrieval_label, k=None):
# qB:查询集 范围{-1,+1}
# rB:检索集 范围{-1,+1}
# query_label: 查询标签
# retrieval_label: 检索标签
num_query = query_label.shape[0] #查询个数
map = 0.
if k is None:
k = retrieval_label.shape[0] #如果不指定k,k将是全部检索个数。对于flickr25k数据集,即18015
for iter in range(num_query):
#每个查询标签乘以检索标签的转置,只要有相同标签,该位置就是1
gnd = (query_label[iter].unsqueeze(0).mm(retrieval_label.t()) > 0).type(torch.float).squeeze()
tsum = torch.sum(gnd) #真实相关的数据个数
print("相关个数:",tsum)
if tsum == 0:
continue
hamm = calc_hamming_dist(qB[iter, :], rB)
_, ind = torch.sort(hamm) #ind :已排序的汉明距,在未排序中的位置
ind.squeeze_()
print("原始 gnd:",gnd)
print("ind :", ind)
gnd = gnd[ind] #按照预测的顺序重排
print("重排后gnd:", gnd)
total = min(k, int(tsum)) #取k和tsum的最小值,这句应该没啥用
#如果有三个相关的,则count是[1,2,3]
count = torch.arange(1, total + 1).type(torch.float).to(gnd.device)
#取出重排后非0元素的位置
tindex = torch.nonzero(gnd)[:total].squeeze().type(torch.float) + 1.0
print("count:",count)
print("tindex:",tindex)
map += torch.mean(count / tindex)
print("map:",map)
map = map / num_query
return map
输出:
相关个数:tensor(3.)
原始 gnd: tensor([1., 1., 0., 0., 1., 0.])
ind : tensor([5, 3, 4, 0, 1, 2])
重排后gnd: tensor([0., 0., 1., 1., 1., 0.])
count: tensor([1., 2., 3.])
tindex: tensor([3., 4., 5.])
map: tensor(0.4778)
特点:相关性,只要有相同标签的就算。如果有n条相关的数据,你只需要把这n条全部找出来,这n条数据内部的顺序不考虑。