手把手带你快速入门超越GAN的Normalizing Flow

手把手带你快速入门超越GAN的Normalizing Flow

作者:Aryansh Omray,微软数据科学工程师,Medium技术博主

机器学习领域的一个基本问题就是如何学习复杂数据的表征是机器学习。

这项任务的重要性在于,现存的大量非结构化和无标签的数据,只有通过无监督式学习才能理解。密度估计、异常检测、文本总结、数据聚类、生物信息学、DNA建模等各方面的应用均需要完成这项任务。

多年来,研究人员发明了许多方法来学习大型数据集的概率分布,包括生成对抗网络(GAN)、变分自编码器(VAE)和Normalizing Flow等。

本文即向大家介绍Normalizing Flow这一为了克服GAN和VAE的不足而提出的方法。

手把手带你快速入门超越GAN的Normalizing Flow_第1张图片
Glow模型的输出样例 (Source)

GAN和VAE的能力本已十分惊人,它们都能通过简单的推理方法学习十分复杂的数据分布。

然而,GAN和VAE都缺乏对概率分布的精确评估和推理,这往往导致VAE中的模糊结果质量不高,GAN训练也面临着如模式崩溃和后置崩溃等挑战。

因此,Normalizing Flow应运而生,试图通过使用可逆函数来解决目前GAN和VAE存在的许多问题。

Normalizing Flow

简单地说,Normalizing Flow就是一系列的可逆函数,或者说这些函数的解析逆是可以计算的。例如,f(x)=x+2是一个可逆函数,因为每个输入都有且仅有一个唯一的输出,并且反之亦然,而f(x)=x²则不是一个可逆函数。这样的函数也被称为双射函数。

手把手带你快速入门超越GAN的Normalizing Flow_第2张图片
图源作者

从上图可以看出,Normalizing Flow可以将复杂的数据点(如MNIST中的图像)转化为简单的高斯分布,反之亦然。和GAN非常不一样的地方是,GAN输入的是一个随机向量,而输出的是一个图像,基于流(Flow)的模型则是将数据点转化为简单分布。在上图的MNIST一例中,我们从高斯分布中抽取随机样本,均可重新获得其对应的MNIST图像。

基于流的模型使用负对数可能性损失函数进行训练,其中p(z)是概率函数。下面的损失函数就是使用统计学中的变量变化公式得到的。

手把手带你快速入门超越GAN的Normalizing Flow_第3张图片
(Source)

Normalizing Flow的优势

与GAN和VAE相比,Normalizing Flow具有各种优势,包括:

  • Normalizing Flow模型不需要在输出中放入噪声,因此可以有更强大的局部方差模型(local variance model);
  • 与GAN相比,基于流的模型训练过程非常稳定,GAN则需要仔细调整生成器和判别器的超参数;
  • 与GAN和VAE相比,Normalizing Flow更容易收敛。

Normalizing Flow的不足

虽然基于流的模型有其优势,但它们也有一些缺点:

  • 基于流的模型在密度估计等任务上的表现不尽如人意;
  • 基于流的模型要求保留变换的体积(volume preservation over transformations),这往往会产生非常高维的潜在空间,通常会导致解释性变差;
  • 基于流的模型产生的样本通常没有GAN和VAE的好。

为了更好地理解Normalizing Flow,我们以Glow架构为例进行解释。Glow是OpenAI在2018年提出的一个基于流的模型。下图展示了Glow的架构。

手把手带你快速入门超越GAN的Normalizing Flow_第4张图片
Glow的架构(Source)

Glow架构由多个表层(superficial layers)组合而成。首先我们来看看Glow模型的多尺度框架。Glow模型由一系列的重复层(命名为尺度)组成。每个尺度包括一个挤压函数和一个流步骤,每个流步骤包含ActNorm、1x1 Convolution和Coupling Layer,流步骤后是分割函数。分割函数在通道维度上将输入分成两个相等的部分。其中一半进入之后的层,另一半则进入损失函数。分割是为了减少梯度消失的影响,梯度消失会在模型以端到端方式(end-to-end)训练时出现。

如下图所示,挤压函数(squeeze function)通过横向重塑张量,将大小为[c, h, w]的输入张量转换为大小为[4c, h/2, w/2]的张量。此外,在测试阶段可以采用重塑函数,将输入的[4c, h/2, w/2]重塑为大小为[c, h, w]的张量。

手把手带你快速入门超越GAN的Normalizing Flow_第5张图片
(Source)

其他层,如ActNorm、1x1 Convolution和Affine Coupling层,可以从下表理解。该表展示了每层的功能(包括正向和反向)。
手把手带你快速入门超越GAN的Normalizing Flow_第6张图片
(Source)

实现

在了解了Normalizing Flow和Glow模型的基础知识后,我们将介绍如何使用PyTorch实现该模型,并在MNIST数据集上进行训练。

Glow模型

首先,我们将使用PyTorch和nflows实现Glow架构。为了节省时间,我们使用nflows包含所有层的实现。

import torch
import torch.nn as nn
import torch.nn.functional as F
from nflows import transforms
import numpy as np
from torchvision.transforms.functional import resize
from nflows.transforms.base import Transform

class Net(nn.Module):

    def __init__(self, in_channel, out_channels):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channel, 64, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 1),
            nn.ReLU(inplace=True),
            ZeroConv2d(64, out_channels),
        )

    def forward(self, inp, context=None):
        return self.net(inp)


