pytorch lightning

背景

众所周知,pytorch是近年热门的深度学习框架之一,与tensorflow相比,普遍认识是pytorch更适合学界,方便学者快速实践深度模型,各类研究论文中,pytorch的算法实现更多。但是,pytorch创建网络确实很快,但在训练时依然有很多繁琐的细节。下面这些情况是否是你在调试pytorch项目的常态,如果是,不如尝试下pytorch lightning,它对于了解pytorch的人来说可以很快实践,并从pytorch项目快速转换成lightning风格。目前大火的图片生成stable-diffusion项目正是lightning实现的。

pytorch 调试的麻烦之处

1.神经网络net很快搭建好了,为了训练网络必须写好dataloader,train loop,loss function,最后是test loop。
2. 以上功能写完后,开始测试,首先调通train loop部分,模型可以训练了,结果在validation部分代码有问题(shape 不匹配,device 不匹配,type不匹配等),结果这次的epoch就相当于白费计算和时间。等找到val部分的问题,又开始训练。

对于train, val 阶段代码的调试,通常将dataset类中添加使用长度data_len设置,只加载部分数据集以快速调试完整代码
3. 第二步调试通过后,一切正常,结果模型精度不高,然后想新的方法设计的新的网络再训练,这下多个模型的训练结果如果不好好整理很容易混淆。另外,由于net经过修改后,再使用旧的训练结果很容易出现 “missing keys” 键不匹配问题
这种情况通常利用config文件配置模型和训练参数。
4. 训练过程因异常中断,想恢复训练,但没有保存训练参数,模型加载继续训练时训练参数仍是epoch=0的情况。比如设置学习率lr从1e-6逐步上升到1e-2再线性下降到0。由于没有保存学习率设置,恢复训练时学习率又从1e-6开始,并且epoch从0开始。
这种情况通常用另外一个文件保存学习率lr, epoch,best score等重要参数,模型加载时一并读取。
5. 你有一块gpu,为了使用其加速训练,必须将tensor,model放置在同一个设置下,你必须仔细检测每个变量与模型。
6. 你有多块gpu,多gpu训练必须使用DataParallel或DistributedDataParallel,使用这些api在模型实例化时必须增加一步调用.module,此时为了兼容单gpu训练又得加上判断
7. gpu显存太小,模型难以训练,只能使用半精度,这时得在各个forward和tensor计算的地方加上混合精度设置。
8. 想要实时查看训练集精度,在val阶段计算精度,然后用tensorboard,wandb工具记录,这时你得学会这些工具怎么调用。
9. 想要复现别人的代码,不同项目风格差异很大,阅读、调试及复用相当麻烦。极端的情况是只有main.py、net.py和dataset.py文件,这些逻辑过于集中,好的情况是按data,model,utils,configs,main.py分块,每个目录的__init__.py提供每个包的接口,在main.py中使用configs不同配置实例化各个组件。
10. 想要计算模型自个部分的参数量与计算量,得使用第三方包计算。
以上是个人使用pytorch过程中觉得麻烦的地方,如果你也遇到了,不妨试试pytorch lightning。
11. 数据集拆分,

pytorch lightning介绍

lightning 是pytorch的轻量级高层API,类似keras之于tensorflow。它利用hook将主要逻辑拆分成不同step,如training_step,validation_step, test_step等,只需为你的模型重写这些需要的方法实现相应的逻辑,给入数据集加载器和创建的模型以实例化Trainer,然后就可以调用fit()训练。下面我们看看在lightning中以上问题怎么处理:

  1. 模型训练与测试完整逻辑。复用原生pytorch实现的dataloader和network,将network的父类换成pl.LightningModule, 在training_step中forward,然后计算损失并返回。val,test的step中只预测。
  2. 代码完整调试。lightning在开始训练前为先执行2个step的验证过程,即先执行两遍validation_step,防止先训练出错。
  3. 模型修改。lightning有cli工具,可以以配置文件的形式实例化模型,训练时自动保存config和超参数。见cli和save_hyperparameters。
  4. 训练恢复。lightning自动保存模型权重和参数,恢复时只需在创建trainer时指定ckpt_path。
  5. 在创建trainer时指定设置和策略,见GPU,避免在多处文件中设置。
  6. 同上。
  7. 半精度。创建trainer时指定precision参数。
  8. lightning 支持tensorboard,wandb,comet,mlflow等多种logger,在创建trainer时设置logger参数,然后在不同的step中调用self.log(‘val/acc’, acc, logger=True),即可将计算的acc记录到logger中。更多设置见logger
  9. 复现别人的代码。参考ORGANIZE PYTORCH INTO LIGHTNING,即1中的过程。
  10. 使用自带工具profiler自动分析
  11. 固定seed,可以在Trainer中只给一个train_loader,然后设置拆分比例,种子固定,拆分的结果也就固定。
    其它特性:
  12. 分步机制。lighning为LightningModule类创建了不同的hook,对应step、epoch、fit等阶段的处理逻辑,比如训练的单步step为training_step(计算loss),batch使用前/后的处理,可以增加预/后处理操作on_train_batch_start,on_train_batch_end,每个epoch的操作为on_train_epoch_start和on_train_epoch_end,可以在start中初始化acc计算对象,在end中统计acc用于记录或保存部分图片可视化。val和test类似。
  13. 回调机制。可以创建自己的callback方法,在不同hook中增加自己的处理逻辑,只需在Trainer创建时传入。lightning自带的callback有ModelCheckpoint用于模型保存的策略。LearningRateFinder搜索合适的学习率等。
    参考:
    lightning 官方文档
    lightning 开源项目

你可能感兴趣的:(pytorch,深度学习)