击败GANs的新生成式模型:score-based model(diffusion model)原理、网络结构、应用、代码、实验、展望

前言:在近两年的NeurIPS、ICCV、CVPR等顶会中,出现了二三十篇score-based generative models相关的论文,这是一种全新的生成式模型。特别是一些论文直接喊出了beat GANs(打败GANs)的口号,全新的生成方式和部分领域领先GANs、VAE的生成效果,让越来越多的人感兴趣并投身于研究中。

  • 会不会是下一个GANs?能否解决目前GANs遇到的问题?
  • 和现有的生成式模型相比有哪些优点?哪些缺点?
  • 目前的网络结构是怎样?
  • 如何用代码实现?
  • 常用的数据集有哪些?
  • 常用的评价指标有哪些?
  • 能应用到哪些领域?
  • 遇到了哪些问题?
  • 发展的瓶颈有哪些?
  • 未来的发展会怎样?

本文就这些问题进行探讨。

目录

原理概述

为什么叫做scored-based?

郎之万动力学

score-based models与diffusion model

三维点云重建任务

网络结构

UNet

Denoising Score Matching

GANs、DPM、DDPM

GANs优点

GANs缺点

DDPM/DPM优点

DDPM/DPM缺点

常用评价指标

常用数据集

一维草图

二维图片

三维模型

应用领域

参考:


原理概述

从数据中估计分数函数,并使用朗之万(Langevin)动力学生成新的样本。因此,scored-based model和diffusion model的核心物理背景都是Langevin动力学。

因为在没有训练数据的区域,估计的分数函数是不准确的,当采样轨迹遇到这些区域时,Langevin动力学可能不能正确收敛。作为补救,用不同强度的高斯噪声对数据进行扰动,并联合估计所有噪声扰动数据分布的得分函数。在推理过程中,将所有噪声尺度的信息与Langevin动力学相结合,从每个噪声扰动分布中依次采样。

和GANs相比,最显著的优势是:

  • 不需要对抗训练的样本质量,不需要进行对抗训练。众所周知,GANs训练难一直是业界难题。主要是因为GANs这种implicit generative models的最大问题是需要对抗训练,而这种训练的方法通常会很不稳定。(PS:scored-based模型的训练也不简单)
  • 灵活的模型架构。
  • 精确的对数似然计算。
  • 不需要再训练模型的逆问题求解。 train后的模型即可参与sampling重建,不需要像StyleGAN的模型训练一个feature网络。

为什么叫做scored-based?

和GANs、VAE一样,scored-based也是implicit generative models隐式生成模型,需要确保易处理的规则化常数(这个后面会提到)以便方便的计算likelihood,而这通常意味着网络结构有较大限制,即无法像NAS那样任意组织和设计网络结构。或者必须依赖于替代的objectives来在训练过程中,近似最大似然(approximate maximum likelihood training)。

但是scored-based对log PDF的梯度进行建模得到一个名为分数函数的量,不需要处理类似likelihood-based models的规则化常数。

这个分数函数被称为:,我们的任务就是最小化模型和数据分布之间的Fisher散度:

郎之万动力学

Langevin dynamics仅通过使用分数函数来对真实数据分布 P ( x )进行马尔科夫链蒙特卡洛(Markov Chain Monte Carlo)的采样。迭代过程如下:

score-based models与diffusion model

scored-based models和diffusion models的原理上大同小异,感兴趣的同学可以参看本系列的上一篇文章:

 《Diffusion Model扩散模型与深度学习(附Python示例)》

这篇文章着重讲了从物理背景到深度学习的过程、数学推导和一般扩散过程的代码示例。本文不再赘述这方面。

三维点云重建任务

1. 一个条件生成问题,因为所考虑的马尔可夫链生成的点云的条件是一些形状潜在的点。这种条件自适应导致的训练和抽样方案与之前对扩散概率模型的研究有显著不同。
2. 二维图像相关DDPM不能直接推广到点云,这是由于三维空间中的点的采样模式是不规则的,而不是图像下方的规则网格结构。
3. 由于点云是由三维空间中的离散点组成的,将这些点视为与热浴接触的非平衡热力学系统中的粒子。在热浴的作用下,粒子的位置以它们扩散并最终扩散到空间的方式随机演化。
4. 通过在每个时间步骤添加噪声,将粒子的初始分布转化为简单的噪声分布。
5. 通过扩散过程将点云的点分布与噪声分布连接起来。为了对点云生成中的点分布进行建模,考虑了反向扩散过程,该过程从噪声分布中恢复了目标点的分布。
6. 将这种反向扩散过程建模为一个马尔可夫链,将噪声分布转换为目标分布。目标是学习它的过渡核,使马尔可夫链可以重建所需的形状。此外,由于马尔可夫链的目的是对点分布进行建模,仅靠马尔可夫链无法生成各种形状的点云。为此,引入了一个形状潜势作为过渡核的条件。在生成设置中,形状潜在遵循一个先验分布,通过标准化流参数化它,以增强模型的表达能力。在自编码的情况下,对形状潜势进行端到端学习。
7. 将训练目标表述为在形状潜势的条件下,使点云的似然值的变分下界最大化,并将其进一步表述为易于处理的封闭表达式。

