[半监督学习] FeatMatch: Feature-Based Augmentation for Semi-Supervised Learning

一些先进的半监督学习方法使用基于图像的转换增强和一致性正则化的组合策略. 在FeatMatch 中, 提出了一种新颖的基于学习特征的细化和增强方法, 该方法可产生各种复杂的转换集. 重要的是, 这些转换使用了通过聚类提取的类内和跨类原型表示中的信息. 这些转换与传统的基于图像的增强相结合, 被用作基于一致性的正则化损失的一部分.

论文地址: FeatMatch: Feature-Based Augmentation for Semi-Supervised Learning
代码地址: https://github.com/GT-RIPL/FeatMatch
会议: ECCV 2020
任务: 分类

FeatMatch 中提出: 通过从其他图像的特征中提取的代表性原型的 soft-attention 来学习细化和增强输入图像特征. 传统的基于图像的数据增强与基于特征的数据增强对比如下图所示:

[半监督学习] FeatMatch: Feature-Based Augmentation for Semi-Supervised Learning_第1张图片

基于特征的数据增强(Feature-Based Augmentation)

如下表所示, 基于特征的数据增强在 FeatMatch 之前还未有人提出, 更多的方法是一些基于图像的通用增强, 以及其他模型中所用到的集成方法等.
[半监督学习] FeatMatch: Feature-Based Augmentation for Semi-Supervised Learning_第2张图片
基于图像的增强已被证明是一种为基于一致性的 SSL 方法, 其生成图像不同视图. 然而, 传统的基于图像的增强存在以下两个限制:

  • 在图像空间中操作, 限制了图像内对纹理或几何的可能转换.
  • 在单个实例中操作, 无法使用其他实例的知识, 无论是在同一类别的内部还是外部.

一些使用 Mixup 的算法仅部分解决了第二个限制, 因为 mixup 仅在两个实例之间运行, 如 ICT MixMatch, ReMixMatch. 另一方面, Manifold Mixup 通过在特征空间中执行 Mixup 来接近第一个限制, 但仅限于两个样本的简单凸组合.

为了同时解决这两个限制, 提出一种新方法, 可以在抽象特征空间而不是图像空间中细化和增强图像特征. 为了有效地利用其他类的知识, 通过在特征空间中执行聚类来将每个类的信息浓缩成一个原型集合. 然后通过从所有类的原型传播的信息来细化和增强图像特征.

原型选择(Prototype Selection)

在特征空间中使用 K-Means 聚类来提取 p k p_k pk 个聚类作为每个类的原型集合. 但是, 这存在两个技术挑战:

  • 在 SSL 设置中, 大多数图像为未标记状态.
  • 即使所有标签都可用, 在运行 K K K-Means 之前从整个数据集中提取所有图像的特征仍然计算量很大.

为了解决这些问题, 在训练循环的每次迭代中存储网络已经生成的特征 f x i f_{xi} fxi 和伪标签 y ^ i \hat{y}_i y^i. K K K-Means 在每个 epoch 都进行原型提取, 最后, 特征细化和增强模块在训练循环中使用新提取的原型更新现有的原型. 基本过程如下图所示:

[半监督学习] FeatMatch: Feature-Based Augmentation for Semi-Supervised Learning_第3张图片

特征增强(Feature Augmentation)

