浅析 Semi-Supervised Learning 中的 consistency 问题

浅析 Semi-Supervised Learning 中的 Consistency 问题

    • 传统半监督学习简述:
    • 现有半监督学习的问题 —— Individual Consistency
    • 实现方法
    • 总结

传统半监督学习简述:

区别于全监督学习,半监督学习针对训练集标记不完整的情况:仅仅部分数据具有标签,然而大量数据是没有标签的。因此,目前半监督学习的关键问题在于如何充分地挖掘没有标签数据的价值。主流的半监督学习方法有下面几种:

  1. Self-Training 方法。这是一种很直观的思路:既然大量数据是没有标签的,那么能否对这些数据生成一些伪标签 Pseuo Labels,对这些伪标签数据的训练从而利用原始的无标签数据。
  2. Adversarial-Learning-based 方法。这类方法基于一种假设,无标签数据通常具有和有标签数据在某种程度上类似的潜在标签。所以很自然地,可以采用 GAN 的图像模拟思路来进行对 Labeled Data 进行类似于 Unlabeled Data 的数据增强,进而利用 Unlabeled Data 的潜在知识。
  3. Consistency-based 方法。 这类方法的核心思路在于 consistency loss,对于进过扰动的 unlabeled data,模型应该对其做出一致性的预测 —— 可以理解成一种利用 unlabeled data 进行网络正则化的方法。 其中经典的算法有 Π-model,Temporal Ensembling 和 Mean Teacher。这些内容具体可以参见 飘入东湖的鱼的知乎专栏。本篇博客的后续讨论都是基于 Mean Teacher 模型上的。

现有半监督学习的问题 —— Individual Consistency

目前大多数半监督学习方法都是基于 consistency-enforcing strategy,利用无标签数据对网络进行正则化,要求预测结果对于输入扰动和网络参数扰动具有一致性。具体来说,给定一个输入样本,对其进行一定程度的扰动 (如添加 Gaussian noise),使得网络对于这些样本具有相似的预测结果。
  这类方法的局限性在于 没有考虑样本和样本之间的关系 —— 这些关系能够有助于从无标签的样本中提取语义信息。如下图所示, 传统半监督学习考虑 individual consistency,将每个样本当成独立的个体考虑,仅仅考虑它和对应扰动之后的样本之间的对应关系。除此之外,我们能否进一步考虑样本之间的关系一致性 (Relation Consistency)在添加扰动之后其 relation consistency 也应该保持 —— 最小化,从而确保 high-level semantic information 也能够被学习到,进而确保学习的鲁棒性和高判别性。
浅析 Semi-Supervised Learning 中的 consistency 问题_第1张图片

实现方法

浅析 Semi-Supervised Learning 中的 consistency 问题_第2张图片

  • 上图的骨架结构是传统的 mean-teacher 框架,其中包括对于 student model 的有监督损失函数 L s L_s Ls (cross entropy loss),和上述提到的 individual consistency loss L c L_c Lc (这里采用的是 mse loss)
  • L s r c L_{src} Lsrc (Sample Relation Consistency)。这里考虑一个 mini-batch 内样本的关系一致性,简单来说就是计算在全连接层之前的 feature map 之间的 similarity。给定 batch size = B B B,即可得到尺寸为 B × B B \times B B×B 的 similarity matrix。对于 student model 和 teacher model 得到的 similarity matrix 计算其之间的差异,作为 L s r c L_{src} Lsrc。进而优化 L s r c L_{src} Lsrc 即可达到对于样本关系一致性的约束。

总结

  • 从方法核心的角度来看,这个方法很类似于对于 feature map 的一致性约束,只不过这里是先对 feature map 计算相似性,然后再对相似性做了一致性的约束。所以文章后续也有讨论,这样通过约束相似性的方式是优于直接约束 feature map。

你可能感兴趣的:(机器学习,深度学习,人工智能,深度学习,机器学习)