FixMatch 是对现有 SSL 方法的简化. FixMatch 首先对弱增强的未标记图像生成伪标签, 接着, 对同一图像进行强增强后, 再计算其预测分布, 最后计算强增强的预测与伪标签之间的交叉熵损失.
论文地址: FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence
代码地址: https://github.com/google-research/fixmatch
会议: NeurIPS 2020
任务: 分类
FixMatch 是 SSL 两种方法的组合: 一致性正则化和伪标签. 它的新颖之处在于这两种方法的组合以及在执行一致性正则化时使用单独的弱增强和强增强.
FixMatch 简要示意图如下:
将弱增强图像输入模型, 当某一预测类别概率高于阈值(虚线)时, 预测将转换为 one-hot 伪标签. 然后, 计算模型对同一图像的强增强的预测. 计算强增强的预测与伪标签之间的交叉熵损失.
文中符号系统如下:
一致性正则化及伪标签方法简要介绍如下:
Consistency regularization. 关于一致性正则化, 核心就是基于平滑假设, 模型对于对增强后数据的预测应与原始数据预测的结果一致.
Pseudo-labeling. 即利用模型本身来获取未标记数据的人工标签. 更具体地说, p b p_b pb 的伪标签 q b q_b qb 可以分别定义为基于锐化的连续分布(软)或基于 arg max \argmax argmax 操作的独热分布(硬). 在本文里, 人工标签一般指"硬"标签, 并且只保留最大类别概率高于预定阈值的情况. pseudo-labeling 使用如下损失函数:
1 μ B ∑ b = 1 μ b ( max ( q b ) ≥ τ ) H ( q ^ b , q b ) (1) \frac{1}{\mu B} \sum_{b=1}^{\mu b}(\max(q_b) \geq \tau)\mathrm{H}(\hat{q}_b,q_b) \tag{1} μB1b=1∑μb(max(qb)≥τ)H(q^b,qb)(1)
其中 q b = p m ( y ∣ u b ) q_b=p_m(y\vert u_b) qb=pm(y∣ub), q ^ b = arg max ( q b ) \hat{q}_b=\argmax(q_b) q^b=argmax(qb), τ \tau τ 为阈值. 鼓励模型的预测是对未标记数据的低熵, 或者说是高置信度.
FixMatch 的损失函数由两个交叉熵损失项组成: 应用于标记数据的监督损失 ℓ s \ell_s ℓs 和无监督损失 ℓ u \ell_u ℓu. 具体来说, ℓ s \ell_s ℓs 只是弱增强标记示例上的标准交叉熵损失:
ℓ s = 1 B ∑ b = 1 B B ( p b , p m ( y ∣ α ( x b ) ) ) (2) \ell_s=\frac{1}{B} \sum_{b=1}^B \mathrm{B}(p_b,p_m(y\vert \alpha(x_b))) \tag{2} ℓs=B1b=1∑BB(pb,pm(y∣α(xb)))(2)
FixMatch 为每个未标记的示例计算一个人工标签, 然后将其用于标准交叉熵损失. 为了获得人工标签, 首先在给定未标记图像的弱增强版本的情况下计算模型的预测类别分布: q b = p m ( y ∣ α ( u b ) ) q_b =p_m(y \vert \alpha(u_b)) qb=pm(y∣α(ub)). 然后, 使用 q ^ b = arg max ( q b ) \hat{q}_b = \argmax(q_b) q^b=argmax(qb) 作为伪标签, 与 u b u_b ub 的强增强版本做交叉熵损失:
ℓ u = 1 μ B ∑ b = 1 μ B ( max ( q b ) ≥ τ ) H ( q ^ b , p m ( y ∣ A ( u b ) ) ) (3) \ell_u=\frac{1}{\mu B} \sum_{b=1}^{\mu B} (\max(q_b)\geq \tau) \mathrm{H}(\hat{q}_b,p_m(y\vert \mathcal{A}(u_b))) \tag{3} ℓu=μB1b=1∑μB(max(qb)≥τ)H(q^b,pm(y∣A(ub)))(3)
综上, FixMatch 的损失函数定义为: l o s s = ℓ s + λ ℓ u loss=\ell_s+\lambda\ell_u loss=ℓs+λℓu. 完整的算法如下:
FixMatch 利用了两种增强: “弱"和"强”.
一些其他重要因素会影响 SSL 的性能, 例如: architecture, optimizer, training schedule 等. 经过实验, 文中发现正则化尤为重要. 在所有的模型和实验中, 使用简单的权重衰减正则化. 同时发现使用 Adam 优化器会导致更差的性能, 而使用 SGD 则没有这种情况, 另外, 使用 SGD 和使用 Nesterov 之间没有存在实质性差异. 对于学习率, 使用余弦学习率衰减. 它将学习率设置为 η cos 7 π k 16 K \eta \cos \frac{7\pi k}{16K} ηcos16K7πk, 其中 η \eta η 是初始学习率, k k k 是当前训练步长, K K K 是总学习率训练步骤. 最后, 使用模型参数的指数移动平均值(EMA)报告最终性能.
FixMatch 可以很容易地使用 SSL 文献中的技术进行扩展. 例如, 来自 ReMixMatch 的增强锚定和分布对齐. 此外, 可以用与模态无关的增强策略, 例如 MixUp 或对抗性扰动代替 FixMatch 中的强增强. 对抗性扰动在 VAT, Adversarial Dropout 中已经应用. MixUp 也在 MixMatch, ICT 中成功应用.