torch.meshgrid

torch.meshgrid(*tensors)

  • tensors: 两个一维向量,如果是0维,当作1维处理

创建网格坐标
Creates grids of coordinates specified by the 1D inputs in attr:tensors.
This is helpful when you want to visualize data over some range of inputs.

返回:两个矩阵

  • 第一个矩阵行相同,列是第一个向量的各个元素
  • 第二个矩阵列相同,行是第二个向量的各个元素

直接看两个例子理解,第一个展示输入输出,第二个便于理解为什么可以创建网格坐标

第一个

x = torch.tensor([1, 2, 3])
y = torch.tensor([4, 5, 6, 7])
grid_x, grid_y = torch.meshgrid(x, y)
print("grid_x: ", grid_x)
print("grid_y: ", grid_y)
print(torch.equal(torch.cat(tuple(torch.dstack([grid_x, grid_y]))), torch.cartesian_prod(x, y)))
'''
grid_x:  
tensor([[1, 1, 1, 1],
        [2, 2, 2, 2],
        [3, 3, 3, 3]])
grid_y:  
tensor([[4, 5, 6, 7],
        [4, 5, 6, 7],
        [4, 5, 6, 7]])
True
'''

第二个演示绘图

可以看到 x , y x,y x,y就直接定义了其横纵坐标,下面的 x , y , z x,y,z x,y,z均为2维

xs = torch.linspace(-5, 5, steps=100)
ys = torch.linspace(-5, 5, steps=100)
x, y = torch.meshgrid(xs, ys)
z = torch.sin(torch.sqrt(x * x + y * y))
ax = plt.axes(projection='3d')
ax.plot_surface(x.numpy(), y.numpy(), z.numpy())
plt.show()

torch.meshgrid_第1张图片

你可能感兴趣的:(PyTorch,opencv,pytorch,深度学习)