NeRF 源码分析解读(五)

NeRF 源码分析解读(五)

在之前的博客中我们介绍了光线的模拟方法以及如何在光线上生成空间中的 3D 点。有了这些 3D 点的坐标,根据论文中提出的模型,将坐标连同视图方向作为输入,得到每个点对应的颜色 RGB 以及 密度 σ \sigma σ 。然后根据体积渲染公式对这条光线上的点进行累积积分,得到光线的颜色,光线的颜色即对应于相应的像素点的颜色。
NeRF 源码分析解读(五)_第1张图片

我们在 分析解读(四)中得到了光线上 3D 点的位置以及该点对应的颜色:

def render_rays():
	
	...
	# 生成光线上每个采样点的位置
	pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None]  # [N_rays, N_samples, 3]
	# 将光线上的每个点投入到 MLP 网络 network_fn 中前向传播得到每个点对应的 (RGB,A)
	raw = network_query_fn(pts, viewdirs, network_fn)

	# 对这些离散点进行体积渲染,即进行积分操作
	rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest)

raw2outputs() 函数将离散的点进行积分,得到对应的像素颜色。下面我们对该函数进行分析:

def raw2outputs(raw, z_vals, rays_d, raw_noise_std=0, white_bkgd=False, pytest=False):

	raw2alpha = lambda raw, dists, act_fn=F.relu : 1. - torch.exp(-act_fn(raw) * dists)

	dists = z_vals[..., 1:] - z_vals[..., :-1]  # 计算两点Z轴之间的距离
	dists = torch.cat([dists, torch.Tensor([1e10]).expand(dists[..., :1].shape).to(device)], -1)  # [N_rays, N_samples]
	dists = dists * torch.norm(rays_d[..., None, :], dim=-1)  # 将 Z 轴之间的距离转换为实际距离

	rgb = torch.sigmoid(raw[..., :3])  # [N_rays, N_samples, 3]  每个点的 RGB 值

	...

	alpha = raw2alpha(raw[..., 3] + noise, dists)  # [N_rays, N_samples]  即透明度

	weights = alpha * torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1)), 1.-alpha + 1e-10], -1), -1)[:, :-1] 
	
	rgb_map = torch.sum(weights[...,None] * rgb, -2)  # [N_rays, 3]

	...

可以看到我们首先定义了一个匿名函数 raw2alpha 。论文中提到,空间中离散点的体积渲染公式表示为:
C ^ ( r ) = ∑ i = 1 N T i ( 1 − exp ⁡ ( − σ i δ i ) ) c i ,  where  T i = exp ⁡ ( − ∑ j = 1 i − 1 σ j δ j ) \hat{C}(\mathbf{r})=\sum_{i=1}^{N} T_{i}\left(1-\exp \left(-\sigma_{i} \delta_{i}\right)\right) \mathbf{c}_{i}, \text { where } T_{i}=\exp \left(-\sum_{j=1}^{i-1} \sigma_{j} \delta_{j}\right) C^(r)=i=1NTi(1exp(σiδi))ci, where Ti=exp(j=1i1σjδj)
raw2alpha 代表体渲染公式中的 1 − e x p ( − σ ∗ δ ) 1 - exp(-\sigma * \delta ) 1exp(σδ) 计算每个点的透明度。那么 1-alpha 即代表 1 − ( 1 − e x p ( − σ ∗ δ ) ) = e x p ( − σ ∗ δ ) 1-(1 - exp(-\sigma * \delta )) = exp(-\sigma * \delta ) 1(1exp(σδ))=exp(σδ)。如此一来 代码中的

torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1)), 1.-alpha + 1e-10], -1), -1)

即代表公式中的 T i T_i Ti ,即 weights 实际上代表的是渲染公式中的 T i ( 1 − exp ⁡ ( − σ i δ i ) ) T_{i}\left(1-\exp \left(-\sigma_{i} \delta_{i}\right)\right) Ti(1exp(σiδi))

