并不是拟合效果好就是最佳的模型,必须要能够被验证。当我们需要预测某个模型的真实“预测”效果时,我们需要具体的评价模型的预测效果,即使在训练集中有比较好的预测效果,在验证集中也不见得理想,这里涉及到模型的交叉验证和变量的相对重要性分析。
交叉验证:将一定比例的数据挑选出来作为训练样本,另外的样本保留,先在训练样本上获取回归方程,再到保留样本上预测。
一:简单的交叉验证:由于测试集和训练集是分开的,可以避免过拟合的现象,随机性太大,说服力不强
1、从全部的训练数据 S中随机选择 中随机选择 s的样例作为训练集 train,剩余的作为测试集 作为测试集 test。
2、通过对测试集训练 ,得到假设函数或者模型 。
3、在测试集对每一个样本根据假设函数或者模型,得到训练集的类标,求出分类正确率。
4、选择具有最大分类率的模型或者假设。
二:k折交叉验证 k-fold cross validation
K折交叉验证法(k-fold CV)是将观测集随机地分为k个大小基本一致的组,或者说折(fold),第一折作为验证集,然后在剩下的k-1个折上拟合模型,均方误差MSE1(响应变量Y为定性变量时则为错误率)由保留折的观测计算得出。重复这个步骤k次,每一次把不同折作为验证集,整个过程会得到k个测试误差的估计MSE1,MSE2,…, MSEk。K折CV估计由这些值求平均计算得到。
最常见的是k=5或k=10的情形,当k=n时便是我们说的留一交叉验证法(leave-one-out cross validation, LOOCV),也即LOOCV是k折CV的特例。
1、 将全部训练集 S分成 k个不相交的子集,假设 S中的训练样例个数为 m,那么每一个子 集有 m/k 个训练样例,,相应的子集称作 {s1,s2,…,sk}。
2、每次从分好的子集中里面,拿出一个作为测试集,其它k-1个作为训练集
3、根据训练训练出模型或者假设函数。
4、 把这个模型放到测试集上,得到分类率。
5、计算k次求得的分类率的平均值,作为该模型或者假设函数的真实分类率。
这个方法充分利用了所有样本。但计算比较繁琐,需要训练k次,测试k次。
model = SVHN_Model1()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), 0.001)
best_loss = 1000.0
use_cuda = False
if use_cuda:
model = model.cuda()
for epoch in range(2):
train_loss = train(train_loader, model, criterion, optimizer, epoch)
val_loss = validate(val_loader, model, criterion)
val_label = [''.join(map(str, x)) for x in val_loader.dataset.img_label]
val_predict_label = predict(val_loader, model, 1)
val_predict_label = np.vstack([
val_predict_label[:, :11].argmax(1),
val_predict_label[:, 11:22].argmax(1),
val_predict_label[:, 22:33].argmax(1),
val_predict_label[:, 33:44].argmax(1),
val_predict_label[:, 44:55].argmax(1),
]).T
val_label_pred = []
for x in val_predict_label:
val_label_pred.append(''.join(map(str, x[x!=10])))
val_char_acc = np.mean(np.array(val_label_pred) == np.array(val_label))
print('Epoch: {0}, Train loss: {1} \t Val loss: {2}'.format(epoch, train_loss, val_loss))
print('Val Acc', val_char_acc)
# 记录下验证集精度
if val_loss < best_loss:
best_loss = val_loss
# print('Find better model in Epoch {0}, saving model.'.format(epoch))
torch.save(model.state_dict(), './model.pt')