StyleGAN2代码阅读笔记

源代码地址:https://github.com/NVlabs/stylegan2-ada-pytorch

  • 这是一篇代码阅读笔记,顾名思义是对代码进行阅读,讲解的笔记。对象是styleGAN2的pytorch版本的代码,在github上有一个开源库。一边笔记方便我回顾,一边也对深度学习初学者有一些阅读理解代码的示例作用吧。代码毕竟是基本功,看到我发的一些代码,评论区问的一些问题实在是代码基础不行。
  • 读一个代码的目的很多样,有些只是想看它的预处理,有些只是想看他的模型,不同目的阅读方式和详略程度也不同。我这里是为了全部读懂给女朋友讲解的,所以会以全部读懂为目的来进行阅读和笔记。

train.py

  • 起始点有很多种,我习惯从主函数开始读。train.py是主程序文件,直接拉到最后可以看到,主函数是从main函数开始的,那就找到main函数,438行开始

train.py/main()

  • 482行看名字就知道是输出日志用的,不重要,可以先跳过
  • 486行,可以看到一些主要参数是在这里设置的,后续要找再来看,这里先跳过
  • 491-498行,设置了一些输出路径,也是跳过,有需要再来看
  • 520-524行,将训练设置保存到一个json文件中
  • 526行到533行是程序的主体部分,在这里启用了多线程,运行subprocess_fn函数,因此下一步就看这个函数。这里稍微展开说一下这个多线程是怎么回事,就是利用了torch.multiprocessing实现了每个GPU分配一个线程,并且多线程之间是用spawn方式创建的。也就是说,你有多少个GPU,就会同时运行多少个subprocess_fn函数,并且spawn方式意味着这些线程都有独立的python解释器程序,资源是复制的,有自己的独立内存而非全部共享内存,而529行是指定了一个临时路径用来给这些线程进行交流,在这个路径下实现需要共享的部分变量。

train.py/subprocess_fn()

  • 367行到380行是torch分布式训练的一些初始化设置
  • 主程序在training_loop模块的training_loop方法中,接下来就跳到这里

training/training_loop.py

training/training_loop.py/training_loop()

  • 这里终于遇到第一个关键点,136行,数据集,调用了construct_class_by_name函数。后面会讲解construct_class_by_name这个函数,这里只需要知道它是个根据输入的参数,返回一个根据参数确定的类的方法即可。最近越来越多的深度学习代码使用这种包装方式,本质上就算想用字符串来调用类,又为了代码统一和简洁,包装得一层接一层,读起来是真的麻烦,而且类名隐藏起来了,甚至无法用vscode的智能追踪来找这里用到的到底是什么类。
  • 这里直接说,training_set根据train.py的107行,是training.dataset.ImageFolderDataset类的对象
  • 以及150行和151行,G和D根据train.py的176 177行分别是training.networks.Generator类和training.networks.Discriminator类的对象。而G_ema是G的一个指数移动平均版本,在训练过程中,G的参数会随着step而更新,而G_ema是G的迭代过程中各个时期的参数的指数移动平均版本,相比G,G_ema的变化更加柔和,这是个常用的技巧。
  • 154到159行的代码加载了模型的参数,可能是预训练的也可能是训练中止接着跑的。
  • 175行 augment_pipe 是train.py 287行 training.augment.AugmentPipe类的对象
  • 180到190进行了多线程训练的模型包装
  • 192到214行定义了训练的几个阶段。这里展开解释一下,GAN的训练策略相比普通模型稍微有些复杂,训练是分阶段的,每个iteration通常要分别训练G和D,并且在训练G的时候,D的参数要固定,训练D的时候,G的参数要固定。这段代码定义了4个阶段:Gmain Greg Dmain Dreg。
  • 195行 loss 根据train.py 187行,是training.loss.StyleGAN2Loss类的对象
  • 199行 opt 根据train.py 185行,是torch.optim.Adam类的对象
  • 216-227行从训练集中采样了一些图片进行可视化(可做debug用),同时也将还没训练的G的输出也做了可视化(可用于检查resume是否加载或者pretrain模型的初始性能)
  • 259行开始训练
  • 260行获取真实图像和对应的label,261-262行归一化图像并划分图像和label到各个GPU
  • 263行-264行生成随机向量作为Generator的输入,并划分到各个GPU
  • 265-267行并从训练集中随机采样label作为条件label,并划分到各个GPU
  • 270行,依次迭代前面提到的4个阶段
  • 278行,把这个阶段需要训练的module设为计算梯度(如训练G的时候,设G的requires_grad为True,而D的为Flase)
  • 281-284行,根据当前所处阶段,为每个GPU分别计算损失。每个阶段的损失介绍损失的时候会展开。
  • 287-294行,根据当前所处阶段,更新待训练的参数(如Gmain和Greg阶段就只更新G的参数),并且把之前设为True的requires_grad改回去,然后进入下一阶段,直到4个阶段全部完成。
  • 296-305行,为G计算指数移动平均,从而更新G_ema的参数
  • 311-315行,根据训练过程的损失,调整数据增强策略的参数,具体在下面介绍数据增强的时候会展开。
  • 318-320行,这里是设置了continue条件,使得每迭代4000(kimg_per_tick*1000)张图片才会运行322行以后的内容一次。实现方式是cur_nimg会一直增加,而tick_start_nimg只有在下面的代码会被设置为cur_nimg,这样一旦运行了一次下面的代码,下次判断小于号就会成立,直到cur_nimg增加了4000使得小于号不成立,然后又会运行一次下面的代码。而done条件是因为,break出循环之前需要运行一次下面的代码,所以设置了当迭代图像数满足图像总数的1000倍的时候,就要退出了,这时候不管是不是每4000次的间隔到了,我都要往下走。
  • 341行设置了另一种退出的方法(代码里似乎没有设abort_fn所以应该这一段代码是没有用到),可以为training_loop传一个有效的abort_fn,使得如果准确率等满足条件返回True,从而不需要跑满1000epoch可以退出。
  • 348-350行为当前iteration生成的图片保存到本地,因为这段代码在322行之后,所以每迭代4000张图片才会生成一次。
  • 353-367行保存了模型参数,同理也是4000张图片才保存一次。
  • 370-379行计算了指标,后续会展开介绍
  • 381-389行和322-338行都说计算运行时间和存储消耗的,就跳过了
  • 391-406行都是输出日志的,跳过
  • 414行是while True循环的唯一退出点。如果运行完成了,就从这里退出循环,结束训练。

