在 Pytorch 中实现 early stopping

写在前面:写该博客是为了记录自己这段时间里的学习收获,重在思路,若有疑问,欢迎评论,若发现该博客有错误,欢迎指出。这两天终于有了空闲时间,所以就多写点。该博客使用了这个 g i t h u b github github 仓库中提供的 p y t o r c h _ t o o l s pytorch\_tools pytorch_tools,而且该仓库中还有这个工具的使用案例,建议读者前往查看。在此感谢作者的代码分享。

实现

有了 p y t o r c h _ t o o l s pytorch\_tools pytorch_tools 工具后,使用 e a r l y   s t o p p i n g early\ stopping early stopping 就很简单了。

先从该工具类中导入 E a r l y S t o p p i n g EarlyStopping EarlyStopping .

# import EarlyStopping
from pytorchtools import EarlyStopping
import torch.utils.data as Data	# 用于创建 DataLoader
import torch.nn as nn

为了方便描述,这里还是会使用一些伪代码,如果你想阅读详细案例的话,不用犹豫直接看上述工具自带的案例代码。

model = yourModel()	# 伪
# 指定损失函数,可以是其他损失函数,根据训练要求决定
criterion = nn.CrossEntropyLoss()	# 交叉熵,注意该损失函数对自动对批量样本的损失取平均
# 指定优化器,可以是其他
optimizer = torch.optim.Adam(model.parameters())
# 初始化 early_stopping 对象
patience = 20	# 当验证集损失在连续20次训练周期中都没有得到降低时,停止模型训练,以防止模型过拟合
early_stopping = EarlyStopping(patience, verbose=True)	# 关于 EarlyStopping 的代码可先看博客后面的内容

batch_size = 64	# 或其他,该参数属于超参,对于如何选择超参,你可以参考下我的上一篇博客
n_epochs = 100	# 可以设置大一些,毕竟你是希望通过 early stopping 来结束模型训练
#----------------------------------------------------------------
# 训练模型,直到 epoch == n_epochs 或者触发 early_stopping 结束训练
for epoch in range(1, n_epochs + 1):

	# 建立训练数据的 DataLoader
	training_dataset = Data.TensorDataset(X_train, y_train)
    # 把dataset放到DataLoader中
    data_loader = Data.DataLoader(
        dataset=training_dataset,
        batch_size=batch_size,	# 批量大小
        shuffle=True	# 是否打乱数据顺序
    )
    #---------------------------------------------------
    model.train()	# 设置模型为训练模式
    # 按小批量训练
	for batch, (data, target) in enumerate(data_loader):
		optimizer.zero_grad()	# 清楚所有参数的梯度
		output = model(data)	# 输出模型预测值
		loss = criterion(output, target)	# 计算损失
		loss.backward()	# 计算损失对于各个参数的梯度
		optimizer.step()	# 执行单步优化操作:更新参数
	#----------------------------------------------------
	model.eval() # 设置模型为评估/测试模式
	# 一般如果验证集不是很大的话,模型验证就不需要按批量进行了,但要注意输入参数的维度不能错
	valid_output = model(X_val)
	valid_loss = criterion(valid_output, y_val)	# 注意这里的输入参数维度要符合要求,我这里为了简单,并未考虑这一点

	early_stopping(valid_loss, model)
	# 若满足 early stopping 要求
	if early_stopping.early_stop:
		print("Early stopping")
		# 结束模型训练
		break
# 获得 early stopping 时的模型参数
model.load_state_dict(torch.load('checkpoint.pt'))	

以下是 p y t o r c h _ t o o l s pytorch\_tools pytorch_tools 工具的代码:

import numpy as np
import torch

class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=7, verbose=False, delta=0):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta

    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), 'checkpoint.pt')	# 这里会存储迄今最优模型的参数
        self.val_loss_min = val_loss

结束

总结完了,不过这些代码还是需要读者根据自己的模型做出改动。希望这篇博客对你会有所帮助,再一次,欢迎指出错误。

你可能感兴趣的:(神经网络)