SSL论文笔记:FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence

文章目录

  • Abstraction
  • FixMatch
    • Algorithm
    • Augmentation
    • Additional important factors
  • Related work
  • Experiments
    • CIFAR10/100, SVNH
    • Barely supervised learning
  • Ablation
    • Sharpening and Thresholding
    • Augmentation strategy
    • Others

Abstraction

两种常用半监督学习领域的方法组合:一致性正则和伪标签(更准确地是artificial labels)

FixMatch方法:

  1. 首先利用模型在未标注数据的弱增强版本上生成一个伪标签;(只保留模型可以高置信度预测的伪标签);

  2. 然后训练该模型,以预测同一图像输入时的强增强版本时的伪标签

SSL论文笔记:FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence_第1张图片

Contribution:

  • SSL领域的SOTA结果
  • 探索了extremely-scarce-labels的情况(一类一个样本),并根据实验结果创建了有趣的几组数据集(从无代表性到有代表性样本)
  • 丰富的消融实验,并囊括了新SSL方法提出时很少提及的基础实验选择(如优化器和学习率策略)

个人觉得其有效的两个原因:

  • strong data augmentation
  • ignoring low-confidence predictions

FixMatch

Algorithm

主要创新性来自于一致性正则和伪标签两种成分的结合,以及在执行一致性正则化时分别使用弱和强增强。

SSL论文笔记:FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence_第2张图片

监督损失 ℓ s \ell_{s} s和无监督损失 ℓ u \ell_{u} u都是CE损失:
ℓ s = 1 B ∑ b = 1 B H ( p b , p m ( y ∣ α ( x b ) ) ) \ell_{s}=\frac{1}{B} \sum_{b=1}^{B} \mathrm{H}\left(p_{b}, p_{\mathrm{m}}\left(y \mid \alpha\left(x_{b}\right)\right)\right) s=B1b=1BH(pb,pm(yα(xb)))

q b = p m ( y ∣ α ( u b ) ) q ^ b = arg ⁡ max ⁡ ( q b ) ℓ u = 1 μ B ∑ k = 1 μ B 1 ( max ⁡ ( q b ) ≥ τ ) H ( q ^ b , p m ( y ∣ A ( u b ) ) ) q_{b}=p_{\mathrm{m}}\left(y \mid \alpha(u_{b})\right) \\ \hat{q}_{b}=\arg \max \left(q_{b}\right) \\ \ell_{u}=\frac{1}{\mu B} \sum_{k=1}^{\mu B} 1\left(\max \left(q_{b}\right) \geq \tau\right) \mathrm{H}\left(\hat{q}_{b}, p_{\mathrm{m}}\left(y \mid \mathcal{A}\left(u_{b}\right)\right)\right) qb=pm(yα(ub))q^b=argmax(qb)u=μB1k=1μB1(max(qb)τ)H(q^b,pm(yA(ub)))

和伪标签《Pseudo-label: The simple and efficient semi-supervised learning method for deep neural networks》一文最大的不同是artificial label是在弱增强图片上得到的,且l计算的loss是强增强图片的模型预测和伪标签之间的。

Note:实操中,构建 U \mathcal{U} U时,将标记数据也作为未标记数据一部分。

在FixMatch中,不需要如一般SSL算法中逐步增加无监督损失的权重( λ u \lambda_{u} λu),猜测是由于置信度 τ \tau τ的存在使得训练早期,少有无标注数据置信度高被作为伪标签;而随着训练进行,模型预测更具置信度,伪标签的出现更频繁。即本文认为伪标记中使用的阈值具有权重调节相似的作用。

Augmentation

  • Weak augmentation: standard flip-and-shift augmentation

  • Strong augmentation:(源自于AutoAugment变种,因其不适用于SSL设置)

RandAugment:有单独的一篇,以及UDA中也是这么做的

