k-折交叉验证(代码)

k = 4
mun_validation_samples = len(data) // k

np.random.shuffle(data)

validation_scores = []
for fold in range(k):
    validation_data = data[num_validation_samples*fold:num_validation_samples*(fold+1)]
    training_data = data[:num_validataion_samples*fold] +
    data[num_validation_samples* (fold+1):]
    
    model = get_model()
    model.train(training_data)
    validation_score = model.evaluate(validation_data)
    validation_score.append(validation_score)

validation_score = np.average(validation_score)

model = get_model() #在所有非测试数据上训练最终模型
model.train(data)
test_score = model.evaluate(test_data)

 

你可能感兴趣的:(深度学习)