def getGlowStep(num_channels, crop_size, i):
    mask = [1] * num_channels
    
    if i % 2 == 0:
        mask[::2] = [-1] * (len(mask[::2]))
    else:
        mask[1::2] = [-1] * (len(mask[1::2]))

    def getNet(in_channel, out_channels):
        return Net(in_channel, out_channels)

    return transforms.CompositeTransform([
        transforms.ActNorm(num_channels),
        transforms.OneByOneConvolution(num_channels),
        transforms.coupling.AffineCouplingTransform(mask, getNet)
    ])



def getGlowScale(num_channels, num_flow, crop_size):
    z = [getGlowStep(num_channels, crop_size, i) for i in range(num_flow)]
    return transforms.CompositeTransform([
        transforms.SqueezeTransform(),
        *z
    ])


def getGLOW():
    num_channels = 1 * 4
    num_flow = 32
    num_scale = 3
    crop_size = 28 // 2
    transform = transforms.MultiscaleCompositeTransform(num_scale)
    for i in range(num_scale):
        next_input = transform.add_transform(getGlowScale(num_channels, num_flow, crop_size),
                                             [num_channels, crop_size, crop_size])
        num_channels *= 2
        crop_size //= 2

    return transform

Glow_model = getGLOW()

我们可以用各种数据集来训练Glow模型,如MNIST、CIFAR-10、ImageNet等。本文为了演示方便,使用的是MNIST数据集。

MNIST这样的数据集可以很容易地从格物钛公开数据集平台获取,该平台包含了机器学习中所有常用的开放数据集,如分类、密度估计、物体检测和基于文本的分类数据集等。

手把手带你快速入门超越GAN的Normalizing Flow_第7张图片

要访问数据集,我们只需要在格物钛的平台上创建账户,就可以直接fork想要的数据集,可以直接下载或者使用格物钛提供的pipeline导入数据集。基本的代码和相关文档可在TensorBay的支持网页上获得。

手把手带你快速入门超越GAN的Normalizing Flow_第8张图片

结合格物钛TensorBay的Python SDK,我们可以很方便地导入MNIST数据集到PyTorch中:

from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms

from tensorbay import GAS
from tensorbay.dataset import Dataset as TensorBayDataset

class MNISTSegment(Dataset):

    def __init__(self, gas, segment_name, transform):
        super().__init__()
        self.dataset = TensorBayDataset("MNIST", gas)
        self.segment = self.dataset[segment_name]
        self.category_to_index = self.dataset.catalog.classification.get_category_to_index()
        self.transform = transform

    def __len__(self):
        return len(self.segment)

    def __getitem__(self, idx):
        data = self.segment[idx]
        with data.open() as fp:
            image_tensor = self.transform(Image.open(fp))

        return image_tensor, self.category_to_index[data.label.classification.category]

模型训练

模型训练可以通过下面的代码简单开始。该代码使用格物钛TensorBay提供的Pipeline创建数据加载器,其中的ACCESS_KEY可以在TensorBay的账户设置中获得。

from nflows.distributions import normal

ACCESS_KEY = "Accesskey-*****"
EPOCH = 100

to_tensor = transforms.ToTensor()
normalization = transforms.Normalize(mean=[0.485], std=[0.229])
my_transforms = transforms.Compose([to_tensor, normalization])

train_segment = MNISTSegment(GAS(ACCESS_KEY), segment_name="train", transform=my_transforms)
train_dataloader = DataLoader(train_segment, batch_size=4, shuffle=True, num_workers=4)

optimizer = torch.optim.Adam(Glow_model.parameters(), 1e-3)

for epoch in range(EPOCH):
    for index, (image, label) in enumerate(train_dataloader):
        if index == 0:
            image_size = image.shaape[2]
            channels = image.shape[1]
        image = image.cuda()
        output, logabsdet = Glow_model._transform(image)
        shape = output.shape[1:]
        log_z = normal.StandardNormal(shape=shape).log_prob(output)
        loss = log_z + logabsdet
        loss = -loss.mean()/(image_size * image_size * channels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print(f"Epoch:{epoch+1}/{EPOCH} Loss:{loss}")

上面代码用的是MNIST数据集,要想使用其他数据集我们可以直接替换该数据集的数据加载器。

样例生成

模型训练完成之后,我们可以通过下面的代码来生成样例:

samples = Glow_model.sample(25)
display(samples)

使用nflows库之后,我们只需要用一行代码就可以生成样例,而display函数则能在一个网格中显示生成的样本。

手把手带你快速入门超越GAN的Normalizing Flow_第9张图片
用MNIST训练模型之后生成的样例

结语

本文向大家介绍了Normalizing Flow的基本知识,并与GAN和VAE进行了比较,同时向大家展示了Glow模型的基本工作方式。我们还讲解了如何简单实现Glow模型,并使用MNIST数据集进行训练。在格物钛公开数据集平台的帮助下,数据集访问变得十分便捷。

【关于格物钛】
格物钛智能科技专注打造人工智能新型基础设施,通过非结构化数据平台和公开数据集社区,帮助机器学习团队和个人更好地释放非结构化数据潜力,让AI应用开发更快、性能表现更优,持续为人工智能赋能千行百业、驱动产业升级、推进科技普惠打造坚实基础。目前已获得红杉、云启、真格、风和、耀途资本以及奇绩创坛的千万美金投资。

你可能感兴趣的:(手把手带你快速入门超越GAN的Normalizing Flow)