Neural-Singular-Hessian 代码库中各个文件和文件夹的作用

根目录文件

  • .gitignore:指定在使用 Git 进行版本控制时需要忽略的文件和文件夹,避免将不必要的文件纳入版本管理,例如临时文件、日志文件等。
  • LICENSE:包含该代码库所采用的开源许可证信息,规定了代码的使用、分发和修改的权限和限制。
  • README.md:项目的说明文档,提供了项目的概述、使用方法、依赖环境、训练和测试步骤等重要信息,帮助使用者快速了解和使用该项目。
  • env.yaml:用于创建 Conda 虚拟环境的配置文件,其中列出了项目所需的各种 Python 包及其版本信息,方便用户一键创建包含所需依赖的环境。


文件夹及其文件

models 文件夹

该文件夹包含项目中使用的各种模型定义和损失函数。

  • __init__.py:Python 包的初始化文件,用于将该文件夹标识为一个 Python 包。
  • convolutionalfeature.py:定义了与卷积特征相关的模型组件或类,可能用于提取数据的卷积特征。
  • custom.py:可能包含自定义的模型层、损失函数或其他自定义组件。
  • filmsiren.py:实现了基于 FiLM - SIREN 架构的模型,FiLMLayer 类定义了具体的层结构,输入经过线性变换后与频率和相位偏移进行运算得到输出。

python运行

