改进了 MixMatch 半监督学习算法, 引入了两种新技术: 分布对齐(Distribution Alignment)和强增广锚点 (Augmentation Anchoring). 分布对齐鼓励未标记数据预测的边际分布接近真实标签的边际分布. 强增广锚点将输入的多个强增强版本输入到模型中, 并鼓励每个输出接近同一输入的弱增强版本的预测.
论文地址: ReMixMatch: Semi-Supervised Learning with Distribution Alignment and Augmentation Anchoring
代码地址: https://github.com/google-research/remixmatch
会议: ICLR 2020
任务: 分类
在 MixMatch 的基础上, 原作者自己提出了改进版本: ReMixMatch. 关于 MixMatch 的介绍, 可以参考上一篇文章: https://blog.csdn.net/by6671715/article/details/122766432?spm=1001.2014.3001.5501.
ReMixMatch 与 MixMatch 主要的区别在于, ReMixMatch 改进了两个地方: 分布对齐, 强增广锚点.
Distribution Alignment 强制要求未标记数据的预测集合与提供的标记数据的分布相匹配, 即根据有标签数据的标签分布, 对无标签的"猜测"标签进行对齐.
Distribution Alignment 可描述如下: 在训练过程中, 我们保持模型对未标记数据的预测的平均值, 称之为 p ~ ( y ) \tilde{p}(y) p~(y), 即移动平均分布. 给定模型对未标记示例 u u u 的预测 q = p m o d e l ( y ∣ u ; θ ) q=p_{model}(y \vert u; \theta) q=pmodel(y∣u;θ), 我们通过比率 p ( y ) / p ~ ( y ) p(y)/ \tilde{p}(y) p(y)/p~(y) 对 q q q 进行缩放, 其中 p ( y ) p(y) p(y) 为标签数据的标签分布, 然后重新规范化结果以形成有效的概率分布: q ~ = N o r m a l i z e ( q × p ( y ) / p ~ ( y ) ) \tilde{q}=\mathrm{Normalize}(q \times p(y)/ \tilde{p}(y)) q~=Normalize(q×p(y)/p~(y)), 其中 N o r m a l i z e ( x ) i = x i / ∑ j x j \mathrm{Normalize}(x)_i=x_i/\sum_j x_j Normalize(x)i=xi/∑jxj. 然后, 我们使用 q ~ \tilde{q} q~ 作为 u u u 的标签猜测, 并像 MixMatch 一样进行锐化和其他处理.
在 MixMatch 中, 通过对无标签数据做 K K K 次增强后取平均得到猜测标签, 再与通过对标签数据做 1 1 1 次增强后的结果做一致性正则.
在 ReMixMatch 中, 对同一无标签数据使用弱增强和强增强, 前者直接指定为猜测标签, 后者再与前者做一致性正则.
在少标签 SSL 情况下, AutoAugment, RandAugment 方法存在一些问题, 因此, 开发了 CTAugment, 一种设计高性能增强策略的替代方法. 与 RandAugment 一样, CTAugment 还对变换进行统一随机采样, CTAugment 不需要在有监督的代理任务上进行优化, 并且没有敏感的超参数, 因此可以直接将其包含在半监督模型中, 以在半监督学习中进行更积极的数据增强实验.
ReMixMatch 算法描述如下:
ReMixMatch 算法同时还输出 U ^ 1 \hat{\mathcal{U}}_1 U^1, 它由每个未标记图像的大幅增强版本及其猜测标签组成. U ^ 1 \hat{\mathcal{U}}_1 U^1 还用于两个额外的损失项, 除了提高稳定性外, 还提供了性能的轻微提升, 损失函数如下:
∑ x , p ∈ X ′ H ( p , p m o d e l ( y ∣ x ; θ ) ) + λ U ∑ u , q ∈ U ′ H ( q , p m o d e l ( y ∣ u ; θ ) ) + λ U ^ 1 ′ ∑ u , q ∈ U 1 ′ H ( q , p m o d e l ( y ∣ u ; θ ) ) + λ r ∑ u ∈ U 1 ′ H ( r , p m o d e l ( r ∣ R o t a t e ( u , r ) ; θ ) ) \sum_{x,p\in\mathcal{X}'} \mathrm{H}(p,p_{model}(y\vert x;\theta))+\lambda_{\mathcal{U}}\sum_{u,q\in\mathcal{U}'} \mathrm{H}(q,p_{model}(y\vert u;\theta))+\lambda_{\mathcal{\hat{U}_1'}}\sum_{u,q\in\mathcal{U}_1'} \mathrm{H}(q,p_{model}(y\vert u;\theta))+\lambda_{r}\sum_{u\in\mathcal{U}_1'} \mathrm{H}(r,p_{model}(r\vert \mathrm{Rotate}(u,r);\theta)) x,p∈X′∑H(p,pmodel(y∣x;θ))+λUu,q∈U′∑H(q,pmodel(y∣u;θ))+λU^1′u,q∈U1′∑H(q,pmodel(y∣u;θ))+λru∈U1′∑H(r,pmodel(r∣Rotate(u,r);θ))
将自监督学习(Self-supervised learning)的思想应用于 SSL 可以产生强大的性能. 所以通过旋转每个图像 u ∈ U 1 ′ ^ u \in \hat{\mathcal{U}_1'} u∈U1′^ , R o t a t e ( u , r ) \mathrm{Rotate}(u,r) Rotate(u,r) 来整合这个想法, 其中从 r r r 均匀地采样旋转角度 r ∼ { 0 , 90 , 180 , 270 } r \sim \{0,90,180,270\} r∼{0,90,180,270}, 然后要求模型作为四类分类问题预测旋转量.