半监督学习方法的分类:
用给定的SSL方法训练一个深度学习模型,性能测试结果在原始带标签的测试集上。这为对比和消融实验的设计起指导作用,若要评估或者比较SSL的在真实数据中的性能,需注意以下几点。
目前许多研究的假设都是利用聚类假设而进行训练的,这些方法都基于一个概念,即如果给一个无标签样本增加扰动,那么扰动数据的预测和原数据的预测不会有明显的改变,在聚类假设下,具有不同真实标签的数据点应当在低密度区域分隔开,因此,某样本在扰动后的预测结果发生类别变化的可能性也该很小。
更正式地说,通过一致性正则化,我们倾向于对相似数据点给出一致预测的函数 f θ \ f_{\theta} fθ。因此,与其最小化在输入空间的零维数据点上的分类成本,正则化的模型使每个数据点周围的流形上的成本最小化,使决策边界远离未标记的数据点,并平滑数据所在的流形[193]。这意思就说扰动数据和原数据认为是“相似数据”。
对于无标签数据 D u \ D_u Du,一致性正则化的目的是最小化原数据和扰动数据输出之间的距离,距离衡量指标有MSE,KL散度,JS散度。
在上一节中,在保持集群假设的设置中,我们强制执行预测的一致性,以将决策边界推到低密度区域,来避免对来自具有不同类的同一聚类的样本进行分类,这违反了聚类的假设。另一种执行这一点的方法是使网络做出置信度(低熵)预测,对于未标记的数据,而不管预测的类如何,阻止决策边界通过数据点附近,否则它将被迫产生低置信度的预测。这是通过添加一个损失项来实现的,以最小化预测函数 f θ ( x ) \ f_{\theta}(x) fθ(x)的熵。
在我看来就是在训练的时候同时输入有标签数据和无标签数据对模型进行训练,两部分损失,一部分就是常规的交叉熵损失,一部分还是一致性准则,扰动样本和原样本的输出距离应该近。
代理标签方法是一类SSL算法,它在未标记的数据上生成代理标签,使用预测函数本身或它的某些变体,而不需要任何监督。这些代理标签与标记数据一起用作目标,提供一些额外的训练信息,即使产生的标签往往是嘈杂的或弱的,并不能反映地面的真相。
这些方法主要可以分为2类(这似乎就是伪标签的思路):
在自训练中,1)少量的标签数据 D l \ D_l Dl先被用来训练模型,再用这个训练好的模型来给未标注数据 D u \ D_u Du指派伪标签。因此,给定一个未标注数据 x \ x x,用此模型先预测其在各类上的概率分布,然后再成对添加数据和伪标签 ( x , arg max f θ ( x ) ) \ (x,\arg\max{f_\theta(x)}) (x,argmaxfθ(x))到训练集中,这里有个前提就是最大概率值应该大于某个阈值 τ \ \tau τ。
第二阶段是使用未标注数据集 D u \ D_u Du的增强数据训练模型,并且利用这个模型又反过来标注未标注数据集 D u \ D_u Du,这个过程需要不断的重复直至模型无法再标注出高置信度的样本。
其他的启发式方法可以用来决定保留哪些代理标签的样本,例如2)使用相对置信度而不是绝对置信度,其中,在一个epoch中具有最高置信度的样本要进行排序,选出其中前n个具有伪标签的样本到训练集 D l \ D_l Dl中。自训练(Self-training)与熵最小化(Entropy Minimization)相似。在这两种情况下,网络都训练输出高置信度的预测。
这种方法的主要缺点是模型无法纠正自己的错误,任何偏差和错误的分类都可以迅速放大,从而导致在未标记的数据点上产生置信度高但错误的代理标签。(这个确实存在,我在实验中似乎产生了大量的错误伪标签)
Billion-scale semi-supervised learning for image classifification
Self-training with Noisy Student improves ImageNet classifification
先举着两个例子,知道自训练是什么东西就行,其他的改进方法就需要继续阅读文献找。
多视图训练在实际应用中是很常见的,视图可以是来自原数据的不同观测手段比如:图像的颜色信息,纹理信息。多视图训练之目的在于学习一个不同的预测函数,这个预测函数是对原数据的 x \ x x的某个视图 v i ( x ) \ v_i(x) vi(x)进行建模(相当于说一个视图一个函数,多个视图就有多个预测函数需要训练)。然后对所有的预测函数进行联合优化,最终增强模型的泛化能力。理想情况下,各视图相互补充,以便生成的模型可以协作以提高彼此的性能。
协同训练要求原数据点 x \ x x可以用两种条件独立的视图进行表示,并且两个视图各自可以充分地用于训练一个好的模型。在标记数据集 D l \ D_l Dl的特定视图上进行训练得到两个预测函数 f θ 1 f_{\theta_1} fθ1和 f θ 2 \ f_{\theta_2} fθ2后,开始标注代理标签的进程。在每次迭代中,若 f θ j \ f_{\theta_j} fθj对某未标注数据的预测输出所对应的概率值高于了某阈值 τ \ \tau τ,那么这个未标注数据则被加入到 f θ i \ f_{\theta_i} fθi的的训练集中。所以说,其中一个模型是拿来提供标注的,另外一个是用这个伪标签数据监督训练的。
像这种训练策略一般是用在多模态数据上,比如RGB-D数据,还有图像-文本数据,他们各自就是不同的视图,所以可以使用协同训练的策略。
但是实际上图像分类这些任务只有一种数据视图,所以在实践当中是用两种不同的分类器或者不同的参数配置。两个视图 v 1 ( x ) v_1(x) v1(x)和 v 2 ( x ) \ v_2(x) v2(x)可以通过注入噪声和应用不同的数据增强来生成。如对抗性扰动生成不同的视图:Deep co-training for semi-supervised image recognition.
我觉得这部分叫“三个训练”比较合适,这个训练策略的思路是应用三个不同的模型,这三个模型首先都要在有标记训练集 D l \ D_l Dl上训练。然后再用这三个模型对未标记数据集进行预测。生成伪标签的策略是:如果预测结果中有两个保持一致,那么这个数据就加入到剩下那个模型的训练集中。如果没有任何数据点被添加到任何模型的训练集上,那么训练就会停止。所以这看起来是造了三个数据集。这个方法的缺陷是计算占用的资源会特别大。
多视图训练就差不多了解到这里。总结来看,要么就是数据增强或者噪声形成多组增强数据,用阈值造伪标签,要么就是多整几个模型平行训练造伪标签。
目前出现的工作多为统一型方法,其目的在于将当前主要的SSL方法(前面那些思路)统一到一个框架中去,以实现更好的性能。 以下是Match系列:MixMatch, ReMixMatch, FixMatch.
数据增强:对一个batch内有标签数据进行增强,对无标签数据进行 K \ K K次叠加增强,生成 K \ K K个无标签数据的增强样本序列。
标签猜测:给无标签数据造伪标签。还是用在有标签数据集上训练的网络进行预测,只不过这里的预测是对 K \ K K个无标签数据增强的样本序列进行 K \ K K次预测。这 K \ K K个预测肯定都是一个关于各个类别概率的矢量,然后再把这些取平均,得到一个平均类别概率矢量,通过这个平均预测得到伪标签(从后面的叙述来看,这里所谓的标签其实是一个概率分布),所以这个伪标签的值就是所有 K \ K K个增强样本的伪标签。
锐化(Sharpening):为了让模型可产生更高置信度的预测并且最小化输出分布的熵,第二步产生的代理标签(在C个类上的概率分布),需要用类别分布的temperature进行锐化调整。
( y ^ ) k = ( y ^ ) k 1 T ∑ k = 1 C ( y ^ ) k 1 T (\hat{y})_k=\frac{(\hat{y})_k^{\frac{1}{T}}}{\sum_{k=1}^{C}(\hat{y})_k^{\frac{1}{T}}} (y^)k=∑k=1C(y^)kT1(y^)kT1
这里面的 k \ k k是对应类别概率的下标,锐化操作都是用在未标记数据集中。这个 T \ T T是放在了每个概率的右上角的,所以这是一个非线性变换,相当于自变量是 1 / T \ 1/T 1/T的对数函数了。
MixUp:上述的操作最终会形成两个新的增强batch。其中一个batch是有标签样本的增强 L \ L L,另外一个batch是无标签样本及其锐化后的概率分布标签 U \ U U。需要注意的是无标签batch里面的样本是有 K \ K K个增强样本的,所以是原本体量的 K \ K K倍,并且无标签样本集也用这些增强样本替换掉了。最后一步是,混合这两个batch的中的样本,形成一个新的batch W = S h u f f l e ( C o n c a t ( L , U ) ) \ W=Shuffle(Concat(L,U)) W=Shuffle(Concat(L,U))。可以看到还用上了随机shuffle。在这之后还要再切成两截,第一截 W 1 \ W_1 W1和 L \ L L一样长,第二截 W 2 \ W_2 W2和 U \ U U一样长。
然后再使用mixup函数,这是一种数据增强的手段,一开始我还以为我看错了这两个公式,查了原文确实是这样mixup的,当然,数据和标签是同步mixup的,可以查mixup的原文。
L ′ = M i x U p ( L , W 1 ) L'=MixUp(L,W_1) L′=MixUp(L,W1)
U ′ = M i x U p ( U , W 2 ) U'=MixUp(U,W_2) U′=MixUp(U,W2)
构建好了连个数据集后,对于 L ′ \ L' L′数据集,使用CE损失进行监督训练损失,对于 U ′ \ U' U′数据集则使用一致性损失(MSE)。因而损失占两部分:
l o s s = l o s s s + w ⋅ l o s s u loss=loss_s+w\cdot loss_u loss=losss+w⋅lossu
(生成式模型略)