选择出的新的原型集合后, 通过 soft-attention 对原型集进行特征细化和增强. 增强模块如下图所示:
[半监督学习] FeatMatch: Feature-Based Augmentation for Semi-Supervised Learning_第4张图片
首先通过学习函数 ϕ e \phi_e ϕe 将特征 f x f_x fx 和第 i i i 个原型特征 f p , i f_{p,i} fp,i 投影到嵌入空间, 分别为 e x = ϕ e ( f x ) e_x=\phi_e(f_x) ex=ϕe(fx) e p , i = ϕ e ( f p , i ) e_{p,i}=\phi_e(f_{p,i}) ep,i=ϕe(fp,i). 计算 e x e_x ex e p , i e_{p,i} ep,i 之间的注意力权重 w i w_i wi:
w i = s o f t m a x ( e x T e p , i ) (1) w_i= \mathrm{softmax}(e_x^\mathrm{T} e_{p,i}) \tag{1} wi=softmax(exTep,i)(1)
其中 softmax 进行标准化点积相似度操作. 然后, 特征细化和增强的信息可以表示为由注意力权重加权的原型特征之和:
f a = r e l u ( ϕ a ( [ e x , ∑ i w i e p , i ] ) ) (2) f_a=\mathrm{relu}(\phi_a([e_x,\sum_iw_ie_{p,i}])) \tag{2} fa=relu(ϕa([ex,iwiep,i]))(2)
其中 ϕ a \phi_a ϕa 为学习函数, [ ⋅ , ⋅ ] [·,·] [,]是沿特征维度的串联操作. 最后, 通过剩余连接对输入图像特征 f x f_x fx 进行优化:
g x = r e l u ( f x + ϕ r ( f a ) ) (3) g_x=\mathrm{relu}(f_x+\phi_r(f_a)) \tag{3} gx=relu(fx+ϕr(fa))(3)
其中 g x g_x gx 为细化的特征, ϕ r \phi_r ϕr 为可学习的函数.

上述注意力机制可以简单地推广到 multi-head attention. 对于同一张图像, 一个 Attention 获得一个表示空间, 如果多个 Attention, 则可以获得多个不同的表示空间. 在实践中, 为了获得更好的效果, 使用 multi-head attention. 为了简单起见, 将上述特征细化和增强过程 A u g F ( ⋅ ) AugF(·) AugF() 定义为 g x = A u g F ( f x ) g_x=AugF(f_x) gx=AugF(fx).

损失函数(Loss Function)

通过学习基于特征的增强, 可以在特征 f x f_x fx 和增强特征 g x g_x gx 之间应用一致性损失. 给定一个分类器 p = C l f ( f ) p=Clf(f) p=Clf(f), 文中发现 A u g F AugF AugF 能够细化输入特征以获得更好的表示, 从而生成更好的伪标签. 因此, 通过 p g = C l f ( g x ) p_g=Clf(g_x) pg=Clf(gx) 计算 g x g_x gx 上的伪标签 p g p_g pg. 基于特征的一致性损失可以计算为: L c o n = H ( p g , C l f ( f x ) ) \mathcal{L}_{con}=\mathcal{H}(p_g,Clf(f_x)) Lcon=H(pg,Clf(fx)).

受 ReMixMatch 的启发, 生成了一个弱增强图像 x x x 及其强增强副本 x ^ \hat{x} x^. 在经过基于特征的增强和细化的弱增强图像 x x x 上计算伪标签, 如 p g = C l f ( A u g F ( E n c ( x ) ) ) p_g=Clf(AugF(Enc(x))) pg=Clf(AugF(Enc(x))). 然后, 可以在强增强数据 x ^ \hat{x} x^ 上计算两个一致性损失, 一个应用了 A u g F AugF AugF, 另一个没有:
L c o n − g = H ( p g , C l f ( A u g F ( E n c ( x ^ ) ) ) ) (4) \mathcal{L}_{con-g}=\mathcal{H}(p_g,Clf(AugF(Enc(\hat{x})))) \tag{4} Lcong=H(pg,Clf(AugF(Enc(x^))))(4)
L c o n − f = H ( p g , C l f ( E n c ( x ^ ) ) ) (5) \mathcal{L}_{con-f}=\mathcal{H}(p_g,Clf(Enc(\hat{x}))) \tag{5} Lconf=H(pg,Clf(Enc(x^)))(5)
关于带标签 y y y 的数据 x x x, 其损失可表示为:
L c l f = H ( y , C l f ( A u g F ( E n c ( x ) ) ) ) (6) \mathcal{L}_{clf}=\mathcal{H}(y,Clf(AugF(Enc(x)))) \tag{6} Lclf=H(y,Clf(AugF(Enc(x))))(6)
综上, 总损失函数为:
L t o t a l = L c l f + λ g L c o n − g + λ f L c o n − f (7) \mathcal{L}_{total}=\mathcal{L}_{clf}+\lambda_g\mathcal{L}_{con-g}+\lambda_f\mathcal{L}_{con-f} \tag{7} Ltotal=Lclf+λgLcong+λfLconf(7)

代码地址: https://github.com/GT-RIPL/FeatMatch

你可能感兴趣的:(论文,机器学习,深度学习,人工智能)