pytorch-lighting使用

参考文档:LightningModule — PyTorch Lightning 1.7.5 documentation (pytorch-lightning.readthedocs.io)

在这里插入图片描述

 在这里插入图片描述

  pytorch_lightning中的训练器Trainer:

import pytorch_lightning as pl
class model(pl.LightningModule):
    pass
model1=model(参数2)
trainer = pl.Trainer(参数1)
trainer.fit(model1)

1.参数1详解:

参数名称 意义 默认值
max-epochs
最多训练轮数
callbacks

添加回调函数或回调函数列表

gpus
使用的gpu数量
accumulate_grad_batches
每k次batches累计一次梯度
logger
设置日志记录器(支持多个)
resume_from_checkpoint
gradient_clip_val
check_val_every_n_epoch
每n个train epoch执行一次验证
num_sanity_val_steps

开始训练前加载n个验证数据进行测试,

k=-1时加载所有验证数据

log_every_n_steps
更新n次网络权重后记录一次日志
flush_logs_every_n_steps
limit_train_batches
使用训练数据的百分比.支持0到1的浮点数和整数,比如0.1代表每个epoch只跑十分之一的数据 支持0到1的浮点数和整数
limit_val_batches
使用验证数据的百分比,10代表每个epoch只跑10个batches 支持0到1的浮点数和整数
limit_test_batches
使用测试数据的百分比. 支持0到1的浮点数和整数

2.lightningmodule方法详解:

也就是定义model类是会定义哪些函数

pytorch-lighting使用_第1张图片

3.trainer.fit参数详解

Trainer.fit(model, train_dataloaders=None, 
    val_dataloaders=None, datamodule=None, 
    ckpt_path=None)

其中model为实例化的pl.LightningModule

你可能感兴趣的:(pytorch,深度学习,python)