every blog every motto: You can do more than you think.
记录torch.pairwise_distance
生成代码:
t = torch.randn(1)
f = torch.randn(1)
计算代码,下同,不重复
dist_matrix = torch.pairwise_distance(t, f)
print('t.shape: ', t.shape, ' t: ', t)
print('-' * 50)
print('f.shape: ', f.shape, ' f: ', f)
print('-' * 50)
print('dist_matrix.shape: ', dist_matrix.shape, ' dis_martix: ', dist_matrix)
值为:
r e s = ( a − b ) 2 res = \sqrt{(a-b)^2} res=(a−b)2
即:
( − 1.0594 − 0.4943 ) 2 = 1.5537 \sqrt{(-1.0594 - 0.4943)^2} = 1.5537 (−1.0594−0.4943)2=1.5537
注意: 输出的维度,0维,即一个标量
t = torch.randn(3)
f = torch.randn(3)
计算过程与上述相同,即对应元素相减平方和后再开方
t = torch.randn(2)
f = torch.randn(3)
报错如下
t = torch.randn(1)
f = torch.randn(3)
虽然元素个数不同,但依然可以计算。计算过程:
元素个数为1的元素依次与f中每个元素依次进行之前的计算步骤,即相减平方和后再开方,可自行验证。
说明: 类似进行了numpy中boradcasting操作
元素相同时,对应元素与相减后平方和再开方
元素不相同时,其中一个元素个数为1才可进行计算,否则报错
t = torch.randn(2, 3)
f = torch.randn(2, 3)
对输出进行了调整,
最内维的元素与前面的计算过程类似,即,对应元素相减平方和在开方
现在我们的输出维度是2
t = torch.randn(4, 3)
f = torch.randn(2, 3)
t = torch.randn(1, 3)
f = torch.randn(2, 3)
t = torch.randn(2, 4)
f = torch.randn(2, 3)
t = torch.randn(2, 1)
f = torch.randn(2, 3)
t = torch.randn(2, 5, 2, 3)
f = torch.randn(2, 5, 2, 3)
t = torch.randn(2, 3, 2, 3)
f = torch.randn(2, 5, 2, 3)
t = torch.randn(2, 1, 2, 3)
f = torch.randn(2, 5, 2, 3)
t = torch.randn(2, 5, 2, 4)
f = torch.randn(2, 5, 2, 3)
t = torch.randn(2, 5, 2, 1)
f = torch.randn(2, 5, 2, 3)
同2.3
[1] https://blog.csdn.net/qq_36560894/article/details/112199266#commentBox
[2] https://pytorch.org/docs/stable/generated/torch.nn.PairwiseDistance.html#torch.nn.PairwiseDistance