Google research出品
论文:https://arxiv.org/abs/2001.07685
官方代码:https://github.com/google-research/fixmatch
对于无标签的样本,FixMatch
:
FixMatch的核心:
主要贡献:(A+B=C的操作)
主要结合了pseudo label 和 consistency regularization两种数据增强方式实现方法。
- 首先,利用一张无标签样本,分别进行:
- “弱”增强(翻转、缩放)
- “强”增强(CutOut、CTAugment、RandAugment),
- 然后,通过model得到预测标签
- 并,通过标准交叉熵损失计算损失
- 注意:
- 上述**“弱“增强方式预测过程,需要设定一个阈值**
- 大于阈值的才计算loss,小于的就不计算
- 相当于,在前期训练阶段中,无标签样本损失可能一直是为0的
首先,将未标记图像的弱增强版本(顶部)输入模型中以获得其预测(红色框)。
当模型为高于阈值(虚线)的任何类别分配概率时,预测将转换为单伪标记。
然后,针对同一张图片的增强版本(底部)计算模型的预测。
训练该模型,使其通过标准的交叉熵损失,使其在强增强版本上的预测与伪标记匹配。
即使在无标签的样本被注入噪声之后,分类器也应该为其输出相同的类分布概率。即强制一个无标签的样本,应该被分类为与自身的增强 相同的分类
一致性正则化
利用未标记的数据,基于这样的假设,即当输入受到扰动的图像时,模型应该输出相似的预测。
模型通过标准监督分类损失,和以下损失函数对未标记数据进行训练:
∑ b = 1 μ B ∥ p m ( y ∣ α ( u b ) ) − p m ( y ∣ α ( u b ) ) ∥ 2 2 \sum _ { b = 1 } ^ { \mu B } \| p _ { m } ( y | \alpha ( u _ { b } ) ) - p _ { m } ( y | \alpha ( u _ { b } ) ) \| _ { 2 } ^ { 2 } ∑b=1μB∥pm(y∣α(ub))−pm(y∣α(ub))∥22
伪标记
1 μ B ∑ b = 1 μ B 1 ( max ( q b ) ≥ τ ) H ( q ^ b , q b ) \frac { 1 } { \mu B } \sum _ { b = 1 } ^ { \mu B } 1 ( \max ( q _ { b } ) \geq \tau ) H ( \hat { q } _ { b } , q _ { b } ) μB1∑b=1μB1(max(qb)≥τ)H(q^b,qb)
由两个交叉熵损失项组成:
损失函数由两个交叉熵损失项组成:一个监督损失项 l S l _ { S } lS, 一个无监督损失项 l U l _ { U } lU
X = { ( x b , p b ) : b ∈ ( 1 , … , B ) } X = \{ ( x _ { b } , p _ { b } ) : b \in ( 1 , \ldots , B ) \} X={(xb,pb):b∈(1,…,B)}
U = { u b : b ∈ ( 1 , … , μ B ) } U = \{ u _ { b } : b \in ( 1 , \ldots , \mu B ) \} U={ub:b∈(1,…,μB)} 表示一个Batch的未标记样本
p m ( y ∣ x ) p _ { m } ( y | x ) pm(y∣x) 表示模型对输入 x x x 预测的类别分布
将两个概率分布, p p p 和 q q q 之间的交叉熵表示为 H ( q , p ) H ( q , p ) H(q,p)
两种类型的增强: 强增强 A ( ⋅ ) A ( \cdot ) A(⋅) ; 弱增强表示为 α ( ⋅ ) α( \cdot ) α(⋅)
====>
对于有标签样本,FixMatch均采用弱增强,其损失函数为:
ℓ s = 1 B ∑ b = 1 B H ( p b , p m ( y ∣ α ( x b ) ) ) \ell _ { s } = \frac { 1 } { B } \sum _ { b = 1 } ^ { B } H ( p _ { b } , p _ { m } ( y | \alpha ( x _ { b } ) ) ) ℓs=B1∑b=1BH(pb,pm(y∣α(xb)))
"""
将有 / 无标签的 batch 拼接后输入模型
:inputs_x: 有标签数据
:inputs_u_w: 无标签数据的弱增强
:inputs_u_s: 无标签数据的强增强
"""
inputs = interleave(
paddle.concat((inputs_x, inputs_u_w, inputs_u_s)), 2 * args.mu + 1)
# 模型输出(全连接层分类预测)
logits = model(inputs)
logits = de_interleave(logits, 2 * args.mu + 1)
# 有标签数据的模型输出
logits_x = logits[:batch_size]
# 有标签预测的交叉熵损失
Lx = F.cross_entropy(logits_x, targets_x, reduction='mean')
对于无标签样本,FixMatch为每个无标签样本预测一个伪标签,然后用于计算交叉熵损失。
为了获得一个伪标签,首先输入无标签图像的弱增强版本 ,并得到模型预测的类概率分布: q b = p m ( y ∣ α ( μ b ) ) q _ { b } = p _ { m } ( y | \alpha ( \mu _ { b } ) ) qb=pm(y∣α(μb))
然后,使用 q ^ b = argmax ( q b ) \hat { q } _ { b } = \operatorname { argmax } ( q _ { b } ) q^b=argmax(qb) ,得到硬伪标签;
接着与 μ b \mu _ { b } μb 的强增强版本 得到的模型预测,计算一致性正则损失:
ℓ u = 1 μ B ∑ b = 1 μ B 1 ( max ( q b ) ≥ τ ) H ( q ^ b , p m ( y ∣ A ( u b ) ) ) \ell _ { u } = \frac { 1 } { \mu B } \sum _ { b = 1 } ^ { \mu B } 1 ( \max ( q _ { b } ) \geq \tau ) H ( \hat { q } _ { b } , p _ { m } ( y | A ( u _ { b } ) ) ) ℓu=μB1∑b=1μB1(max(qb)≥τ)H(q^b,pm(y∣A(ub)))
# 弱增强和强增强模型预测
logits_u_w, logits_u_s = logits[batch_size:].chunk(2)
# 对弱增强的模型输出使用 softmax + argmax 得到伪标签 targets_u
pseudo_label = F.softmax(logits_u_w.detach() / args.T, axis=-1)
targets_u = paddle.argmax(pseudo_label, axis=-1) # 利用 argmax 得到硬伪标签
# 通过阈值筛选伪标签
max_probs = paddle.max(pseudo_label, axis=-1)
mask = paddle.greater_equal(
max_probs,
paddle.to_tensor(args.threshold)).astype(paddle.float32)
# 无标签预测的交叉熵损失(一致性损失)
Lu = (F.cross_entropy(logits_u_s, targets_u,
reduction='none') * mask).mean()
# 两个损失加权相加
loss = Lx + args.lambda_u * Lu
FixMatch的总损失
总损失是两个损失函数的加权和:
l s + λ u l u l _ { s } + \lambda _ { u } l _ { u } ls+λulu
FixMatch利用两种增强: “弱”和“强
弱增强。是标准的翻转-移位增强策略
在水平方向上,随机翻转图像,概率为50%
在垂直和水平方向上,随机转换图像,概率最高为12.5%
“强”增强。尝试了两种基于自增强的方法
FixMatch和其他的SSL方法的关键区别在于
伪标签是基于弱增强图像预测的硬伪标签
而对于强增强图像的模型输出的全连接层预测,直接计算损失(不进行 argmax)
这对FixMatch的成功至关重要
UDA和MixMatch中用了sharpen构建软伪标签
sharpen 引入了一个超参数
但 并不是起到筛选伪标签的作用
FixMatch 的消融实验表明
在Mean-Teacher、MixMatch等SSL算法中
FixMatch利用了两种数据增强:“弱”和“强”
弱增强是标准的随机翻转和移位的数据增强策略。
对于弱增强,FIxMatch在有标签数据样本上以50%的概率进行水平翻转图像;
对于强增强,FixMatch与UDA一样
论文还研究了,弱增强和强增强的不同组合对伪标签生成的影响:
由于优化器采用了weight_decay
no_decay = ['bias', 'bn']
scheduler = get_cosine_schedule_with_warmup(args.lr, args.warmup, args.total_steps)
grouped_parameters = [
# 若网络层不包含 bias 或 BatchNorm,则应用 weight_decay
{'params': [p for n, p in model.named_parameters() if not any(
nd in n for nd in no_decay)], 'weight_decay': args.wdecay},
# 反之,则不用 weight_decay
{'params': [p for n, p in model.named_parameters() if any(
nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer = optim.Momentum(learning_rate=scheduler,
momentum=0.9,
parameters=grouped_parameters,
use_nesterov=args.nesterov)
使用余弦学习速率衰减,衰减策略设置为
η cos ( 7 π k 16 K ) \eta \cos ( \frac { 7 \pi k } { 16 K } ) ηcos(16K7πk) ,其中 η 是初始学习率
def get_cosine_schedule_with_warmup(learning_rate, num_warmup_steps,
num_training_steps,
num_cycles=7. / 16.,
last_epoch=-1):
"""
借助 LambdaDecay 实现余弦学习率衰减
"""
def _lr_lambda(current_step):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
no_progress = float(current_step - num_warmup_steps) / \
float(max(1, num_training_steps - num_warmup_steps))
return max(0., math.cos(math.pi * num_cycles * no_progress))
return LambdaDecay(learning_rate=learning_rate,
lr_lambda=_lr_lambda,
last_epoch=last_epoch)
backbone网络架构:默认为 Wide ResNet-28-2
训练的超参数如下:
无标签损失权重 λ
初始学习率 η
优化器 momentum 参数 β ,weight_decay 参数 λ
伪标签阈值 τ
有 / 无标签样本比例 1: 7: μ
batch_size
尽管FixMatch非常简单,但它在各种标准的半监督学习benchmark上都达到了SOTA
为了得到最优超参数,该文章后面对超参做了大量的消融实验
在半监督学习算法日益复杂的发展中,FixMatch以出人意料的简单获得了SOTA性能
论文指出,由于这种简单性,我们能够彻底研究FixMatch是如何发挥作用的