最近在看图像生成相关论文,记录一下学习内容。感觉只看论文有点干巴,所以理论代码一对一上。
VQGAN (Vector Quantized Generative Adversarial Network) 是一种基于 GAN 的生成模型,可以将图像或文本转换为高质量的图像。
VQGAN整体模型需要两步训练。
如上图所示,从一张输入图片开始(一般是RGB图片) x ∈ R H × W × 3 x \in \mathbb{R}^{H\times W×3} x∈RH×W×3,其通过CNN Encoder编码后得到中间特征变量 z ^ ∈ R h × w × n z \hat z \in \mathbb{R}^{h\times w×n_z} z^∈Rh×w×nz。这时再引入一个codebook,注意,如果是普通的AutoEncoder,则会将 z ^ \hat z z^ 直接送入解码器中进行图像重建。而在VQVAE/VQGAN中,会将 z ^ \hat z z^进行进一步离散化编码成 z q ∈ R h × w × n z z_q\in \mathbb{R}^{h\times w×n_z} zq∈Rh×w×nz。
具体做法为:预先生成一个离散数值的codebook Z = { z k } k = 1 K , z k ∈ R n z \mathcal Z=\{z_k\}_{k=1}^{K},z_k \in \mathbb{R}^{n_z} Z={zk}k=1K,zk∈Rnz,在 z ^ \hat z z^ 的每一个编码位置都去 Z \mathcal Z Z中去寻找其距离最近的code,生成具有相同维度的变量。特别注意,这里 z ^ , z q \hat z,z_q z^,zq和 Z \mathcal Z Z中的单个编码特征的维度都为 n z n_z nz。这一步离散编码的过程就叫做“quantization”, 也就是上面的那个公式。
这样一来,就可以在已经数值离散化的 z q z_q zq基础上使用CNN Decoder进行解码:
x ^ = G ( z q ) = G ( q ( E ( x ) ) ) \hat x=G(z_q)=G(q(E(x))) x^=G(zq)=G(q(E(x)))
整个过程的自监督损失如下:
L V Q ( E , G , Z ) = ∣ ∣ x − x ^ ∣ ∣ 2 + ∣ ∣ s g [ E ( x ) ] − z q ∣ ∣ 2 + ∣ ∣ s g ( z q ) − E ( x ) ∣ ∣ 2 \mathcal L_{VQ}(E,G,Z)=||x-\hat x||^2+||sg[E(x)]-z_q||^2+||sg(z_q)-E(x)||^2 LVQ(E,G,Z)=∣∣x−x^∣∣2+∣∣sg[E(x)]−zq∣∣2+∣∣sg(zq)−E(x)∣∣2其中,上式中的第一项 L r e c \mathcal L_{rec} Lrec 为重建损失(reconstruction loss) s g [ ⋅ ] sg[·] sg[⋅] 为梯度终止操作(stop-gradient operation),其目的在于保证神经网络梯度可以正常回传,而不受离散编码的影响。因此在codebook的搭建过程中,我们看到由 z ^ \hat z z^得到 z q z_q zq之后,先计算出公式中后两项损失,然后又增加了一步detach操作。
loss = torch.mean((z_q.detach() - z)**2) + self.beta * torch.mean((z_q - z.detach())**2)
z_q = z + (z_q - z).detach()
这么一来,在其后面计算 L r e c \mathcal L_{rec} Lrec,即公式的第一项中, z q z_q zq的梯度可以顺利复制到 z ^ \hat z z^上,而不受离散编码过程的干扰。除了这个重建过程使用的自监督损失外,还加入了GAN中的对抗loss。文章里没有具体写对抗loss的类型。通过源码可以发现使用的是hinge loss。对于判别器而言,其损失函数可以笼统地表示为:
L G A N ( { E , G , Z } , D ) = l o g D ( x ) + l o g ( 1 − D ( x ^ ) ) \mathcal L_{GAN}(\{E,G,\mathcal Z\}, D)=logD(x)+log(1-D(\hat x)) LGAN({E,G,Z},D)=logD(x)+log(1−D(x^))
所以总的误差可以写成:
L = L V Q + λ L G A N \mathcal L = \mathcal L_{VQ}+\lambda \mathcal L_{GAN} L=LVQ+λLGAN
总结来说就是:
x → z ^ → z q → x ^ x\to \hat z\to z_q\to \hat x x→z^→zq→x^
下面主要来看看这三部分的代码
CNN Encoder, CNN Decoder是一种基于UNet的代码结构,具体细节可以从原文中获取,这里不在细说
class Encoder(nn.Module):
def __init__(self, args):
super(Encoder, self).__init__()
channels = [128, 128, 128, 256, 256, 512]
attn_resolutions = [16]
num_res_blocks = 2
resolution = 256
layers = [nn.Conv2d(args.image_channels, channels[0], 3, 1, 1)]
for i in range(len(channels)-1):
in_channels = channels[i]
out_channels = channels[i + 1]
for j in range(num_res_blocks):
layers.append(ResidualBlock(in_channels, out_channels))
in_channels = out_channels
if resolution in attn_resolutions:
layers.append(NonLocalBlock(in_channels))
if i != len(channels)-2:
layers.append(DownSampleBlock(channels[i+1]))
resolution //= 2
layers.append(ResidualBlock(channels[-1], channels[-1]))
layers.append(NonLocalBlock(channels[-1]))
layers.append(ResidualBlock(channels[-1], channels[-1]))
layers.append(GroupNorm(channels[-1]))
layers.append(Swish())
layers.append(nn.Conv2d(channels[-1], args.latent_dim, 3, 1, 1))
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x)
具体的模块定义可以阅读源代码,这个都不难理解。
class Decoder(nn.Module):
def __init__(self, args):
super(Decoder, self).__init__()
channels = [512, 256, 256, 128, 128]
attn_resolutions = [16]
num_res_blocks = 3
resolution = 16
in_channels = channels[0]
layers = [nn.Conv2d(args.latent_dim, in_channels, 3, 1, 1),
ResidualBlock(in_channels, in_channels),
NonLocalBlock(in_channels),
ResidualBlock(in_channels, in_channels)]
for i in range(len(channels)):
out_channels = channels[i]
for j in range(num_res_blocks):
layers.append(ResidualBlock(in_channels, out_channels))
in_channels = out_channels
if resolution in attn_resolutions:
layers.append(NonLocalBlock(in_channels))
if i != 0:
layers.append(UpSampleBlock(in_channels))
resolution *= 2
layers.append(GroupNorm(in_channels))
layers.append(Swish())
layers.append(nn.Conv2d(in_channels, args.image_channels, 3, 1, 1))
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x)
我最开始看的时候,最不明白的地方就是这个codebook,一直在想,这兄弟是哪蹦出来的。其实就是另外定义的一个网络,说白了甚至算不上一个网络就是一个
nn.Embedding()
,还是之前没看VQVAE的锅。
class Codebook(nn.Module):
def __init__(self, args):
super(Codebook, self).__init__()
self.num_codebook_vectors = args.num_codebook_vectors
self.latent_dim = args.latent_dim
self.beta = args.beta
self.embedding = nn.Embedding(self.num_codebook_vectors, self.latent_dim)
self.embedding.weight.data.uniform_(-1.0 / self.num_codebook_vectors, 1.0 / self.num_codebook_vectors)
def forward(self, z):
z = z.permute(0, 2, 3, 1).contiguous()
z_flattened = z.view(-1, self.latent_dim)
d = torch.sum(z_flattened**2, dim=1, keepdim=True) + \
torch.sum(self.embedding.weight**2, dim=1) - \
2*(torch.matmul(z_flattened, self.embedding.weight.t()))
min_encoding_indices = torch.argmin(d, dim=1)
z_q = self.embedding(min_encoding_indices).view(z.shape)
loss = torch.mean((z_q.detach() - z)**2) + self.beta * torch.mean((z_q - z.detach())**2)
z_q = z + (z_q - z).detach()
z_q = z_q.permute(0, 3, 1, 2)
return z_q, min_encoding_indices, loss
经VQGAN得到的压缩图像与真实图像有一个本质性的不同:真实图像的像素值具有连续性,相邻的颜色更加相似,而压缩图像的像素值则没有这种连续性。
压缩图像的这一特性让寻找一个压缩图像生成模型变得异常困难。多数强大的真实图像生成模型(比如GAN)都是输出一个连续的浮点颜色值,再做一个浮点转整数的操作,得到最终的像素值。而对于压缩图像来说,这种输出连续颜色的模型都不适用了。而恰好,Transformer天生就支持建模离散的输出。在NLP中,每个单词都可以用一个离散的数字表示。Transformer会不断生成表示单词的数字,以达到生成句子的效果。
VQGAN的作者使用了自回归图像生成模型的常用做法,给图像的每个像素从左到右,从上到下规定一个顺序。有了先后顺序后,图像就可以被视为一个一维句子,可以用Transfomer生成句子的方式来生成图像了。在第i 步,Transformer会根据前i−1 个像素 s < i s_{s<i生成第 i i i 个像素 s i s_i si.
来看具体实现——训练过程:
现在进入第二步,这篇论文毕竟是个图像生成的任务,注意之前的三个零件已经训练好不动了,现在我们需要得到一组排列好的code,送进CNN Decoder中来实现图像生成。那么这组code怎么来的?这就是Transformer发挥作用的地方了。该工作使用的Transformer模型为著名的GPT-2。迁移到VQGAN中,即可理解为先预测一个code,再一步步地通过已经预测好的code去推断下一个code。
code都是从训练好的codebook Z \mathcal Z Z中寻找,就像写文章一样,你有词典了,现在你要从词典中一个字一个字的写成一篇新文章
为了训练Transformer,
假设被替换后的code组合的索引为modified_indices,原本 z q z_q zq的code索引为unmodified_indices,那么Transformer的学习过程即为:喂入modified_indices,通过训练学习重构出unmodified_indices。
L t r a n s f o r m e r = E x ∼ p ( x ) [ − l o g p ( s ) ] \mathcal L_{transformer}=\mathbb E_{x\sim p(x)}[-logp(s)] Ltransformer=Ex∼p(x)[−logp(s)]
代码具体实现如下:
"""
首先得到由x前传得到的unmodified_indices
"""
sos_tokens = torch.ones(x.shape[0], 1) * self.sos_token
# (B, 1), sos_token是一个整数,表示从第几个token开始预测,一般为0
mask = torch.bernoulli(self.pkeep * torch.ones(unmodified_indices.shape, device=unmodified_indices.device))
# (B, h*w), 元素都为0和1,0的是mask掉的元素,1是保留的元素(比例为pkeep)
mask = mask.round().to(dtype=torch.int64)
random_indices = torch.randint_like(indices, self.transformer.config.vocab_size)
# (B, h*w), 生成一些任意的indices,用来填充被遮挡的部分
modified_indices= mask * unmodified_indices+ (1 - mask) * random_indices
# (B, h*w), mask为1(未遮挡)部分仍然保留原始indices,mask为0(遮挡)部分用random_indices填充
modified_indices= torch.cat((sos_tokens, modified_indices), dim=1)
# (B, h*w+1),将0放到第一个indice前面
targets = unmodified_indices
logits, _ = self.transformer(modified_indices[:, :-1])
# logits: (B, h*w, num_codebook_vectors), 意思是h*w个indices处,预测出来的对应每一个codebook_vector的概率
"""
然后再由logits和targets之间计算交叉熵损失
"""
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.reshape(-1))
注意这是训练的过程,不是生成的过程。在VQGAN无条件生成图片的过程中,没有任何先验条件,CNN Encoder直接被弃用。我们需要得到一组排列好的code,送进CNN Decoder中来实现图像生成。