三维重建部分本质上是一个2D到3D的建模过程,利用3D点的位置(x,y,z)及方位视角(θ,φ)作为输入,通过多层感知机(MLP)建模该点对应的颜色color(c)及体素密度volume density(σ),形成了3D场景的”隐式表示“。
此处2D的输入包括 'llff'、'blender'、'LINEMOD'、'deepvoxels' 四种类型的数据集。其中的llff如前文所讲,是经过colmap处理之后N张图片形成的N*17矩阵,也就是论文中讲的5D输入(不同的数据集可能有不同的输入?没有全部仔细看过)。
images, poses, render_poses, hwf, i_split = load_blender_data(args.datadir, args.half_res, args.testskip)
basedir是基路径,half_res默认为False,同时lego数据集中的默认训练图片规格为 800×800 ,当该参数为True,则需要将图片规格转化为 400×400;
testskip的值为N,则会从测试集和验证集中挑选 1/N数量作为测试集和验证集,相当于跳跃步长。
def load_blender_data(basedir, half_res=False, testskip=1):
splits = ['train', 'val', 'test']
metas = {}
for s in splits:
with open(os.path.join(basedir, 'transforms_{}.json'.format(s)), 'r') as fp:
//meta:'camera_angle_x'相机的水平视场 (horizontal field of view),可以用于算焦距 (focal),'frames'里面有图片路径、图片的相机外参transform_matrix和旋转值rotation(未用到)
metas[s] = json.load(fp)
all_imgs = []
all_poses = []
counts = [0]
for s in splits:
meta = metas[s]
imgs = []
poses = []
if s=='train' or testskip==0:
skip = 1
skip = testskip
for frame in meta['frames'][::skip]:
fname = os.path.join(basedir, frame['file_path'] + '.png')
imgs = (np.array(imgs) / 255.).astype(np.float32) # keep all 4 channels (RGBA)
poses = np.array(poses).astype(np.float32)
counts.append(counts[-1] + imgs.shape[0])
i_split = [np.arange(counts[i], counts[i+1]) for i in range(3)]
imgs = np.concatenate(all_imgs, 0)
poses = np.concatenate(all_poses, 0)
H, W = imgs[0].shape[:2]
camera_angle_x = float(meta['camera_angle_x'])
focal = .5 * W / np.tan(.5 * camera_angle_x)
render_poses = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180,180,40+1)[:-1]], 0)
def pose_spherical(theta, phi, radius):
c2w = trans_t(radius)
c2w = rot_phi(phi/180.*np.pi) @ c2w
c2w = rot_theta(theta/180.*np.pi) @ c2w
c2w = torch.Tensor(np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w
return c2w
trans_t = lambda t : torch.Tensor([
rot_phi = lambda phi : torch.Tensor([
[0,np.sin(phi), np.cos(phi),0],
rot_theta = lambda th : torch.Tensor([
[np.sin(th),0, np.cos(th),0],
if half_res:
H = H//2
W = W//2
focal = focal/2.
imgs_half_res = np.zeros((imgs.shape[0], H, W, 4))
for i, img in enumerate(imgs):
imgs_half_res[i] = cv2.resize(img, (W, H), interpolation=cv2.INTER_AREA)
imgs = imgs_half_res
# imgs = tf.image.resize_area(imgs, [400, 400]).numpy()
return imgs, poses, render_poses, [H, W, focal], i_split
image N*H*W*4,pose N*4*4,render_pose 40*4*4,hwf 是最后一列的HWF,i_split是训练集、测试集、验证集的划分情况
near = 2.
far = 6.
if args.white_bkgd:
images = images[...,:3]*images[...,-1:] + (1.-images[...,-1:])
images = images[...,:3]
# Cast intrinsics to right types
H, W, focal = hwf
H, W = int(H), int(W)
hwf = [H, W, focal]
if K is None:
K = np.array([
[focal, 0, 0.5*W],
[0, focal, 0.5*H],
[0, 0, 1]
接着就到了创建模型,创建模型这里细讲, render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer = create_nerf(args)
NeRF模型的作用通过多层感知机(MLP)建模该点对应的颜色color(c)及体素密度volume density(σ),形成了3D场景的”隐式表示“,那么具体做法如上图,图中的Positional encoding是作者发现让①中的MLP网络(F:(x,d) -> (c,σ))直接操作 (x,y,z,θ,φ)输入会导致渲染在表示颜色和几何形状方面的高频变化方面表现不佳,表明深度网络偏向于学习低频函数。因此在将(x,y,z,θ,φ)输入传递给网络之前,使用高频函数将输入映射到更高维度的空间,可以更好地拟合包含高频变化的数据。
embed_fn, input_ch = get_embedder(args.multires, args.i_embed)
def get_embedder(multires, i=0):
if i == -1:
return nn.Identity(), 3
embed_kwargs = {
'include_input' : True,
'input_dims' : 3,
'max_freq_log2' : multires-1,
'num_freqs' : multires,
'log_sampling' : True,
'periodic_fns' : [torch.sin, torch.cos],
embedder_obj = Embedder(**embed_kwargs)
embed = lambda x, eo=embedder_obj : eo.embed(x)
return embed, embedder_obj.out_dim
# Positional encoding (section 5.1)
class Embedder:
def __init__(self, **kwargs):
self.kwargs = kwargs
def create_embedding_fn(self):
embed_fns = []
d = self.kwargs['input_dims']
out_dim = 0
if self.kwargs['include_input']:
embed_fns.append(lambda x : x)
out_dim += d
max_freq = self.kwargs['max_freq_log2']
N_freqs = self.kwargs['num_freqs']
if self.kwargs['log_sampling']:
freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs)
freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs)
for freq in freq_bands:
for p_fn in self.kwargs['periodic_fns']:
embed_fns.append(lambda x, p_fn=p_fn, freq=freq : p_fn(x * freq))
out_dim += d
self.embed_fns = embed_fns
self.out_dim = out_dim
def embed(self, inputs):
return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
这一段就是对θ,φ编码的编码器函数,以及模型输入,netdepth论文里提到了是8,8层MLP,netwidth是256,inputch是63,outputch是5,skips为4,但不清楚具体作用,后文再看,args.use_viewdirs为True表示输入包含方向信息,以5D的向量作为输入,否则仅由位置信息作为3D输入。这里的N_importance即为论文中在粗网络之后采样的 Nf个基于粗网络分布的点。粗网络来自于Hierarchical volume sampling方法,
该部分指出在Volume Rendering中是在每条相机光线上的N个查询点密集地评估神经辐射场网络,这是低效的(仍然重复采样与渲染图像无关的自由空间和遮挡区域),于是提出一种分层体积采样的做法,同时优化一个“粗糙”的网络和一个“精细”的网络。
input_ch_views = 0
embeddirs_fn = None
if args.use_viewdirs:
embeddirs_fn, input_ch_views = get_embedder(args.multires_views, args.i_embed)
output_ch = 5 if args.N_importance > 0 else 4
skips = [4]
model = NeRF(D=args.netdepth, W=args.netwidth,
input_ch=input_ch, output_ch=output_ch, skips=skips,
input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs).to(device)
(0): Linear(in_features=63, out_features=256, bias=True)
(1): Linear(in_features=256, out_features=256, bias=True)
(2): Linear(in_features=256, out_features=256, bias=True)
(3): Linear(in_features=256, out_features=256, bias=True)
(4): Linear(in_features=256, out_features=256, bias=True)
(5): Linear(in_features=319, out_features=256, bias=True)
(6): Linear(in_features=256, out_features=256, bias=True)
(7): Linear(in_features=256, out_features=256, bias=True)
class NeRF(nn.Module):
def __init__(self, D=8, W=256, input_ch=3, input_ch_views=3, output_ch=4, skips=[4], use_viewdirs=False):
super(NeRF, self).__init__()
self.D = D
self.W = W
self.input_ch = input_ch
self.input_ch_views = input_ch_views
self.skips = skips
self.use_viewdirs = use_viewdirs
self.pts_linears = nn.ModuleList(
[nn.Linear(input_ch, W)] + [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + input_ch, W) for i in range(D-1)])
### Implementation according to the official code release (https://github.com/bmild/nerf/blob/master/run_nerf_helpers.py#L104-L105)
self.views_linears = nn.ModuleList([nn.Linear(input_ch_views + W, W//2)])
### Implementation according to the paper
# self.views_linears = nn.ModuleList(
# [nn.Linear(input_ch_views + W, W//2)] + [nn.Linear(W//2, W//2) for i in range(D//2)])
if use_viewdirs:
self.feature_linear = nn.Linear(W, W)
self.alpha_linear = nn.Linear(W, 1)
self.rgb_linear = nn.Linear(W//2, 3)
self.output_linear = nn.Linear(W, output_ch)
前向推理部分,我们后续得到了输入再讲,在下一步中定义了细网络,lego.txt中给出N_importance是128,即对应 Nf=128 。和之前的网络相比,参数均没有变化
grad_vars = list(model.parameters())
model_fine = None
if args.N_importance > 0:
model_fine = NeRF(D=args.netdepth_fine, W=args.netwidth_fine,
input_ch=input_ch, output_ch=output_ch, skips=skips,
input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs).to(device)
grad_vars += list(model_fine.parameters())
network_query_fn = lambda inputs, viewdirs, network_fn : run_network(inputs, viewdirs, network_fn,
def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024*64):
"""Prepares inputs and applies network 'fn'.
inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]])
embedded = embed_fn(inputs_flat)
if viewdirs is not None:
input_dirs = viewdirs[:,None].expand(inputs.shape)
input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]])
embedded_dirs = embeddirs_fn(input_dirs_flat)
embedded = torch.cat([embedded, embedded_dirs], -1)
outputs_flat = batchify(fn, netchunk)(embedded)
outputs = torch.reshape(outputs_flat, list(inputs.shape[:-1]) + [outputs_flat.shape[-1]])
return outputs
optimizer = torch.optim.Adam(params=grad_vars, lr=args.lrate, betas=(0.9, 0.999))
start = 0
basedir = args.basedir
expname = args.expname
# Load checkpoints
if args.ft_path is not None and args.ft_path!='None':
ckpts = [args.ft_path]
ckpts = [os.path.join(basedir, expname, f) for f in sorted(os.listdir(os.path.join(basedir, expname))) if 'tar' in f]
print('Found ckpts', ckpts)
if len(ckpts) > 0 and not args.no_reload:
ckpt_path = ckpts[-1]
print('Reloading from', ckpt_path)
ckpt = torch.load(ckpt_path)
start = ckpt['global_step']
# Load model
if model_fine is not None:
render_kwargs_train = {
'network_query_fn' : network_query_fn,#匿名函数,给定三位点、方向,利用给定网络求解RGBA
'perturb' : args.perturb,
'N_importance' : args.N_importance,#每条光线上细采样点的数量
'network_fine' : model_fine,#精细网络
'N_samples' : args.N_samples,#每条光线上粗采样点的数量
'network_fn' : model,#粗网络
'use_viewdirs' : args.use_viewdirs,#是否使用视点方向
'white_bkgd' : args.white_bkgd,#是否将透明背景用白色填充
'raw_noise_std' : args.raw_noise_std,#归一化密度
# NDC only good for LLFF-style forward facing data
if args.dataset_type != 'llff' or args.no_ndc:
print('Not ndc!')
render_kwargs_train['ndc'] = False
render_kwargs_train['lindisp'] = args.lindisp
render_kwargs_test = {k : render_kwargs_train[k] for k in render_kwargs_train}
render_kwargs_test['perturb'] = False
render_kwargs_test['raw_noise_std'] = 0.
return render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer
{'network_query_fn': . at 0x000002838FAF99D8>, 'perturb': 1.0, 'N_importance': 128, 'network_fine': NeRF(
(pts_linears): ModuleList(
(0): Linear(in_features=63, out_features=256, bias=True)
(1): Linear(in_features=256, out_features=256, bias=True)
(2): Linear(in_features=256, out_features=256, bias=True)
(3): Linear(in_features=256, out_features=256, bias=True)
(4): Linear(in_features=256, out_features=256, bias=True)
(5): Linear(in_features=319, out_features=256, bias=True)
(6): Linear(in_features=256, out_features=256, bias=True)
(7): Linear(in_features=256, out_features=256, bias=True)
(views_linears): ModuleList(
(0): Linear(in_features=283, out_features=128, bias=True)
(feature_linear): Linear(in_features=256, out_features=256, bias=True)
(alpha_linear): Linear(in_features=256, out_features=1, bias=True)
(rgb_linear): Linear(in_features=128, out_features=3, bias=True)
), 'N_samples': 64, 'network_fn': NeRF(
(pts_linears): ModuleList(
(0): Linear(in_features=63, out_features=256, bias=True)
(1): Linear(in_features=256, out_features=256, bias=True)
(2): Linear(in_features=256, out_features=256, bias=True)
(3): Linear(in_features=256, out_features=256, bias=True)
(4): Linear(in_features=256, out_features=256, bias=True)
(5): Linear(in_features=319, out_features=256, bias=True)
(6): Linear(in_features=256, out_features=256, bias=True)
(7): Linear(in_features=256, out_features=256, bias=True)
(views_linears): ModuleList(
(0): Linear(in_features=283, out_features=128, bias=True)
(feature_linear): Linear(in_features=256, out_features=256, bias=True)
(alpha_linear): Linear(in_features=256, out_features=1, bias=True)
(rgb_linear): Linear(in_features=128, out_features=3, bias=True)
), 'use_viewdirs': True, 'white_bkgd': True, 'raw_noise_std': 0.0, 'ndc': False, 'lindisp': False}
global_step = start
bds_dict = {
'near' : near,
'far' : far,
# Move testing data to GPU
render_poses = torch.Tensor(render_poses).to(device)
# Short circuit if only rendering out from trained model
if args.render_only:
print('RENDER ONLY')
with torch.no_grad():
if args.render_test:
# render_test switches to test poses
images = images[i_test]
# Default is smoother render_poses path
images = None
testsavedir = os.path.join(basedir, expname, 'renderonly_{}_{:06d}'.format('test' if args.render_test else 'path', start))
os.makedirs(testsavedir, exist_ok=True)
print('test poses shape', render_poses.shape)
rgbs, _ = render_path(render_poses, hwf, K, args.chunk, render_kwargs_test, gt_imgs=images, savedir=testsavedir, render_factor=args.render_factor)
print('Done rendering', testsavedir)
imageio.mimwrite(os.path.join(testsavedir, 'video.mp4'), to8b(rgbs), fps=30, quality=8)
use_batching = not args.no_batching从多张图中取用光线,特别的,对于lego数据集来说,并没有采用该策略。所以这段代码在lego重建过程中,是不运行的,这里我们先跳过,当中的核心函数过程为get_rays_np()。
# Prepare raybatch tensor if batching random rays
N_rand = args.N_rand
use_batching = not args.no_batching
if use_batching:
# For random ray batching
print('get rays')
rays = np.stack([get_rays_np(H, W, K, p) for p in poses[:,:3,:4]], 0) # [N, ro+rd, H, W, 3]
print('done, concats')
rays_rgb = np.concatenate([rays, images[:,None]], 1) # [N, ro+rd+rgb, H, W, 3]
rays_rgb = np.transpose(rays_rgb, [0,2,3,1,4]) # [N, H, W, ro+rd+rgb, 3]
rays_rgb = np.stack([rays_rgb[i] for i in i_train], 0) # train images only
rays_rgb = np.reshape(rays_rgb, [-1,3,3]) # [(N-1)*H*W, ro+rd+rgb, 3]
rays_rgb = rays_rgb.astype(np.float32)
print('shuffle rays')
i_batch = 0
# Move training data to GPU
if use_batching:
images = torch.Tensor(images).to(device)
poses = torch.Tensor(poses).to(device)
if use_batching:
rays_rgb = torch.Tensor(rays_rgb).to(device)
N_iters = 200000 + 1
print('TRAIN views are', i_train)
print('TEST views are', i_test)
print('VAL views are', i_val)
TRAIN views are [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
96 97 98 99]
TEST views are [113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
131 132 133 134 135 136 137]
VAL views are [100 101 102 103 104 105 106 107 108 109 110 111 112]
start = start + 1
for i in trange(start, N_iters):
time0 = time.time()
# Sample random ray batch
if use_batching:
# Random over all images
batch = rays_rgb[i_batch:i_batch+N_rand] # [B, 2+1, 3*?]
batch = torch.transpose(batch, 0, 1)
batch_rays, target_s = batch[:2], batch[2]
i_batch += N_rand
if i_batch >= rays_rgb.shape[0]:
print("Shuffle data after an epoch!")
rand_idx = torch.randperm(rays_rgb.shape[0])
rays_rgb = rays_rgb[rand_idx]
i_batch = 0
# Random from one image
img_i = np.random.choice(i_train)
target = images[img_i]
target = torch.Tensor(target).to(device)
pose = poses[img_i, :3,:4]
if N_rand is not None:
rays_o, rays_d = get_rays(H, W, K, torch.Tensor(pose)) # (H, W, 3), (H, W, 3)
# Ray helpers
def get_rays(H, W, K, c2w):
i, j = torch.meshgrid(torch.linspace(0, W-1, W), torch.linspace(0, H-1, H)) # pytorch's meshgrid has indexing='ij'
i = i.t()
j = j.t()
dirs = torch.stack([(i-K[0][2])/K[0][0], -(j-K[1][2])/K[1][1], -torch.ones_like(i)], -1)
# Rotate ray directions from camera frame to the world frame
rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs]
# Translate camera frame's origin to the world frame. It is the origin of all rays.
rays_o = c2w[:3,-1].expand(rays_d.shape)
return rays_o, rays_d
if K is None:
K = np.array([
[focal, 0, 0.5*W],
[0, focal, 0.5*H],
[0, 0, 1]
i, j = torch.meshgrid(torch.linspace(0, W-1, W), torch.linspace(0, H-1, H))生成坐标
torch.linspace(start, end, steps=100, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) → Tensor
torch.t()是一个类似于求矩阵的转置的函数,但是它要求输入的tensor结构维度<=2D。 这里参考大佬的博客。
(62条消息) 线性代数:转置矩阵(matrix transpose)和逆矩阵(matrix inverse)_逆矩阵和转置矩阵_羊羊2035的博客-CSDN博客
i = i.t() j = j.t()
NeRF代码解读-相机参数与坐标系变换 - 知乎 (zhihu.com)
通过上述的讨论,我们第一步是要先写出相机中心和像素点在相机坐标系的3D坐标。下面我们以OpenCV/Colmap的相机坐标系为例介绍。相机中心的坐标很明显就是[0,0,0]了。像素点的坐标可能复杂一点:首先3D像素点的x和y坐标是2D的图像坐标 (i, j)减去光心坐标 (cx,cy),然后z坐标其实就是焦距f (因为图像平面距离相机中心的距离就是焦距f)。
dirs = torch.stack([(i-K[0][2])/K[0][0], -(j-K[1][2])/K[1][1], -torch.ones_like(i)], -1)
所以K[0][0]和K[1][1]就是focal,也就是焦距,K[0][2]是Cx,也就是W/2,K[1][2]是Cy,也就是H/2,公式里的负号是因为OpenCV/Colmap的相机坐标系里相机的Up/Y朝下, 相机光心朝向+Z轴,而NeRF/OpenGL相机坐标系里相机的Up/朝上,相机光心朝向-Z轴,所以这里代码在方向向量dir的第二和第三项乘了个负号。
tensor([[[-0.3600, 0.3600, -1.0000],
[-0.3582, 0.3600, -1.0000],
[-0.3564, 0.3600, -1.0000],
[ 0.3546, 0.3600, -1.0000],
[ 0.3564, 0.3600, -1.0000],
[ 0.3582, 0.3600, -1.0000]],
[[-0.3600, 0.3582, -1.0000],
[-0.3582, 0.3582, -1.0000],
[-0.3564, 0.3582, -1.0000],
[ 0.3546, 0.3582, -1.0000],
[ 0.3564, 0.3582, -1.0000],
[ 0.3582, 0.3582, -1.0000]],
[[-0.3600, 0.3564, -1.0000],
[-0.3582, 0.3564, -1.0000],
[-0.3564, 0.3564, -1.0000],
[ 0.3546, 0.3564, -1.0000],
[ 0.3564, 0.3564, -1.0000],
[ 0.3582, 0.3564, -1.0000]],
[[-0.3600, -0.3546, -1.0000],
[-0.3582, -0.3546, -1.0000],
[-0.3564, -0.3546, -1.0000],
[ 0.3546, -0.3546, -1.0000],
[ 0.3564, -0.3546, -1.0000],
[ 0.3582, -0.3546, -1.0000]],
[[-0.3600, -0.3564, -1.0000],
[-0.3582, -0.3564, -1.0000],
[-0.3564, -0.3564, -1.0000],
[ 0.3546, -0.3564, -1.0000],
[ 0.3564, -0.3564, -1.0000],
[ 0.3582, -0.3564, -1.0000]],
[[-0.3600, -0.3582, -1.0000],
[-0.3582, -0.3582, -1.0000],
[-0.3564, -0.3582, -1.0000],
[ 0.3546, -0.3582, -1.0000],
[ 0.3564, -0.3582, -1.0000],
[ 0.3582, -0.3582, -1.0000]]])
# Rotate ray directions from camera frame to the world frame rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs]
1.torch.sum(input, dtype=None)
2.torch.sum(input, list: dim, bool: keepdim=False, dtype=None) → Tensor
得以完成。c2w[:3,:3] 即3列分别表达关于x轴、y轴、z轴的信息,乘完以后变成400*400*3*3,torch.sum对3*3的列求和,就变成了3*1,这和坐标变换公式是一致,只能说非常巧妙了。我的理解是这个计算将相机坐标系下的3D空间射线方向矩阵,转换到世界坐标系下。最终的rays_d就是400*400*3,3是射线在世界坐标系下的射线方向
dot product, equals to: [c2w.dot(dir) for dir in dirs]
(65条消息) numpy广播机制_红烧code的博客-CSDN博客
# Translate camera frame's origin to the world frame. It is the origin of all rays. rays_o = c2w[:3,-1].expand(rays_d.shape)
rays_o, rays_d,rays_o即射线的原点,rays_d是射线的方向,总共有400*400个像素点,就有400*400个射线。
if i < args.precrop_iters:
dH = int(H//2 * args.precrop_frac)
dW = int(W//2 * args.precrop_frac)
coords = torch.stack(
torch.linspace(H//2 - dH, H//2 + dH - 1, 2*dH),
torch.linspace(W//2 - dW, W//2 + dW - 1, 2*dW)
), -1)
if i == start:
print(f"[Config] Center cropping of size {2*dH} x {2*dW} is enabled until iter {args.precrop_iters}")
coords = torch.stack(torch.meshgrid(torch.linspace(0, H-1, H), torch.linspace(0, W-1, W)), -1) # (H, W, 2)
coords = torch.reshape(coords, [-1,2]) # (H * W, 2)
select_inds = np.random.choice(coords.shape[0], size=[N_rand], replace=False) # (N_rand,)
select_coords = coords[select_inds].long() # (N_rand, 2)
rays_o = rays_o[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3)
rays_d = rays_d[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3)
batch_rays = torch.stack([rays_o, rays_d], 0)
target_s = target[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3)
precrop_iters = 500 precrop_frac = 0.5
if i < args.precrop_iters: dH = int(H//2 * args.precrop_frac) dW = int(W//2 * args.precrop_frac) coords = torch.stack( torch.meshgrid( torch.linspace(H//2 - dH, H//2 + dH - 1, 2*dH), torch.linspace(W//2 - dW, W//2 + dW - 1, 2*dW) ), -1) if i == start: print(f"[Config] Center cropping of size {2*dH} x {2*dW} is enabled until iter {args.precrop_iters}") else: coords = torch.stack(torch.meshgrid(torch.linspace(0, H-1, H), torch.linspace(0, W-1, W)), -1) # (H, W, 2)
#numpy.random.choice(a, size=None, replace=True, p=None)
coords = torch.reshape(coords, [-1,2]) # (H * W, 2)
select_inds = np.random.choice(coords.shape[0], size=[N_rand], replace=False) # (N_rand,)
select_coords = coords[select_inds].long() # (N_rand, 2)
rays_o = rays_o[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3)
rays_d = rays_d[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3)
batch_rays = torch.stack([rays_o, rays_d], 0)
target_s = target[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3)
parser.add_argument("--chunk", type=int, default=1024*32, help='number of rays processed in parallel, decrease if running out of memory')
##### Core optimization loop #####
rgb, disp, acc, extras = render(H, W, K, chunk=args.chunk, rays=batch_rays,
verbose=i < 10, retraw=True,
def render(H, W, K, chunk=1024*32, rays=None, c2w=None, ndc=True,
near=0., far=1.,
use_viewdirs=False, c2w_staticcam=None,
if c2w is not None:
# special case to render full image
rays_o, rays_d = get_rays(H, W, K, c2w)
# use provided ray batch
rays_o, rays_d = rays
use_viewdirs: # provide ray directions as input,c2w_staticcam#special case to visualize effect of viewdirs,这个选项默认是无,翻译是静态相机,
【返回输入张量给定维dim 上每行的p范数】
torch.norm(input, p, dim, out=None,keepdim=False) → Tensor,默认是p=2,也就是求2范数,,其实是归一化操作,表示和这个向量方向相同的单位向量,这样的运算也叫向量的单位化。reshape作用不大,因为vierdirs本身是[1024,3]的矩阵
(64条消息) 机器学习中的范数规则化之(一)L0、L1与L2范数_l范数_zouxy09的博客-CSDN博客
if use_viewdirs:
# provide ray directions as input
viewdirs = rays_d
if c2w_staticcam is not None:
# special case to visualize effect of viewdirs
rays_o, rays_d = get_rays(H, W, K, c2w_staticcam)
viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True)
viewdirs = torch.reshape(viewdirs, [-1,3]).float()
(64条消息) NeRF神经辐射场中关于光线从世界坐标系转换为NDC坐标系 Representing Scenes as Neural Radiance Fields for View Synthesis_出门吃三碗饭的博客-CSDN博客
sh = rays_d.shape # [..., 3]
if ndc:
# for forward facing scenes
rays_o, rays_d = ndc_rays(H, W, K[0][0], 1., rays_o, rays_d)
# Create ray batch
rays_o = torch.reshape(rays_o, [-1,3]).float()
rays_d = torch.reshape(rays_d, [-1,3]).float()
near, far = near * torch.ones_like(rays_d[...,:1]), far * torch.ones_like(rays_d[...,:1])
rays = torch.cat([rays_o, rays_d, near, far], -1)
if use_viewdirs:
rays = torch.cat([rays, viewdirs], -1)
# Render and reshape
all_ret = batchify_rays(rays, chunk, **kwargs)
def batchify_rays(rays_flat, chunk=1024*32, **kwargs):
"""Render rays in smaller minibatches to avoid OOM.
all_ret = {}
for i in range(0, rays_flat.shape[0], chunk):
ret = render_rays(rays_flat[i:i+chunk], **kwargs)
for k in ret:
if k not in all_ret:
all_ret[k] = []
all_ret = {k : torch.cat(all_ret[k], 0) for k in all_ret}
return all_ret
def render_rays(ray_batch,
"""Volumetric rendering.
ray_batch: array of shape [batch_size, ...]. All information necessary
for sampling along a ray, including: ray origin, ray direction, min
dist, max dist, and unit-magnitude viewing direction.
network_fn: function. Model for predicting RGB and density at each point
in space.
network_query_fn: function used for passing queries to network_fn.
N_samples: int. Number of different times to sample along each ray.
retraw: bool. If True, include model's raw, unprocessed predictions.
lindisp: bool. If True, sample linearly in inverse depth rather than in depth.
perturb: float, 0 or 1. If non-zero, each ray is sampled at stratified
random points in time.
N_importance: int. Number of additional times to sample along each ray.
These samples are only passed to network_fine.
network_fine: "fine" network with same spec as network_fn.
white_bkgd: bool. If True, assume a white background.
raw_noise_std: ...
verbose: bool. If True, print more debugging info.
rgb_map: [num_rays, 3]. Estimated RGB color of a ray. Comes from fine model.
disp_map: [num_rays]. Disparity map. 1 / depth.
acc_map: [num_rays]. Accumulated opacity along each ray. Comes from fine model.
raw: [num_rays, num_samples, 4]. Raw predictions from model.
rgb0: See rgb_map. Output for coarse model.
disp0: See disp_map. Output for coarse model.
acc0: See acc_map. Output for coarse model.
z_std: [num_rays]. Standard deviation of distances along ray for each
首先就是将ray_batch 的输入分开拆出rays_o
和 viewdirs
,形状分别是[chunk, 3]
,[chunk, 3]
,[chunk, 1]
,[chunk, 1]
和 [chunk, 3]
。torch.reshape(ray_batch[...,6:8], [-1,1,2])其实就是将(1024,2),变成了(1024,1,2)
N_rays = ray_batch.shape[0]
rays_o, rays_d = ray_batch[:,0:3], ray_batch[:,3:6] # [N_rays, 3] each
viewdirs = ray_batch[:,-3:] if ray_batch.shape[-1] > 8 else None
bounds = torch.reshape(ray_batch[...,6:8], [-1,1,2])
near, far = bounds[...,0], bounds[...,1] # [-1,1]
“图形学小白”友好的NeRF原理透彻讲解 - 知乎 (zhihu.com)
若通过给定pose,从NeRF的模型中获得一张输出图片,关键就是获得每一个图片每一个像素坐标的像素值。在NeRF的paper中,给定一个camera pose,要计算某个像素坐标 (x,y) 的像素。通俗来说,该点的像素计算方法为:从相机光心发出一条射线(camera ray)经过该像素坐标,途径三维场景很多点,这些“途径点”或称作“采样点”的某种累加决定了该像素的最终颜色。
数学上,它的颜色由下面的“体渲染公式”计算而出,其中C表示渲染出的像素点颜色,σ表示体素密度, r 和 d 分别表示camera ray上的距离和ray上的方向,r=o+dt, t表示在camera ray上采样点离相机光心的距离,T表示透射比,也叫光学厚度、介质透明度,c表示当前区域的粒子发光和内散射辐射强度,也就是表面的实际颜色。
将 [tn, tf]均匀划分为 N个区间,并在这 N个区域内随机采样得到 N个采样点,即 ti(i=1,…,N),进行求和得到颜色的估计值, σi和 δi分别表示当前区域粒子密度和区间步长 ,ci表示当前区域的粒子发光和内散射辐射强度, Ti表示透射比,括号中的公式表示透明度,公式没有考虑从相机发出的射线本身的强度。
所以, t_vals = torch.linspace(0., 1., steps=N_samples)相当于在(0,1)之间先均匀采样N个点,这里N_samples=64,lindisp是false,就是从深度采样,和公式是一样的。z_vals此时就是[1024,64],expand之后没变化,此时的z_vals就是公式中的ti
t_vals = torch.linspace(0., 1., steps=N_samples)
if not lindisp:
z_vals = near * (1.-t_vals) + far * (t_vals)
z_vals = 1./(1./near * (1.-t_vals) + 1./far * (t_vals))
z_vals = z_vals.expand([N_rays, N_samples])
接下来是均匀区间内随机产生采样点的过程,这里 perturb
,并设 pytest
为 false
的形状为[chunk, N_samples-1]
的区间端点的中点,随后分别补充整个区间的上下界 tft_f 和 tnt_n ,得到 upper
和 lower
形状均为 [B, N_Sample]
。而随机数组 t_rand
的形状为 [chunk, N_samples]
,元素大小在[0, 1]
。z_vals = lower + (upper - lower) * t_rand
,形状为 [chunk, N_samples]
中的一个区间的 z_val
似乎只能取到原先区间的 [low, mid]
而不是 [low, up]
和 mid
这里进行数据拼接 pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None]
形状为 [chunk, N_samples, 3]
(15 封私信 / 81 条消息) 为什么NeRF里可以仅凭位置和角度信息经过MLP预测出某点的rgb颜色? - 知乎 (zhihu.com)
if perturb > 0.:
# get intervals between samples
mids = .5 * (z_vals[...,1:] + z_vals[...,:-1])
upper = torch.cat([mids, z_vals[...,-1:]], -1)
lower = torch.cat([z_vals[...,:1], mids], -1)
# stratified samples in those intervals
t_rand = torch.rand(z_vals.shape)
# Pytest, overwrite u with numpy's fixed random numbers
if pytest:
t_rand = np.random.rand(*list(z_vals.shape))
t_rand = torch.Tensor(t_rand)
z_vals = lower + (upper - lower) * t_rand
pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None] # [N_rays, N_samples, 3]
pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None] # [N_rays, N_samples, 3]
def create_embedding_fn(self):
embed_fns = []
d = self.kwargs['input_dims']
out_dim = 0
if self.kwargs['include_input']:
embed_fns.append(lambda x : x)
out_dim += d
def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024*64):
"""Prepares inputs and applies network 'fn'.
inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]])
embedded = embed_fn(inputs_flat)
用于光线方向的单位化和位置编码,这里可以看出是给点的坐标加入了这个方向信息。那首先就是将输入的[1024,3]dirs扩展成[1024,64,3],再同样展开成[1024*64,3],然后我恍然大悟,原来viewdirs就是θ,φ信息, embeddirs_fn就是3+4*2*3,输出为[1024*64,27],最后再和位置信息cat成[1024*64,90]
if viewdirs is not None:
input_dirs = viewdirs[:,None].expand(inputs.shape)
input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]])
embedded_dirs = embeddirs_fn(input_dirs_flat)
embedded = torch.cat([embedded, embedded_dirs], -1)
outputs_flat = batchify(fn, netchunk)(embedded)
def batchify(fn, chunk):
"""Constructs a version of 'fn' that applies to smaller batches.
if chunk is None:
return fn
def ret(inputs):
return torch.cat([fn(inputs[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0)
return ret
前向推理首先就是将位置和方向信息分开,这里input_ch和input_ch_views根据之前的定义分别是63和27,input_pts就是位置信息,这里直接进行了8层MLP计算,在第5层的时候(skips=4),又引入了一次input_pts,最终输出是[1024*64,256],然后是对方向信息的特征提取,体素密度volume density(σ)层就直接输出,关键是rgb信息,首先h又经过了一层MLP,然后再引入input_views直接cat,再经过一层MLP将特征层缩小到128,最后再缩小到RGB3通道信息,最终RGB+σ作为输出返回,所以返回结果应该是[1024*64,4]
def forward(self, x):
input_pts, input_views = torch.split(x, [self.input_ch, self.input_ch_views], dim=-1)
h = input_pts
for i, l in enumerate(self.pts_linears):
h = self.pts_linears[i](h)
h = F.relu(h)
if i in self.skips:
h = torch.cat([input_pts, h], -1)
(pts_linears): ModuleList(
(0): Linear(in_features=63, out_features=256, bias=True)
(1): Linear(in_features=256, out_features=256, bias=True)
(2): Linear(in_features=256, out_features=256, bias=True)
(3): Linear(in_features=256, out_features=256, bias=True)
(4): Linear(in_features=256, out_features=256, bias=True)
(5): Linear(in_features=319, out_features=256, bias=True)
(6): Linear(in_features=256, out_features=256, bias=True)
(7): Linear(in_features=256, out_features=256, bias=True)
if self.use_viewdirs:
alpha = self.alpha_linear(h)
feature = self.feature_linear(h)
h = torch.cat([feature, input_views], -1)
for i, l in enumerate(self.views_linears):
h = self.views_linears[i](h)
h = F.relu(h)
rgb = self.rgb_linear(h)
(views_linears): ModuleList(
(0): Linear(in_features=283, out_features=128, bias=True)
(feature_linear): Linear(in_features=256, out_features=256, bias=True)
(alpha_linear): Linear(in_features=256, out_features=1, bias=True)
(rgb_linear): Linear(in_features=128, out_features=3, bias=True)
outputs = torch.cat([rgb, alpha], -1)
outputs = self.output_linear(h)
return outputs
outputs = torch.reshape(outputs_flat, list(inputs.shape[:-1]) + [outputs_flat.shape[-1]])
推理完以后让我们再回到render_rays函数,在得到模型输出的raw之后,raw2outputs函数再得到各种结果图,raw2outputs的输入有raw,z_vals也就是64个采样点的间距信息,rays_d光线方向,raw_noise_std=0, white_bkgd=true, pytest=pytest
raw = network_query_fn(pts, viewdirs, network_fn)
rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest)
raw2outputs函数首先就定义了raw2alpha函数,这个函数的作用就是求解之前的透明度, 定义了输入是raw,dists,act_fn,计算公式是同括号里的一致,但是加了一个激活函数F.relu,也就是保证raw是正值,然后此时的ci其实就是各点的RGB值,σi是网络输出的密度值, δi是计算得到的各点之间的步长dists,所以计算就豁然开朗了
def raw2outputs(raw, z_vals, rays_d, raw_noise_std=0, white_bkgd=False, pytest=False):
"""Transforms model's predictions to semantically meaningful values.
raw: [num_rays, num_samples along ray, 4]. Prediction from model.
z_vals: [num_rays, num_samples along ray]. Integration time.
rays_d: [num_rays, 3]. Direction of each ray.
rgb_map: [num_rays, 3]. Estimated RGB color of a ray.
disp_map: [num_rays]. Disparity map. Inverse of depth map.
acc_map: [num_rays]. Sum of weights along each ray.
weights: [num_rays, num_samples]. Weights assigned to each sampled color.
depth_map: [num_rays]. Estimated distance to object.
raw2alpha = lambda raw, dists, act_fn=F.relu: 1.-torch.exp(-act_fn(raw)*dists)
dists = z_vals[...,1:] - z_vals[...,:-1]
dists = torch.cat([dists, torch.Tensor([1e10]).expand(dists[...,:1].shape)], -1) # [N_rays, N_samples]
dists = dists * torch.norm(rays_d[...,None,:], dim=-1)
rgb = torch.sigmoid(raw[...,:3]) # [N_rays, N_samples, 3]
noise = 0.
if raw_noise_std > 0.:
noise = torch.randn(raw[...,3].shape) * raw_noise_std
# Overwrite randomly sampled data if pytest
if pytest:
noise = np.random.rand(*list(raw[...,3].shape)) * raw_noise_std
noise = torch.Tensor(noise)
alpha = raw2alpha(raw[...,3] + noise, dists)
X = torch.ones((alpha.shape[0], 1))
是形状为 [1024, 1]
的全 1 矩阵Y = torch.cat([X, 1.-alpha + 1e-10], -1)
是形状为 [1024, 64+1]
,每行的第一个元素为全 1,后i个元素为Z = torch.cumprod(Y, -1)
是形状为 [1024, 64+1]
的连乘矩阵,除每行最后一个元素外,第 i个元素依次对应每条光线的 Ti 。最终再和alpha相乘得到[1024,64]的weights,
weights = alpha * torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1)), 1.-alpha + 1e-10], -1), -1)[:, :-1]
rgb_map = torch.sum(weights[...,None] * rgb, -2) # [N_rays, 3]
depth_map = torch.sum(weights * z_vals, -1)
disp_map = 1./torch.max(1e-10 * torch.ones_like(depth_map), depth_map / torch.sum(weights, -1))
acc_map = torch.sum(weights, -1)
rgb_map: [num_rays, 3]. Estimated RGB color of a ray.
disp_map: [num_rays]. Disparity map. Inverse of depth map.
acc_map: [num_rays]. Sum of weights along each ray.
weights: [num_rays, num_samples]. Weights assigned to each sampled color.
depth_map: [num_rays]. Estimated distance to object.
if white_bkgd:
rgb_map = rgb_map + (1.-acc_map[...,None])
return rgb_map, disp_map, acc_map, weights, depth_map
再回到render_rays函数,N_importance=128,这里涉及到NeRF的Hierarchical volume sampling设计,首先使用分层抽样对第一组Nc位置进行采样,并在这些位置评估“粗糙”网络。给出这个“粗糙”网络的输出,然后我们对每条射线上的点进行更明智的采样,即利用反变换采样从这个分布中采样第二组Nf位置,然后将Nc+Nf采样得到的数据输入“精细”网络,并计算最终渲染的光线颜色C(r)。z_samples就是Nf的计算过程,sample_pdf函数就是反变换采样过程篇幅有限,后续再展开讲吧,detach 意为分离,对某个张量调用函数detach(),detach() 的作用是返回一个Tensor,它和原张量的数据相同,但requires_grad=False,也就意味着detach() 得到的张量不会具有梯度。这一性质即使我们修改其requires_grad 属性也无法改变。
if N_importance > 0:
rgb_map_0, disp_map_0, acc_map_0 = rgb_map, disp_map, acc_map
z_vals_mid = .5 * (z_vals[...,1:] + z_vals[...,:-1])
z_samples = sample_pdf(z_vals_mid, weights[...,1:-1], N_importance, det=(perturb==0.), pytest=pytest)
z_samples = z_samples.detach()
z_vals, _ = torch.sort(torch.cat([z_vals, z_samples], -1), -1)
pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None] # [N_rays, N_samples + N_importance, 3]
run_fn = network_fn if network_fine is None else network_fine
# raw = run_network(pts, fn=run_fn)
raw = network_query_fn(pts, viewdirs, run_fn)
rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest)
ret = {'rgb_map' : rgb_map, 'disp_map' : disp_map, 'acc_map' : acc_map}
if retraw:
ret['raw'] = raw
if N_importance > 0:
ret['rgb0'] = rgb_map_0
ret['disp0'] = disp_map_0
ret['acc0'] = acc_map_0
ret['z_std'] = torch.std(z_samples, dim=-1, unbiased=False) # [N_rays]
ret = render_rays(rays_flat[i:i+chunk], **kwargs)
for k in ret:
if k not in all_ret:
all_ret[k] = []
all_ret = {k : torch.cat(all_ret[k], 0) for k in all_ret}
return all_ret
# Render and reshape
all_ret = batchify_rays(rays, chunk, **kwargs)
for k in all_ret:
k_sh = list(sh[:-1]) + list(all_ret[k].shape[1:])
all_ret[k] = torch.reshape(all_ret[k], k_sh)
k_extract = ['rgb_map', 'disp_map', 'acc_map']
ret_list = [all_ret[k] for k in k_extract]
ret_dict = {k : all_ret[k] for k in all_ret if k not in k_extract}
return ret_list + [ret_dict]
img_loss = img2mse(rgb, target_s)#计算损失函数
trans = extras['raw'][...,-1]
loss = img_loss
psnr = mse2psnr(img_loss)
if 'rgb0' in extras:
img_loss0 = img2mse(extras['rgb0'], target_s)
loss = loss + img_loss0
psnr0 = mse2psnr(img_loss0)
### update learning rate ###
decay_rate = 0.1
decay_steps = args.lrate_decay * 1000
new_lrate = args.lrate * (decay_rate ** (global_step / decay_steps))
for param_group in optimizer.param_groups:
param_group['lr'] = new_lrate
dt = time.time()-time0