torch.meshgrid 的函数原型是
torch.meshgrid(*tensors, indexing=None)
indexing 是 torch.meshgrid 的一个参数。
torch.meshgrid 的功能是生成 “网格数据”,比如生成一张图像的所有像素坐标。
本文,以 高度为 4,宽度为 7(即 H=4,W=7)的图像为例子说明。
import torch
H = 4
W = 7
H_arange = torch.arange(0,H)
W_arange = torch.arange(0,W)
print("\nH: 4, W:7\n")
grid_i, grid_j = torch.meshgrid(W_arange, H_arange ,indexing='ij')
print("grid_i shape: ",grid_i.shape," grid_j shape: ",grid_j.shape)
print("grid_i:\n",grid_i,"\n","grid_j:\n",grid_j,'\n')
grid_x, grid_y = torch.meshgrid(W_arange, H_arange ,indexing='xy')
print("grid_x shape: ",grid_x.shape," grid_y shape: ",grid_y.shape)
print("grid_x:\n",grid_x,"\n","grid_y:\n",grid_y,'\n')
当设 indexing='ij'
返回的:
grid_i shape: torch.Size([7, 4])
grid_j shape: torch.Size([7, 4])
他们的shape都是 [W,H]
shape 第 0 维是 W,第 1 维是 H
这跟 输入 torch.meshgrid 参数 W_arange, H_arange的相对顺序是一致的。
相关代码输出如下:
grid_i:
tensor([[0, 0, 0, 0],
[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3],
[4, 4, 4, 4],
[5, 5, 5, 5],
[6, 6, 6, 6]])
grid_j:
tensor([[0, 1, 2, 3],
[0, 1, 2, 3],
[0, 1, 2, 3],
[0, 1, 2, 3],
[0, 1, 2, 3],
[0, 1, 2, 3],
[0, 1, 2, 3]])
当设 indexing='xy'
grid_x shape: torch.Size([4, 7])
grid_y shape: torch.Size([4, 7])
他们的shape都是 [H,W]
shape 第 1 维是 W,第 0 维是 H
这跟 输入 torch.meshgrid 参数 W_arange, H_arange的相对顺序是相反的。
相关代码输出如下:
grid_x shape: torch.Size([4, 7]) grid_y shape: torch.Size([4, 7])
grid_x:
tensor([[0, 1, 2, 3, 4, 5, 6],
[0, 1, 2, 3, 4, 5, 6],
[0, 1, 2, 3, 4, 5, 6],
[0, 1, 2, 3, 4, 5, 6]])
grid_y:
tensor([[0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1],
[2, 2, 2, 2, 2, 2, 2],
[3, 3, 3, 3, 3, 3, 3]])
代码
# OpenCV Convention: uv first column then row
coords_ij_col_row = torch.stack([grid_i,grid_j], dim=-1).reshape(-1, 2)
print(coords_ij_col_row)
# Matrix Index Convention: first row then column
coords_ij_row_col = torch.stack([grid_j,grid_i], dim=-1).reshape(-1, 2)
print(coords_ij_row_col)
输出:
# OpenCV Convention: uv first column then row
tensor([[0, 0],
[0, 1],
[0, 2],
[0, 3],
......
[6, 0],
[6, 1],
[6, 2],
[6, 3]])
# Matrix Index Convention: first row then column
tensor([[0, 0],
[1, 0],
[2, 0],
[3, 0],
......
[0, 6],
[1, 6],
[2, 6],
[3, 6]])
Process finished with exit code 0
# OpenCV Convention: uv first column then row
coords_xy_col_row = torch.stack([grid_x,grid_y], dim=-1).reshape(-1, 2)
print(coords_xy_col_row)
# Matrix Index Convention: first row then column
coords_xy_row_col = torch.stack([grid_y,grid_x], dim=-1).reshape(-1, 2)
print(coords_xy_row_col)
输出:
# OpenCV Convention: uv first column then row
tensor([[0, 0],
[1, 0],
[2, 0],
[3, 0],
[4, 0],
[5, 0],
[6, 0],
......
[0, 3],
[1, 3],
[2, 3],
[3, 3],
[4, 3],
[5, 3],
[6, 3]])
# Matrix Index Convention: first row then column
tensor([[0, 0],
[0, 1],
[0, 2],
[0, 3],
[0, 4],
[0, 5],
[0, 6],
......
[3, 0],
[3, 1],
[3, 2],
[3, 3],
[3, 4],
[3, 5],
[3, 6]])
Process finished with exit code 0
如果不设置,pytorch 会报错
UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument.
grid_x, grid_y = torch.meshgrid(W_arange, H_arange ,indexing='xy')
# Matrix Index Convention: first row then column
coords_xy_row_col = torch.stack([grid_y,grid_x], dim=-1).reshape(-1, 2)
print(coords_xy_row_col)
grid_i, grid_j = torch.meshgrid(W_arange, H_arange ,indexing='ij')
# OpenCV Convention: uv first column then row
coords_ij_col_row = torch.stack([grid_i,grid_j], dim=-1).reshape(-1, 2)
print(coords_ij_col_row)