torch.meshgrid(*tensors)
创建网格坐标
Creates grids of coordinates specified by the 1D inputs in attr:tensors.
This is helpful when you want to visualize data over some range of inputs.
返回:两个矩阵
直接看两个例子理解,第一个展示输入输出,第二个便于理解为什么可以创建网格坐标
第一个
x = torch.tensor([1, 2, 3])
y = torch.tensor([4, 5, 6, 7])
grid_x, grid_y = torch.meshgrid(x, y)
print("grid_x: ", grid_x)
print("grid_y: ", grid_y)
print(torch.equal(torch.cat(tuple(torch.dstack([grid_x, grid_y]))), torch.cartesian_prod(x, y)))
'''
grid_x:
tensor([[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3]])
grid_y:
tensor([[4, 5, 6, 7],
[4, 5, 6, 7],
[4, 5, 6, 7]])
True
'''
第二个演示绘图
可以看到 x , y x,y x,y就直接定义了其横纵坐标,下面的 x , y , z x,y,z x,y,z均为2维
xs = torch.linspace(-5, 5, steps=100)
ys = torch.linspace(-5, 5, steps=100)
x, y = torch.meshgrid(xs, ys)
z = torch.sin(torch.sqrt(x * x + y * y))
ax = plt.axes(projection='3d')
ax.plot_surface(x.numpy(), y.numpy(), z.numpy())
plt.show()