伪标签

什么是伪标签

伪标签是将置信度较高的测试数据添加到训练数据中的过程。伪标签一共有5个步骤。

  1. 使用**训练集数据(Train1)**训练一个模型。
  2. 使用训练好的模型预测测试集数据
  3. 将预测的置信度较高的样本加入到训练集中。
  4. 使用新的训练集训练一个新的模型
  5. 使用新的模型去预测测试集数据

1. 建立第一个模型

正常建立模型即可

2. 预测测试集

正常测试即可

3. 增加伪标签数据到训练集

将所有预测的置信度Pr(y=1|x)>0.99Pr(y=0|x)>0.99的加入到训练集中。

4.训练一个新的模型

然后使用新的数据集去训练新的模型。

5. 预测测试数据

预测测试数据,然后提交

为什么伪标签能够起作用

QDA可以更好地理解伪标签的工作原理。QDA的工作原理是利用p维空间中的点来寻找超椭球体。随着点的增多,QDA可以更好地估计每个椭球面的中心和形状(从而更好地进行预测)。
伪标签可以帮助所有类型的模型,因为所有模型都可以可视化为寻找目标=1和目标=0在p维空间的形状。更多的点可以更好地估计形状。

参考链接

I’m overfitting and I know it

你可能感兴趣的:(pytorch,深度学习)