跨模态搜索——MAP计算过程

跨模态搜索——MAP计算过程

1.汉明距计算:

公式:

解释:不同的位数越多,汉明距越大

代码:

def calc_hamming_dist(B1, B2):
    q = B2.shape[1]   #哈希码位数
    if len(B1.shape) < 2:
        B1 = B1.unsqueeze(0)
    distH = 0.5 * (q - B1.mm(B2.t()))  #计算汉明码距离 公式4
    return distH

输入:

B1: [1, -1, 1, 1]

B2:[1, -1, 1, -1], [-1, -1, 1, -1], [-1, -1, 1, -1], [1, 1, -1, -1], [-1, 1, -1, -1], [1, 1, -1, 1]]

输出:[1., 2., 2., 3., 4., 2.]

2. MAP计算

假设原始相关向量为:[1,1,0,0,1,0](1表示与它相关,0表示与它不相关),那么一共三个相关的。按照预测的哈希进行汉明码排序后,为[0,0,1,1,1,0]。则MAP=(1/3+2/4+3/5)/3=0.4778。再例,原始相关向量为:[1,1,0,0,1,0]。按照预测的哈希进行汉明码排序后,为[1,0,1,1,0]。则MAP=(1/1+2/3+3/4)/3=0.9167。

如果k为2,原始相关向量为[1,1,0,0,1,0],重排后为[1,0,1,1,0],则MAP=(1/1+2/3)/2=0.8333。

def calc_map_k(qB, rB, query_label, retrieval_label, k=None):
    # qB:查询集  范围{-1,+1}
    # rB:检索集  范围{-1,+1}
    # query_label: 查询标签
    # retrieval_label: 检索标签
    num_query = query_label.shape[0]  #查询个数
    map = 0.
    if k is None:
        k = retrieval_label.shape[0]  #如果不指定k,k将是全部检索个数。对于flickr25k数据集,即18015
    for iter in range(num_query):
        #每个查询标签乘以检索标签的转置,只要有相同标签,该位置就是1
        gnd = (query_label[iter].unsqueeze(0).mm(retrieval_label.t()) > 0).type(torch.float).squeeze()
        tsum = torch.sum(gnd)   #真实相关的数据个数
        print("相关个数:",tsum)
        if tsum == 0:
            continue
        hamm = calc_hamming_dist(qB[iter, :], rB)
        _, ind = torch.sort(hamm) #ind :已排序的汉明距,在未排序中的位置
        ind.squeeze_()
        print("原始 gnd:",gnd)
        print("ind    :", ind)
        gnd = gnd[ind]  #按照预测的顺序重排
        print("重排后gnd:", gnd)
        total = min(k, int(tsum))  #取k和tsum的最小值,这句应该没啥用
        #如果有三个相关的,则count是[1,2,3]
        count = torch.arange(1, total + 1).type(torch.float).to(gnd.device)
        #取出重排后非0元素的位置
        tindex = torch.nonzero(gnd)[:total].squeeze().type(torch.float) + 1.0
        print("count:",count)
        print("tindex:",tindex)
        map += torch.mean(count / tindex)
        print("map:",map)
    map = map / num_query
    return map

输出:

相关个数:tensor(3.)
原始 gnd: tensor([1., 1., 0., 0., 1., 0.])
ind    : tensor([5, 3, 4, 0, 1, 2])
重排后gnd: tensor([0., 0., 1., 1., 1., 0.])
count: tensor([1., 2., 3.])
tindex: tensor([3., 4., 5.])
map: tensor(0.4778)

特点:相关性,只要有相同标签的就算。如果有n条相关的数据,你只需要把这n条全部找出来,这n条数据内部的顺序不考虑。

 

 

 

 

你可能感兴趣的:(跨模态检索,python,算法,机器学习)