pytorch使用早停策略

文章目录

  • 早停的目的与流程
  • 早停策略
  • pytorch使用示例
  • 参考网站

早停的目的与流程

目的:防止模型过拟合,由于深度学习模型可以无限迭代下去,因此希望在即将过拟合时、或训练效果微乎其微时停止训练。
pytorch使用早停策略_第1张图片

流程如下:

  1. 将数据集切分为三部分:训练数据(数据量最多),验证数据(数据量最少,一般10%-20%左右即可),测试数据(数据量第二多)
  2. 模型通过训练集,得到训练集的 L o s s t r a i n Loss_{train} Losstrain
  3. 然后模型通过验证集,此时不是训练,不需要反向传播,得到验证集的 L o s s v a l i d Loss_{valid} Lossvalid
  4. 早停策略通过 L o s s t r a i n Loss_{train} Losstrain L o s s v a l i d Loss_{valid} Lossvalid来判断,是否需要中断训练

早停策略

早停策略,我们都是拿着验证集训练集来说事:

  1. 常用的策略:

    ♣ 如果训练集loss与验证集loss连续几次下降不明显,就早停
    ♣ 验证集loss连续n次不降反升则早停。(通常是3次)

  2. 根据泛化损失卡阈值的策略

    ♣ 将目前已有的验证集的最小loss记录下来,看当前的验证集loss与最小的loss之间的差距
    ♣ 通过公式: G L ( t ) = 100 ⋅ ( E v a ( t ) E o p t ( t ) − 1 ) {GL(t)} = 100 \cdot \big( \frac{E_{va}(t)}{E_{opt}(t)} - 1) GL(t)=100(Eopt(t)Eva(t)1)计算一个值,并称之为泛化损失
    ♣ 当这个泛化损失超过阈值的时候停止训练

  3. 根据度量进展卡阈值的策略:我们通常假设过拟合会出现在训练集loss很难下降的时候,此时模型继续强行下降loss会导致过拟合的风险,因此,

    ♣ 定一个迭代周期,为训练k次,判断本次迭代的时候平均训练loss比最小训练loss大多少
    (公式: P k ( t ) = 1000 ⋅ ( ∑ t ′ = t − k + 1 t E t r ( t ′ ) k ⋅ m i n t ′ = t − k + 1 t E t r ( t ′ ) − 1 ) P_k(t) = 1000 \cdot \big( \frac{ \sum_{t' = t-k+1}^t E_{tr}(t') }{ k \cdot min_{t' = t-k+1}^t E_{tr}(t') } -1 \big) Pk(t)=1000(kmint=tk+1tEtr(t)t=tk+1tEtr(t)1)
    ♣ 然后结合上面的泛化损失,计算 G L ( t ) P k ( t ) \frac{GL(t)}{P_k(t)} Pk(t)GL(t)
    ♣ 当这个值大于一个阈值时,停止训练

pytorch使用示例

我们参考https://github.com/Bjarten/early-stopping-pytorch这个项目的早停策略

EarlyStopping类在:https://github.com/Bjarten/early-stopping-pytorch/blob/master/pytorchtools.py

结合深度学习的示例如下:

import torch
import torch.nn as nn
import os
from sklearn.datasets import make_regression
from torch.utils.data import Dataset, DataLoader
import numpy as np


class EarlyStopping: # 这个是别人写的工具类,大家可以把它放到别的地方
    """Early stops the training if validation loss doesn't improve after a given patience."""

    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement.
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
            trace_func (function): trace print function.
                            Default: print
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path
        self.trace_func = trace_func

    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            self.trace_func(
                f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss


class MyDataSet(Dataset):  # 定义数据格式
    def __init__(self, train_x, train_y, sample):
        self.train_x = train_x
        self.train_y = train_y
        self._len = sample

    def __getitem__(self, item: int):
        return self.train_x[item], self.train_y[item]

    def __len__(self):
        return self._len


def get_data():
    """构造数据"""
    sample = 20000
    data_x, data_y = make_regression(n_samples=sample, n_features=100)  # 生成数据集
    train_data_x = data_x[:int(sample * 0.8)]
    train_data_y = data_y[:int(sample * 0.8)]
    valid_data_x = data_x[int(sample * 0.8):]
    valid_data_y = data_y[int(sample * 0.8):]
    train_loader = DataLoader(MyDataSet(train_data_x, train_data_y, len(train_data_x)), batch_size=10)
    valid_loader = DataLoader(MyDataSet(valid_data_x, valid_data_y, len(valid_data_x)), batch_size=10)
    return train_loader, valid_loader


class LinearRegressionModel(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LinearRegressionModel, self).__init__()
        self.linear = nn.Linear(input_dim, output_dim)  # 输入的个数,输出的个数

    def forward(self, x):
        out = self.linear(x)
        return out


def main():
    train_loader, valid_loader = get_data()
    model = LinearRegressionModel(input_dim=100, output_dim=1)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
    criterion = nn.MSELoss()
    early_stopping = EarlyStopping(patience=4, verbose=True)  # 早停

    # 开始训练模型
    for epoch in range(1000):
        # 正常的训练
        print("迭代第{}次".format(epoch))
        model.train()
        train_loss_list = []
        for train_x, train_y in train_loader:
            optimizer.zero_grad()
            outputs = model(train_x.float())
            loss = criterion(outputs.flatten(), train_y.float())
            loss.backward()
            train_loss_list.append(loss.item())
            optimizer.step()
        print("训练loss:{}".format(np.average(train_loss_list)))
        # 早停策略判断
        model.eval()
        with torch.no_grad():
            valid_loss_list = []
            for valid_x, valid_y in valid_loader:
                outputs = model(valid_x.float())
                loss = criterion(outputs.flatten(), valid_y.float())
                valid_loss_list.append(loss.item())
            avg_valid_loss = np.average(valid_loss_list)
            print("验证集loss:{}".format(avg_valid_loss))
            early_stopping(avg_valid_loss, model)
            if early_stopping.early_stop:
                print("此时早停!")
                break


if __name__ == '__main__':
    main()

参考网站

  • 深度学习技巧之Early Stopping(早停法):https://www.datalearner.com/blog/1051537860479157
  • early-stopping-pytorch:https://github.com/Bjarten/early-stopping-pytorch

你可能感兴趣的:(pytorch/神经网络,pytorch,深度学习)