Nerf Pytorch 代码 shuffle_ray 代码阅读

训练Nerf 之前首先要生成 Ray:
一张图像生成的 光线是:(2,378,504,3) 在函数get_rays_np
36张图像生成的Ray:(36,2,378,504,3)

 # [N, ro+rd, H, W, 3]
rays = np.stack([get_rays_np(H, W, K, p) for p in poses[:, :3, :4]], 0) 

计算光线r(t)= o +td的思路:
方向d 用 将像素(i,j)投影到相机坐标系下的归一化平面3D 点(X,Y, Z),然后通过 pose 转到世界坐标系下面:

Xc = (i-cx)/ f,
Yc = (cy- j) / f,
Zc = -1
[Xw,Yw,Zw] = R * [Xc,Yc,Zc] ## R 是 c->w 的旋转矩阵

光线心o 的计算,就是 pose矩阵里面的平移向量t。


def get_rays_np(H, W, K, c2w):
    i, j = np.meshgrid(np.arange(W, dtype=np.float32), np.arange(H, dtype=np.float32), indexing='xy')

    dirs = np.stack([(i-K[0][2])/K[0][0], -(j-K[1][2])/K[1][1], -np.ones_like(i)], -1)
    # Rotate ray directions from camera frame to the world frame
    rays_d = np.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1)  # dot product, equals to: [c2w.dot(dir) for dir in dirs]
    # Translate camera frame's origin to the world frame. It is the origin of all rays.
    rays_o = np.broadcast_to(c2w[:3,-1], np.shape(rays_d))
    return rays_o, rays_d
def get_rays_np(H, W, K, c2w):
# 1 生成图像的采样点 (i,j)
    i, j = np.meshgrid(np.arange(W, dtype=np.float32), np.arange(H, dtype=np.float32), indexing='xy')
    
    dirs = np.stack([(i-K[0][2])/K[0][0], -(j-K[1][2])/K[1][1], -np.ones_like(i)], -1)
    # Rotate ray directions from camera frame to the world frame
    rays_d = np.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1)  # dot product, equals to: [c2w.dot(dir) for dir in dirs]
    # Translate camera frame's origin to the world frame. It is the origin of all rays.
    rays_o = np.broadcast_to(c2w[:3,-1], np.shape(rays_d))
    return rays_o, rays_d

取出所有的光线,比如32 张training set 的图像,每一张图像的 resolutation 为(H,W),所以总的光线为 32 * H *W ,然后每个batch_size = 1024 为一组,进行光线的shuffle。所以一个batch 的Tensor 为(1024,3,3)。 dim=1 的位置数值为3,表示:o + d + rgb 三个信息。

研读render 函数:

Train 函数里面调用render 函数:

    #####  Core optimization loop  #####
        rgb, disp, acc, extras = render(H, W, K, chunk=args.chunk, rays=batch_rays,verbose=i < 10, retraw=True,**render_kwargs_train)

代码默认的batchsize = 1024
batch_rays 表示 1024 条Ray 的 光心 o 和 方向 d 维度是:(2,1024,3)

def batchify_rays(rays_flat, chunk=1024 * 32, **kwargs):
    """Render rays in smaller minibatches to avoid OOM.
    ## batch = 1024 , 但是chunk = 1024 *32 实际上 rays_flat[i:i + chunk] 还是取得1024 维度,因为rays_flat 一共只有dim=0 ,一共只有1024维度
    """
    all_ret = {}
    for i in range(0, rays_flat.shape[0], chunk):

        ret = render_rays(rays_flat[i:i + chunk], **kwargs)
        for k in ret:
            if k not in all_ret:
                all_ret[k] = []
            all_ret[k].append(ret[k])

    all_ret = {k: torch.cat(all_ret[k], 0) for k in all_ret}
    return all_ret
    
def render(H, W, K, chunk=1024 * 32, rays=None, c2w=None, ndc=True,
           near=0., far=1.,
           use_viewdirs=False, c2w_staticcam=None,
           **kwargs):
    """Render rays
    Args:
      H: int. Height of image in pixels.
      W: int. Width of image in pixels.
      focal: float. Focal length of pinhole camera.
      chunk: int. Maximum number of rays to process simultaneously. Used to
        control maximum memory usage. Does not affect final results.
      rays: array of shape [2, batch_size, 3]. Ray origin and direction for
        each example in batch.
      c2w: array of shape [3, 4]. Camera-to-world transformation matrix.
      ndc: bool. If True, represent ray origin, direction in NDC coordinates.
      near: float or array of shape [batch_size]. Nearest distance for a ray.
      far: float or array of shape [batch_size]. Farthest distance for a ray.
      use_viewdirs: bool. If True, use viewing direction of a point in space in model.
      c2w_staticcam: array of shape [3, 4]. If not None, use this transformation matrix for 
       camera while using other c2w argument for viewing directions.
    Returns:
      rgb_map: [batch_size, 3]. Predicted RGB values for rays.
      disp_map: [batch_size]. Disparity map. Inverse of depth.
      acc_map: [batch_size]. Accumulated opacity (alpha) along a ray.
      extras: dict with everything returned by render_rays().
    """
    if c2w is not None:
        # special case to render full image
        rays_o, rays_d = get_rays(H, W, K, c2w)
    else:
        # use provided ray batch
        rays_o, rays_d = rays

    if use_viewdirs:
        # provide ray directions as input
        viewdirs = rays_d
        if c2w_staticcam is not None:
            # special case to visualize effect of viewdirs
            rays_o, rays_d = get_rays(H, W, K, c2w_staticcam)

         # 对于 direction 进行了模值归一化
        viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True)
        viewdirs = torch.reshape(viewdirs, [-1, 3]).float()

    sh = rays_d.shape  # [..., 3]
    if ndc:
        # for forward facing scenes
        rays_o, rays_d = ndc_rays(H, W, K[0][0], 1., rays_o, rays_d)

    # Create ray batch
    rays_o = torch.reshape(rays_o, [-1, 3]).float()
    rays_d = torch.reshape(rays_d, [-1, 3]).float()

    near, far = near * torch.ones_like(rays_d[..., :1]), far * torch.ones_like(rays_d[..., :1])
    rays = torch.cat([rays_o, rays_d, near, far], -1)  ##(1024, 8)
    if use_viewdirs:
        rays = torch.cat([rays, viewdirs], -1)   ## (1024,11)  [ray_o,ray_d,near,far, viewdir]

    # Render and reshape
    all_ret = batchify_rays(rays, chunk, **kwargs)
    for k in all_ret:
        k_sh = list(sh[:-1]) + list(all_ret[k].shape[1:])
        all_ret[k] = torch.reshape(all_ret[k], k_sh)

    k_extract = ['rgb_map', 'disp_map', 'acc_map']
    ret_list = [all_ret[k] for k in k_extract]
    ret_dict = {k: all_ret[k] for k in all_ret if k not in k_extract}
    return ret_list + [ret_dict]