class FiLMLayer(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.layer = nn.Linear(input_dim, hidden_dim)

    def forward(self, x, freq, phase_shift):
        """
        sin(freq * lin(x) + phase_shift)
        Args:
            x (B, ..., input_dim)
            freq: (B,...,hidden_dim)
            phase_shift: (B,...,hidden_dim)
        Returns
            output (B, ..., hidden_dim)
        """
        x = self.layer(x)
        assert (freq.shape == phase_shift.shape)
        if freq.ndim == (x.ndim - 1):
            freq = freq.unsqueeze(-2)
            phase_shift = phase_shift.unsqueeze(-2)

        return torch.sin(freq * x + phase_shift)

  • losses.py:定义了项目中使用的各种损失函数,如 eikonal_lossrelax_eikonal_losslatent_rg_loss 等,以及 MorseLoss 类,用于计算不同类型的损失。

python运行

def latent_rg_loss(latent_reg, device):
    # compute the VAE latent representation regularization loss
    if latent_reg is not None:
        reg_loss = latent_reg.mean()
    else:
        reg_loss = torch.tensor([0.0], device=device)

    return reg_loss

class MorseLoss(nn.Module):
    def __init__(self, weights=None, loss_type='siren_wo_n_w_morse', div_decay='none',
                 div_type='l1', bidirectional_morse=True, udf=False):
        super().__init__()
        if weights is None:
            weights = [3e3, 1e2, 1e2, 5e1, 1e2, 1e1]
        self.weights = weights  # sdf, intern, normal, eikonal, div
        self.loss_type = loss_type
        self.div_decay = div_decay
        self.div_type = div_type
        self.use_morse = True if 'morse' in self.loss_type else False
        self.bidirectional_morse = bidirectional_morse
        self.udf = udf

  • ocn 子文件夹:可能包含与卷积占用网络(Convolutional Occupancy Networks)相关的代码。
    • overfit_network.py:定义了用于过拟合单个形状的网络模型。
    • shape_network.py:定义了用于形状空间学习的网络模型。
utils 文件夹

该文件夹包含一些通用的工具函数,用于辅助项目的开发和运行。

  • utils.py:包含了许多通用的工具函数,如日志记录、模型参数计数、随机种子设置、表面提取等。

python运行

def log_losses(writer, epoch, bach_idx, num_batches, loss_dict, batch_size):
    # helper function to log losses to tensorboardx writer
    fraction_done = (bach_idx + 1) / num_batches
    iteration = (epoch + fraction_done) * num_batches * batch_size
    for loss in loss_dict.keys():
        writer.add_scalar(loss, loss_dict[loss].item(), iteration)
    return iteration

def surface_extraction_single(ndf, grad, b_max, b_min, resolution):
    # 从 CAP - UDF 借鉴的单线程表面提取函数
    v_all = []
    t_all = []
    threshold = 0.005  # accelerate extraction
    v_num = 0
    for i in range(resolution - 1):
        for j in range(resolution - 1):
            for k in range(resolution - 1):
                ndf_loc = ndf[i:i + 2]
                ndf_loc = ndf_loc[:, j:j + 2, :]
                ndf_loc = ndf_loc[:, :, k:k + 2]
                if np.min(ndf_loc) > threshold:
                    continue
                grad_loc = grad[i:i + 2]
                grad_loc = grad_loc[:, j:j + 2, :]
                grad_loc = grad_loc[:, :, k:k + 2]

                res = np.ones((2, 2, 2))
                for ii in range(2):
                    for jj in range(2):
                        for kk in range(2):
                            if np.dot(grad_loc[0][0][0], grad_loc[ii][jj][kk]) < 0:
                                res[ii][jj][kk] = -ndf_loc[ii][jj][kk]
                            else:
                                res[ii][jj][kk] = ndf_loc[ii][jj][kk]

                if res.min() < 0:
                    vertices, triangles, _, _ = measure.marching_cubes(
                        res, 0.0)
                    vertices[:, 0] += i
                    vertices[:, 1] += j
                    vertices[:, 2] += k
                    triangles += v_num
                    v_all.append(vertices)
                    t_all.append(triangles)

                    v_num += vertices.shape[0]

    v_all = np.concatenate(v_all)
    t_all = np.concatenate(t_all)

    mesh = trimesh.Trimesh(v_all, t_all)
    mesh.remove_duplicate_faces()
    mesh.remove_degenerate_faces()
    mesh.fill_holes()

    return mesh.vertices, mesh.faces

  • utils_mp.py:可能包含与多进程相关的工具函数,用于提高程序的并行处理能力。
  • visualizations.py:提供可视化相关的工具函数,如绘制切割图、生成网格等,方便对训练结果进行可视化展示。
data 文件夹
  • sdf 子文件夹:用于存放表面重建任务所需的输入数据,如点云文件(.xyz 和 .ply)。
shapespace 文件夹

该文件夹包含与形状空间学习相关的代码和数据。

  • dfaust_dataset.py:定义了用于 DFaust 数据集的数据集类,负责加载和处理 DFaust 数据集,包括点云数据的采样和处理。

python运行

def load_points(self, index):
    return np.load(os.path.join(self.dataset_path, self.npyfiles_mnfld[index].strip()))

def __getitem__(self, index):
    point_set_mnlfld = torch.from_numpy(self.load_points(index)).float()  # (250000, 6) which has xyz, normal xyz

    random_idx = torch.randperm(point_set_mnlfld.shape[0])[:self.n_points]
    point_set_mnlfld = torch.index_select(point_set_mnlfld, 0, random_idx)  # (pnts, 6)

    mnfld_points = point_set_mnlfld[:, :self.d_in]

    if self.with_normals:
        normals = point_set_mnlfld[:, -self.d_in:]  # todo adjust to case when we get no sigmas
    else:
        normals = torch.empty(0)

    # Q_far
    nonmnfld_points = np.random.uniform(-self.grid_range, self.grid_range,
                                        size=(self.n_points, 3)).astype(np.float32)  # (n_points, 3)
    nonmnfld_points = torch.from_numpy(nonmnfld_points).float()

    # Q_near
    dist = torch.cdist(mnfld_points, mnfld_points)
    sigmas = torch.topk(dist, k=51, dim=1, largest=False)[0][:, -1:]  # (n_points, 1)
    near_points = (mnfld_points + sigmas * torch.randn(mnfld_points.shape[0],
                                                       mnfld_points.shape[1]))
    return {'mnfld_points': mnfld_points, 'mnfld_n': normals, 'nonmnfld_points': nonmnfld_points,
            'near_points': near_points, 'indices': index, 'name': self.npyfiles_mnfld[index]}

  • log_conv_all_half_eikonal_1e - 4_amsgrad_200_epoch_cos_1e - 6_1500_2 文件夹:可能用于存储训练过程中的日志文件、模型检查点等信息。
  • pl_conv_finetune_shapespace.py:用于对形状空间模型进行微调的脚本,先加载预训练模型进行推理,然后对网络进行微调并输出最终结果。
  • pl_conv_train_shapespace.py:用于训练形状空间模型的脚本,定义了训练过程中的数据加载、模型定义、损失计算、优化器设置等核心逻辑。

python运行

class BaseTrainer(pl.LightningModule):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.learning_rate = args.lr
        self.net = ShapeNetwork(decoder_hidden_dim=args.decoder_hidden_dim,
                                decoder_n_hidden_layers=args.decoder_n_hidden_layers)

        self.criterion = MorseLoss(weights=args.loss_weights, loss_type=args.loss_type, div_decay=args.morse_decay,
                                   div_type=args.morse_type, bidirectional_morse=args.bidirectional_morse)

    def training_step(self, batch, batch_idx):
        self.net.train()
        self.net.zero_grad(set_to_none=True)
        mnfld_points, mnfld_n_gt, nonmnfld_points, near_points, indices = batch['mnfld_points'], batch[
            'mnfld_n'], batch['nonmnfld_points'], batch['near_points'], batch['indices']
        mnfld_points.requires_grad_()
        nonmnfld_points.requires_grad_()
        near_points.requires_grad_()

        output_pred = self.net(nonmnfld_points, mnfld_points, near_points=near_points if self.args.morse_near else None)

        loss_dict, _ = self.criterion(output_pred, mnfld_points, nonmnfld_points,
                                      near_points=near_points if args.morse_near else None)
        for key, value in loss_dict.items():
            self.log(key, value, on_step=True, logger=True)
        self.log('total_loss', loss_dict['loss'], on_step=True, logger=True, prog_bar=True)
        if self.local_rank == 0 and self.global_rank == 0 and batch_idx % 30 == 0:
            weights = self.criterion.weights
            utils.log_string("Weights: {}".format(weights), log_file)
            utils.log_string('Epoch: {}, Loss: {:.5f} = L_Mnfld: {:.5f} + '
                             'L_NonMnfld: {:.5f} + L_Nrml: {:.5f} + L_Eknl: {:.5f} + L_Div: {:.5f} + L_Morse: {:.5f} + L_Latent: {:.10f}'.format(
                self.current_epoch, loss_dict["loss"].item(), weights[0] * loss_dict["sdf_term"].item(),
                                                              weights[1] * loss_dict["inter_term"].item(),
                                                              weights[2] * loss_dict["normals_loss"].item(),
                                                              weights[3] * loss_dict["eikonal_term"].item(),
                                                              weights[4] * loss_dict["div_loss"].item(),
                                                              weights[5] * loss_dict['morse_term'].item(),
                                                              weights[6] * loss_dict['latent_reg_term'].item()),
                log_file)
            utils.log_string('Unweighted L_s : L_Mnfld: {:.5f},  '
                             'L_NonMnfld: {:.5f},  L_Nrml: {:.5f},  L_Eknl: {:.5f}, L_Morse: {:.5f}, L_Latent: {:.10f}'.format(
                loss_dict["sdf_term"].item(), loss_dict["inter_term"].item(),
                loss_dict["normals_loss"].item(), loss_dict["eikonal_term"].item(),
                loss_dict['morse_term'].item(), loss_dict['latent_reg_term'].item()),
                log_file)
        return {'loss': loss_dict['loss'], 'mnfld': mnfld_points[:1]}

    def training_epoch_end(self, outputs):
        mnfld = outputs[0]['mnfld']
        self.net.eval()
        if self.global_rank == 0 and self.local_rank == 0:
            with torch.no_grad():
                t0 = time.time()
                out_dir = "{}/vis_results/".format(args.logdir)
                os.makedirs(out_dir, exist_ok=True)
                global_feat = self.net.encoder.encode(mnfld)
                try:
                    pred_mesh = utils.implicit2mesh(decoder=self.net, mods=None, feat=global_feat,
                                                    grid_res=128,
                                                    get_mesh=True, device=next(self.net.parameters()).device)
                    pred_mesh.export(os.path.join(out_dir, "pred_mesh_{}.ply".format(self.current_epoch)))
                except Exception as e:
                    print('Can not plot')
                    print(e)
                print('Plot took {:.3f}s'.format(time.time() - t0))
        # update weights
        curr_epoch = self.current_epoch
        self.criterion.update_morse_weight(curr_epoch, self.args.num_epochs, self.args.decay_params)

    def configure_optimizers(self):
        # Setup Adam optimizers
        optimizer = torch.optim.Adam(self.trainer.model.parameters(), lr=self.learning_rate, amsgrad=True)
        lr_sch = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=150 * 10, T_mult=2, eta_min=1e-6)

        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": lr_sch,
                "interval": "step",
                "frequency": 1,
            },
        }

  • shapespace_dfaust_args.py:处理形状空间学习任务的命令行参数,定义了训练所需的各种参数,如学习率、批次大小等。
assets 文件夹
  • RP.jpg:项目相关的图片文件,可能用于项目文档或展示。
surface_reconstruction 文件夹

该文件夹包含与表面重建任务相关的代码。

  • recon_dataset.py:定义了用于表面重建任务的数据集类,负责加载和处理表面重建所需的数据。
  • run_sdf_recon.py:用于运行表面重建的脚本,会对 ./data/sdf/input 目录下的所有 *.xyz 和 *.ply 文件进行表面重建。
  • surface_recon_args.py:处理表面重建任务的命令行参数,定义了训练所需的各种参数,如学习率、批次大小等。
  • train_surface_reconstruction.py:表面重建任务的训练脚本,包含模型的定义、数据加载、优化器设置、损失函数计算和训练循环等核心训练逻辑。

你可能感兴趣的:(研究生,人工智能,深度学习,计算机视觉)