pytorch lightning之验证与测试

训练

训练部分已在《入门篇》介绍。

验证集和测试集中评估模型

通常将数据集分为三部分,train/val/test,val集在训练时评估模型的泛化性,选择其中表现最好的checkpoint。test集只在模型训练完成后使用,用于评估模型的真实性能。

添加test流程

划分数据集

以下代码使用torchvision包内实现的MNIST。如果使用自定义的数据集,先用pytorch实现Dataset子类,再继承pl.LightningDataModule类,实现相应接口。

import torch.utils.data as data
from torchvision import datasets
import torchvision.transforms as transforms

# Load data sets
transform = transforms.ToTensor()
train_set = datasets.MNIST(root="MNIST", download=True, train=True, transform=transform)
test_set = datasets.MNIST(root="MNIST", download=True, train=False, transform=transform)

实现test_step()接口

在trainer.test()阶段会自动调用test_step方法,根据需要内部可以增加保存图片、评估模型等功能。

class LitAutoEncoder(pl.LightningModule):
    def training_step(self, batch, batch_idx):
        ...

    def test_step(self, batch, batch_idx):
        # this is the test loop
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        test_loss = F.mse_loss(x_hat, x)
        self.log("test_loss", test_loss)

测试

模型训练完成后,即可调用test()方法进入测试流程

from torch.utils.data import DataLoader

# initialize the Trainer
trainer = Trainer()

# 训练模型
trainer.fit(model, data)

# 训练完成后测试
trainer.test(model, dataloaders=DataLoader(test_set))

验证阶段validation的流程

与test 流程类似,实现validation_step()接口,可以配合on_validation_epoch_end()方法在计算所有样例后评估模型。

class LitAutoEncoder(pl.LightningModule):
    def training_step(self, batch, batch_idx):
        ...

    def validation_step(self, batch, batch_idx):
        # this is the validation loop
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        val_loss = F.mse_loss(x_hat, x)
        self.log("val_loss", val_loss)
        self.metric.update(x_hat, y)# metric是任务相关的评价方法,比如更新混淆矩阵
    
    def on_validation_epoch_step(self, batch, batch_idx):
        
        # 从混淆矩阵中计算tp,fp, tn, fn, acc, F1等指标 
        score = self.metric.get_scores()
        # 记录,横坐标为epoch
        self.log('val/F1', score['F1'], logger=True, on_epoch=True)

预测predict流程

实现predict_step方法,然后调用trainer.predict()

其它HOOK见LightningModule,了解LightningModule的接口基本就会用pl了。

你可能感兴趣的:(pytorch,lightning)