这个API可用于计算特征图之间的像素级的距离,输入x1维度为[N,C,H,W]
,输入x2的维度为[M,C,H,W]
。可以通过torch.pairwise_distance(x1, x2)
来计算得到像素级距离。
其中要求
N==M
orN==1
orM==1
这个API我在官方文档没有搜到,而是在通过一篇文章的github源码偶然得知,通过自己的尝试从而总结,如有不全面,还请见谅。
已有模板特征T
,其维度为[1,C,H,W]
,想要计算特征图F
(维度为[1, C, H, W]
)与模板特征之间每个像素点(共HxW个像素)的距离。代码示例如下:
t = torch.randn(1,3,3,3)
f = torch.randn(4,3,3,3)
dist_matrix = torch.pairwise_distance(t, f)
print(dist_matrix.shape)
# torch.Size([4, 3, 3])
已有像素级模板特征T
,其维度为[1,C,1,1]
,想要计算特征图F
(维度为[1, C, H, W]
)中每个像素(共HxW个像素)与模板像素特征的距离。代码示例如下:
t = torch.randn(1,3,1,1)
f = torch.randn(4,3,3,3)
dist_matrix = torch.pairwise_distance(t, f)
print(dist_matrix.shape)
# torch.Size([4, 3, 3])
还有许多不同的用法,这里不再叙述
因为没有找到对应的官方文档,因此自己写了一些检测程序。代码如下:
x = torch.from_numpy(np.array([1,1,1])).float().view(-1,3).unsqueeze(-1).unsqueeze(-1)
y = torch.from_numpy(np.array([[[3,3,3],[1,1,1]],
[[1,1,1],[1,1,1]]])).float().permute(2,0,1).unsqueeze(0)
# print(x.shape,'\n',x)
# print(y.shape,'\n',y)
dist_matrix = torch.pairwise_distance(x, y)
print(dist_matrix)
构造x和y,维度上:x为[1,3,1,1]
,y为[1,3,2,2]
。其中y[0,0]与模板像素差距比较大,其它像素位置上与模板像素相同。
输出:
torch.Size([1, 3, 1, 1]) # x.shape
tensor([[[[1.]], # x
[[1.]],
[[1.]]]])
torch.Size([1, 3, 2, 2]) # y.shape
tensor([[[[3., 1.], # y
[1., 1.]],
[[3., 1.],
[1., 1.]],
[[3., 1.],
[1., 1.]]]])
tensor([[[3.4641e+00, 1.7321e-06], # dist_matrix
[1.7321e-06, 1.7321e-06]]])
可以看到除了[0,0]位置上值比较大,其他都接近于0.
x = torch.from_numpy(np.array([[1,1,1], [3,3,3]])).float().view(-1,3).unsqueeze(-1).unsqueeze(-1)
y = torch.from_numpy(np.array([[[3,3,3],[1,1,1]],
[[1,1,1],[1,1,1]]])).float().permute(2,0,1).unsqueeze(0)
print(x.shape,'\n',x)
print(y.shape,'\n',y)
dist_matrix = torch.pairwise_distance(x, y)
print(dist_matrix)
构造x和y,维度上:x为[2,3,1,1]
,y为[1,3,2,2]
。其中y[0,0]与模板像素特征[0]差距比较大,其它像素位置上与模板像素[0]相同,y[0,0]与模板像素特征[1]相同,其它像素位置上与模板像素[1]差距较大。
torch.Size([2, 3, 1, 1]) # x.shape
tensor([[[[1.]], # x
[[1.]],
[[1.]]],
[[[3.]],
[[3.]],
[[3.]]]])
torch.Size([1, 3, 2, 2]) # y.shape
tensor([[[[3., 1.], # y
[1., 1.]],
[[3., 1.],
[1., 1.]],
[[3., 1.],
[1., 1.]]]])
tensor([[[3.4641e+00, 1.7321e-06], # dist_matrix
[1.7321e-06, 1.7321e-06]],
[[1.7321e-06, 3.4641e+00],
[3.4641e+00, 3.4641e+00]]])
可以看到distance_matrix[0]除了[0,0]位置上值比较大,其他都接近于0,而distance_matrix[1]的[0,0]位置上为0。