官方文档
paper-v1
paper-v2
搬运自此处
ray matching
在Nerf的训练中,我们往往需要首先定义一个rays_o(射线的origin)以及rays_d(射线的方向),然后通过取不同的t值(预定义或者其他方式),获取这条射线上的不同的点:
pts = rays_o + rays_d * t # we may probably have 1024 t
然而,在实际情况中,这条射线上的很多点所代表的空间可能是empty的,其对应的density为0,这些点对这条射线的rendering完全没有任何作用。假设我们在训练的时候可以跳过这些空的区域,我们就可以减少每条线上所采样的点,理论上进而加速训练。为了实现这个想法,nerfacc中,我们可以将region of interest
划分成一个个网格(occupancy grid
),每个网格储存这个grid
是occupied or empty
的信息。这样,在训练的时候,假设当前采样的点落在empty
的grid
内,我们就可以忽略这个采样点,加速模型的训练。(具体实现请看文章下半部分的Occupancy Grid)
除此之外,对于一条射线,我们假设它打到了一个物体(比如说一个面墙),那么理论上,我们就不需要关注这面墙后面的点,因为我们对该条射线所预期的颜色即为墙的颜色,墙后面的信息我们并不清楚,也不重要,理所当然的,也可以省去墙后的点。实现该想法的方法很简单,nerfacc允许我们设定一个 T T T 值(Transmittance)的阈值,比如 T < 1 ∗ e − 4 T<1*e^{−4} T<1∗e−4
。在投射的过程中,我们会计算每一个点的density,从而算出他所对应的 T T T 值,假设该点的 T T T 值小于阈值,代表该射线遇到了遮挡(或者说打到了一个东西),该点之后的其他点可以被省去。
所以,总的来说,nerfacc在算法上,采用跳过空区域与提前终止射线在遮挡区域,通过减少每条线上点的数量,来加速模型的训练。
这里主要是硬件层面如何高效的实现以上算法,读者可以去paper里面细读,这里不做赘述(作者硬件方面是个小白)。
在上文中提到的Occupancy Grid有一个明显的问题,就是当场景变得越来越大的时候,所需要的grid数量也会暴涨,这会给内存带来极大的负担。于是nerfacc采用了Mip-Nerf 360的思想,在查询occupancy grid的时候,通过一个非线性函数将无边界(Unbounded)的大场景map到有限的grid中,从而实现对于大场景,也可以使用occupancy grid加速训练。
GPU优化过的,更快的渲染方式。
下面来讲一下nerfacc库中几个比较重要的类与函数。(以下代码参考了instant-nsr-pl)
这个类就是我们上文中提到的,用来跳过empty区域的occupancy grid,先看一下它一般是怎么定义的:
from nerfacc import ContractionType, OccupancyGrid
# define the bounding box for Region of Interest
self.scene_aabb = torch.as_tensor(
[-self.config.radius, -self.config.radius, -self.config.radius, self.config.radius, self.config.radius,
self.config.radius], dtype=torch.float32)
# define the contraction_type for scene contraction
self.contraction_type = ContractionType.AABB # or ContractionType.UN_BOUNDED_SPHERE, ContractionType.UN_BOUNDED_TANH
# create Occupancy Grid
self.occupancy_grid = OccupancyGrid(
roi_aabb=self.scene_aabb,
resolution=256, # if res is different along different axis, use [256,128,64]
contraction_type=self.contraction_type)
分为三步:
至此,我们已经定义好了我们的Occupancy Grid
的基本特征。接下来,我们需要自己写一个函数,用来更新这个occupancy grid(我们一开始也不知道哪里是空的,所以需要随着模型的训练,同时评估并更新每一个网格的信息)
一个例子:
def occ_eval_fn(x):
density, _ = self.nerf_network(x)
return density * self.render_step_size
self.occupancy_grid.every_n_step(step=global_step, occ_eval_fn=occ_eval_fn)
接下来,我们看看occupancy grid
是如何在训练中被引入的。
这个函数里面实现了上述所说的跳过空区域与提前终止射线在遮挡区域的算法(注意,这个函数对input是不可导的)
例子(参考官方文档):
import torch
from nerfacc import OccupancyGrid, ray_marching, unpack_info
device = "cuda:0"
batch_size = 128
rays_o = torch.rand((batch_size, 3), device=device)
rays_d = torch.randn((batch_size, 3), device=device)
rays_d = rays_d / rays_d.norm(dim=-1, keepdim=True)
# Ray marching with near far plane.
ray_indices, t_starts, t_ends = ray_marching(
rays_o, rays_d, near_plane=0.1, far_plane=1.0, render_step_size=1e-3
)
# Ray marching with aabb.
scene_aabb = torch.tensor([0.0, 0.0, 0.0, 1.0, 1.0, 1.0], device=device)
ray_indices, t_starts, t_ends = ray_marching(
rays_o, rays_d, scene_aabb=scene_aabb, render_step_size=1e-3
)
# Ray marching with per-ray t_min and t_max.
t_min = torch.zeros((batch_size,), device=device)
t_max = torch.ones((batch_size,), device=device)
ray_indices, t_starts, t_ends = ray_marching(
rays_o, rays_d, t_min=t_min, t_max=t_max, render_step_size=1e-3
)
# Ray marching with aabb and skip areas based on occupancy grid.
scene_aabb = torch.tensor([0.0, 0.0, 0.0, 1.0, 1.0, 1.0], device=device)
grid = OccupancyGrid(roi_aabb=[0.0, 0.0, 0.0, 0.5, 0.5, 0.5]).to(device)
ray_indices, t_starts, t_ends = ray_marching(
rays_o, rays_d, scene_aabb=scene_aabb, grid=grid, render_step_size=1e-3
)
可以看到,该函数还是很简单的,大体上为这个思路:
该函数的返回值有三个,分别是ray_indices,t_starts与t_ends
。他们的shape分别为(n_samples,),(n_samples,1)与(n_samples,1)。其中,n_samples代表在这次ray_marching的过程中,我们一个取了多少点(包括所有的ray)。
打个比方,我们一共射出三条线,通过最开始提到的优化算法,其中一条取了10个点,另外两条都取了5个点,那么这个n_samples就是10+5+5=20,而这个ray_indices就是表示每一个点是属于这三条线的哪一条(0,1,2)。t_starts与t_ends则是表示各点之间的间隔,我们可以通过以下代码获得最后取点的具体坐标:
# Convert t_starts and t_ends to sample locations.
t_mid = (t_starts + t_ends) / 2.0
sample_locs = rays_o[ray_indices] + t_mid * rays_d[ray_indices]
有了这些点后,我们就可以进行最后的rendering,nerfacc为我们提供了优化GPU加速后的渲染函数。
例子(来自官方文档):
rays_o = torch.rand((128, 3), device="cuda:0")
rays_d = torch.randn((128, 3), device="cuda:0")
rays_d = rays_d / rays_d.norm(dim=-1, keepdim=True)
ray_indices, t_starts, t_ends = ray_marching(
rays_o, rays_d, near_plane=0.1, far_plane=1.0, render_step_size=1e-3)
def rgb_sigma_fn(t_starts, t_ends, ray_indices):
# This is a dummy function that returns random values.
rgbs = torch.rand((t_starts.shape[0], 3), device="cuda:0")
sigmas = torch.rand((t_starts.shape[0], 1), device="cuda:0")
return rgbs, sigmas
colors, opacities, depths = rendering(
t_starts, t_ends, ray_indices, n_rays=128, rgb_sigma_fn=rgb_sigma_fn)
print(colors.shape, opacities.shape, depths.shape)
#torch.Size([128, 3]) torch.Size([128, 1]) torch.Size([128, 1])
可见,对于rendering函数,我们需要提供在上一步ray_marching中获得的t_starts, t_end, ray_indices与射线的数量以及一个rgb_sigma_fn。
这个rgb_sigma_fn是什么呢?很简单,就是一个query函数,即输入t_starts, t_end, ray_indices,获得各点的rgb与density。
To be more specific, 这个函数可以这样写(来自instant-nsr-pl):
def rgb_sigma_fn(t_starts, t_ends, ray_indices):
ray_indices = ray_indices.long()
t_origins = rays_o[ray_indices]
t_dirs = rays_d[ray_indices]
positions = t_origins + t_dirs * (t_starts + t_ends) / 2. #获得每个点的坐标
density, feature = self.geometry(positions)
rgb = self.texture(feature, t_dirs) # 输入到函数中,获得rgb与density
return rgb, density
此外,如果我们去看这个rendering函数的源代码,我们发现它其实是一连串nerfacc提供的python API所构建而成。我们如果需要render出其他不同的output(例如深度图的variance),便可以不直接采用rendering,而是采用调用python api的方法。
下面是rendering函数的另一种写法(简要版)
rgbs, sigmas = rgb_sigma_fn(t_starts, t_ends, ray_indices) #计算rgb与sigma(density)
weights = render_weight_from_density(
t_starts,
t_ends,
sigmas,
ray_indices=ray_indices,
n_rays=n_rays,
) # 通过调用nerfacc的API,计算出每个点对应的weights
#通过累计不同的value,获得不同的output
colors = accumulate_along_rays(
weights, ray_indices, values=rgbs, n_rays=n_rays)
opacities = accumulate_along_rays(
weights, ray_indices, values=None, n_rays=n_rays)
depths = accumulate_along_rays(
weights,
ray_indices,
values=(t_starts + t_ends) / 2.0,
n_rays=n_rays,)
具体的API可以见python API,里面提供了非常灵活的函数,可以让我们根据自己的网络,去render不同的output。
最后,获得的colors/opacities/depths就可以去做loss,然后更新网络啦!