StyleGAN-v2 ADA 的 pytorch 代码分析与实际使用

R e s o u r c e s \rm Resources Resources

  P a p e r \rm Paper Paper >> T r a i n i n g   G e n e r a t i v e   A d v e r s a r i a l   N e t w o r k s   w i t h   L i m i t e d   D a t a \rm Training~Generative~Adversarial~Networks~with~Limited~Data Training Generative Adversarial Networks with Limited Data
C o d e s \rm Codes Codes >>tensorflow/pytorch
B l o g s \rm Blogs Blogs  >> C S D N : \rm CSDN: CSDN: T r a i n i n g   G e n e r a t i v e   A d v e r s a r i a l   N e t w o r k s   w i t h   L i m i t e d   D a t a \rm Training~Generative~Adversarial~Networks~with~Limited~Data Training Generative Adversarial Networks with Limited Data

A n a l y s i s \rm Analysis Analysis

1. 1. 1. 计算 L a d v L_{adv} Ladv 时,前馈真实样本或伪样本给鉴别器 D D D;

## 对生成伪样本
fake_logits = self.run_D(fake_img, fake_c)
## 对真实训练样本
real_logits = self.run_D(real_img, real_c)

2. 2. 2. run_D 的具体细节为:

def run_D(self, img, c):
    ## 使用定义的含 p 的数据增强流水线作 aug
    img = self.augment_pipe(img) 
    logits = self.D(img, c)
    return logits

3. 3. 3. 记录 L a d v L_{adv} Ladv E [ s i g n ( D t r a i n ) ] {\mathbb E}[\rm sign (D_{train})] E[sign(Dtrain)] (这里使用的是 WGAN,所以 L a d v L_{adv} Ladv 计算方式比较简单, min ⁡ / max ⁡ \min/\max min/max ?_logits 即可)

training_stats.report('Loss/scores/real', real_logits)
training_stats.report('Loss/signs/real' , real_logits.sign()) ##  

4. 4. 4. 具体的,每一个被 r e p o r t e d \rm reported reported 的状态(统计数据 s t a t i s t i c \rm statistic statistic)被记录了 3 个统计量:

## `elems` 是一个形参,这里考虑 tensor `logits`
moments = torch.stack([
    torch.ones_like(elems).sum(), ## 记录数量 (count)
    elems.sum(),                  ## 求和
    ## 计算 E[sign(D_train)] 只需要前两个统计量,即:moments[1]/moments[0]
    elems.square().sum(),
])

5. 5. 5. 累计前面 4. 4. 4. 记录的数据 moments of real_logits,累积(会使用到 _moments.add_(moments) N = 4 N=4 N=4 次迭代( i t e r / m i n i b a t c h \rm iter/minibatch iter/minibatch
6. 6. 6. 通过获取 E [ s i g n ( D t r a i n ) ] \mathbb E[\rm sign(D_{train})] E[sign(Dtrain)] 来动态更新 p

# Execute ADA heuristic.
if (ada_stats is not None) and \ ## 是否使用 ADA 这一项技术
   (batch_idx % ada_interval == 0): ## N 值
    ada_stats.update()
    adjust = np.sign(ada_stats['Loss/signs/real'] - ada_target) \ ## ada_target 是 r_t 的阈值,文中设置是 0.6
                                                                  ## ada_stats['Loss/signs/real'] = moments_of_real_logits[1]/moments_of_real_logits[0]
             * (batch_size * ada_interval) / (ada_kimg * 1000)    ## 增益,the gain := (BxN)/SCALE, `B` is batch size, `N` is # of batches; all-in-all, it is FIXED.
    ## 更新 p 值
    augment_pipe.p.copy_((augment_pipe.p + adjust)\               ## D 偏强,则 adjust 为正,Aug 强度适当增大;D 偏弱,则 adjust 为负,Aug 强度适当减弱
                         .max(misc.constant(0, device=device)))   ## clip/truncate,限制概率在有效范围

U s a g e   o f   A u g P i p e \rm Usage~of~AugPipe Usage of AugPipe

引用脚本文件 ./training/augment.py,直接初始化 nn.Module 模块——

import augment

aug_pipe = augment.AugmentPipe()
## input --type=torch.tensor --size=(N,C,H,W)
aug_input = aug_pipe(input)

B T W \rm BTW BTW,这个项目对于 pytorch 多线程、多卡并行训练编程有非常好的借鉴性,木奉!

你可能感兴趣的:(走在填坑路上,深度学习,算法)