#本文代码来自 kewa123/nerf-pl
# Perform model inference to get rgb and raw sigma
B = xyz_.shape[0]
out_chunks = []
for i in range(0, B, chunk):
# Embed positions by chunk
xyz_embedded = embedding_xyz(xyz_[i:i+chunk])
if not weights_only:
xyzdir_embedded = torch.cat([xyz_embedded,
dir_embedded[i:i+chunk]], 1)
else:
xyzdir_embedded = xyz_embedded
out_chunks += [model(xyzdir_embedded, sigma_only=weights_only)]
out = torch.cat(out_chunks, 0)
if weights_only:
sigmas = out.view(N_rays, N_samples_)
else:
rgbsigma = out.view(N_rays, N_samples_, 4)
rgbs = rgbsigma[..., :3] # (N_rays, N_samples_, 3)
sigmas = rgbsigma[..., 3] # (N_rays, N_samples_)
关于 out = torch.cat(out_chunks, 0)
的注解
注意虽然只出现了一个变量 out_chunks,但是它本身是 tensor 的 list。
如果用一个 tensor,执行 cat 函数会报错
A = torch.tensor([[1,2,3,4],[8,6,5,3]])
B = torch.tensor([[1,2,3,4],[8,6,5,3]])
C = [A ,B]
D = torch.cat(C,0)
print(D)
输出:
tensor([[1, 2, 3, 4],
[8, 6, 5, 3],
[1, 2, 3, 4],
[8, 6, 5, 3]])
class ClassName(SomeModuleName):
def __init__(self, OherParameters):
super(ClassName, self).__init__()