Trainer--学习笔记

  1. 在训练神经网络的时候通常都会写一个训练代码块,通过这个代码块的执行开始训练网络
  2. 学习神经网络模型时写的训练代码块, 整个流程就是在写一个脚本文件:
    1. 定义数据加载器
    2. 定义优化器
    3. 定义损失函数
    4. 模型实例化
    5. 循环读取数据,开始迭代训练
  3. 形成一个抽象的类Trainer,这样能够在一定程度上提高代码的复用性、可读性以及可扩展性
  4. 动态的将需要的功能模块“注册”到Trainer类中,而不需要去修改Trainer最原始的定义,实现可扩展的功能,对Trainer类进行升级,使他能够具备插件化处理的功能。
  5. 定义了一个插件队列的字典, 它保存不同时机调用的插件序列, 那么问题来了,一般会在什么时候调用这些插件呢?
    1. 在每次获取到数据之后,训练之前 对数据进行不同处理?
    2. 在完成一次backward操作之后 显示当前loss 或者accuracy
    3. 在完成每次batch or epoch 后保存模型或者修改学习率?
  6. 定义了四种类别的插件:
    1. iteration:一般是在完成一个batch 训练之后进行的事件调用序列(一般不改动网络或者优化器,如:计算准确率)调用序列;
    2. batch 在进行batch 训练之前需要进行的事件调用序列
    3. epoch 完成一个epoch 训练之后进行的事件调用序列
    4. update 完成一个batch训练之后进行的事件(涉及到对网络或者优化器的改动,如:学习率的调整)
    5. 注意,iteration 跟update 两种插件调用的时候传入的参数不一样,iteration 会传入batch output,loss 等训练过程中的数据, 而update传入的的model,方便对网络的修改
  7. pytorch lightning:

    1. 自动early stopping,自动batch_size, leaning rate搜索

    2. 当模型比较复杂时,为了提高代码的可读性,建议先使用torch.nn.Module构建网络,再使用pytorch_lightning.LightningModule对其进行包装

    3. 使用Pytorch进行开发时,到这里模型定义就结束了,其余的训练、验证的具体实现会放在train.py等训练代码中;但Pytorch Lightning模型实现的是整个系统,所以训练的细节也会在这个类中实现

    4. 当模型比较复杂时,为了提高代码的可读性,建议先使用torch.nn.Module构建网络,再使用pytorch_lightning.LightningModule对其进行包装

你可能感兴趣的:(pytorch,深度学习,tensorflow,cnn)