[StackGAN实践] [2]网络训练

网络结构

论文中的网络结构图如下,embedding的提取直接使用预训练好的text encoder进行提取(不是本文重点)。提出的StackGAN整个模型包含2个GAN网络,分别用于两个阶段:
Stage1 :embedding+ noise 为输入,利用GAN输出低分辨率的64x64大小的影像;
Stage2 :embedding+ Stage I的低分辨率生成影像 为输入,利用GAN输高分辨率的256x256大小的影像
[StackGAN实践] [2]网络训练_第1张图片
结合代码,stage I与stage II 的详细结构如下:

注意:其实代码中stage II 鉴别器输出的logit 有两种,分为condition 和uncondition,分别对应着有无引入embedding信息。(图中只显示了condition的logit输出)

每个阶段的GAN训练流程是相同的:

  1. 生成fake img;
  2. 训练鉴别器。考虑三种鉴别器输入,(1)real pairs:真实图像与对应的文本embedding,gt为 1;(2)wrong pairs:真实图像与不匹配的文本embedding,gt为 0;(3) fake pairs:生成图像与对应的文本embedding,gt为 0
  3. 训练生成器。鉴别器输入只考虑fake pairs,,gt为 1

网络训练

官方的pytorch实现有问题,stage I的生成器损失无法收敛。
在尝试1、提高D的lr同时降低G的lr 以及 反过来调整lr;2、提高G的通道数;3、使用改进的GAN损失函数形式 后,均无法正常收敛。

你可能感兴趣的:(StackGAN)