NeRF基础代码解析

embedders

对position和view direction做embedding。

class FreqEmbedder(nn.Module):
	def __init__(self, in_dim=3, multi_res=10, use_log_bands=True, include_input=True):
		super().__init__()
		self.in_dim = in_dim
		self.num_freqs = multi_res
		self.max_freq_log2 = multi_res
		self.use_log_bands = use_log_bands
		self.periodic_fns = [torch.sin, torch.cos]
		self.include_input = include_input
		self.embed_fns = None
		self.out_dim = None
		self.num_embed_fns = None
		self.create_embedding_fn()

def create_embedding_fn(self):
	self.embed_fns = []
	# 10 * 2 * 3 = 60
	self.out_dim = self.num_freqs * len(self.periodic_fns) * self.in_dim)
	if self.include_input:
		self.embed_fns.append(lambda x: x)
		self.out_dim += self.in_dim	# 63

	if self.use_log_lands:
		freq_bands = 2. ** torch.linspace(0., self.max_freq_log2, steps=self.num_freqs)
	else:
		freq_bands = torch.linspace(2.**0, 2.**self.max_freq_log2, steps=self.num_freqs)
	for freq in freq_bands:
		for p_fn in self.periodic_fns:
			self.embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x*freq))
	self.num_embed_fns = len(self.embed_fns)

def forward(self, x):
	"""
	x: [..., in_dim], xyz or view direction.
	embedding: [..., out_dim], corresponding frequency encoding.
	"""
	embed_lst = [embed_fn(x) for embed_fn in self.embed_fns]
	# [[x, sin(x), cos(x), sin(2x), cos(2x),...,sin(512x), cos(512x)]]
	embedding = torch.cat(embed_lst, dim=-1)
	return embedding

NeRFBackbone

position和view经过embedding后,得到特征向量。再输入到NeRFBackbone网络中,得到sigma和color输出。