Importance Sampling:

这里的采样应用的是 逆变换采样理论(Inverse Transform Sampling):生成一系列的 Adaptive sampling(Importance Sampling),根据 weight 的PDF 进行采样,在weight 的 CDF 上进行采样,生成满足weight PDF 的采样点

参考网址:https://www.cnblogs.com/heben/p/10908010.html

# Hierarchical sampling (section 5.2)
'''
    bin 是coarse 采样点 interval 之间的点, weight 是 coarse 采样点 在 coarse network 中的作volume rendering的权重
    这里的采样应用的是 逆变换采样理论(Inverse Transform Sampling):生成一系列的 Adaptive sampling(Importance Sampling),根据 weight 的PDF 进行采样,生成满足weight PDF 的采样点 
    直观理解:U 是一个(0,1)的均匀分布,产生随机数。 生成的U 大部分位于CDF的快速上升期(y值跨度大),CDF的快速上升期与对应的 PDF 的 weight 也比较大
    对CDF 的 反函数 进行均匀采样,生成的是 和 PDF 分布 相同的采样值。
    参考网址:https://www.cnblogs.com/heben/p/10908010.html
'''
def sample_pdf(bins, weights, N_samples, det=False, pytest=False):
    # Get pdf
    weights = weights + 1e-5 # prevent nans
    pdf = weights / torch.sum(weights, -1, keepdim=True)  ## 对每条光线上的采样点的 权重weight 进行归一化
    ## 通过对 PDF 进行累加,计算 CDF
    cdf = torch.cumsum(pdf, -1)
    cdf = torch.cat([torch.zeros_like(cdf[...,:1]), cdf], -1)  # (batch, len(bins))

    # Take uniform samples
    if det:
        u = torch.linspace(0., 1., steps=N_samples)
        u = u.expand(list(cdf.shape[:-1]) + [N_samples])
    else:
        u = torch.rand(list(cdf.shape[:-1]) + [N_samples]) # (1024,64) 每一个数字是平均采样的

    # Pytest, overwrite u with numpy's fixed random numbers
    if pytest:
        np.random.seed(0)
        new_shape = list(cdf.shape[:-1]) + [N_samples]
        if det:
            u = np.linspace(0., 1., N_samples)
            u = np.broadcast_to(u, new_shape)
        else:
            u = np.random.rand(*new_shape)
        u = torch.Tensor(u)

    ''''
    ----------------------------->   一条Ray 上面有63个bin, 每个bin 的端点都是 corese 网络的 采样点,每个bin 存储的是一个cdf 数值。 现在的目标是在这64个bin 中继续采样出64个点,用于fine 网络
    inds_g = torch.stack([below, above], -1) shape: 【1024,64,2】 表示的是 给个interval 的左右端点的X 坐标索引
    cdf_g 表示 inds_g 左右端点对应的 CDF 数值大小
    bins_g 表示 inds_g 左右端点对应的 Ray 上面的 深度值大小
    '''
    # Invert CDF 逆变换采样

    u = u.contiguous()  ## u 是一个(1024,64) 的 生成的均匀分布量
    inds = torch.searchsorted(cdf, u, right=True)  ## 求出u = cdf 时候,对应的 cdf 位于 bin 的横坐标 x
    below = torch.max(torch.zeros_like(inds-1), inds-1)
    above = torch.min((cdf.shape[-1]-1) * torch.ones_like(inds), inds)
    inds_g = torch.stack([below, above], -1)  # (1024, 64, 2)

    # cdf_g = tf.gather(cdf, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2)
    # bins_g = tf.gather(bins, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2)
    matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
    cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)  ## 取出cdf 的纵坐标
    bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)

    denom = (cdf_g[...,1]-cdf_g[...,0])
    denom = torch.where(denom<1e-5, torch.ones_like(denom), denom)
    t = (u-cdf_g[...,0])/denom
    samples = bins_g[...,0] + t * (bins_g[...,1]-bins_g[...,0])

    return samples

你可能感兴趣的:(python)