[半监督学习] FlexMatch: Boosting Semi-Supervised Learning with Curriculum Pseudo Labeling

在 FixMatch 中, 对所有类别使用预定义的常量阈值来选择有助于训练的未标记数据, 因此无法考虑不同类别的不同学习状态和学习难度, UDA 也是如此. 为解决这个问题, 提出了课程伪标签(Curriculum Pseudo Labeling, CPL), 这是一种根据模型的学习状态来利用未标记数据的课程学习方法. CPL 的核心是在不同时刻灵活地调整不同类别的阈值.

论文地址: FlexMatch: Boosting Semi-Supervised Learning with Curriculum Pseudo Labeling
代码地址: https://github.com/TorchSSL/TorchSSL
会议: NeurIPS 2021
任务: 分类

FlexMatch 使用了 CPL, CPL 是一种课程学习(Curriculum Learning)策略, 考虑到半监督学习中不同的学习状态, CPL 将预定义的阈值替换为灵活的阈值. FlexMatch 只需不到 FixMatch 训练时间的1/5就可以达到最终精度.

课程学习(Curriculum Learning)

根据样本的难易程度, 给不同难度的训练样本分配不同的权重. 初始阶段, 给简单样本的权重最高, 随着训练过程的持续, 较难样本的权重将会逐渐被调高. 将权重动态分配的过程称之为课程(Curriculum), 课程初始阶段简易样本居多, 课程末尾阶段样本难度增加, 即"先易后难".

针对不同的实际问题可以设置不同的样本难易程度评价标准. 例如对于一个原始样本, 对其进行强扰动后, 样本的就由简单变向复杂.

课程伪标签(Curriculum Pseudo Labeling, CPL)

根据学习状态来动态确定阈值并非易事. 最理想的方法是计算每个类的评估准确度并使用它们来缩放阈值:
τ t ( c ) = a t ( c ) ⋅ τ (1) \tau_t(c)=a_t(c) \cdot\tau \tag{1} τt(c)=at(c)τ(1)
其中 τ t ( c ) \tau_t(c) τt(c) t t t 时刻 c c c 类别的灵活阈值, a t ( c ) a_t(c) at(c) 是相应的评估精度. 由于不能在模型学习过程中使用评估集, 因此必须从训练集中分离一个额外的验证集来进行准确性评估. 但是在 SSL 中, 标记数据原本就十分稀缺, 不能再剥离一部分出去. 其次, 为了在训练过程中动态调整阈值, 必须连续在每个时刻 t t t 进行准确度评估, 这将大大减慢训练速度.

