PyTorch Lightning教程二:验证、测试、checkpoint、早停策略

介绍:上一期介绍了如何利用PyTorch Lightning搭建并训练一个模型(仅使用训练集),为了保证模型可以泛化到未见过的数据上,数据集通常被分为训练和测试两个集合,测试集与训练集相互独立,用以测试模型的泛化能力。本期通过增加验证和测试集来达到该目的,同时,还引入checkpoint和早停策略,以得到模型最佳权重。

相关链接:https://lightning.ai/docs/pytorch/stable/levels/basic_level_2.html

训练集、验证集、测试集的使用

1.添加依赖,获取训练集和测试集

添加相应的依赖,同时使用MNIST数据集,获取训练和测试集

import torch.utils.data as data
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# 加载数据(测试集,train=False)
transform = transforms.ToTensor()
train_set = datasets.MNIST(root="MNIST", download=True, train=True, transform=transform)
test_set = datasets.MNIST(root="MNIST", download=True, train=False, transform=transform)

2.实现并调用test_step

在定义LightningModule中,实现test_step方法;在外部,调用test方法

class LitAutoEncoder(pl.LightningModule):
    def training_step(self, batch, batch_idx):
        ...

    def test_step(self, batch, batch_idx): # 测试,该方法与training_step相似
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        test_loss = F.mse_loss(x_hat, x)
        self.log("test_loss", test_loss)

# 初始化Trainer
trainer = Trainer()

# 执行test方法
trainer.test(model, dataloaders=DataLoader(test_set))

3.实现并调用验证集

通常使用torch.utils.data中的方法,将训练集中的一部分数据化为验证集

# 训练集中的20%数据划为验证集
train_set_size = int(len(train_set) * 0.8)
valid_set_size = len(train_set) - train_set_size

# 拆分,使用data.random_split方法
seed = torch.Generator().manual_seed(42)
train_set, valid_set = data.random_split(train_set, [train_set_size, valid_set_size], generator=seed)

与测试集一样,需要在定义LightningModule中,实现validation_step方法;在外部,调用fit方法

class LitAutoEncoder(pl.LightningModule):
    def training_step(self, batch, batch_idx):
        ...

    def validation_step(self, batch, batch_idx):
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        val_loss = F.mse_loss(x_hat, x)
        self.log("val_loss", val_loss)
    
    def test_step(self, batch, batch_idx):
        ...
# 调用torch.utils.data中的DataLoader对训练和测试集进行封装
train_loader = DataLoader(train_set)
valid_loader = DataLoader(valid_set)

# 在fit方法中,引入valid_loader,即验证集
trainer = Trainer()
trainer.fit(model, train_loader, valid_loader)

checkpoint

checkpoint有两个作用,一是能得到每一次epoch后的模型权重,能得到最佳表现的权重;二是能够在中断或停止后,继续在当前checkpoint处,继续训练。在Lightning中的checkpoint,包含模型的整个内部状态,这与普通的PyTorch不同,即使在最复杂的分布式训练环境中,Lightning也可以保存恢复模型所需的一切。包含以下状态:

  • 16-bit scaling factor (若使用16精度训练)
  • Current epoch
  • Global step
  • LightningModule’s state_dict
  • State of all optimizers
  • State of all learning rate schedulers
  • State of all callbacks (for stateful callbacks)
  • State of datamodule (for stateful datamodules)
  • The hyperparameters (init arguments) with which the model was created
  • The hyperparameters (init arguments) with which the datamodule was created
  • State of Loops

保存与调用方法

# 保存方法,可自定义default_root_dir路径,若不设置路径,将会自动保存
trainer = Trainer(default_root_dir="some/path/")

# 调用方法
model = MyLightningModule.load_from_checkpoint("/path/to/checkpoint.ckpt")
model.eval()	# disable randomness, dropout, etc...
y_hat = model(x)

调用,还可以使用torch的方法

checkpoint = torch.load(checkpoint, map_location=lambda storage, loc: storage)
print(checkpoint["hyper_parameters"])
# {"learning_rate": the_value, "another_parameter": the_other_value}

