torch.meshgrid 使用探究

torch.meshgrid 的函数原型是

torch.meshgrid(*tensors, indexing=None)

indexing 是 torch.meshgrid 的一个参数。

torch.meshgrid 的功能是生成 “网格数据”,比如生成一张图像的所有像素坐标。

本文,以 高度为 4,宽度为 7(即 H=4,W=7)的图像为例子说明。

返回值不同

import torch

H = 4
W = 7

H_arange = torch.arange(0,H)
W_arange = torch.arange(0,W)

print("\nH: 4,    W:7\n")
grid_i, grid_j = torch.meshgrid(W_arange, H_arange ,indexing='ij') 
print("grid_i shape: ",grid_i.shape,"      grid_j shape: ",grid_j.shape)
print("grid_i:\n",grid_i,"\n","grid_j:\n",grid_j,'\n')


grid_x, grid_y = torch.meshgrid(W_arange, H_arange ,indexing='xy')  
print("grid_x shape: ",grid_x.shape,"      grid_y shape: ",grid_y.shape)
print("grid_x:\n",grid_x,"\n","grid_y:\n",grid_y,'\n')

indexing=‘ij’

当设 indexing='ij'
返回的:
grid_i shape: torch.Size([7, 4])
grid_j shape: torch.Size([7, 4])

他们的shape都是 [W,H]
shape 第 0 维是 W,第 1 维是 H
这跟 输入 torch.meshgrid 参数 W_arange, H_arange的相对顺序是一致的

相关代码输出如下


grid_i:
 tensor([[0, 0, 0, 0],
        [1, 1, 1, 1],
        [2, 2, 2, 2],
        [3, 3, 3, 3],
        [4, 4, 4, 4],
        [5, 5, 5, 5],
        [6, 6, 6, 6]]) 
 grid_j:
 tensor([[0, 1, 2, 3],
        [0, 1, 2, 3],
        [0, 1, 2, 3],
        [0, 1, 2, 3],
        [0, 1, 2, 3],
        [0, 1, 2, 3],
        [0, 1, 2, 3]]) 

indexing=‘xy’

当设 indexing='xy'
grid_x shape: torch.Size([4, 7])
grid_y shape: torch.Size([4, 7])

他们的shape都是 [H,W]

shape 第 1 维是 W,第 0 维是 H
这跟 输入 torch.meshgrid 参数 W_arange, H_arange的相对顺序是相反的

相关代码输出如下

grid_x shape:  torch.Size([4, 7])       grid_y shape:  torch.Size([4, 7])
grid_x:
 tensor([[0, 1, 2, 3, 4, 5, 6],
        [0, 1, 2, 3, 4, 5, 6],
        [0, 1, 2, 3, 4, 5, 6],
        [0, 1, 2, 3, 4, 5, 6]]) 
 grid_y:
 tensor([[0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1],
        [2, 2, 2, 2, 2, 2, 2],
        [3, 3, 3, 3, 3, 3, 3]]) 

进一步生成像素坐标

indexing=‘ij’

代码

# OpenCV Convention: uv first column then row
coords_ij_col_row = torch.stack([grid_i,grid_j], dim=-1).reshape(-1, 2)
print(coords_ij_col_row)

# Matrix Index Convention: first row then column
coords_ij_row_col = torch.stack([grid_j,grid_i], dim=-1).reshape(-1, 2)
print(coords_ij_row_col)

输出:

# OpenCV Convention: uv first column then row
tensor([[0, 0],
        [0, 1],
        [0, 2],
        [0, 3],
       ......
        [6, 0],
        [6, 1],
        [6, 2],
        [6, 3]]) 
# Matrix Index Convention: first row then column
tensor([[0, 0],
        [1, 0],
        [2, 0],
        [3, 0],
        ......
        [0, 6],
        [1, 6],
        [2, 6],
        [3, 6]])

Process finished with exit code 0

indexing=‘xy’

# OpenCV Convention: uv first column then row
coords_xy_col_row = torch.stack([grid_x,grid_y], dim=-1).reshape(-1, 2)
print(coords_xy_col_row)

# Matrix Index Convention: first row then column
coords_xy_row_col = torch.stack([grid_y,grid_x], dim=-1).reshape(-1, 2)
print(coords_xy_row_col)

输出:

# OpenCV Convention: uv first column then row
tensor([[0, 0],
        [1, 0],
        [2, 0],
        [3, 0],
        [4, 0],
        [5, 0],
        [6, 0],
        ......
        [0, 3],
        [1, 3],
        [2, 3],
        [3, 3],
        [4, 3],
        [5, 3],
        [6, 3]])

# Matrix Index Convention: first row then column
tensor([[0, 0],
        [0, 1],
        [0, 2],
        [0, 3],
        [0, 4],
        [0, 5],
        [0, 6],
......
        [3, 0],
        [3, 1],
        [3, 2],
        [3, 3],
        [3, 4],
        [3, 5],
        [3, 6]])

Process finished with exit code 0

如果不认为设置indexing参数,pytorch 默认 indexing = ‘ij’

如果不设置,pytorch 会报错

UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument.

总结

根据矩阵索引传统,先 row 后 column,那就要选择indexing = ‘xy’

grid_x, grid_y = torch.meshgrid(W_arange, H_arange ,indexing='xy')
# Matrix Index Convention: first row then column
coords_xy_row_col = torch.stack([grid_y,grid_x], dim=-1).reshape(-1, 2)
print(coords_xy_row_col)

根据OpenCV索引传统,先 column 后row ,那就要选择indexing = ‘ij’

grid_i, grid_j = torch.meshgrid(W_arange, H_arange ,indexing='ij') 
# OpenCV Convention: uv first column then row
coords_ij_col_row = torch.stack([grid_i,grid_j], dim=-1).reshape(-1, 2)
print(coords_ij_col_row)

你可能感兴趣的:(PyTorch,python,深度学习,pytorch)