k折(k-fold)交叉验证

k折交叉验证是用来选择模型设计并调节超参数的。


首先,我们先来看一下k折交叉验证的使用背景

我在训练模型的时候,通过已有的数据训练得出一个模型,在这里用于训练的数据称为训练集。当我们把这个模型部署到真实环境去的时候,通过真实的数据得到结果,这个时候这些真实的数据就叫做测试集。

那么问题来了,我们在训练模型的时候怎么样评估模型的好坏呢?怎么样调节模型的超参数呢?

答:一般我们都是将训练用的数据分为训练集和验证集。训练集用来训练数据,验证集测试模型的好坏,据此不断调节模型的超参数。

那么问题又来了,在这里验证集是不参与训练的,而如果我们的数据比较少,想要充分利用训练集里的数据怎么办呢?只是使用固定的一个验证集会不会训练得出来的结果会不会不够公正

答:解决这两个问题的其中一个方法就是k折交叉验证了。


k折交叉验证的训练过程如下:

1. 将数据分成k份,并进行k次训练。每次训练将1份作为验证集,剩下的k-1份作为训练集,k次训练正好每1份都当了一次验证集。

2. 将每次训练的误差做平均得到平均误差。依据均误差来调节模型的超参数

3. 超参数固定好之后,用完整的数据集来重新训练模型

(注意:对于的每一次训练,模型都得要重新进行初始化,而并不是用上一次训练得到模型来进行下一次的训练。)

def get_k_fold_data(k, i, X, y):
'''它返回第i折交叉验证时所需要的训练和验证数据。'''
    assert k > 1
    fold_size = X.shape[0] // k
    X_train, y_train = None, None
    for j in range(k):
        idx = slice(j * fold_size, (j + 1) * fold_size)
        X_part, y_part = X[idx, :], y[idx]
        if j == i:
            X_valid, y_valid = X_part, y_part
        elif X_train is None:
            X_train, y_train = X_part, y_part
        else:
            X_train = nd.concat(X_train, X_part, dim=0)
            y_train = nd.concat(y_train, y_part, dim=0)
    return X_train, y_train, X_valid, y_valid


def k_fold(k, X_train, y_train, num_epochs,
           learning_rate, weight_decay, batch_size):
'''在 ? 折交叉验证中我们训练 ? 次并返回训练和验证的平均误差'''
    train_l_sum, valid_l_sum = 0, 0
    for i in range(k):
        data = get_k_fold_data(k, i, X_train, y_train)
        net = get_net()
        # 这里的*表示传入的参数是一个元组(tuple),对应的**传入的参数是一个字典
        train_ls, valid_ls = train(net, *data, num_epochs, learning_rate,
                                   weight_decay, batch_size)
        train_l_sum += train_ls[-1]
        valid_l_sum += valid_ls[-1]
        if i == 0:
            d2l.semilogy(range(1, num_epochs + 1), train_ls, 'epochs', 'rmse',
                         range(1, num_epochs + 1), valid_ls,
                         ['train', 'valid'])
        print('fold %d, train rmse %f, valid rmse %f'
              % (i, train_ls[-1], valid_ls[-1]))
    return train_l_sum / k, valid_l_sum / k

 

你可能感兴趣的:(学习笔记,深度学习-李牧)