随机在一堆变换中为mini-batch中每个样本选取 N N N种变换形式,对于失真度 M M M,不采用全局量,在每个训练步骤(而不是使用一个固定的全局值)从预定义的范围内随机采样一个数量级,对于半监督训练来说效果更好,这与UDA中的用法相似。

CTAugment:ReMixMatch采用,使用控制理论中的思想来消除对AutoAugment中增强学习的需求

有一组18种可能的变换,变换的幅度值被划分为bin,每个bin被分配一个权重。最初,所有bin的权重均为1。现在从该集合中以相等的概率随机选择两个变换,形成变换序列,这类似于RandAugment。对于每个变换,根据归一化的bin权重随机选择一个幅值bin。有标签样本通过这两个转换得到了增强,并传递给模型以进行预测根据模型预测值与实际标签的接近程度,更新这些变换的bin权重。因此,它学会选择具有较高的机会来预测正确的标签的模型,从而在网络容差范围内进行增强。

因此,我们看到,与RandAugment不同,CTAugment可以在训练过程中动态学习每个变换的幅度。因此,我们无需在某些受监督的代理任务上对其进行优化,并且它没有敏感的超参数可优化。因此,这非常适合缺少标签数据的半监督环境。

Additional important factors

使用了简单的权重衰减正则化,由于和Adam优化器一起用有问题,所以使用了带动量的标准sgd

使用cosine learning rate decay

使用模型参数的EMA

Related work

单纯的伪标签算法【22】不具竞争性,但将人工标签作为流程的一部分是近期一些SSL算法采用的方式,同时可认为EntMin是伪标签导致的一种技巧。

近期,实验表明使用强数据增强策略可以产生更好的结果,这些强增强版本样本几乎在数据分布之外,但却有益于SSL。

作者认为FixMatch可以作为UDA和ReMixMatch的一种简化,这两者也可以看作是使用一个弱增强样本去生成一个人工标签并且和强增强样本执行一致性正则化,具体可看Table 1:

SSL论文笔记:FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence_第3张图片

Experiments

CIFAR10/100, SVNH

除了ReMixMatch外,没有其他工作考虑每个类仅有少于25的标注数据,本文考虑了每个类仅有4个标注样本的情况。

SSL论文笔记:FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence_第4张图片

除了cifar100上ReMixMatch最佳,其他都是FixMatch最佳。在将ReMixMatch的各成分移植到FixMatch的过程中,发现最重要的一项是分布对齐(Distribution Alignment),它鼓励模型以等概率emit所有类别。将DA组合到FixMatch达到了超越ReMixMatch本身的效果。

FixMatch使用RandAugment和CTAugment在大部分情况下性能相近,除了每类只有4个标注样本的情况,这可以被结果存在高方差解释。

Barely supervised learning

CIFAR10: only one example per class

发现分类器性能与给定数据集中标注样本的质量有很大关系,选择低质量样本作为标注样本会使得模型很难有效学习某些特定类。作者据此给了从高质量到低质量八个数据集。(每个数据集十个样本)

Ablation

Sharpening and Thresholding

Pseudo-label一般在半监督中特指hard-label

直接利用argmax其实是对标sharpen操作去达到一个最小化熵效果的,不过FixMatch这里还用了阈值去排除一些标签不那么可靠的,而这个阈值方法也是可以和sharpen操作结合的。

实验表明sharpen操作并不能在存在阈值的情况下显著提升性能。【似乎是指直接根据阈值one-hot就挺好?】

Augmentation strategy

  • 衡量了一下CutOut用在RandAugment和CTAugment之后的作用

  • 数据增强同时用于伪标签的生成与预测阶段:

如果生成伪标签使用强增强,则模型在训练初期发散,因此认为伪标签需要用弱增强数据;

如果使用弱增强数据进行预测,训练不稳定并且准确率逐渐崩溃,因此认为在训练时的模型预测使用强增强比较好。

Others

可以结合实验自己感悟,在此不作赘述。

你可能感兴趣的:(论文心得等,深度学习,人工智能)