文本生成图像之AttnGAN

一个新系列文本生成图像,这个是之前一直在研究的东西,有一些idea,但gan的训练有点坑而且很费时。先记录下来,放这。

AttnGAN

一、github与论文链接

github链接: [https://github.com/cn-boop/At...]()

论文链接:AttnGAN: Fine-Grained Text to Image Generation with Attentional Generative Adversarial Networks

二、阅读总结

1.Abstract

在本文中作者提出了一个 Attentional Generative Ad-
versarial Network(AttnGAN),一种attention-driven的多stage的细粒度文本到图像生成器。
并借助一个深层注意多模态相似模型(deep attentional multimodal similarity model)来训练该生成器。
它首次表明 the layered attentional GAN 能够自动选择单词级别的condition来生成图像的不同部分。

2.Model Structure

  • attentional generative network

该部分使用了注意力机制来生成图像中的子区域,并且在生成每个子区域时还考虑了文本中与该子区域最相关的词。

  • Deep Attentional Multimodal Similarity Model (DAMSM)

该部分用来计算生成的图像与文本的匹配程度。用来训练生成器。

3.Pipeline

  • 输入的文本通过一个Text Encoder 得到 sentence feature 和word features
  • 用sentence feature 生成一个低分辨率的图像I0
  • 基于I0 加入 word features 和setence feature 生成更高分辨率细粒度的图像

三、代码详解

1.attentional generative network

  • step1:使用text_encoder的得到sentence features 和word features.
  • step2:sentence features(sentence embedding)提取condition ,然后与z结合产生低分辨率的图像以及对应的图像特征h0.
  • step3: 每一层低分辨图像的特征被用来生成下一层的高分辨图像特征

word feature(e,大小为DT) 与h0 (大小为D’N)通过attention model ,输出大小为D’*N的张量

1.text_encoder

text_encoder = RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM
class RNN_ENCODER(nn.Module):
    def __init__(self, ntoken, ninput=300, drop_prob=0.5,
             nhidden=128, nlayers=1, bidirectional=True):
    super(RNN_ENCODER, self).__init__()
    ......
得到的word_emb为[2,256,18]   sent_emb:[2,256]
num_words = words_embs.size(2) #每句话18个词

2.image_encoder


image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)# 256
self.emb_features = conv1x1(768, self.nef)  #self.enf:256
self.emb_cnn_code = nn.Linear(2048, self.nef)
def forward(self, x):
    features = None
    # --> fixed-size input: batch x 3 x 299 x 299
    x = nn.Upsample(size=(299, 299), mode='bilinear')(x)
    ......
image_encoder用于最后提取256*256图像的图像特征,输入是256*256的图像,输出是2048维的向量。采用的是inceptionv3的网络架构

3.generative network

第一个生成器

 if cfg.TREE.BRANCH_NUM > 0:
self.h_net1 = INIT_STAGE_G(ngf * 16, ncf)
self.img_net1 = GET_IMAGE_G(ngf)
class INIT_STAGE_G(nn.Module):  #可以看出  和stackgan的第一个生成器类似
def forward(self, z_code, c_code):
    """
    :param z_code: batch x cfg.GAN.Z_DIM
    :param c_code: batch x cfg.TEXT.EMBEDDING_DIM
    :return: batch x ngf/16 x 64 x 64
    """
    ......
输入sent_emb,得到[2,32,64,64]的h_code1和[2,3,64,64]的fake_img1。(这时候还没用到word embedding)

后面的生成器


if cfg.TREE.BRANCH_NUM > 1:
    self.h_net2 = NEXT_STAGE_G(ngf, nef, ncf)
    self.img_net2 = GET_IMAGE_G(ngf)
class NEXT_STAGE_G(nn.Module):
    def __init__(self, ngf, nef, ncf):
        super(NEXT_STAGE_G, self).__init__()
        self.gf_dim = ngf
        self.ef_dim = nef
        self.cf_dim = ncf
        ......
输入上一层的h_code以及c_code,word_embs,mask.mask是个啥不知道,在论文和源码中均未找到解释。大小为[2,18].返回[2,32,128,128]的h_code2和[2,18,64,64]的attn

生成器Loss Function


def generator_loss(netsD, image_encoder, fake_imgs, real_labels,
               words_embs, sent_emb, match_labels,
               cap_lens, class_ids):
    numDs = len(netsD)
    batch_size = real_labels.size(0)
    logs = ''
    # Forward
    errG_total = 0
    ......

3.discriminator network

class D_NET64(nn.Module):
    def __init__(self, b_jcu=True):
        super(D_NET64, self).__init__()
        ndf = cfg.GAN.DF_DIM  #64
        nef = cfg.TEXT.EMBEDDING_DIM  #256
        self.img_code_s16 = encode_image_by_16times(ndf)
        if b_jcu:
            self.UNCOND_DNET = D_GET_LOGITS(ndf, nef, bcondition=False)
                ......

辨别器Loss Function

def discriminator_loss(netD, real_imgs, fake_imgs, conditions,
                   real_labels, fake_labels):
    # Forward
    real_features = netD(real_imgs)
    fake_features = netD(fake_imgs.detach())
    # loss
    #
    cond_real_logits = netD.COND_DNET(real_features, conditions)
    ......

2.Deep Attentional Multimodal Similarity Model (DAMSM)

DAMSM Structure

  • text_encoder

    是一个双向LSTM.输出sentence embedding和word embedding (e :D*T)
  • image encoder

    
    它的输出是一个2048维的向量,代表了整个图像的特征:f'
  • s=e.transpose()*v,

这样把图像和句子结合在了一起.s(i,j)代表了句子中的第i个单词和第图像中第j个区域的相关性.
最后s和h结合得到相关损失.

DAMSM Loss Function


Q:generative pictures             D:text description
def evaluate(dataloader, cnn_model, rnn_model, batch_size):
        cnn_model.eval()
       rnn_model.eval()
    s_total_loss = 0
    w_total_loss = 0
    for step, data in enumerate(dataloader, 0):
        real_imgs, captions, cap_lens, \
            class_ids, keys = prepare_data(data)
            words_features, sent_code = cnn_model(real_imgs[-1])
        # nef = words_features.size(1)
        # words_features = words_features.view(batch_size, nef, -1)
        hidden = rnn_model.init_hidden(batch_size)
        words_emb, sent_emb = rnn_model(captions, cap_lens, hidden)
        w_loss0, w_loss1, attn = words_loss(words_features, words_emb, labels,
                                            cap_lens, class_ids, batch_size)



你可能感兴趣的:(算法,自然语言处理,深度学习,机器学习,计算机视觉)