两种常用半监督学习领域的方法组合:一致性正则和伪标签(更准确地是artificial labels)
FixMatch方法:
首先利用模型在未标注数据的弱增强版本上生成一个伪标签;(只保留模型可以高置信度预测的伪标签);
然后训练该模型,以预测同一图像输入时的强增强版本时的伪标签
Contribution:
个人觉得其有效的两个原因:
主要创新性来自于一致性正则和伪标签两种成分的结合,以及在执行一致性正则化时分别使用弱和强增强。
监督损失 ℓ s \ell_{s} ℓs和无监督损失 ℓ u \ell_{u} ℓu都是CE损失:
ℓ s = 1 B ∑ b = 1 B H ( p b , p m ( y ∣ α ( x b ) ) ) \ell_{s}=\frac{1}{B} \sum_{b=1}^{B} \mathrm{H}\left(p_{b}, p_{\mathrm{m}}\left(y \mid \alpha\left(x_{b}\right)\right)\right) ℓs=B1b=1∑BH(pb,pm(y∣α(xb)))
q b = p m ( y ∣ α ( u b ) ) q ^ b = arg max ( q b ) ℓ u = 1 μ B ∑ k = 1 μ B 1 ( max ( q b ) ≥ τ ) H ( q ^ b , p m ( y ∣ A ( u b ) ) ) q_{b}=p_{\mathrm{m}}\left(y \mid \alpha(u_{b})\right) \\ \hat{q}_{b}=\arg \max \left(q_{b}\right) \\ \ell_{u}=\frac{1}{\mu B} \sum_{k=1}^{\mu B} 1\left(\max \left(q_{b}\right) \geq \tau\right) \mathrm{H}\left(\hat{q}_{b}, p_{\mathrm{m}}\left(y \mid \mathcal{A}\left(u_{b}\right)\right)\right) qb=pm(y∣α(ub))q^b=argmax(qb)ℓu=μB1k=1∑μB1(max(qb)≥τ)H(q^b,pm(y∣A(ub)))
和伪标签《Pseudo-label: The simple and efficient semi-supervised learning method for deep neural networks》一文最大的不同是artificial label是在弱增强图片上得到的,且l计算的loss是强增强图片的模型预测和伪标签之间的。
Note:实操中,构建 U \mathcal{U} U时,将标记数据也作为未标记数据一部分。
在FixMatch中,不需要如一般SSL算法中逐步增加无监督损失的权重( λ u \lambda_{u} λu),猜测是由于置信度 τ \tau τ的存在使得训练早期,少有无标注数据置信度高被作为伪标签;而随着训练进行,模型预测更具置信度,伪标签的出现更频繁。即本文认为伪标记中使用的阈值具有权重调节相似的作用。
Weak augmentation: standard flip-and-shift augmentation
Strong augmentation:(源自于AutoAugment变种,因其不适用于SSL设置)
RandAugment:有单独的一篇,以及UDA中也是这么做的
随机在一堆变换中为mini-batch中每个样本选取 N N N种变换形式,对于失真度 M M M,不采用全局量,在每个训练步骤(而不是使用一个固定的全局值)从预定义的范围内随机采样一个数量级,对于半监督训练来说效果更好,这与UDA中的用法相似。
CTAugment:ReMixMatch采用,使用控制理论中的思想来消除对AutoAugment中增强学习的需求
有一组18种可能的变换,变换的幅度值被划分为bin,每个bin被分配一个权重。最初,所有bin的权重均为1。现在从该集合中以相等的概率随机选择两个变换,形成变换序列,这类似于RandAugment。对于每个变换,根据归一化的bin权重随机选择一个幅值bin。有标签样本通过这两个转换得到了增强,并传递给模型以进行预测根据模型预测值与实际标签的接近程度,更新这些变换的bin权重。因此,它学会选择具有较高的机会来预测正确的标签的模型,从而在网络容差范围内进行增强。
因此,我们看到,与RandAugment不同,CTAugment可以在训练过程中动态学习每个变换的幅度。因此,我们无需在某些受监督的代理任务上对其进行优化,并且它没有敏感的超参数可优化。因此,这非常适合缺少标签数据的半监督环境。
使用了简单的权重衰减正则化,由于和Adam优化器一起用有问题,所以使用了带动量的标准sgd
使用cosine learning rate decay
使用模型参数的EMA
单纯的伪标签算法【22】不具竞争性,但将人工标签作为流程的一部分是近期一些SSL算法采用的方式,同时可认为EntMin是伪标签导致的一种技巧。
近期,实验表明使用强数据增强策略可以产生更好的结果,这些强增强版本样本几乎在数据分布之外,但却有益于SSL。
作者认为FixMatch可以作为UDA和ReMixMatch的一种简化,这两者也可以看作是使用一个弱增强样本去生成一个人工标签并且和强增强样本执行一致性正则化,具体可看Table 1:
除了ReMixMatch外,没有其他工作考虑每个类仅有少于25的标注数据,本文考虑了每个类仅有4个标注样本的情况。
除了cifar100上ReMixMatch最佳,其他都是FixMatch最佳。在将ReMixMatch的各成分移植到FixMatch的过程中,发现最重要的一项是分布对齐(Distribution Alignment),它鼓励模型以等概率emit所有类别。将DA组合到FixMatch达到了超越ReMixMatch本身的效果。
FixMatch使用RandAugment和CTAugment在大部分情况下性能相近,除了每类只有4个标注样本的情况,这可以被结果存在高方差解释。
CIFAR10: only one example per class
发现分类器性能与给定数据集中标注样本的质量有很大关系,选择低质量样本作为标注样本会使得模型很难有效学习某些特定类。作者据此给了从高质量到低质量八个数据集。(每个数据集十个样本)
Pseudo-label一般在半监督中特指hard-label
直接利用argmax其实是对标sharpen操作去达到一个最小化熵效果的,不过FixMatch这里还用了阈值去排除一些标签不那么可靠的,而这个阈值方法也是可以和sharpen操作结合的。
实验表明sharpen操作并不能在存在阈值的情况下显著提升性能。【似乎是指直接根据阈值one-hot就挺好?】
衡量了一下CutOut用在RandAugment和CTAugment之后的作用
数据增强同时用于伪标签的生成与预测阶段:
如果生成伪标签使用强增强,则模型在训练初期发散,因此认为伪标签需要用弱增强数据;
如果使用弱增强数据进行预测,训练不稳定并且准确率逐渐崩溃,因此认为在训练时的模型预测使用强增强比较好。
可以结合实验自己感悟,在此不作赘述。