网络结构

UNet

unet在医疗领域大名鼎鼎,优点是能够学到更丰富维度的信息,一定要好好看一看原始论文:《U-Net: Convolutional Networks for Biomedical Image Segmentation》。UNet模型使用了一堆剩余层和下采样卷积,然后是一堆剩余层和上采样卷积,用跳过连接将空间大小相同的层连接起来。此外使用了一个单头的16 *16分辨率的全局注意层,并在每个残差块中添加嵌入时间步长的投影。

首次在score-based model中使用unet的是论文:Ho, J., Jain, A., and Abbeel, P. Denoising diffusion probabilistic models, 2020

后续的大部分工作都是在这篇论文提出的网络结构上修修补补,经典的unet model class代码如下,复用的时候直接继承即可。

class UNetModel(nn.Module):

    def __init__(
        self,
        in_channels,
        model_channels,
        out_channels,
        num_res_blocks,
        attention_resolutions,
        dropout=0,
        channel_mult=(1, 2, 4, 8),
        conv_resample=True,
        dims=2,
        # dims=1,
        num_classes=None,
        use_checkpoint=False,
        num_heads=1,
        num_heads_upsample=-1,
        use_scale_shift_norm=False,
    ):
        super().__init__()

        if num_heads_upsample == -1:
            num_heads_upsample = num_heads

        self.in_channels = in_channels
        self.model_channels = model_channels
        self.out_channels = out_channels
        self.num_res_blocks = num_res_blocks
        self.attention_resolutions = attention_resolutions
        self.dropout = dropout
        self.channel_mult = channel_mult
        # self.channel_mult = (1, 2, 4, 8)
        self.conv_resample = conv_resample
        self.num_classes = num_classes
        self.use_checkpoint = use_checkpoint
        self.num_heads = num_heads
        self.num_heads_upsample = num_heads_upsample

        time_embed_dim = model_channels * 4
        self.time_embed = nn.Sequential(
            linear(model_channels, time_embed_dim),
            SiLU(),
            linear(time_embed_dim, time_embed_dim),
        )

        if self.num_classes is not None:
            self.label_emb = nn.Embedding(num_classes, time_embed_dim)

        self.input_blocks = nn.ModuleList(
            [
                TimestepEmbedSequential(
                    conv_nd(dims, in_channels, model_channels, 3, padding=1)
                )
            ]
        )
        input_block_chans = [model_channels]
        ch = model_channels
        ds = 1
        for level, mult in enumerate(channel_mult):
            for _ in range(num_res_blocks):
                layers = [
                    ResBlock(
                        ch,
                        time_embed_dim,
                        dropout,
                        out_channels=mult * model_channels,
                        dims=dims,
                        use_checkpoint=use_checkpoint,
                        use_scale_shift_norm=use_scale_shift_norm,
                    )
                ]
                ch = mult * model_channels
                if ds in attention_resolutions:
                    layers.append(
                        AttentionBlock(
                            ch, use_checkpoint=use_checkpoint, num_heads=num_heads
                        )
                    )
                self.input_blocks.append(TimestepEmbedSequential(*layers))
                input_block_chans.append(ch)
            if level != len(channel_mult) - 1:
                self.input_blocks.append(
                    TimestepEmbedSequential(Downsample(ch, conv_resample, dims=dims))
                )
                input_block_chans.append(ch)
                ds *= 2

        self.middle_block = TimestepEmbedSequential(
            ResBlock(
                ch,
                time_embed_dim,
                dropout,
                dims=dims,
                use_checkpoint=use_checkpoint,
                use_scale_shift_norm=use_scale_shift_norm,
            ),
            AttentionBlock(ch, use_checkpoint=use_checkpoint, num_heads=num_heads),
            ResBlock(
                ch,
                time_embed_dim,
                dropout,
                dims=dims,
                use_checkpoint=use_checkpoint,
                use_scale_shift_norm=use_scale_shift_norm,
            ),
        )

        self.output_blocks = nn.ModuleList([])
        for level, mult in list(enumerate(channel_mult))[::-1]:
            for i in range(num_res_blocks + 1):
                layers = [
                    ResBlock(
                        ch + input_block_chans.pop(),
                        time_embed_dim,
                        dropout,
                        out_channels=model_channels * mult,
                        dims=dims,
                        use_checkpoint=use_checkpoint,
                        use_scale_shift_norm=use_scale_shift_norm,
                    )
                ]
                ch = model_channels * mult
                if ds in attention_resolutions:
                    layers.append(
                        AttentionBlock(
                            ch,
                            use_checkpoint=use_checkpoint,
                            num_heads=num_heads_upsample,
                        )
                    )
                if level and i == num_res_blocks:
                    layers.append(Upsample(ch, conv_resample, dims=dims))
                    ds //= 2
                self.output_blocks.append(TimestepEmbedSequential(*layers))

        self.out = nn.Sequential(
            normalization(ch),
            SiLU(),
            zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
        )

    def convert_to_fp16(self):
        """
        Convert the torso of the model to float16.
        """
        self.input_blocks.apply(convert_module_to_f16)
        self.middle_block.apply(convert_module_to_f16)
        self.output_blocks.apply(convert_module_to_f16)

    def convert_to_fp32(self):
        """
        Convert the torso of the model to float32.
        """
        self.input_blocks.apply(convert_module_to_f32)
        self.middle_block.apply(convert_module_to_f32)
        self.output_blocks.apply(convert_module_to_f32)

    @property
    def inner_dtype(self):
        """
        Get the dtype used by the torso of the model.
        """
        return next(self.input_blocks.parameters()).dtype

    def forward(self, x, timesteps, y=None):
        """
        Apply the model to an input batch.

        :param x: an [N x C x ...] Tensor of inputs.
        :param timesteps: a 1-D batch of timesteps.
        :param y: an [N] Tensor of labels, if class-conditional.
        :return: an [N x C x ...] Tensor of outputs.
        """
        assert (y is not None) == (
            self.num_classes is not None
        ), "must specify y if and only if the model is class-conditional"

        hs = []
        emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))

        if self.num_classes is not None:
            assert y.shape == (x.shape[0],)
            emb = emb + self.label_emb(y)

        h = x.type(self.inner_dtype)
        # 此时h是和每一个batch的数据size一样
        # print(f"h size befor is {h.size()}")
        # 下采样
        for module in self.input_blocks:
            h = module(h, emb)  # 卷积+池化
            # print(f"h size after is {h.size()}")
            hs.append(h)
        # 连接层
        h = self.middle_block(h, emb)
        # 上采样
        for module in self.output_blocks:
            hs_temp = hs.pop()
            # print(f"h size is {h.size()}; hs.pop() size is {hs_temp.size()}")
            # if (h.size()[2] != hs_temp.size()[2]) or (h.size()[3] != hs_temp.size()[3]):
            #     # 一般h size大于hs size
            #     # temp_shape = (h.size()[0]*h.size()[1]*h.size()[2]*h.size()[3]) / (hs_temp.size()[0]*hs_temp.size()[2]*hs_temp.size()[3])
            #     continue
            # cat_in = th.cat([h, hs.pop()], dim=1)
            cat_in = th.cat([h, hs_temp], dim=1)
            h = module(cat_in, emb)
        h = h.type(x.dtype)
        return self.out(h)

    def get_feature_vectors(self, x, timesteps, y=None):
        """
        Apply the model and return all of the intermediate tensors.

        :param x: an [N x C x ...] Tensor of inputs.
        :param timesteps: a 1-D batch of timesteps.
        :param y: an [N] Tensor of labels, if class-conditional.
        :return: a dict with the following keys:
                 - 'down': a list of hidden state tensors from downsampling.
                 - 'middle': the tensor of the output of the lowest-resolution
                             block in the model.
                 - 'up': a list of hidden state tensors from upsampling.
        """
        hs = []
        emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
        if self.num_classes is not None:
            assert y.shape == (x.shape[0],)
            emb = emb + self.label_emb(y)
        result = dict(down=[], up=[])
        h = x.type(self.inner_dtype)
        for module in self.input_blocks:
            h = module(h, emb)
            hs.append(h)
            result["down"].append(h.type(x.dtype))
        h = self.middle_block(h, emb)
        result["middle"] = h.type(x.dtype)
        for module in self.output_blocks:
            cat_in = th.cat([h, hs.pop()], dim=1)
            h = module(cat_in, emb)
            result["up"].append(h.type(x.dtype))
        return result

