介绍:上一期介绍了如何利用PyTorch Lightning搭建并训练一个模型(仅使用训练集),为了保证模型可以泛化到未见过的数据上,数据集通常被分为训练和测试两个集合,测试集与训练集相互独立,用以测试模型的泛化能力。本期通过增加验证和测试集来达到该目的,同时,还引入checkpoint和早停策略,以得到模型最佳权重。
相关链接:https://lightning.ai/docs/pytorch/stable/levels/basic_level_2.html
添加相应的依赖,同时使用MNIST数据集,获取训练和测试集
import torch.utils.data as data
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
# 加载数据(测试集,train=False)
transform = transforms.ToTensor()
train_set = datasets.MNIST(root="MNIST", download=True, train=True, transform=transform)
test_set = datasets.MNIST(root="MNIST", download=True, train=False, transform=transform)
在定义LightningModule中,实现test_step方法;在外部,调用test方法
class LitAutoEncoder(pl.LightningModule):
def training_step(self, batch, batch_idx):
...
def test_step(self, batch, batch_idx): # 测试,该方法与training_step相似
x, y = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
test_loss = F.mse_loss(x_hat, x)
self.log("test_loss", test_loss)
# 初始化Trainer
trainer = Trainer()
# 执行test方法
trainer.test(model, dataloaders=DataLoader(test_set))
通常使用torch.utils.data中的方法,将训练集中的一部分数据化为验证集
# 训练集中的20%数据划为验证集
train_set_size = int(len(train_set) * 0.8)
valid_set_size = len(train_set) - train_set_size
# 拆分,使用data.random_split方法
seed = torch.Generator().manual_seed(42)
train_set, valid_set = data.random_split(train_set, [train_set_size, valid_set_size], generator=seed)
与测试集一样,需要在定义LightningModule中,实现validation_step方法;在外部,调用fit方法
class LitAutoEncoder(pl.LightningModule):
def training_step(self, batch, batch_idx):
...
def validation_step(self, batch, batch_idx):
x, y = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
val_loss = F.mse_loss(x_hat, x)
self.log("val_loss", val_loss)
def test_step(self, batch, batch_idx):
...
# 调用torch.utils.data中的DataLoader对训练和测试集进行封装
train_loader = DataLoader(train_set)
valid_loader = DataLoader(valid_set)
# 在fit方法中,引入valid_loader,即验证集
trainer = Trainer()
trainer.fit(model, train_loader, valid_loader)
checkpoint有两个作用,一是能得到每一次epoch后的模型权重,能得到最佳表现的权重;二是能够在中断或停止后,继续在当前checkpoint处,继续训练。在Lightning中的checkpoint,包含模型的整个内部状态,这与普通的PyTorch不同,即使在最复杂的分布式训练环境中,Lightning也可以保存恢复模型所需的一切。包含以下状态:
# 保存方法,可自定义default_root_dir路径,若不设置路径,将会自动保存
trainer = Trainer(default_root_dir="some/path/")
# 调用方法
model = MyLightningModule.load_from_checkpoint("/path/to/checkpoint.ckpt")
model.eval() # disable randomness, dropout, etc...
y_hat = model(x)
调用,还可以使用torch的方法
checkpoint = torch.load(checkpoint, map_location=lambda storage, loc: storage)
print(checkpoint["hyper_parameters"])
# {"learning_rate": the_value, "another_parameter": the_other_value}
也可以实现重现,例如模型LitModel(in_dim=32, out_dim=10)
# 使用 in_dim=32, out_dim=10
model = LitModel.load_from_checkpoint(PATH)
# 使用 in_dim=128, out_dim=10
model = LitModel.load_from_checkpoint(PATH, in_dim=128, out_dim=10)
Lightning和PyTorch完全兼容
checkpoint = torch.load(CKPT_PATH)
encoder_weights = checkpoint["encoder"]
decoder_weights = checkpoint["decoder"]
设置checkpoint不可见
trainer = Trainer(enable_checkpointing=False)
如果想全部重新恢复
model = LitModel()
trainer = Trainer()
自动恢复所有相关参数 model, epoch, step, LR schedulers, etc…
trainer.fit(model, ckpt_path="some/path/to/my_checkpoint.ckpt")
在Lightning中,早停回调步骤如下:
EarlyStopping
callback. 载入EarlyStopping回调方法log()
method. 加载日志方法monitor
to the logged metric of your choice. 设置monitormode
based on the metric needs to be monitored. 设置modeEarlyStopping
callback to the Trainer
callbacks flag. 调入EarlyStroppingfrom lightning.pytorch.callbacks.early_stopping import EarlyStopping
class LitModel(LightningModule):
def validation_step(self, batch, batch_idx):
loss = ...
self.log("val_loss", loss)
model = LitModel()
trainer = Trainer(callbacks=[EarlyStopping(monitor="val_loss", mode="min")])
trainer.fit(model)
# 也可以使用自定义的EarlyStopping策略
early_stop_callback = EarlyStopping(monitor="val_accuracy", min_delta=0.00, patience=3, verbose=False, mode="max")
trainer = Trainer(callbacks=[early_stop_callback])
# EarlyStopping的文档链接https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html#lightning.pytorch.callbacks.EarlyStopping
check_val_every_n_epoch
and val_check_interval
。# coding:utf-8
import torch
import torch.nn as nn
import torch.utils.data as data
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
import lightning as L
# --------------------------------
# Step 1: 定义模型
# --------------------------------
class LitAutoEncoder(L.LightningModule):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3))
self.decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28))
def training_step(self, batch, batch_idx):
x, y = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = F.mse_loss(x_hat, x)
self.log("train_loss", loss)
return loss
def test_step(self, batch, batch_idx): # 测试,该方法与training_step相似
x, y = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
test_loss = F.mse_loss(x_hat, x)
self.log("test_loss", test_loss)
def validation_step(self, batch, batch_idx):
x, y = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
val_loss = F.mse_loss(x_hat, x)
self.log("val_loss", val_loss)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
def forward(self, x):
# forward 定义了一次 预测/推理 行为
embedding = self.encoder(x)
return embedding
# --------------------------------
# Step 2: 加载数据+模型
# --------------------------------
transform = transforms.ToTensor()
train_set = datasets.MNIST(root="MNIST", download=True, train=True, transform=transform)
test_set = datasets.MNIST(root="MNIST", download=True, train=False, transform=transform)
# 训练集中的20%数据划为验证集
train_set_size = int(len(train_set) * 0.8)
valid_set_size = len(train_set) - train_set_size
# 拆分,使用data.random_split方法
seed = torch.Generator().manual_seed(42)
train_set, valid_set = data.random_split(train_set, [train_set_size, valid_set_size], generator=seed)
train_loader = DataLoader(train_set)
valid_loader = DataLoader(valid_set)
autoencoder = LitAutoEncoder()
# --------------------------------
# Step 3: 训练+验证+测试
# --------------------------------
# 训练+验证
trainer = L.Trainer(default_root_dir="some/path/") # 这里自定义需要保存的路径
trainer.fit(autoencoder, train_loader, valid_loader)
# 测试
trainer.test(autoencoder, dataloaders=DataLoader(test_set))