这里其实很简单。
首先根据指数函数的性质: T i = exp ⁡ ( − ∑ j = 1 i − 1 σ j δ j ) = ∏ j = 1 i − 1 e x p ( − σ j δ j ) = ∏ j = 1 i − 1 ( 1 − a l p h a ) T_{i}=\exp \left(-\sum_{j=1}^{i-1} \sigma_{j} \delta_{j}\right) = \prod_{j=1}^{i-1} exp(-\sigma_{j} \delta_{j}) = \prod_{j=1}^{i-1}(1-alpha) Ti=exp(j=1i1σjδj)=j=1i1exp(σjδj)=j=1i1(1alpha) 。将论文中的公式略做变换使其能够和 alpha结合使用。

t = torch.cat([torch.ones(alpha.shape[0], 1), 1. - alpha + 1e-10], -1)
T = torch.cumprod(t, -1)  # 公式中的累乘
weights = alpha * T[:, :-1]

这里构造一个变量 t [1, 1-alpha] shape is [N_rays, N_samples+1] 。torch.cumprod(t, -1) 保持第一列不变,后面的列依次累乘前列。又因为公式中计算的是前 i-1 列的累积结果,所以舍去最后一列取 T[:, :-1]
同理rgb_map 代表了最终的渲染颜色,代码中关于其他的几个 disp_map, acc_map,depth_map 作为额外的信息输出,与NeRF原理关联性不大我们不再予以分析。由此我们得到了每个光线对应的像素值的颜色。我们继续对 render_rays() 函数进行分析:

def render_rays():
	
	...
	
	raw = network_query_fn(pts, viewdirs, network_fn)
	rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest)

	# 分层采样的细采样阶段
    if N_importance > 0:
       rgb_map_0, disp_map_0, acc_map_0 = rgb_map, disp_map, acc_map

       z_vals_mid = .5 * (z_vals[...,1:] + z_vals[...,:-1])
       z_samples = sample_pdf(z_vals_mid, weights[...,1:-1], N_importance, det=(perturb==0.), pytest=pytest)  # 根据权重 weight 判断这个点在物体表面附近的概率,重新采样
       z_samples = z_samples.detach()

       z_vals, _ = torch.sort(torch.cat([z_vals, z_samples], -1), -1)
       pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None] # [N_rays, N_samples + N_importance, 3]  生成新的采样点坐标

       run_fn = network_fn if network_fine is None else network_fine
       #raw = run_network(pts, fn=run_fn)
       raw = network_query_fn(pts, viewdirs, run_fn)  # 生成新采样点的颜色密度

       rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest)  # 生成细化的像素点的颜色

	ret = {'rgb_map' : rgb_map, 'disp_map' : disp_map, 'acc_map' : acc_map}
	...
	return ret

至此,render_rays() 函数分析完毕,那么调用 render_rays() 函数的 batchify_rays() 函数也走向了结束,返回所有光线对应的累积属性,再把这些属性返回给调用 batchify_rays() 函数的 render() 函数,render() 再把这些数据整理,返回给 train()
到此为止,代码中原理向的部分就已经解释完毕了,接来下就是神经网络训练步骤中的计算损失,损失反向传播以及更新梯度了:

def train():
	...
	
	# 4、开始训练
	for i in trange(start, N_iters):
		
		if use_batching:
			...
		else: 
			...
		
		rgb, disp, acc, extras = render(H, W, K, chunk=args.chunk, rays=batch_rays,
                                                verbose=i < 10, retraw=True,
                                                **render_kwargs_train)
		optimizer.zero_grad()
        img_loss = img2mse(rgb, target_s)  # 计算 MSE 损失
        trans = extras['raw'][...,-1]
        loss = img_loss
        psnr = mse2psnr(img_loss)  # 将损失转换为 PSNR 指标

        loss.backward()  # 损失反向传播
        optimizer.step()

        # NOTE: IMPORTANT!
        ###   update learning rate   ###  动态更新学习率
        decay_rate = 0.1
        decay_steps = args.lrate_decay * 1000
        new_lrate = args.lrate * (decay_rate ** (global_step / decay_steps))
        for param_group in optimizer.param_groups:
            param_group['lr'] = new_lrate
        ################################
        ...

至此,NeRF 源码解读完毕。NeRF 源码作为大多数衍生版本的基础,认真理解之后有助于提升其他代码的阅读速度,博客中解释的比较粗浅,有解释不到位的地方读者可以指出。希望能够和各位同学共同讨论NeRF的发展方向,共同进步!

完结撒花~

你可能感兴趣的:(NeRF,计算机视觉,人工智能,深度学习,python)