dnnlib/util.py/construct_class_by_name()

  • 这个函数只有两行,调用了call_func_by_name函数并以其返回值作为自身的返回值。call_func_by_name函数定义在279行,调用了get_obj_by_name函数,并进一步调用得到的func_obj,以func_obj的返回值作为call_func_by_name的返回值。所以这里其实就是调用了get_obj_by_name函数得到了类,func_obj保存的就是得到的类,然后实例化并返回,所以返回的是类的实例化对象。
  • get_obj_by_name函数在273行,调用了get_module_from_obj_name和get_obj_from_module。有点绕,其实是因为,name是xx.yy.zz的格式,zz才是类名,xx.yy是模块名,所以先调用222行的get_module_from_obj_name从xx.yy.zz中提取出xx.yy和zz,然后再借助get_obj_from_module函数从xx.yy模块中调用zz类。
  • get_module_from_obj_name函数的核心就在231-239行,231-232行其实就是给出根据“.”的位置对字符串划分成两部分的全部可能,所以如果是xx.yy.zz就会被拆成xx和yy.zz或者xx.yy和zz。然后在235到239行,对每种可能性都进行尝试,尝试从xx.yy中import zz,尝试从xx中import yy.zz,因为用的是try,试不出来可以继续,直到试出来,就知道正确的划分方法是什么。
  • 而get_obj_from_module函数是通过269行的getattr函数来获取模块中的类的。

training/dataset.py

  • ImageFolderDataset类定义在training/dataset.py的154行,是同文件24行Dataset类的子类,一般看__getitem__函数即可。返回值有image和label。image根据210-220行的重写,是一个CHW的unit8(0-255)的np array。label是onehot的float32的np array

training/networks.py

  • Generator类定义在training/networks.py的477行。476行是一个装饰器,意思是调用training.networks.Generator的时候,实质上返回的是persistence.persistent_class(Generator),这个装饰器只是为这个类添加了一些辅助功能,不影响接下来的理解,所以先跳过,后续会解释这个装饰器,先接着看模型
  • 模型由两个子模块组成:MappingNetworkSynthesisNetwork

