P a p e r \rm Paper Paper >> T r a i n i n g G e n e r a t i v e A d v e r s a r i a l N e t w o r k s w i t h L i m i t e d D a t a \rm Training~Generative~Adversarial~Networks~with~Limited~Data Training Generative Adversarial Networks with Limited Data
C o d e s \rm Codes Codes >>tensorflow
/pytorch
B l o g s \rm Blogs Blogs >> C S D N : \rm CSDN: CSDN: T r a i n i n g G e n e r a t i v e A d v e r s a r i a l N e t w o r k s w i t h L i m i t e d D a t a \rm Training~Generative~Adversarial~Networks~with~Limited~Data Training Generative Adversarial Networks with Limited Data
1. 1. 1. 计算 L a d v L_{adv} Ladv 时,前馈真实样本或伪样本给鉴别器 D D D;
## 对生成伪样本
fake_logits = self.run_D(fake_img, fake_c)
## 对真实训练样本
real_logits = self.run_D(real_img, real_c)
2. 2. 2. run_D
的具体细节为:
def run_D(self, img, c):
## 使用定义的含 p 的数据增强流水线作 aug
img = self.augment_pipe(img)
logits = self.D(img, c)
return logits
3. 3. 3. 记录 L a d v L_{adv} Ladv 和 E [ s i g n ( D t r a i n ) ] {\mathbb E}[\rm sign (D_{train})] E[sign(Dtrain)] (这里使用的是 WGAN
,所以 L a d v L_{adv} Ladv 计算方式比较简单, min / max \min/\max min/max ?_logits
即可)
training_stats.report('Loss/scores/real', real_logits)
training_stats.report('Loss/signs/real' , real_logits.sign()) ##
4. 4. 4. 具体的,每一个被 r e p o r t e d \rm reported reported 的状态(统计数据 s t a t i s t i c \rm statistic statistic)被记录了 3 个统计量:
## `elems` 是一个形参,这里考虑 tensor `logits`
moments = torch.stack([
torch.ones_like(elems).sum(), ## 记录数量 (count)
elems.sum(), ## 求和
## 计算 E[sign(D_train)] 只需要前两个统计量,即:moments[1]/moments[0]
elems.square().sum(),
])
5. 5. 5. 累计前面 4. 4. 4. 记录的数据 moments of real_logits
,累积(会使用到 _moments.add_(moments)
) N = 4 N=4 N=4 次迭代( i t e r / m i n i b a t c h \rm iter/minibatch iter/minibatch)
6. 6. 6. 通过获取 E [ s i g n ( D t r a i n ) ] \mathbb E[\rm sign(D_{train})] E[sign(Dtrain)] 来动态更新 p
# Execute ADA heuristic.
if (ada_stats is not None) and \ ## 是否使用 ADA 这一项技术
(batch_idx % ada_interval == 0): ## N 值
ada_stats.update()
adjust = np.sign(ada_stats['Loss/signs/real'] - ada_target) \ ## ada_target 是 r_t 的阈值,文中设置是 0.6
## ada_stats['Loss/signs/real'] = moments_of_real_logits[1]/moments_of_real_logits[0]
* (batch_size * ada_interval) / (ada_kimg * 1000) ## 增益,the gain := (BxN)/SCALE, `B` is batch size, `N` is # of batches; all-in-all, it is FIXED.
## 更新 p 值
augment_pipe.p.copy_((augment_pipe.p + adjust)\ ## D 偏强,则 adjust 为正,Aug 强度适当增大;D 偏弱,则 adjust 为负,Aug 强度适当减弱
.max(misc.constant(0, device=device))) ## clip/truncate,限制概率在有效范围
引用脚本文件 ./training/augment.py
,直接初始化 nn.Module
模块——
import augment
aug_pipe = augment.AugmentPipe()
## input --type=torch.tensor --size=(N,C,H,W)
aug_input = aug_pipe(input)
B T W \rm BTW BTW,这个项目对于 pytorch
多线程、多卡并行训练编程有非常好的借鉴性,木奉!