class NeRFBackbone(nn.Module):
	def __init__(self, pos_dim=3, cond_dim=64, view_dim=3, hid_dim=128, num_density_linears=8, num_color_linears=3, skip_layer_indices=[4]):
		self.pos_dim = pos_dim
		self.cond_dim = cond_dim
		self.view_dim = view_dim
		self.hid_dim = hid_dim
		self.out_dim = 4	# rgb + sigma
		self.num_density_linears = num_density_linears
		self.num_color_linears = num_color_linears
		self.skip_layer_indices = skip_layer_indices
		
		density_input_dim = pos_dim + cond_dim
		self.density_linears = nn.ModuleList(
			[nn.Linear(density_input_dim, hid_dim)] +
			[nn.Linear(hid_dim, hid_dim) if i not in self.skip_layer_indices else nn.Linear(hid_dim + density_input_dim, hid_dim) for i in range(num_density_linears - 1)]
		)
		self.density_out_linear = nn.Linear(hid_dim, 1)
		
		color_input_dim = view_dim + hid_dim
		self.color_linears = nn.ModuleList(
			[nn.Linear(color_input_dim, hid_dim//2)] +
			[nn.Linear(hid_dim//2, hid_dim//2) for _ in range(num_color_linears - 1)]
		)
		self.color_out_linear = nn.Linear(hid_dim//2, 3)
	
	def forward(self, pos, view, view):
		"""
			pos: [bs, n_sample, pos_dim], encoding of position.
			cond: [cond_dim,], condition features.
			view: [bs, view_dim], encoding of view direction.
		"""
		bs, n_sample, _ = pos.shape
		if cond.dim == 1:	# [cond_dim]
			cond = cond.squeeze()[None, None, :].expand([bs, n_sample, self.cond_dim])
		elif cond_dim == 2:	# [batch, cond_dim]
			cond = cond[:, None, :].expand([bs, n_sample, self.cond_dim])
		
		view = view[:, None, :].expand([bs, n_sample, self.view_dim])
		density_linear_input = torch.cat([pos, cond], dim=-1)
		h = density_linear_input
		for i in range(len(self.density_linears)):
			h = self.density_linears[i](h)
			h = F.relu(h)
			if i in self.skip_layer_indices:
				h = torch.cat([density_linear_input, h], -1)
		sigma = self.density_out_linear(h)
		
		h = torch.cat([h, view], -1)
		for i in range(len(self.color_linears)):
			h = self.color_linears[i](h)
			h = F.relu(h)
		rgb = self.color_out_linear(h)
		outputs = torch.cat([rgb, sigma], -1)
		return outputs

Ray Sampler

一张图的height = 1280, width = 720, 对这张图采样4096条从相机原点发出的光线ray。

def get_rays(H, W, focal, c2w, cx=None, cy=None):
	"""
	Get the rays emitted from camera to all pixels.
	The ray is represented in world coordinate.
	input:
		H: height of the image in pixel.
		W: width of the image in pixel.
		focal: focal length of the camera in pixel.
		c2w: 3x4 camera-to-world matrix, it should be something like this:
			[[r11, r12, r13, t1],
		 	 [r21, r22, r23, t2],
		 	 [r31, r32, r33, t3]]
		cx: center of camera in width axis.
		cy: center of camera in height axis.
	return:
		rays_o: start point of the ray.
		rays_d: direction of the ray. so you can sample the point in the ray with: xyz = rays_o + rays_d * z_val, where z_val is the distance.
		
	"""
	 j_pixels, i_pixels = torch.meshgrid(torch.linspace(0, H-1, H), torch.linspace(0, W-1, W))
	 if cx is None:
	 	cx = W * 0.5
	 if cy is None:
	 	cy = H * 0.5
	 directions = torch.stack([(i_pixels - cx)/focal, -(j_pixels - cy)/focal, -torch.ones_like(i_pixels)], dim=-1)	# [W, H, 3]
	 # Rotate ray directions from camera to the world frame.
	 rays_d = torch.sum(directions[..., None, :] * c2w[:3, :3], dim=-1)
	 # origin point of all ray, camera center in world coodinate.
	 rays_o = c2w[:3, -1].expand(rays_d.shape)
	 return rays_o, rays_d

class BaseRaySampler:
	def __init__(self, N_rays):
		super(BaseRaySampler, self).__init__()
		self.N_rays = N_rays
	
	def __call__(self, H, W, focal, c2w):
		rays_o, rays_d = get_rays(H, W, focal, c2w)
		selected_coords = self.sample_rays(H, W)
		rays_o = rays_o[select_coords[:, 0], select_coords[:, 1]]	# [N_rand, 3]
		rays_d = rays_d[select_coords[:, 0], select_coords[:, 1]]	# [N_rand, 3]
		return rays_o, rays_d, select_coords
	
	def sample_rays(self, H, W, **kwargs):
		raise NotImplementedError
	
class UniformRaySampler(BaseRaySampler):
	def __init__(self, N_rays=None):
		super().__init__(N_rays=N_rays)
	
	def sample_ray(self, H, W, n_rays=None, rect=None, in_rect_percent=0.9, **kwargs):
	if n_rays is None:
		n_rays = self.N_rays
	coords = torch.stack(torch.meshgrid(torch.linspace(0, H-1, H), torch.linspace(0, W-1, W)), -1)	# [H, W, 2]
	coords = torch.reshape(coords, [-1, 2])	# [H * W, 2]
	if rect is None:
		# uniformly sample the whole image
		selected_inds = np.random.choice(coords.shape[0], size=[n_rays], replace=False)
		selected_coords = coords[selected_inds].long()
	else:
		# uniformly sample from rect region and out-rect, respectively.
		......
	return seleced_coords
	
	def __call__(self, H, W, focal, c2w, n_rays=None, selected_coords=None, rect=None, in_rect_percent=0.9, **kwargs):
		rays_o, rays_d = get_rays(H, W, focal, c2w)
		if select_coords s None:
			select_coords = self.sample_rays(H, W, n_rays, rect, in_rect_percent)
		rays_o = rays_o[selected_coords[:, 0], selected_coords[:, 1]]
		rays_d = rays_d[selected_coords[:, 0], selected_coords[:, 1]]
		return rays_o, rays_d, selected_coords
	
	def sample_pixels_from_img_with_select_coords(self, img, select_coords):
		return img[selected_coords[:, 0], select_coords[:, 1]]

你可能感兴趣的:(pytorch,人工智能,python)