Denoising Score Matching

unet运用于这一领域时间较晚,最早开山鼻祖论文2020年才发表。在此之前,业界普遍使用的是去噪分数匹配。

这一方法首先通过分数匹配去噪来学习分数函数,直观上,这意味着训练神经网络(称为评分网络)去噪被高斯噪声模糊的图像。一个关键点是使用多个噪声尺度来干扰数据,以便评分网络既能捕获粗粒度图像特征,也能捕获细粒度图像特征。然而,如何选择这些噪声尺度是一个非常棘手的问题。

其次,通过运行Langevin动力学生成样本,从白噪声入手,利用评分网络将白噪声逐步降噪成图像。

GANs、DPM、DDPM

GANs优点

1. 采样wall-clock更快

GANs缺点

1. 很难训练,在没有仔细选择超参数和正则化器的情况下崩溃。
2. gan能够以多样性换取保真度,产生高质量的样本,但不覆盖整个分布。
3. 由于对抗损失,GANs的训练过程可能不稳定。自回归模型假设生成顺序是不自然的,可能会限制模型的灵活性。

DDPM/DPM优点

DDPM = DPM + denoising score matching(denoising autoencoders)

1. 捕获了更多的多样性,而且通常比gan更容易缩放和训练。
2. 分布覆盖、固定的训练目标和易于扩展。

