pytorch lightning之快速调试

调试代码

此阶段主要测试各阶段代码是否有问题。

快速测试

fast_dev_run项可以配置train/val/test阶段的循环次数,跑完就停止代码,快速查看各流程代码正确性,避免train调试后训练又在val/test阶段出错,白白浪费时间和计算成本。

Trainer(fast_dev_run=7)# 每个阶段只循环7次,也可以设置为True,只循环5次。

使用部分数据测试

功能与上面类似,但是运行指定epochs的周期过程,只是train/val/test流程使用部分数据。

# 使用10%训练集 和 1% 验证集
trainer = Trainer(limit_train_batches=0.1, limit_val_batches=0.01)

# use 10 batches of train and 5 batches of val
trainer = Trainer(limit_train_batches=10, limit_val_batches=5)

validation_step()提前检查

设置num_sanity_val_steps,lightning会在开始训练前默认先执行再次validation_step,避免训练后验证阶段出错。

trainer = Trainer(num_sanity_val_steps=2)

打印网络模型信息

调用train.fit()后,lightning会自动打印模型信息,如下:

  | Name  | Type        | Params
----------------------------------
0 | net   | Sequential  | 132 K
1 | net.0 | Linear      | 131 K
2 | net.1 | BatchNorm1d | 1.0 K

也可以利用内置的callback ModelSummary打印子模块的信息。需要配置好callback后传入Trainer。

trainer = Trainer(callbacks=[ModelSummary(max_depth=-1)])# 打印深度为所有。

打印各模块输入输出的尺寸

在LightningModule中设置example_input_array属性

class LitModel(LightningModule):
    def __init__(self, *args, **kwargs):
        self.example_input_array = torch.Tensor(32, 1, 28, 28)
  | Name  | Type        | Params | In sizes  | Out sizes
--------------------------------------------------------------
0 | net   | Sequential  | 132 K  | [10, 256] | [10, 512]
1 | net.0 | Linear      | 131 K  | [10, 256] | [10, 512]
2 | net.1 | BatchNorm1d | 1.0 K  | [10, 512] | [10, 512]

你可能感兴趣的:(pytorch,深度学习,lightning)