Task04: 模型训练与验证

0. 数据集搭建

  1. 训练集(Train Set):模型用于训练和调整模型参数;
  2. 验证集(Validation Set):用来验证模型精度和调整模型超参数;
  3. 测试集(Test Set):验证模型的泛化能力。

在数据建模比赛中,一般三者都已经分好了,即:训练集、验证集发放数据和标签,而测试集仅发放数据;而如果赛方没有提前划分验证集,则需要参赛人员自行划分,有以下划分方法

  1. 留出法(Hold-Out):直接对数据随机划分成两份,适用于数据量大的情况;
  2. 交叉验证法(Cross Validation,CV):对数据划分为若干折(等分),对每一折进行验证时,其余折用于训练,验证集精度以各份平均表示,适用于数据量一般的情况;
  3. 自助采样法(BootStrap):有放回采样获得训练集和验证集,适用于数据量较少的情况。

数据集划分的原则是:每份数据的标签分布都能代表整体分布

1. 模型训练与验证

主要步骤如下:

  • 构造训练集和验证集;
  • 每轮进行训练和验证,并根据最优验证集精度保存模型。

模型构建、数据集构建过程与上一task一致,这里不赘述。需要注意的是每次交替训练和验证时要切换模型的状态:

def train(train_loader, model, criterion, optimizer, epoch):
    # 切换模型为训练模式
    model.train()

    for i, (input, target) in enumerate(train_loader):
        c0, c1, c2, c3, c4, c5 = model(data[0])
        loss = criterion(c0, data[1][:, 0]) + \
                criterion(c1, data[1][:, 1]) + \
                criterion(c2, data[1][:, 2]) + \
                criterion(c3, data[1][:, 3]) + \
                criterion(c4, data[1][:, 4]) + \
                criterion(c5, data[1][:, 5])
        loss /= 6
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

def validate(val_loader, model, criterion):
    # 切换模型为预测模型
    model.eval()
    val_loss = []

    # 不记录模型梯度信息
    with torch.no_grad():
        for i, (input, target) in enumerate(val_loader):
            c0, c1, c2, c3, c4, c5 = model(data[0])
            loss = criterion(c0, data[1][:, 0]) + \
                    criterion(c1, data[1][:, 1]) + \
                    criterion(c2, data[1][:, 2]) + \
                    criterion(c3, data[1][:, 3]) + \
                    criterion(c4, data[1][:, 4]) + \
                    criterion(c5, data[1][:, 5])
            loss /= 6
            val_loss.append(loss.item())
    return np.mean(val_loss)

最后保存/加载最优模型

torch.save(model_object.state_dict(), 'model.pt')
model.load_state_dict(torch.load(' model.pt'))

2. 调lian参dan trick

  • http://www.lamda.nju.edu.cn/weixs/project/CNNTricks/CNNTricks.html
  • http://karpathy.github.io/2019/04/25/recipe/
一般流程

你可能感兴趣的:(Task04: 模型训练与验证)