training/networks.py/MappingNetwork()

  • MappingNetwork定义在174行,从初始化函数看起,200行前面定义了一些变量的维度,201行定义了中间全连接层的维度。
  • 204行定义了第一个全连接层,当使用condition label的时候,对这个one hot的condition label进行embed,embed后的特征将和z连在一起作为后续网络的输入。
  • 205-209行定义了网络的主体全连接层
  • 211-212行定义了一个名为w_avg的变量,它并不会随梯度反向传播更新梯度,但会在一些特殊的时刻进行值的更新和被使用。
  • 这里的FullyConnectedLayer(89行)相比普通的全连接层的区别在于,当lr_multiplier不为1时(208行定义的就不为1,是0.01),这些层的参数的学习率和其它参数的学习率相比会乘以一个lr_multiplier(具体实现其实就是把参数直接乘以一个lr_multiplier再去用,实际效果就等同于学习率乘了一个倍数,因为计算这些参数的梯度的时候也是会因此乘以一个lr_multiplier导致step的时候步长会乘以一个lr_multiplier的)
  • 接下来看forward函数。219和222行都仅仅是检查向量的shape。normalize_2nd_moment函数看21行,其实就是先统计这些特征值的标准差(每个样本单独统计),接着除以标准差进行归一化。其实这么说不太准确,因为没有减去均值,仅仅是先平方,然后平均,然后开根,然后除(rsqrt是1/sqrt)。而20行的装饰器仅仅是使得torch.autograd.profiler.record_function能跟踪到这个函数而已。至于torch.autograd.profiler后续会介绍是个什么东西。
  • 然后在223-224行,c向量送进一个全连接层编码,归一化,然后和归一化后的z向量被concatenate到一起,作为后面全连接层的输入
  • 226-229行就是主体的mapping network,对合并的编码和z向量前向传播经过几层全连接层
  • 231-234行保存全连接网络的输出的移动平均(lerp是根据w_avg_beta对w_avg和x进行插值的函数)到w_avg变量中
  • 236-239行重复了num_ws份x,放在dimension 1上,也就是说现在shape是(B,num_ws,w_dim),具体num_ws是什么下面介绍SynthesisNetwork时会展开说明
  • 242-248行,查完整份代码没有看到哪里有把truncation_psi设为非1的值,所以理论上正常情况这部分代码是不会运行到的。看意思应该是利用w_avg对x进行进一步移动平均,这里的移动平均就是对x做了,影响的是x的值,前面的移动平均只是存下来而已,对实际训练过程不会有什么影响。之所以说是截断,是因为当x在训练过程中突然出现异常大或者异常小的值时,这段代码可以通过移动平均限制这些值不要偏离正常范围太远。

training/networks.py/SynthesisNetwork()

  • SynthesisNetwork定义在424行。首先看init函数,440行根据要生成的图片的分辨率,定义了各个block的resolution,依次是2的2,3,4,。。n次方,使得2的n次方最接近要生成的图片的分辨率。441行则定义了各个block的通道数为32768除以block的resolution,但最小是512。
  • 442定义了一个称为fp16_resolution的变量。FP16是一个降低运算量和内存占用的技巧,将32位浮点运算用半精度运算来近似。模型对分辨率最高的num_fp16_res个block进行FP16计算,所以这里是在算开始进行FP16计算的block的resolution。在448行当block的resolution大于等于这里算出来的fp16_resolution时,意味着这个block要进行FP16计算而非全精度的计算。
  • 445-455行定义了SynthesisNetwork的主体由堆叠的几个SynthesisBlock组成。这里还统计了num_ws,后续会解释这个是什么。
  • SynthesisBlock后面解释,先接着看forward函数,forward函数的输入是MappingNetwork的输出,即是全连接并repeat了num_ws遍后的编码特征,shape为(B,num_ws,w_dim)
  • 463-466行将输入的ws特征在dimension 1上拆成多份,每一份分给一个block。也就是说现在每个block的输入是(B,num_conv,w_dim),其实就是每个block分到了num_conv个重复的特征向量。最后一个block会得到 num_conv+num_torgb份(因为只有最后一个block的num_torgb不为0)
  • 468-471行则开始前向传播,每个block的输入是分得的ws和上一个block的输出(x和img),第一个block输入的(x和img)为None。最后一个block输出的img为SynthesisNetwork最终的输出

persistence.persistent_class

  • 这个函数在torch_utils/persistence.py文件的35行被定义。可以看到99行和130行,这个函数返回的是输入类的一个子类,这个子类为这个输入的类添加了一些功能,包括:保存类的初始化参数,为类添加打包函数(__reduce__方法的功能即为当代码被pickle打包时能输出正确的字符串等,展开说有点离题,具体可以自己去查,这里是暂时不需要了解的细节)

TODO:

training/networks.py/SynthesisBlock()

  • SynthesisBlock定义在329行

training.networks.Discriminator

training.augment.AugmentPipe

training.loss.StyleGAN2Loss

metric_main

misc.assert_shape

misc.profiled_function

如何对这份代码进行修改以实现自己的idea

你可能感兴趣的:(项目经验,pytorch,深度学习,python)