torch.meshgrid()和np.meshgrid()的区别

np.meshgrid()函数常用于生成二维网格,比如图像的坐标点。
pytorch中也有一个类似的函数torch.meshgrid(),功能也类似,但是两者的用法有区别,使用时需要注意(刚踩坑,因此记录一下。。。)

比如我要生成一张图像(h=6, w=10)的xy坐标点,看下两者的实现方式:

np.meshgrid()

>>> import numpy as np
>>> h = 6
>>> w = 10
>>> xs, ys = np.meshgrid(np.arange(w), np.arange(h))
>>> xs.shape
(6, 10)
>>> ys.shape
(6, 10)
>>> xs
array([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])
>>> ys
array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
       [2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
       [3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
       [4, 4, 4, 4, 4, 4, 4, 4, 4, 4],
       [5, 5, 5, 5, 5, 5, 5, 5, 5, 5]])
>>> xys = np.stack([xs, ys], axis=-1)
>>> xys.shape
(6, 10, 2)

torch.meshgrid()

>>> import torch
>>> h = 6
>>> w = 10
>>> ys,xs = torch.meshgrid(torch.arange(h), torch.arange(w))
>>> xs.shape
torch.Size([6, 10])
>>> ys.shape
torch.Size([6, 10])
>>> xs
tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])
>>> ys
tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
        [3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
        [4, 4, 4, 4, 4, 4, 4, 4, 4, 4],
        [5, 5, 5, 5, 5, 5, 5, 5, 5, 5]])
>>> xys = torch.stack([xs, ys], dim=-1)
>>> xys.shape
torch.Size([6, 10, 2])

从python交互式窗口可以清晰的看出numpy和pytorch中meshgrid()函数的区别,就不用文字总结了,自己体会哈哈哈。

你可能感兴趣的:(torch.meshgrid()和np.meshgrid()的区别)