Pytorch-Lightning中的训练器--Trainer

Pytorch-Lightning中的训练器—Trainer

Trainer()

常用参数

参数名称 含义 默认值 接受类型
callbacks 添加回调函数或回调函数列表 None(ModelCheckpoint默认值) Union[List[Callback], Callback, None]
enable_checkpointing 是否使用callbacks True bool
enable_progress_bar 是否显示进度条 True bool
enable_model_summary 是否打印模型摘要 True bool
gpus 使用的gpu数量(int)或gpu节点列表(liststr) None(不使用GPU) Union[int, str, List[int], None]
precision 指定训练精度 32(full precision) Union[int, str]
default_root_dir 模型保存和日志记录默认根路径 None(os.getcwd()) Optional[str]
logger 设置日志记录器(支持多个),若没设置logger的save_dir,则使用default_root_dir True(默认日志记录) Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool]
max_epochs 最多训练轮数(指定为**-1可以设置为无限次**) None(1000) Optional[int]
min_epochs 最少训练轮数。当有Early Stop时使用 None(1) Optional[int]
max_steps 最大网络权重更新次数 -1(禁用) Optional[int]
min_steps 最少网络权重更新次数 None(禁用) Optional[int]
weights_save_path 权重保存路径(优先级高于default_root_dir),ModelCheckpoint未定义路径时将使用该路径 None(default_root_dir) Optional[str]
log_every_n_steps 更新n次网络权重后记录一次日志 50 int
auto_scale_batch_size 在进行任何训练前自动搜索最佳batch_size并保存到模型的self.bacth_size中,str参数表示搜索策略 False Union[str, bool]
auto_lr_find 自动搜索最佳学习率并存储到self.lrself.learing_rate,str参数表示学习率参数的属性名 False Union[str, bool]
auto_select_gpus 自动寻找合适的GPU,对于GPU独占模式非常有用 False bool
limit_train_batches
limit_test_batches
limit_val_batches

limit_predict_batches
使用训练/测试/验证/预测数据的百分比.如果数据过多,或正在调试可以使用。 1.0 Union[int, float] (float = 比例, int = num_batches).
fast_dev_run 如果设定为true,会只执行一个batch的train, val 和 test,然后结束。仅用于debug False bool
accumulate_grad_batches 每k次batches累计一次梯度 None(无梯度累计) Union[int, Dict[int, int], None]
check_val_every_n_epoch 每n个train epoch执行一次验证 1 int
num_sanity_val_steps 开始训练前加载n个验证数据进行测试,k=-1时加载所有验证数据 2 int

额外的解释

  • 这里max_steps/min_steps中的step就是指的是优化器的step(),优化器每step()一次就会更新一次网络权重
  • 梯度累加(Gradient Accumulation):受限于显存大小,一些训练任务只能使用较小的batch_size,但一般batch-size越大(一定范围内)模型收敛越稳定效果相对越好;梯度累加可以先累加多个batch的梯度再进行一次参数更新,相当于增大了batch_size

Trainer.fit()

参数详解

参数名称 含义 默认值
model LightningModule实例
train_dataloaders 训练数据加载器 None
val_dataloaders 验证数据加载器 None
ckpt_path ckpt文件路径(从这里文件恢复训练) None
datamodule LightningDataModule实例 None

ckpt_path参数详解(从之前的模型恢复训练)

​ 使用该参数指定一个模型ckpt文件(需要保存整个模型,而不是仅仅保存模型权重),Trainer将从ckpt文件的下一个epoch继续训练。

示范

net = MyNet(...)
trainer = pl.Trainer(...)
# 假设模型保存在./ckpt中
trainer.fit(net, train_iter, val_iter, ckpt_path='./ckpt/myresult.ckpt')

使用注意

  • 请不要使用Trainer()中的resume_from_checkpoint参数,该参数未来将被丢弃,请使用Trainer.fit()的ckpt_path参数

Trainer.test()Trainer.validate()

参数详解

参数名称 含义 默认值
model LightningModule实例 None
verbose 是否打印测试结果 True
dataloaders 测试数据加载器(可以使torch.utils.data.DataLoader) None
ckpt_path ckpt文件路径(从这里文件恢复训练) None
datamodule LightningDataModule实例 None
  • Returns:测试/验证期间相关度量值的字典列表(列表长度等于测试/验证数据加载器个数),比如validation/test_step(), validation/test_epoch_end(),中的回调钩子
  • ckpt_path:如果设置了该参数则会使用该ckpt文件中的权重,否则如果模型已经训练完毕则使用当前权重,其他情况如果配置了checkpoint callbacks则加载该checkpoint callbacks对应的最佳模型

Trainer.predict()

参数详解

参数名称 含义 默认值
model LightningModule实例 None
dataloaders 数据加载器 None
ckpt_path ckpt文件路径(从这里文件恢复训练) None
datamodule LightningDataModule实例 None
return_predictions 是否返回结果,目前不支持设置 None(True)
  • ckpt_path:使用该ckpt文件中的权重,如果为None如果模型已经训练完毕则使用当前权重,其他情况如果配置了checkpoint callbacks则加载该checkpoint callbacks对应的最佳模型

使用注意

  • preict()中会禁用日志功能

Trainer.tune()

功能解释

  • 对模型超参数进行调整

常用参数

参数名称 含义 默认值
model LightningModule实例
train_dataloaders 训练数据加载器 None
val_dataloaders 验证数据加载器 None
datamodule LightningDataModule实例 None
scale_batch_size_kwargs 传递给scale_batch_size()的参数 None
lr_find_kwargs 传递给lr_find()的参数 None

使用注意

  • auto_lr_find标志当且仅当执行trainer.tune(model)代码时工作

其他注意点

  • .test()若非直接调用,不会运行。
  • .test()会自动load最优模型。
  • model.eval() and torch.no_grad() 在进行测试时会被自动调用。
  • 默认情况下,Trainer()运行于CPU上。

Trainer属性

  • callback_metrics
  • current_epoch
  • logger
  • logged_metrics
  • log_dir
  • is_global_zero
  • progress_bar_metrics

你可能感兴趣的:(Python,pytorch,深度学习,人工智能)