一个新系列文本生成图像,这个是之前一直在研究的东西,有一些idea,但gan的训练有点坑而且很费时。先记录下来,放这。
AttnGAN
一、github与论文链接
github链接: [https://github.com/cn-boop/At...]()
论文链接:AttnGAN: Fine-Grained Text to Image Generation with Attentional Generative Adversarial Networks
二、阅读总结
1.Abstract
在本文中作者提出了一个 Attentional Generative Ad-
versarial Network(AttnGAN),一种attention-driven的多stage的细粒度文本到图像生成器。
并借助一个深层注意多模态相似模型(deep attentional multimodal similarity model)来训练该生成器。
它首次表明 the layered attentional GAN 能够自动选择单词级别的condition来生成图像的不同部分。
2.Model Structure
- attentional generative network
该部分使用了注意力机制来生成图像中的子区域,并且在生成每个子区域时还考虑了文本中与该子区域最相关的词。
- Deep Attentional Multimodal Similarity Model (DAMSM)
该部分用来计算生成的图像与文本的匹配程度。用来训练生成器。
3.Pipeline
- 输入的文本通过一个Text Encoder 得到 sentence feature 和word features
- 用sentence feature 生成一个低分辨率的图像I0
- 基于I0 加入 word features 和setence feature 生成更高分辨率细粒度的图像
三、代码详解
1.attentional generative network
- step1:使用text_encoder的得到sentence features 和word features.
- step2:sentence features(sentence embedding)提取condition ,然后与z结合产生低分辨率的图像以及对应的图像特征h0.
- step3: 每一层低分辨图像的特征被用来生成下一层的高分辨图像特征
word feature(e,大小为DT) 与h0 (大小为D’N)通过attention model ,输出大小为D’*N的张量
1.text_encoder
text_encoder = RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM
class RNN_ENCODER(nn.Module):
def __init__(self, ntoken, ninput=300, drop_prob=0.5,
nhidden=128, nlayers=1, bidirectional=True):
super(RNN_ENCODER, self).__init__()
......
得到的word_emb为[2,256,18] sent_emb:[2,256]
num_words = words_embs.size(2) #每句话18个词
2.image_encoder
image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)# 256
self.emb_features = conv1x1(768, self.nef) #self.enf:256
self.emb_cnn_code = nn.Linear(2048, self.nef)
def forward(self, x):
features = None
# --> fixed-size input: batch x 3 x 299 x 299
x = nn.Upsample(size=(299, 299), mode='bilinear')(x)
......
image_encoder用于最后提取256*256图像的图像特征,输入是256*256的图像,输出是2048维的向量。采用的是inceptionv3的网络架构
3.generative network
第一个生成器
if cfg.TREE.BRANCH_NUM > 0:
self.h_net1 = INIT_STAGE_G(ngf * 16, ncf)
self.img_net1 = GET_IMAGE_G(ngf)
class INIT_STAGE_G(nn.Module): #可以看出 和stackgan的第一个生成器类似
def forward(self, z_code, c_code):
"""
:param z_code: batch x cfg.GAN.Z_DIM
:param c_code: batch x cfg.TEXT.EMBEDDING_DIM
:return: batch x ngf/16 x 64 x 64
"""
......
输入sent_emb,得到[2,32,64,64]的h_code1和[2,3,64,64]的fake_img1。(这时候还没用到word embedding)
后面的生成器
if cfg.TREE.BRANCH_NUM > 1:
self.h_net2 = NEXT_STAGE_G(ngf, nef, ncf)
self.img_net2 = GET_IMAGE_G(ngf)
class NEXT_STAGE_G(nn.Module):
def __init__(self, ngf, nef, ncf):
super(NEXT_STAGE_G, self).__init__()
self.gf_dim = ngf
self.ef_dim = nef
self.cf_dim = ncf
......
输入上一层的h_code以及c_code,word_embs,mask.mask是个啥不知道,在论文和源码中均未找到解释。大小为[2,18].返回[2,32,128,128]的h_code2和[2,18,64,64]的attn
生成器Loss Function
def generator_loss(netsD, image_encoder, fake_imgs, real_labels,
words_embs, sent_emb, match_labels,
cap_lens, class_ids):
numDs = len(netsD)
batch_size = real_labels.size(0)
logs = ''
# Forward
errG_total = 0
......
3.discriminator network
class D_NET64(nn.Module):
def __init__(self, b_jcu=True):
super(D_NET64, self).__init__()
ndf = cfg.GAN.DF_DIM #64
nef = cfg.TEXT.EMBEDDING_DIM #256
self.img_code_s16 = encode_image_by_16times(ndf)
if b_jcu:
self.UNCOND_DNET = D_GET_LOGITS(ndf, nef, bcondition=False)
......
辨别器Loss Function
def discriminator_loss(netD, real_imgs, fake_imgs, conditions,
real_labels, fake_labels):
# Forward
real_features = netD(real_imgs)
fake_features = netD(fake_imgs.detach())
# loss
#
cond_real_logits = netD.COND_DNET(real_features, conditions)
......
2.Deep Attentional Multimodal Similarity Model (DAMSM)
DAMSM Structure
-
text_encoder
是一个双向LSTM.输出sentence embedding和word embedding (e :D*T)
-
image encoder
它的输出是一个2048维的向量,代表了整个图像的特征:f'
- s=e.transpose()*v,
这样把图像和句子结合在了一起.s(i,j)代表了句子中的第i个单词和第图像中第j个区域的相关性.
最后s和h结合得到相关损失.
DAMSM Loss Function
Q:generative pictures D:text description
def evaluate(dataloader, cnn_model, rnn_model, batch_size):
cnn_model.eval()
rnn_model.eval()
s_total_loss = 0
w_total_loss = 0
for step, data in enumerate(dataloader, 0):
real_imgs, captions, cap_lens, \
class_ids, keys = prepare_data(data)
words_features, sent_code = cnn_model(real_imgs[-1])
# nef = words_features.size(1)
# words_features = words_features.view(batch_size, nef, -1)
hidden = rnn_model.init_hidden(batch_size)
words_emb, sent_emb = rnn_model(captions, cap_lens, hidden)
w_loss0, w_loss1, attn = words_loss(words_features, words_emb, labels,
cap_lens, class_ids, batch_size)