Pytorch计算距离(例如欧式距离)torch.nn.PairwiseDistance

通常,我们计算欧式距离,例如[0,0]到[1,1]的距离为 2 \sqrt2 2

pdist = nn.PairwiseDistance(p=2)#p=2就是计算欧氏距离,p=1就是曼哈顿距离,例如上面的例子,距离是1.
input1 = torch.randn(100, 128)
input2 = torch.randn(100, 128)#上面两个形状要一样。
output = pdist(input1,input2)#计算各自每一行之间的欧式距离。
output.shape

torch.Size([100])

注意:

  1. 两个输入形状要一样。因为行之间需要相互计算距离。
  2. 形状必须是[N,D],不能是[D],后者需要改成[1,D]否则报错。

你可能感兴趣的:(Pytorch深入理解与实战,pytorch)