对position和view direction做embedding。
class FreqEmbedder(nn.Module):
def __init__(self, in_dim=3, multi_res=10, use_log_bands=True, include_input=True):
super().__init__()
self.in_dim = in_dim
self.num_freqs = multi_res
self.max_freq_log2 = multi_res
self.use_log_bands = use_log_bands
self.periodic_fns = [torch.sin, torch.cos]
self.include_input = include_input
self.embed_fns = None
self.out_dim = None
self.num_embed_fns = None
self.create_embedding_fn()
def create_embedding_fn(self):
self.embed_fns = []
# 10 * 2 * 3 = 60
self.out_dim = self.num_freqs * len(self.periodic_fns) * self.in_dim)
if self.include_input:
self.embed_fns.append(lambda x: x)
self.out_dim += self.in_dim # 63
if self.use_log_lands:
freq_bands = 2. ** torch.linspace(0., self.max_freq_log2, steps=self.num_freqs)
else:
freq_bands = torch.linspace(2.**0, 2.**self.max_freq_log2, steps=self.num_freqs)
for freq in freq_bands:
for p_fn in self.periodic_fns:
self.embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x*freq))
self.num_embed_fns = len(self.embed_fns)
def forward(self, x):
"""
x: [..., in_dim], xyz or view direction.
embedding: [..., out_dim], corresponding frequency encoding.
"""
embed_lst = [embed_fn(x) for embed_fn in self.embed_fns]
# [[x, sin(x), cos(x), sin(2x), cos(2x),...,sin(512x), cos(512x)]]
embedding = torch.cat(embed_lst, dim=-1)
return embedding
position和view经过embedding后,得到特征向量。再输入到NeRFBackbone网络中,得到sigma和color输出。
class NeRFBackbone(nn.Module):
def __init__(self, pos_dim=3, cond_dim=64, view_dim=3, hid_dim=128, num_density_linears=8, num_color_linears=3, skip_layer_indices=[4]):
self.pos_dim = pos_dim
self.cond_dim = cond_dim
self.view_dim = view_dim
self.hid_dim = hid_dim
self.out_dim = 4 # rgb + sigma
self.num_density_linears = num_density_linears
self.num_color_linears = num_color_linears
self.skip_layer_indices = skip_layer_indices
density_input_dim = pos_dim + cond_dim
self.density_linears = nn.ModuleList(
[nn.Linear(density_input_dim, hid_dim)] +
[nn.Linear(hid_dim, hid_dim) if i not in self.skip_layer_indices else nn.Linear(hid_dim + density_input_dim, hid_dim) for i in range(num_density_linears - 1)]
)
self.density_out_linear = nn.Linear(hid_dim, 1)
color_input_dim = view_dim + hid_dim
self.color_linears = nn.ModuleList(
[nn.Linear(color_input_dim, hid_dim//2)] +
[nn.Linear(hid_dim//2, hid_dim//2) for _ in range(num_color_linears - 1)]
)
self.color_out_linear = nn.Linear(hid_dim//2, 3)
def forward(self, pos, view, view):
"""
pos: [bs, n_sample, pos_dim], encoding of position.
cond: [cond_dim,], condition features.
view: [bs, view_dim], encoding of view direction.
"""
bs, n_sample, _ = pos.shape
if cond.dim == 1: # [cond_dim]
cond = cond.squeeze()[None, None, :].expand([bs, n_sample, self.cond_dim])
elif cond_dim == 2: # [batch, cond_dim]
cond = cond[:, None, :].expand([bs, n_sample, self.cond_dim])
view = view[:, None, :].expand([bs, n_sample, self.view_dim])
density_linear_input = torch.cat([pos, cond], dim=-1)
h = density_linear_input
for i in range(len(self.density_linears)):
h = self.density_linears[i](h)
h = F.relu(h)
if i in self.skip_layer_indices:
h = torch.cat([density_linear_input, h], -1)
sigma = self.density_out_linear(h)
h = torch.cat([h, view], -1)
for i in range(len(self.color_linears)):
h = self.color_linears[i](h)
h = F.relu(h)
rgb = self.color_out_linear(h)
outputs = torch.cat([rgb, sigma], -1)
return outputs
一张图的height = 1280, width = 720, 对这张图采样4096条从相机原点发出的光线ray。
def get_rays(H, W, focal, c2w, cx=None, cy=None):
"""
Get the rays emitted from camera to all pixels.
The ray is represented in world coordinate.
input:
H: height of the image in pixel.
W: width of the image in pixel.
focal: focal length of the camera in pixel.
c2w: 3x4 camera-to-world matrix, it should be something like this:
[[r11, r12, r13, t1],
[r21, r22, r23, t2],
[r31, r32, r33, t3]]
cx: center of camera in width axis.
cy: center of camera in height axis.
return:
rays_o: start point of the ray.
rays_d: direction of the ray. so you can sample the point in the ray with: xyz = rays_o + rays_d * z_val, where z_val is the distance.
"""
j_pixels, i_pixels = torch.meshgrid(torch.linspace(0, H-1, H), torch.linspace(0, W-1, W))
if cx is None:
cx = W * 0.5
if cy is None:
cy = H * 0.5
directions = torch.stack([(i_pixels - cx)/focal, -(j_pixels - cy)/focal, -torch.ones_like(i_pixels)], dim=-1) # [W, H, 3]
# Rotate ray directions from camera to the world frame.
rays_d = torch.sum(directions[..., None, :] * c2w[:3, :3], dim=-1)
# origin point of all ray, camera center in world coodinate.
rays_o = c2w[:3, -1].expand(rays_d.shape)
return rays_o, rays_d
class BaseRaySampler:
def __init__(self, N_rays):
super(BaseRaySampler, self).__init__()
self.N_rays = N_rays
def __call__(self, H, W, focal, c2w):
rays_o, rays_d = get_rays(H, W, focal, c2w)
selected_coords = self.sample_rays(H, W)
rays_o = rays_o[select_coords[:, 0], select_coords[:, 1]] # [N_rand, 3]
rays_d = rays_d[select_coords[:, 0], select_coords[:, 1]] # [N_rand, 3]
return rays_o, rays_d, select_coords
def sample_rays(self, H, W, **kwargs):
raise NotImplementedError
class UniformRaySampler(BaseRaySampler):
def __init__(self, N_rays=None):
super().__init__(N_rays=N_rays)
def sample_ray(self, H, W, n_rays=None, rect=None, in_rect_percent=0.9, **kwargs):
if n_rays is None:
n_rays = self.N_rays
coords = torch.stack(torch.meshgrid(torch.linspace(0, H-1, H), torch.linspace(0, W-1, W)), -1) # [H, W, 2]
coords = torch.reshape(coords, [-1, 2]) # [H * W, 2]
if rect is None:
# uniformly sample the whole image
selected_inds = np.random.choice(coords.shape[0], size=[n_rays], replace=False)
selected_coords = coords[selected_inds].long()
else:
# uniformly sample from rect region and out-rect, respectively.
......
return seleced_coords
def __call__(self, H, W, focal, c2w, n_rays=None, selected_coords=None, rect=None, in_rect_percent=0.9, **kwargs):
rays_o, rays_d = get_rays(H, W, focal, c2w)
if select_coords s None:
select_coords = self.sample_rays(H, W, n_rays, rect, in_rect_percent)
rays_o = rays_o[selected_coords[:, 0], selected_coords[:, 1]]
rays_d = rays_d[selected_coords[:, 0], selected_coords[:, 1]]
return rays_o, rays_d, selected_coords
def sample_pixels_from_img_with_select_coords(self, img, select_coords):
return img[selected_coords[:, 0], select_coords[:, 1]]