Spearman’s correlation coefficient--斯皮尔曼相关系数pytorch与numpy实现

文章目录

    • Spearman’s correlation介绍
    • Pytorch实现
    • Numpy实现

Spearman’s correlation介绍

斯皮尔曼等级相关(Spearman’s correlation coefficient for ranked data)主要用于解决名称数据和顺序数据相关的问题。适用于两列变量,而且具有等级变量性质具有线性关系的资料。由英国心理学家、统计学家斯皮尔曼根据积差相关的概念推导而来,一些人把斯皮尔曼等级相关看做积差相关的特殊形式。

公式如下:
在这里插入图片描述

Pytorch实现

矩阵运算实现,运行简便快捷,变量名字可自行替换。输入logits即可

def compute_rank_correlation(att, grad_att):
    """
    Function that measures Spearman’s correlation coefficient between target logits and output logits:
    att: [n, m]
    grad_att: [n, m]
    """
    def _rank_correlation_(att_map, att_gd):
        n = torch.tensor(att_map.shape[1])
        upper = 6 * torch.sum((att_gd - att_map).pow(2), dim=1)
        down = n * (n.pow(2) - 1.0)
        return (1.0 - (upper / down)).mean(dim=-1)

    att = att.sort(dim=1)[1]
    grad_att = grad_att.sort(dim=1)[1]
    correlation = _rank_correlation_(att.float(), grad_att.float())
    return correlation

Numpy实现

这里调用函数前,请保证输入的maps都已经转成了rank的形式

def rank_correlation(att_map, att_gd):
	"""
	Function that measures Spearman’s correlation coefficient between target and output:
	"""
	n = att_map.shape[1]
	upper = 6 *np.sum(np.square(att_gd - att_map), axis=-1)
	down = n*(np.square(n)-1)
	return np.mean(1 - (upper/down))

你可能感兴趣的:(pytorch,python,numpy,pytorch)