DDPM/DPM缺点

1. 采样的wall-clock time比gan慢。
2. 在视觉样本质量方面仍然存在不足。
3. 使用了多个去噪步骤(因此向前传递),它们在采样时间上仍然比gan慢。

常用评价指标

评价指标大部分文章都要对比GANs,所以和GANs运用的数据集相似。

  1. FID 《Gans trained by a two time-scale update rule converge to a local nash equilibrium.》 比IS能更好地捕捉多样性,比IS更好地符合人类的判断。 描述初始潜空间中两个图像分布之间距离的对称度量。

  2. Inception Score 《Improved techniques for training gans》 衡量了一个模型在捕获完整的ImageNet类分布的同时,仍然产生单个类的令人信服的样本的程度。这个度量的一个缺点是,它没有奖励覆盖整个分布或捕获类中的多样性,并且记住完整数据集的一小部分的模型仍然会有很高的IS。

  3. Precision 《Improved precision and recall metric for assessing generative models》 主要描述精度、模型保真度。

  4. recall 主要描述查全率、衡量多样性、分布覆盖率。

  5. retrieval
    用retrieval对比来说明重建效果也是常用的方法

常用数据集

一维草图

  • https://quickdraw.withgoogle.com/

二维图片

  • imagenet:ImageNet
  • LSUN lmdb
  • FFHQ
  • CelebA
  • cifar10,可以使用以下代码下载:
    import os
    import tempfile
    
    import torchvision
    from tqdm.auto import tqdm
    
    CLASSES = (
        "plane",
        "car",
        "bird",
        "cat",
        "deer",
        "dog",
        "frog",
        "horse",
        "ship",
        "truck",
    )
    
    
    def main():
        for split in ["train", "test"]:
            out_dir = f"cifar_{split}"
            if os.path.exists(out_dir):
                print(f"skipping split {split} since {out_dir} already exists.")
                continue
    
            print("downloading...")
            with tempfile.TemporaryDirectory() as tmp_dir:
                dataset = torchvision.datasets.CIFAR10(
                    root=tmp_dir, train=split == "train", download=True
                )
    
            print("dumping images...")
            os.mkdir(out_dir)
            for i in tqdm(range(len(dataset))):
                image, label = dataset[i]
                filename = os.path.join(out_dir, f"{CLASSES[label]}_{i:05d}.png")
                image.save(filename)
    
    
    if __name__ == "__main__":
        main()
    

三维模型

  • shapenet:ShapeNet简介和下载、binvox文件python示例_沉迷单车的追风少年-CSDN博客

应用领域

  • 音频建模
    DiffWave: A Versatile Diffusion Model for Audio Synthesis
    PriorGrad: Improving Conditional Denoising Diffusion Models with Data-Driven Adaptive Prior
  • 语音合成
    Grad-TTS: A Diffusion Probabilistic Model for Text-to-Speech
  • 时间序列预测
    Autoregressive Denoising Diffusion Models for Multivariate Probabilistic Time Series Forecasting
  • 二维图像生成
    Diffusion Models Beat GANs on Image Synthesis
    Improved Denoising Diffusion Probabilistic Models
    Denoising Diffusion Probabilistic Models
    Improved Techniques for Training Score-Based Generative Models
  • 三维点云重建
    Diffusion Probabilistic Models for 3D Point Cloud Generation

参考:

  • [生成模型新方向]: score-based generative models_g11d111的博客-CSDN博客
  • Diffusion Model扩散模型与深度学习(附Python示例)_沉迷单车的追风少年-CSDN博客
  • ShapeNet简介和下载、binvox文件python示例_沉迷单车的追风少年-CSDN博客
  • Yang Song | Generative Modeling by Estimating Gradients of the Data Distribution
  • Improved Techniques for Training Score-Based Generative Models
  • Diffusion Models Beat GANs on Image Synthesis

你可能感兴趣的:(深度学习,神经网络,pytorch)