from scipy.spatial import distance
# 以下两种方式视情况选择
scipy.spatial.distance.pdist()
scipy.spatial.distance.cdist()
在神经网络的训练过程中,应用以上工具包需要把torch.tensor转变成numpy格式再计算,存在两个缺点:一是耗时,格式变来变去,而且从GPU迁移到CPU再返回到GPU;二是会造成梯度丢失。
import torch
import torch.tensor as tensor
import torch.nn.functional as F
a = tensor([[1., 1., 1.],
[2., 2., 2.],
[3., 3., 3.],
[4., 4., 4.]]) #建立tensor
d=F.pdist(a, p=2)
print(d)
"""
输出:tensor([1.7321, 3.4641, 5.1962, 1.7321, 3.4641, 1.7321])
"""
import torch
import torch.tensor as tensor
"""
自定义的距离矩阵函数
"""
def pdists(A, squared = False, eps = 1e-8):
prod = torch.mm(A, A.t())
norm = prod.diag().unsqueeze(1).expand_as(prod)
res = (norm + norm.t() - 2 * prod).clamp(min = 0)
if squared:
return res
else:
res = res.clamp(min = eps).sqrt()
return res
"""应用示例"""
a = tensor([[1., 2., 3.],
[4., 5., 6.],
[7., 8., 9.],
[10., 11., 12.]])
c=pdists(a, squared = False)
print(c)
"""打印结果
tensor([[1.0000e-04, 5.1962e+00, 1.0392e+01, 1.5588e+01],
[5.1962e+00, 1.0000e-04, 5.1962e+00, 1.0392e+01],
[1.0392e+01, 5.1962e+00, 1.0000e-04, 5.1962e+00],
[1.5588e+01, 1.0392e+01, 5.1962e+00, 1.0000e-04]])
"""
在torch.nn.functional.pdist的文档介绍中有这么一句话:
简单翻译:计算输入中每对行向量之间的p范数距离。 这与torch.norm(input[:, None] - input, dim=2, p=p)的对角线以外的上部三角形部分相同。 如果行是连续的,此功能将更快。
这句话暗示:torch.norm函数可用于计算距离矩阵,而且可以选择L1、L2范数或者其他范数。
应用示例:
import torch
import torch.tensor as tensor
a = tensor([[1., 1., 1.],
[2., 2., 2.],
[3., 3., 3.],
[4., 4., 4.]]) #建立tensor
b=torch.norm(a[:, None]-a, dim=2, p=2)
print(b)
"""
tensor([[0.0000, 1.7321, 3.4641, 5.1962],
[1.7321, 0.0000, 1.7321, 3.4641],
[3.4641, 1.7321, 0.0000, 1.7321],
[5.1962, 3.4641, 1.7321, 0.0000]])
"""
对应的,可以把torch.norm封装成新的pdist函数:
import torch
import torch.tensor as tensor
"""函数封装"""
def pdist(a,dim=2, p=2):
dist_matrix = torch.norm(a[:, None]-a, dim, p)
return dist_matrix
import torch
def cosinematrix(A):
prod = torch.mm(A, A.t())#分子
norm = torch.norm(A,p=2,dim=1).unsqueeze(0)#分母
cos = prod.div(torch.mm(norm.t(),norm))
return cos
# 使用
d_matrix=cosinematrix(inputs)
文章参考:pytorch不用for循环计算一个矩阵各行之间的L1 、L2范数距离和余弦距离_小鱼的代码世界-CSDN博客_pytorch计算距离矩阵