行人重识别mAP的源代码

 			def mean_ap(
        	distmat,
        	query_ids=None,
        	gallery_ids=None,
        	query_cams=None,
        	gallery_cams=None,
        	average=True):
    			m, n = distmat.shape
              # Sort and find correct matches
              indices = np.argsort(distmat, axis=1)
              matches = (gallery_ids[indices] == query_ids[:, np.newaxis])
              # Compute AP for each query
              aps = np.zeros(m)
              is_valid_query = np.zeros(m)
              for i in range(m):
                # Filter out the same id and same camera
                valid = ((gallery_ids[indices[i]] != query_ids[i]) |
                         (gallery_cams[indices[i]] != query_cams[i]))
               # 去掉那些id相同,并且摄像头相同的样例 !((gallery_ids[indices[i]]==query_ids[i])&(gallery_cams[indices[i]] == 		query_cams[i]))
                y_true = matches[i, valid]
                y_score = -distmat[i][indices[i]][valid]
                if not np.any(y_true): continue
                is_valid_query[i] = 1
                aps[i] = average_precision_score(y_true, y_score)
                #sklearn 当中用来计算AP的函数,y_true 代表真实的标签 y_score代表检索分数
              if len(aps) == 0:
                raise RuntimeError("No valid query")
              if average:
                return float(np.sum(aps)) / np.sum(is_valid_query)
              return aps, is_valid_query
          
	average_precision 计算方式:https://scikit-learn.org/stable/modules/generated/sklearn.metrics.average_precision_score.html

你可能感兴趣的:(行人重识别)