NeRF-pl代码理解记录

#本文代码来自 kewa123/nerf-pl

将数据分割成好多个chunks,逐chunk将数据丢进model,逐个输出结果,最后将结果拼接起来

  • B:所有 batch 的数目,大小等于 N_rays*N_samples
  • weights_only : 如果为 true,则没有 用 weights 计算出RGB的过程。
  • out_chunks += [model(xyzdir_embedded, sigma_only=weights_only)] 是将 数据丢进 model,然后拼接起来
       # Perform model inference to get rgb and raw sigma
        B = xyz_.shape[0]
        out_chunks = []
        for i in range(0, B, chunk):
            # Embed positions by chunk
            xyz_embedded = embedding_xyz(xyz_[i:i+chunk])
            if not weights_only:
                xyzdir_embedded = torch.cat([xyz_embedded,
                                             dir_embedded[i:i+chunk]], 1)
            else:
                xyzdir_embedded = xyz_embedded
            out_chunks += [model(xyzdir_embedded, sigma_only=weights_only)]

        out = torch.cat(out_chunks, 0)
        if weights_only:
            sigmas = out.view(N_rays, N_samples_)
        else:
            rgbsigma = out.view(N_rays, N_samples_, 4)
            rgbs = rgbsigma[..., :3] # (N_rays, N_samples_, 3)
            sigmas = rgbsigma[..., 3] # (N_rays, N_samples_)

关于 out = torch.cat(out_chunks, 0) 的注解
注意虽然只出现了一个变量 out_chunks,但是它本身是 tensor 的 list。
如果用一个 tensor,执行 cat 函数会报错

A = torch.tensor([[1,2,3,4],[8,6,5,3]])
B = torch.tensor([[1,2,3,4],[8,6,5,3]])
C = [A ,B]
D = torch.cat(C,0)
print(D)

输出:

tensor([[1, 2, 3, 4],
        [8, 6, 5, 3],
        [1, 2, 3, 4],
        [8, 6, 5, 3]])

通用做法

class ClassName(SomeModuleName):
    def __init__(self, OherParameters):
        super(ClassName, self).__init__()

你可能感兴趣的:(计算机视觉/图形学,python,深度学习,pytorch)