在训练深度神经网络时,如果训练时间较长,我们通常希望在训练过程中定期保存模型的参数,以便稍后从该点恢复训练或进行推理。PyTorch Lightning 提供了 ModelCheckpoint
回调函数来帮助我们自动保存模型参数。
在本文中,我们将探讨如何使用 PyTorch Lightning 训练模型并使用 ModelCheckpoint
自动从训练过程中保存模型的参数。
pytorch lightning 提供了保存 checkpoint API https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html#lightning.pytorch.callbacks.ModelCheckpoint
利用 **every_n_train_steps 、train_time_interval 、every_n_epochs **设置保存 checkpoint 的按照步数、时间、epoch数来保存 checkpoints,注意三者互斥,如果要同时实现对应的功能需要创建多个 MODELCHECKPOINT
利用 save_top_k
设置保存所有的模型,因为不是通过 monitor 的形式保存的,所以 save_top_k
只能设置为 -1,0,1,分别表示保存所有的模型,不保存模型和保存最后一个模型
首先,我们需要准备数据集和模型。这里我们使用 PyTorch 的 FashionMNIST
数据集和一个简单的卷积神经网络作为模型。
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets import FashionMNIST
from torchvision import transforms
from torch.utils.data import DataLoader
import pytorch_lightning as pl
class Net(pl.LightningModule):
def __init__(self, num_classes=10):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.fc1 = nn.Linear(64 * 7 * 7, 128)
self.fc2 = nn.Linear(128, num_classes)
def forward(self, x):
x = self.pool(nn.functional.relu(self.conv1(x)))
x = self.pool(nn.functional.relu(self.conv2(x)))
x = x.view(-1, 64 * 7 * 7)
x = nn.functional.relu(self.fc1(x))
x = self.fc2(x)
return x
def cross_entropy_loss(self, logits, labels):
return nn.functional.cross_entropy(logits, labels)
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = self.cross_entropy_loss(logits, y)
self.log('train_loss', loss, on_step=True, on_epoch=True, logger=True)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = self.cross_entropy_loss(logits, y)
self.log('val_loss', loss, logger=True)
return loss
def configure_optimizers(self):
optimizer = optim.Adam(self.parameters(), lr=0.01)
return optimizer
# 准备数据集
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)),
])
train_set = FashionMNIST(".", train=True, transform=transform, download=True)
train_loader = DataLoader(train_set, batch_size=64, shuffle=True, num_workers=2)
val_set = FashionMNIST(".", train=False, transform=transform, download=True)
val_loader = DataLoader(val_set, batch_size=64, shuffle=False, num_workers=2)
接下来,我们使用 Trainer
类来训练模型。
# 训练模型
trainer = pl.Trainer(
gpus=1,
max_epochs=5,
callbacks=[pl.callbacks.ModelCheckpoint(every_n_train_steps=60, save_top_k=-1)]
)
model = Net(num_classes=10)
trainer.fit(model, train_loader, val_loader)
上面的代码将训练模型 5 个 epoch,并在每训练 60 步(batch)时保存一个 checkpoint。ModelCheckpoint
回调函数的 save_top_k
参数为 -1,表示保存所有 checkpoint。
当我们需要从保存的 checkpoint 恢复模型时,可以使用 Trainer
类的 resume_from_checkpoint
参数:
trainer = pl.Trainer(
gpus=1,
max_epochs=5,
callbacks=[pl.callbacks.ModelCheckpoint(every_n_train_steps=60, save_top_k=-1)],
resume_from_checkpoint='path/to/checkpoint.ckpt'
)
model = Net(num_classes=10)
trainer.fit(model, train_loader, val_loader)