首先,代码需要导入了实现模型、数据处理和训练所需的各个包:
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import pytorch_lightning as pl
PyTorch Lightning 的核心是继承 pl.LightningModule
,在这里我们通过子类化创建了一个简单的模型类 LitModel
。
__init__
方法def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.net = nn.Sequential(
nn.Linear(28*28, 128),
nn.ReLU(),
nn.Linear(128, 10)
)
self.loss_fn = nn.CrossEntropyLoss()
nn.CrossEntropyLoss
),适用于多分类问题。forward
方法def forward(self, x):
x = self.flatten(x)
return self.net(x)
training_step
方法def training_step(self, batch, batch_idx):
x, y = batch
pred = self(x)
loss = self.loss_fn(pred, y)
self.log("train_loss", loss) # 自动记录训练损失
return loss
batch
中获取输入 x
和标签 y
。forward
方法计算预测结果 pred
。self.log
自动记录训练损失,便于后续监控和日志分析。validation_step
方法def validation_step(self, batch, batch_idx):
x, y = batch
pred = self(x)
loss = self.loss_fn(pred, y)
self.log("val_loss", loss) # 自动记录验证损失
return loss
test_step
方法def validation_step(self, batch, batch_idx):
x, y = batch
pred = self(x)
loss = self.loss_fn(pred, y)
self.log("test_loss", loss) # 自动记录测试损失
return loss
configure_optimizers
方法def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3)
使用 LightningDataModule 能够使数据预处理、划分和加载更加模块化,便于在多个训练阶段(训练、验证、测试)中复用同一数据处理流程。
__init__
方法def __init__(self, batch_size=32):
super().__init__()
self.batch_size = batch_size
prepare_data
方法def prepare_data(self):
# 下载数据集
MNIST(root="data", train=True, download=True)
MNIST(root="data", train=False, download=True)
setup
方法def setup(self, stage=None):
# 数据预处理和划分
transform = ToTensor()
mnist_full = MNIST(root="data", train=True, transform=transform)
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
self.mnist_test = MNIST(root="data", train=False, transform=transform)
ToTensor()
将图像转换为张量。random_split
将训练集划分为 55000 个训练样本和 5000 个验证样本。def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=self.batch_size, shuffle=True)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=self.batch_size)
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=self.batch_size)
shuffle=True
确保数据在每个 epoch 中被打乱。在 if __name__ == "__main__":
块中,完成了模型与数据模块的实例化,并利用 Lightning 提供的 Trainer 完成训练和测试。
dm = MNISTDataModule(batch_size=32)
model = LitModel()
trainer = pl.Trainer(
max_epochs=3, # 训练 3 个 epoch
accelerator="gpu", # 指定使用 GPU
devices=[0],
)
devices=[0]
表示使用第 0 个 GPU(如果有多个 GPU,可根据需求调整)。trainer.fit(model, datamodule=dm)
trainer.test(model, datamodule=dm)
training_step
和 validation_step
,并利用 DataModule 提供的各个 DataLoader。test_step
(如果有定义)或者复用 validation_step
来计算测试指标。通过这份代码,你可以学到以下关键点:
这种清晰分离模型与数据逻辑的设计,不仅使代码结构更清晰,也方便在不同场景下复用和扩展。希望这个教程能帮助你更好地理解 PyTorch Lightning 的使用方法,并在项目中灵活应用这种高效的训练流程。
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import pytorch_lightning as pl
# 1. 定义 LightningModule 子类
class LitModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.net = nn.Sequential(
nn.Linear(28*28, 128),
nn.ReLU(),
nn.Linear(128, 10)
)
self.loss_fn = nn.CrossEntropyLoss()
def forward(self, x):
x = self.flatten(x)
return self.net(x)
def training_step(self, batch, batch_idx):
x, y = batch
pred = self(x)
loss = self.loss_fn(pred, y)
self.log("train_loss", loss) # 自动记录训练损失
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
pred = self(x)
loss = self.loss_fn(pred, y)
self.log("val_loss", loss) # 自动记录验证损失
return loss
def test_step(self, batch, batch_idx):
x, y = batch
pred = self(x)
loss = self.loss_fn(pred, y)
self.log("test_loss", loss) # 自动记录测试损失
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3)
# 2. 准备数据模块
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, batch_size=32):
super().__init__()
self.batch_size = batch_size
def prepare_data(self):
# 下载数据集
MNIST(root="data", train=True, download=True)
MNIST(root="data", train=False, download=True)
def setup(self, stage=None):
# 数据预处理和划分
transform = ToTensor()
mnist_full = MNIST(root="data", train=True, transform=transform)
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
self.mnist_test = MNIST(root="data", train=False, transform=transform)
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=self.batch_size, shuffle=True)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=self.batch_size)
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=self.batch_size)
# 3. 训练模型
if __name__ == "__main__":
# 初始化数据模块和模型
dm = MNISTDataModule(batch_size=32)
model = LitModel()
# 创建训练器并训练
trainer = pl.Trainer(
max_epochs=3, # 训练3个epoch
accelerator="gpu", # 选择GPU
devices=[0],
)
trainer.fit(model, datamodule=dm)
# 测试模型
trainer.test(model, datamodule=dm)