也可以实现重现,例如模型LitModel(in_dim=32, out_dim=10)

# 使用 in_dim=32, out_dim=10
model = LitModel.load_from_checkpoint(PATH)
# 使用 in_dim=128, out_dim=10
model = LitModel.load_from_checkpoint(PATH, in_dim=128, out_dim=10)

Lightning和PyTorch完全兼容

checkpoint = torch.load(CKPT_PATH)
encoder_weights = checkpoint["encoder"]
decoder_weights = checkpoint["decoder"]

设置checkpoint不可见

trainer = Trainer(enable_checkpointing=False)

如果想全部重新恢复

model = LitModel()
trainer = Trainer()

自动恢复所有相关参数 model, epoch, step, LR schedulers, etc…

trainer.fit(model, ckpt_path="some/path/to/my_checkpoint.ckpt")

早停策略

EarlyStopping Callback

在Lightning中,早停回调步骤如下:

  • Import EarlyStopping callback. 载入EarlyStopping回调方法
  • Log the metric you want to monitor using log() method. 加载日志方法
  • Init the callback, and set monitor to the logged metric of your choice. 设置monitor
  • Set the mode based on the metric needs to be monitored. 设置mode
  • Pass the EarlyStopping callback to the Trainer callbacks flag. 调入EarlyStropping
from lightning.pytorch.callbacks.early_stopping import EarlyStopping

class LitModel(LightningModule):
    def validation_step(self, batch, batch_idx):
        loss = ...
        self.log("val_loss", loss)

model = LitModel()
trainer = Trainer(callbacks=[EarlyStopping(monitor="val_loss", mode="min")])
trainer.fit(model)

# 也可以使用自定义的EarlyStopping策略
early_stop_callback = EarlyStopping(monitor="val_accuracy", min_delta=0.00, patience=3, verbose=False, mode="max")
trainer = Trainer(callbacks=[early_stop_callback])
# EarlyStopping的文档链接https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html#lightning.pytorch.callbacks.EarlyStopping

注意

  • EarlyStopping默认在一次Validation后调用,但是Validation可以自定义多少次epoch后进行一次验证,例如check_val_every_n_epoch and val_check_interval

完整代码

# coding:utf-8
import torch
import torch.nn as nn
import torch.utils.data as data
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
import lightning as L

# --------------------------------
# Step 1: 定义模型
# --------------------------------
class LitAutoEncoder(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3))
        self.decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28))

    def training_step(self, batch, batch_idx):
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        self.log("train_loss", loss)
        return loss

    def test_step(self, batch, batch_idx):  # 测试,该方法与training_step相似
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        test_loss = F.mse_loss(x_hat, x)
        self.log("test_loss", test_loss)

    def validation_step(self, batch, batch_idx):
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        val_loss = F.mse_loss(x_hat, x)
        self.log("val_loss", val_loss)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

    def forward(self, x):
        # forward 定义了一次 预测/推理 行为
        embedding = self.encoder(x)
        return embedding
# --------------------------------
# Step 2: 加载数据+模型
# --------------------------------
transform = transforms.ToTensor()
train_set = datasets.MNIST(root="MNIST", download=True, train=True, transform=transform)
test_set = datasets.MNIST(root="MNIST", download=True, train=False, transform=transform)

# 训练集中的20%数据划为验证集
train_set_size = int(len(train_set) * 0.8)
valid_set_size = len(train_set) - train_set_size

# 拆分,使用data.random_split方法
seed = torch.Generator().manual_seed(42)
train_set, valid_set = data.random_split(train_set, [train_set_size, valid_set_size], generator=seed)
train_loader = DataLoader(train_set)
valid_loader = DataLoader(valid_set)

autoencoder = LitAutoEncoder()
# --------------------------------
# Step 3: 训练+验证+测试
# --------------------------------
# 训练+验证
trainer = L.Trainer(default_root_dir="some/path/")	# 这里自定义需要保存的路径
trainer.fit(autoencoder, train_loader, valid_loader)

# 测试
trainer.test(autoencoder, dataloaders=DataLoader(test_set))

你可能感兴趣的:(pytorch,pytorch,人工智能,python)