为解决上述问题, CPL 使用另一种方法来估计学习状态, 它不引入额外的推理过程, 也不需要额外的验证集. 其关键假设是, 通过预测属于该类且高于阈值的样本数量来反映一个类的学习效果, 然后使用它们来调整阈值 τ τ \tau_τ ττ. 如下图所示:
[半监督学习] FlexMatch: Boosting Semi-Supervised Learning with Curriculum Pseudo Labeling_第1张图片
如果一个类具有较少样本且其预测置信度达到阈值, 则称其具有较大的学习难度或较差的学习状态:
σ t ( c ) = ∑ n = 1 N 1 ( max ⁡ ( p m , t ( y ∣ u n ) ) > τ ) ⋅ 1 ( arg max ⁡ ( p m , t ( y ∣ u n ) ) = c ) (2) \sigma_t(c)=\sum_{n=1}^N \mathbb{1}(\max(p_{m,t}(y\vert u_n))>\tau) \cdot \mathbb{1}(\argmax(p_{m,t}(y\vert u_n))=c) \tag{2} σt(c)=n=1N1(max(pm,t(yun))>τ)1(argmax(pm,t(yun))=c)(2)
其中 σ t ( c ) \sigma_t(c) σt(c) 时是在所有样本中对高于固定阈值且属于类别 c c c 的样本的数目, 反映了类 c c c t t t 时刻的学习效果. p m , t ( y ∣ u n ) p_{m,t}(y\vert u_n) pm,t(yun) 是模型在 t t t 时刻对未标记数据 u n u_n un 的预测, N N N 是未标记数据的总数. 当未标记数据集是平衡的(即属于不同类别的未标记数据的数量相等或接近)时, 较大的 σ t ( c ) \sigma_t(c) σt(c) 表示更好的学习效果. 通过对 σ t ( c ) \sigma_t(c) σt(c) 应用以下归一化使其范围在 0 到 1 之间, 然后可以使用它来缩放固定阈值 τ \tau τ:
β t ( c ) = σ t ( c ) max ⁡ c σ t (3) \beta_t(c)=\frac{\sigma_t(c)}{\underset{c}{\max}\sigma_t} \tag{3} βt(c)=cmaxσtσt(c)(3)
τ t ( c ) = β t ( c ) ⋅ τ (4) \tau_t(c)=\beta_t(c) \cdot \tau \tag{4} τt(c)=βt(c)τ(4)
随着学习的进行, 学习状态良好的类的阈值会提高, 以选择性地提取更高质量的样本. 最终, 当所有类都达到可靠的准确度时, 阈值都将接近 τ \tau τ. 不过阈值并不总是增长态, 如果未标记的数据在后面的迭代中被分类到不同的类别, 阈值也可能会降低. 这个新阈值用于计算 FlexMatch 中的无监督损失, 可以表示为:
L u , t = 1 μ B ∑ b = 1 μ B 1 ( max ⁡ ( q b ) ≥ τ t ) H ( q ^ b , p m ( y ∣ A ( u b ) ) ) (5) \mathcal{L}_{u,t}=\frac{1}{\mu B} \sum_{b=1}^{\mu B} \mathbb{1}(\max(q_b)\geq \tau_t) \mathrm{H}(\hat{q}_b,p_m(y\vert \mathcal{A}(u_b))) \tag{5} Lu,t=μB1b=1μB1(max(qb)τt)H(q^b,pm(yA(ub)))(5)
其中 q b = p m ( y ∣ α ( u b ) ) q_b=p_m(y\vert \alpha(u_b)) qb=pm(yα(ub)), 这份损失的形式结构与 FixMatch 基本一致. 最后, FlexMatch 中的损失表示为有监督和无监督损失的加权组合:
L t = L s + λ L u , t (6) \mathcal{L}_t=\mathcal{L}_s+\lambda\mathcal{L}_{u,t} \tag{6} Lt=Ls+λLu,t(6)
其中 L s \mathcal{L}_s Ls 为有监督损失:
L s = 1 B ∑ b = 1 B H ( y b , p m ( y ∣ α ( x b ) ) ) (7) \mathcal{L}_{s}=\frac{1}{B} \sum_{b=1}^{B}\mathrm{H}(y_b,p_m(y\vert \alpha(x_b))) \tag{7} Ls=B1b=1BH(yb,pm(yα(xb)))(7)

其他

为避免早阶段训练可能出现的盲目预测, 将式(3)改写为:
β t ( c ) = σ t ( c ) max ⁡ { max ⁡ c σ t , N − ∑ c σ t } (8) \beta_t(c)=\frac{\sigma_t(c)}{\max \{ \underset{c}{\max}\sigma_t,N-\underset{c}{\sum}\sigma_t \}\tag{8}} βt(c)=max{cmaxσt,Ncσt}σt(c)(8)
这确保了在训练开始时, 所有估计的学习效果从 0 逐渐上升, 直到未使用的未标记数据的数量 N − ∑ c σ t N-\underset{c}{\sum}\sigma_t Ncσt 不再占主导地位.

同时, 还提出一个非线性映射函数 M \mathcal{M} M, 当 β t ( c ) \beta_t(c) βt(c) 均匀地从 0 到 1 范围内变化时, 使阈值具有非线性的增加曲线:
τ t ( c ) = M ( β t ( c ) ) ⋅ τ (9) \tau_t(c)=\mathcal{M}(\beta_t(c)) \cdot \tau \tag{9} τt(c)=M(βt(c))τ(9)
显然, 如果 M \mathcal{M} M 为恒等函数时, 式(9)与式(4)相同. 并且映射函数是单调递增的, 最大值不大于 1 / τ 1/\tau 1/τ. 在文献中, 选择凸函数 M ( x ) = x 2 − x \mathcal{M}(x) = \frac{x}{2−x} M(x)=2xx 作为映射函数.

FlexMatch 完整算法如下:
[半监督学习] FlexMatch: Boosting Semi-Supervised Learning with Curriculum Pseudo Labeling_第2张图片
代码地址: https://github.com/TorchSSL/TorchSSL

你可能感兴趣的:(论文,机器学习,深度学习,人工智能)