解读nerf_pytorch中的get_rays和get_rays_np函数

source code from yenchenlin:
https://github.com/yenchenlin/nerf-pytorch

作者对于numpy的各种操作出神入化,其精炼程度令人叹为观止。本文总结其中两个函数的物理模型意义与(尤其是)矩阵计算意义,作为学习记录。

物理模型意义

详见:
https://zhuanlan.zhihu.com/p/593204605/
中的《3D空间射线怎么构造》。

矩阵计算意义(重点)

比较get_rays和get_rays_np可以发现,前者是在pytorch中、后者实在numpy中的同一操作(所以后者函数名以“np”结尾)。因此我们选择其中一个进行研究即可(get_rays):

def get_rays(H, W, K, c2w):
    i, j = torch.meshgrid(torch.linspace(0, W-1, W), torch.linspace(0, H-1, H))  # pytorch's meshgrid has indexing='ij'
    i = i.t()
    j = j.t()
    dirs = torch.stack([(i-K[0][2])/K[0][0], -(j-K[1][2])/K[1][1], -torch.ones_like(i)], -1)
    # Rotate ray directions from camera frame to the world frame
    rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1)  # dot product, equals to: [c2w.dot(dir) for dir in dirs]
    # Translate camera frame's origin to the world frame. It is the origin of all rays.
    rays_o = c2w[:3,-1].expand(rays_d.shape)
    return rays_o, rays_d

接下来进行我学习花费良久的逐行解释——

输入参数

调用该函数的函数有相关注解:

H: int. Height of image in pixels.
W: int. Width of image in pixels.
c2w: array of shape [3, 4]. Camera-to-world transformation matrix.

K则是一个(3x3)矩阵。

第一行

    i, j = torch.meshgrid(torch.linspace(0, W-1, W), torch.linspace(0, H-1, H)) 

作者给了一句尾注:

pytorch’s meshgrid has indexing=‘ij’

torch.linspace(0,W-1,W)

的意思,从0到W-1取一共W个点,弄成一个行向量。同理,

torch.linspace(0,H-1,H)

从0到W-1取一共W个点,弄成一个行向量。然后把两个放入torch的meshgrid,就可以得到一个以第一个参数为而重复的矩阵,以及一个以第二个参数为而重复的矩阵。注意,这一点和numpy的meshgrid是恰恰相反的(无语)。所以这就解释了第二行和第三行(numpy的计算相对更加符合思考的惯用形式,相对):

i = i.t()
j = j.t()

另外,和numpy的meshgrid相比,torch的meshgrid自带默认的类似前者的“indexing=‘xy’”的功能。

第四行

  dirs = torch.stack([(i-K[0][2])/K[0][0], -(j-K[1][2])/K[1][1], -torch.ones_like(i)], -1)

这里,K[0][2])/K[0][0]只是一个标量,而i是一个(H,W)的矩阵,那这样就意味着广播的介入,因此i-K[0][2])/K[0][0]就是一个(H,W)的矩阵:它意思是每个像素点的横坐标都根据https://zhuanlan.zhihu.com/p/593204605/
中的《3D空间射线怎么构造》的公式计算好了。显然,每个像素点的纵坐标也同样通过一个(H,W)的矩阵 -(j-K[1][2])/K[1][1]得到了。类似地,z坐标的情况则是 -torch.ones_like(i)

好,那么torch.stack在这里是要做什么呢?观察其axis参数为-1。我参考了https://blog.csdn.net/weixin_44201525/article/details/109769214的讲法,特别是:

axis为0,表示它堆叠方向为第0维,堆叠的内容为数组第0维的数据。前面说了第0维是相对于堆叠的数组而言的,而这里数组的第0维其实就是整个3×4的数组(其中第1维为行,第2维为某一行中的一个值,这里有一个层层深入的感觉),所以就是以整个3×4的数组为堆叠内容在第0维上进行堆叠,等到的结果就是一个3×3×4的新数组。再通俗一点,就是将a,b,c分别作为堆叠内容进行堆叠得到3×3×4的输出。

以及

和刚才的解释一样,axis为1表示堆叠的方向为3×4数组的第1维(行),堆叠内容也为3×4数组的第1维的数据。而3×4的数组的第1维就是它的行,以数组a为例,它的堆叠数据分别是[0 1 2 3],[ 4 5 6 7],[ 8 9 10 11]。

意思就是说,根据层层深入的思想,axis=-1,也就是最后一个维度,那么就可以理解为,你通过这个维度之前的一层层维度,深入到了这个维度,然后开始堆叠。

所以我们看,这里dirs就是一个像素点一个像素点地“堆叠”,其中每个像素点的信息就是它的xyz坐标。

第五、六行

    # Rotate ray directions from camera frame to the world frame
    rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1)  # dot product, equals to: [c2w.dot(dir) for dir in dirs]

这里确实十分令人头疼!
如前所述,dirs的维度在上一行应该是(W,H,3),其中最后的3“遍历”每个点的x、y、z。现在,

dirs[...,np.newaxis,:]

可以得到一个(W, H, 1, 3)的矩阵。那么它和c2w[:3,:3]的关系是啥?阅读https://blog.csdn.net/qq_51352578/article/details/125074264 学习numpy的广播机制,可以知道,它这里插入一个新的1维度,可以让逐点乘法*得以完成。但是这也不是乱加的axis。我们要问,这个操作的物理意义是啥?
事实上,答案和上一行的解读类似。就是说,你插入了一个newaxis,那么广播的时候你就自己复制了之前维度的东西。在咱的场景里,这个之前的维度不是别的,正是一个个像素点!事实上,c2w[:3,:3] 即3列分别表达关于x轴、y轴、z轴的信息
(参见 c2w矩阵的值直接描述了相机坐标系的朝向和原点 )。这里的*运算可理解为:


> (。。。)  点 点 点    *    c2w(3,3)
>           口 口 口
> 			口 口 口
> 			口 口 口

然后sum就是按列求和(其中同一个点被案列复制了三遍,这就是加了个newaxis的效果!,有转置的特性)。这也符合作者注释里面说的:

dot product, equals to: [c2w.dot(dir) for dir in dirs]

即每个dir是锁定了横坐标的点坐标数据,然后被c2w左乘。

第七、八行

Translate camera frame's origin to the world frame. It is the origin of all rays.
    rays_o = c2w[:3,-1].expand(rays_d.shape)

参看expand
也就是指定维度的一种广播

第九行

return rays_o, rays_d

不难知道此时返回的两个都是维度为(H,W,3)

你可能感兴趣的:(pytorch,python,深度学习)