torch.pairwise_distance(): 计算特征图之间的像素级欧氏距离

文章目录

    • torch.pairwise_distance(x1, x2)
    • 使用示例1
    • 使用示例2
    • 正确性检查
      • 程序1
      • 程序2

torch.pairwise_distance(x1, x2)

这个API可用于计算特征图之间的像素级的距离,输入x1维度为[N,C,H,W],输入x2的维度为[M,C,H,W]。可以通过torch.pairwise_distance(x1, x2)来计算得到像素级距离。

其中要求N==M or N==1 or M==1

这个API我在官方文档没有搜到,而是在通过一篇文章的github源码偶然得知,通过自己的尝试从而总结,如有不全面,还请见谅。

使用示例1

已有模板特征T,其维度为[1,C,H,W],想要计算特征图F(维度为[1, C, H, W])与模板特征之间每个像素点(共HxW个像素)的距离。代码示例如下:

t = torch.randn(1,3,3,3)
f = torch.randn(4,3,3,3)

dist_matrix = torch.pairwise_distance(t, f)
print(dist_matrix.shape)
# torch.Size([4, 3, 3])

使用示例2

已有像素级模板特征T,其维度为[1,C,1,1],想要计算特征图F(维度为[1, C, H, W])中每个像素(共HxW个像素)与模板像素特征的距离。代码示例如下:

	t = torch.randn(1,3,1,1)
	f = torch.randn(4,3,3,3)
	
	dist_matrix = torch.pairwise_distance(t, f)
	print(dist_matrix.shape)
	# torch.Size([4, 3, 3])

还有许多不同的用法,这里不再叙述

正确性检查

因为没有找到对应的官方文档,因此自己写了一些检测程序。代码如下:

程序1

    x = torch.from_numpy(np.array([1,1,1])).float().view(-1,3).unsqueeze(-1).unsqueeze(-1)
    y = torch.from_numpy(np.array([[[3,3,3],[1,1,1]],
                                 [[1,1,1],[1,1,1]]])).float().permute(2,0,1).unsqueeze(0)
    # print(x.shape,'\n',x)
    # print(y.shape,'\n',y)
    dist_matrix = torch.pairwise_distance(x, y)
    print(dist_matrix)

构造x和y,维度上:x为[1,3,1,1],y为[1,3,2,2]。其中y[0,0]与模板像素差距比较大,其它像素位置上与模板像素相同。

输出:

torch.Size([1, 3, 1, 1]) 	# x.shape
 tensor([[[[1.]],	# x

         [[1.]],

         [[1.]]]])
         
torch.Size([1, 3, 2, 2]) 	# y.shape
 tensor([[[[3., 1.],	# y
          [1., 1.]],

         [[3., 1.],
          [1., 1.]],

         [[3., 1.],
          [1., 1.]]]])
          
tensor([[[3.4641e+00, 1.7321e-06],	# dist_matrix
         [1.7321e-06, 1.7321e-06]]])

可以看到除了[0,0]位置上值比较大,其他都接近于0.

程序2

	x = torch.from_numpy(np.array([[1,1,1], [3,3,3]])).float().view(-1,3).unsqueeze(-1).unsqueeze(-1)
    y = torch.from_numpy(np.array([[[3,3,3],[1,1,1]],
                                 [[1,1,1],[1,1,1]]])).float().permute(2,0,1).unsqueeze(0)
    print(x.shape,'\n',x)
    print(y.shape,'\n',y)
    dist_matrix = torch.pairwise_distance(x, y)
    print(dist_matrix)

构造x和y,维度上:x为[2,3,1,1],y为[1,3,2,2]。其中y[0,0]与模板像素特征[0]差距比较大,其它像素位置上与模板像素[0]相同,y[0,0]与模板像素特征[1]相同,其它像素位置上与模板像素[1]差距较大。

torch.Size([2, 3, 1, 1]) 	# x.shape
 tensor([[[[1.]],	# x

         [[1.]],

         [[1.]]],


        [[[3.]],

         [[3.]],

         [[3.]]]])
         
torch.Size([1, 3, 2, 2]) 	# y.shape
 tensor([[[[3., 1.],	# y
          [1., 1.]],

         [[3., 1.],
          [1., 1.]],

         [[3., 1.],
          [1., 1.]]]])
          
tensor([[[3.4641e+00, 1.7321e-06],	# dist_matrix
         [1.7321e-06, 1.7321e-06]],

        [[1.7321e-06, 3.4641e+00],
         [3.4641e+00, 3.4641e+00]]])

可以看到distance_matrix[0]除了[0,0]位置上值比较大,其他都接近于0,而distance_matrix[1]的[0,0]位置上为0。

你可能感兴趣的:(快乐ML/DL,深度学习,计算机视觉,pytorch)