- .gitignore:指定在使用 Git 进行版本控制时需要忽略的文件和文件夹,避免将不必要的文件纳入版本管理,例如临时文件、日志文件等。
- LICENSE:包含该代码库所采用的开源许可证信息,规定了代码的使用、分发和修改的权限和限制。
- README.md:项目的说明文档,提供了项目的概述、使用方法、依赖环境、训练和测试步骤等重要信息,帮助使用者快速了解和使用该项目。
- env.yaml:用于创建 Conda 虚拟环境的配置文件,其中列出了项目所需的各种 Python 包及其版本信息,方便用户一键创建包含所需依赖的环境。
models
文件夹该文件夹包含项目中使用的各种模型定义和损失函数。
- __init__.py:Python 包的初始化文件,用于将该文件夹标识为一个 Python 包。
- convolutionalfeature.py:定义了与卷积特征相关的模型组件或类,可能用于提取数据的卷积特征。
- custom.py:可能包含自定义的模型层、损失函数或其他自定义组件。
- filmsiren.py:实现了基于 FiLM - SIREN 架构的模型,
FiLMLayer
类定义了具体的层结构,输入经过线性变换后与频率和相位偏移进行运算得到输出。python运行
class FiLMLayer(nn.Module): def __init__(self, input_dim, hidden_dim): super().__init__() self.layer = nn.Linear(input_dim, hidden_dim) def forward(self, x, freq, phase_shift): """ sin(freq * lin(x) + phase_shift) Args: x (B, ..., input_dim) freq: (B,...,hidden_dim) phase_shift: (B,...,hidden_dim) Returns output (B, ..., hidden_dim) """ x = self.layer(x) assert (freq.shape == phase_shift.shape) if freq.ndim == (x.ndim - 1): freq = freq.unsqueeze(-2) phase_shift = phase_shift.unsqueeze(-2) return torch.sin(freq * x + phase_shift)
- losses.py:定义了项目中使用的各种损失函数,如
eikonal_loss
、relax_eikonal_loss
、latent_rg_loss
等,以及MorseLoss
类,用于计算不同类型的损失。python运行
def latent_rg_loss(latent_reg, device): # compute the VAE latent representation regularization loss if latent_reg is not None: reg_loss = latent_reg.mean() else: reg_loss = torch.tensor([0.0], device=device) return reg_loss class MorseLoss(nn.Module): def __init__(self, weights=None, loss_type='siren_wo_n_w_morse', div_decay='none', div_type='l1', bidirectional_morse=True, udf=False): super().__init__() if weights is None: weights = [3e3, 1e2, 1e2, 5e1, 1e2, 1e1] self.weights = weights # sdf, intern, normal, eikonal, div self.loss_type = loss_type self.div_decay = div_decay self.div_type = div_type self.use_morse = True if 'morse' in self.loss_type else False self.bidirectional_morse = bidirectional_morse self.udf = udf
ocn
子文件夹:可能包含与卷积占用网络(Convolutional Occupancy Networks)相关的代码。
- overfit_network.py:定义了用于过拟合单个形状的网络模型。
- shape_network.py:定义了用于形状空间学习的网络模型。
utils
文件夹该文件夹包含一些通用的工具函数,用于辅助项目的开发和运行。
- utils.py:包含了许多通用的工具函数,如日志记录、模型参数计数、随机种子设置、表面提取等。
python运行
def log_losses(writer, epoch, bach_idx, num_batches, loss_dict, batch_size): # helper function to log losses to tensorboardx writer fraction_done = (bach_idx + 1) / num_batches iteration = (epoch + fraction_done) * num_batches * batch_size for loss in loss_dict.keys(): writer.add_scalar(loss, loss_dict[loss].item(), iteration) return iteration def surface_extraction_single(ndf, grad, b_max, b_min, resolution): # 从 CAP - UDF 借鉴的单线程表面提取函数 v_all = [] t_all = [] threshold = 0.005 # accelerate extraction v_num = 0 for i in range(resolution - 1): for j in range(resolution - 1): for k in range(resolution - 1): ndf_loc = ndf[i:i + 2] ndf_loc = ndf_loc[:, j:j + 2, :] ndf_loc = ndf_loc[:, :, k:k + 2] if np.min(ndf_loc) > threshold: continue grad_loc = grad[i:i + 2] grad_loc = grad_loc[:, j:j + 2, :] grad_loc = grad_loc[:, :, k:k + 2] res = np.ones((2, 2, 2)) for ii in range(2): for jj in range(2): for kk in range(2): if np.dot(grad_loc[0][0][0], grad_loc[ii][jj][kk]) < 0: res[ii][jj][kk] = -ndf_loc[ii][jj][kk] else: res[ii][jj][kk] = ndf_loc[ii][jj][kk] if res.min() < 0: vertices, triangles, _, _ = measure.marching_cubes( res, 0.0) vertices[:, 0] += i vertices[:, 1] += j vertices[:, 2] += k triangles += v_num v_all.append(vertices) t_all.append(triangles) v_num += vertices.shape[0] v_all = np.concatenate(v_all) t_all = np.concatenate(t_all) mesh = trimesh.Trimesh(v_all, t_all) mesh.remove_duplicate_faces() mesh.remove_degenerate_faces() mesh.fill_holes() return mesh.vertices, mesh.faces
- utils_mp.py:可能包含与多进程相关的工具函数,用于提高程序的并行处理能力。
- visualizations.py:提供可视化相关的工具函数,如绘制切割图、生成网格等,方便对训练结果进行可视化展示。
data
文件夹
sdf
子文件夹:用于存放表面重建任务所需的输入数据,如点云文件(.xyz
和.ply
)。
shapespace
文件夹该文件夹包含与形状空间学习相关的代码和数据。
- dfaust_dataset.py:定义了用于 DFaust 数据集的数据集类,负责加载和处理 DFaust 数据集,包括点云数据的采样和处理。
python运行
def load_points(self, index): return np.load(os.path.join(self.dataset_path, self.npyfiles_mnfld[index].strip())) def __getitem__(self, index): point_set_mnlfld = torch.from_numpy(self.load_points(index)).float() # (250000, 6) which has xyz, normal xyz random_idx = torch.randperm(point_set_mnlfld.shape[0])[:self.n_points] point_set_mnlfld = torch.index_select(point_set_mnlfld, 0, random_idx) # (pnts, 6) mnfld_points = point_set_mnlfld[:, :self.d_in] if self.with_normals: normals = point_set_mnlfld[:, -self.d_in:] # todo adjust to case when we get no sigmas else: normals = torch.empty(0) # Q_far nonmnfld_points = np.random.uniform(-self.grid_range, self.grid_range, size=(self.n_points, 3)).astype(np.float32) # (n_points, 3) nonmnfld_points = torch.from_numpy(nonmnfld_points).float() # Q_near dist = torch.cdist(mnfld_points, mnfld_points) sigmas = torch.topk(dist, k=51, dim=1, largest=False)[0][:, -1:] # (n_points, 1) near_points = (mnfld_points + sigmas * torch.randn(mnfld_points.shape[0], mnfld_points.shape[1])) return {'mnfld_points': mnfld_points, 'mnfld_n': normals, 'nonmnfld_points': nonmnfld_points, 'near_points': near_points, 'indices': index, 'name': self.npyfiles_mnfld[index]}
log_conv_all_half_eikonal_1e - 4_amsgrad_200_epoch_cos_1e - 6_1500_2
文件夹:可能用于存储训练过程中的日志文件、模型检查点等信息。- pl_conv_finetune_shapespace.py:用于对形状空间模型进行微调的脚本,先加载预训练模型进行推理,然后对网络进行微调并输出最终结果。
- pl_conv_train_shapespace.py:用于训练形状空间模型的脚本,定义了训练过程中的数据加载、模型定义、损失计算、优化器设置等核心逻辑。
python运行
class BaseTrainer(pl.LightningModule): def __init__(self, args): super().__init__() self.args = args self.learning_rate = args.lr self.net = ShapeNetwork(decoder_hidden_dim=args.decoder_hidden_dim, decoder_n_hidden_layers=args.decoder_n_hidden_layers) self.criterion = MorseLoss(weights=args.loss_weights, loss_type=args.loss_type, div_decay=args.morse_decay, div_type=args.morse_type, bidirectional_morse=args.bidirectional_morse) def training_step(self, batch, batch_idx): self.net.train() self.net.zero_grad(set_to_none=True) mnfld_points, mnfld_n_gt, nonmnfld_points, near_points, indices = batch['mnfld_points'], batch[ 'mnfld_n'], batch['nonmnfld_points'], batch['near_points'], batch['indices'] mnfld_points.requires_grad_() nonmnfld_points.requires_grad_() near_points.requires_grad_() output_pred = self.net(nonmnfld_points, mnfld_points, near_points=near_points if self.args.morse_near else None) loss_dict, _ = self.criterion(output_pred, mnfld_points, nonmnfld_points, near_points=near_points if args.morse_near else None) for key, value in loss_dict.items(): self.log(key, value, on_step=True, logger=True) self.log('total_loss', loss_dict['loss'], on_step=True, logger=True, prog_bar=True) if self.local_rank == 0 and self.global_rank == 0 and batch_idx % 30 == 0: weights = self.criterion.weights utils.log_string("Weights: {}".format(weights), log_file) utils.log_string('Epoch: {}, Loss: {:.5f} = L_Mnfld: {:.5f} + ' 'L_NonMnfld: {:.5f} + L_Nrml: {:.5f} + L_Eknl: {:.5f} + L_Div: {:.5f} + L_Morse: {:.5f} + L_Latent: {:.10f}'.format( self.current_epoch, loss_dict["loss"].item(), weights[0] * loss_dict["sdf_term"].item(), weights[1] * loss_dict["inter_term"].item(), weights[2] * loss_dict["normals_loss"].item(), weights[3] * loss_dict["eikonal_term"].item(), weights[4] * loss_dict["div_loss"].item(), weights[5] * loss_dict['morse_term'].item(), weights[6] * loss_dict['latent_reg_term'].item()), log_file) utils.log_string('Unweighted L_s : L_Mnfld: {:.5f}, ' 'L_NonMnfld: {:.5f}, L_Nrml: {:.5f}, L_Eknl: {:.5f}, L_Morse: {:.5f}, L_Latent: {:.10f}'.format( loss_dict["sdf_term"].item(), loss_dict["inter_term"].item(), loss_dict["normals_loss"].item(), loss_dict["eikonal_term"].item(), loss_dict['morse_term'].item(), loss_dict['latent_reg_term'].item()), log_file) return {'loss': loss_dict['loss'], 'mnfld': mnfld_points[:1]} def training_epoch_end(self, outputs): mnfld = outputs[0]['mnfld'] self.net.eval() if self.global_rank == 0 and self.local_rank == 0: with torch.no_grad(): t0 = time.time() out_dir = "{}/vis_results/".format(args.logdir) os.makedirs(out_dir, exist_ok=True) global_feat = self.net.encoder.encode(mnfld) try: pred_mesh = utils.implicit2mesh(decoder=self.net, mods=None, feat=global_feat, grid_res=128, get_mesh=True, device=next(self.net.parameters()).device) pred_mesh.export(os.path.join(out_dir, "pred_mesh_{}.ply".format(self.current_epoch))) except Exception as e: print('Can not plot') print(e) print('Plot took {:.3f}s'.format(time.time() - t0)) # update weights curr_epoch = self.current_epoch self.criterion.update_morse_weight(curr_epoch, self.args.num_epochs, self.args.decay_params) def configure_optimizers(self): # Setup Adam optimizers optimizer = torch.optim.Adam(self.trainer.model.parameters(), lr=self.learning_rate, amsgrad=True) lr_sch = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=150 * 10, T_mult=2, eta_min=1e-6) return { "optimizer": optimizer, "lr_scheduler": { "scheduler": lr_sch, "interval": "step", "frequency": 1, }, }
- shapespace_dfaust_args.py:处理形状空间学习任务的命令行参数,定义了训练所需的各种参数,如学习率、批次大小等。
assets
文件夹
RP.jpg
:项目相关的图片文件,可能用于项目文档或展示。
surface_reconstruction
文件夹该文件夹包含与表面重建任务相关的代码。
- recon_dataset.py:定义了用于表面重建任务的数据集类,负责加载和处理表面重建所需的数据。
- run_sdf_recon.py:用于运行表面重建的脚本,会对
./data/sdf/input
目录下的所有*.xyz
和*.ply
文件进行表面重建。- surface_recon_args.py:处理表面重建任务的命令行参数,定义了训练所需的各种参数,如学习率、批次大小等。
- train_surface_reconstruction.py:表面重建任务的训练脚本,包含模型的定义、数据加载、优化器设置、损失函数计算和训练循环等核心训练逻辑。