领域泛化(Domain Generalization)是机器学习和计算机视觉中的一个重要概念,它指的是模型能够从一个或多个源领域(source domains)学习到的知识或模式,成功地应用到与训练时未见过的目标领域(target domain)上,即使这些领域之间存在分布差异。简单来说,领域泛化就是希望模型能够“举一反三”,不仅限于在特定数据集或特定环境下表现良好,而是能够跨越不同的环境或数据集依然保持稳定的性能。
领域泛化问题并非空穴来风,而是有很强的现实背景。例如,在特定的医疗应用中,由于进行手术这一操作的昂贵和不可重复性,我们无法收集到足够多的手术数据;在老人日常的跌倒检测问题中,真实的跌倒数据无法通过大量实验来收集,更不必说需要收集所有年龄的老人跌倒数据;在跨数据的行为识别场景中,无法收集到所有位置情况下的传感器数据。这些真实的应用启发我们要构建一个具有强泛化能力的模型以便在不同的应用场景中部署。
以PACS数据集为例介绍领域泛化问题。训练集包含若干来自三个领域的数据:简笔画(sketch)、卡通画(cartoon)、以及艺术画(art painting)。领域泛化要求我们只依赖给定的三个领域数据训练出有强泛化能力的模型,以便在未知的领域“如照片(photos)]上具有好的表现图。
领域泛化(Domain generalization)给定M个训练的源领域数据 S t r a i n = S i ∣ i = 1 , … , M S_{train} = {S^i | i = 1,…,M} Strain=Si∣i=1,…,M,其中第i个领域数据被表示为 S i = { ( x j i , y j i ) } j = 1 n i S^i=\{{(x^i_j,y^i_j)}\}^{n_i}_{j=1} Si={(xji,yji)}j=1ni。这些源领域数据分布各不相同: P X Y i ≠ P X Y j y , 1 < i ≠ j ≤ M P^i_{XY} ≠ P^j_{XY}y,1PXYi=PXYjy,1<i=j≤M。领域泛化的目标是从这M个源领域数据中学习一个具有强泛化能力的预测函数h:X→y,使其在一个未知的测试数据 S t e s t S_{test} Stest(即 S t e s t S_{test} Stest在训练过程中不可访问且 P X Y t e s t ≠ P X Y i P^{test}_{XY} ≠ P^i_{XY} PXYtest=PXYi for i ∈ {1,…,M})上具有最小的误差: m i n h E ( x , y ) ∈ S t e s t [ l ( h ( x ) , y ) ] min_hE_(x,y){\in}S_{test}[l(h(x),y)] minhE(x,y)∈Stest[l(h(x),y)]
其中E和l(,)分别是期望和损失函数。
领域泛化方法主要可以分为以下三个大类。
数据操作(Data manipulation):此大类方法专注于对输入的数据进行操作,以此来辅助学习具有泛化能力的表征。
此大类下主要有两类方法:
1、数据增强(Dataaugmentation):数据增强方法利用增强,领域随机和变化对数据进行一定程度的增强;
2、数据生成:数据生成则生成一些辅助样
表示学习(Representation learning):此大类方法在领域泛化中十分流行。
其包含两大类方法:
1、领域不变特征学习(Domain-invariant representa-tion learning):其中,领域不变特征学习通过核学习,对抗学习,显式特征对齐及不变风险最小化等方式学习泛化表征
2、特征解耦(Feature disentanglement ):特征解耦则试图将特征解耦为领域共享和领域特异的特征。
学习策略(Learning strategy):此大类方法包含一些常用的学习范式:
1、集成学习(Ensemble learning ):集成学习利用集成的思想从诸多模型中学习具有强泛化能力的模型;
2、元学习(Meta-learning):以元学习为基础的方法则利用元学习自发地从构造的任务中学习元知识
3、其他范式:其他学习范式则包括自监督和梯度操作等。
机器学习模型的泛化能力通常依赖于训练数据的数量和多样性。给定一个训练数据集,我们可以对数据进行一系列操作来生成额外的训练数据,以此增强模型的泛化能力。这是最简单直接的领域泛化方法。基于数据操作的领域泛化方法的目的是增加训练数据的数量和多样性。此类方法可以被表示为 m i n h E x , y [ l ( h ( x ) , y ) ] + E x ’ , y [ l ( h ( x ) , y ) ] min_hE_{x,y}[l(h(x),y)]+E_{x’,y}[l(h(x),y)] minhEx,y[l(h(x),y)]+Ex’,y[l(h(x),y)]。
其中 x’=mani(x)表示使用一个特定的函数mani(·)进行的数据操作。根据此函数的不同点,我们将数据操作分为两大类:数据增强(data augmentation ) 和数据生成(data generation )。特别地,数据生成中有一类非常重要的方法:基于Mixup 的方法。
数据增强方法是增强机器学习模型泛化能力的最有效方法之一。经典的增强操作包括对输入数据进行反转、旋转、缩放、裁剪、添加噪声等操作。这些操作已被当下的深度学习方法所广泛使用(本书代码部分的dataloader函数基本都使用了一定的数据增强手段)。不失一般性,这些方法也可以用于领域泛化问题中,此时mani(·)函数则表示特定的数据增强操作。
在常用的数据增强之外,领域随机法(Domain Randomization)是一种数据生成的简单有效手段。其目的是为了在训练过程中,通过有限的训练数据尽可能去模拟复杂多变的新数据、新环境,以此使得模型可以对不同的环境和数据具有强鲁棒性。领域随机法主要通过如下的变换来生成数据(常用于图像数据中):
●训练数据的数量和形状
●训练数据的位置和纹理
●照相机的视角和光照
●环境中的光照强度和位置
●添加到数据中的随机噪声的类型和内容
事实上,领域随机法也是深度学习方法常用的数据增强技巧。在领域泛化问题中,这种方法尤其值得注意。因为领域泛化解决的就是用有限的、不同分布的数据去模拟尽可能丰富的使用场景。
基于Mixup的数据生成方法是一种简单且有效的数据生成手段。Mixup通过对任意两对输入数据进行线性插值来生成新的训练数据。线性插值通过一个Beta分布来控制两对样本的权重。
Mixup的核心思想是将两个随机选取的训练样本进行线性插值,生成一个新的样本用于训练。这种方法通过增加训练数据的多样性,使得模型能够更好地泛化到未见过的数据上。Mixup的具体步骤包括:
1、数据准备:准备已经标注好的训练数据集。
2、随机选取样本:从训练数据集中随机选取两个样本X1和X2,以及它们对应的标签y1和y2。
3、生成插值系数:随机生成一个插值系数α,其取值范围通常在[0,1]之间。这个系数可以通过均匀分布U(0,1)或Beta分布Beta*(α1,α2)来生成,其中α1和α2是Beta分布的超参数。
4、样本插值:使用插值系数α对选取的两个样本进行线性插值,生成新的样本X。数学公式表示为:
X = a X 1 + ( 1 − α ) X 2 X=aX_1+(1−α)X_2 X=aX1+(1−α)X2
5、标签插值:同样地,对应的标签也需要进行线性插值,生成新的标签y。数学公式表示为:
y = α y 1 + ( 1 − α ) y 2 y=αy_1+(1−α)y_2 y=αy1+(1−α)y2
注意,对于分类任务,如果标签是one-hot编码的,这里的插值将产生一个新的概率分布作为标签。
6、模型训练:使用新生成的样本X和标签y来训练深度学习模型。
优点:
缺点:
特征学习一直以来是机器学习的研究重点,同样也是进行领域泛化的重要武器。我们将预测函数h分解为h=fog,其中g为一表示学习函数,f则为分类器。因此,表示学习可以被形式化为 m i n f , g E x , y l ( f ( g ( x ) ) , y ) + λ l r e g min_{f,g}E_{x,y}l(f(g(x)),y)+{\lambda}l_{reg} minf,gEx,yl(f(g(x)),y)+λlreg
其中reg表示特定的正则项。入为可调超参数。许多方法的重点便是如何更好地设计表示学习函数g和其正则项 l r e g l_{reg} lreg
在领域不变成分分析(DICA)中,核方法被用于寻找一个特征变换,使得在这个变换空间中,来自不同领域(或称为域)的数据之间的分布差异最小化。这种变换旨在提取出领域间共通的、不随领域变化而变化的特征,即领域不变成分。
在DICA中,核方法可以通过引入核函数来改进传统的线性变换方法,使其能够处理非线性关系。具体而言,可以使用核技巧将原始数据映射到一个高维特征空间,在这个空间中执行DICA算法,以提取出更加复杂的领域不变成分。这种方法增强了DICA处理非线性数据的能力,提高了其在实际应用中的泛化性能。
特征解耦是一种启发式的特征分解方法,它通常具有明确的物理含义或实验背景,并基于特定任务或数据特性进行操作。其目的是分离那些在实验者看来是耦合的数据特征,这些耦合可能由数据的特定属性或生成过程引入。特征解耦的目标在于提高模型的性能或可解释性,例如,在面部识别中分离身份特征与表情特征,或在目标检测任务中分离分类和回归的特征。
具体方法包括但不限于:
1、模型架构设计:通过设计具有特定结构的神经网络模型,如多任务学习网络、生成对抗网络(GAN)等,来实现对特征的分离和解耦。
2、损失函数设计:通过设计合理的损失函数来引导模型学习解耦后的特征表示。例如,使用互信息最大化或最小化等方法来确保解耦后的特征相互独立。
3、数据预处理:在数据输入模型之前,通过预处理步骤(如数据增强、特征选择等)来降低数据间的耦合性,为后续的解耦操作提供基础。
集成学习通过将多个基学习器(如分类器、回归器等)的预测结果进行综合,以获得比单一学习器更准确的预测。常见的集成学习方法包括Bagging、Boosting和Stacking等。这些方法通过不同的方式训练基学习器,并在最终预测时将它们的结果进行融合。
在领域泛化问题中,集成学习方法可以通过以下几种方式应用:
对于每个源域,可以训练一个特定的学习器(如分类器)。这些学习器在各自的领域内具有良好的性能。在测试时,可以将这些领域特定的学习器的预测结果进行融合,以获得对目标域样本的预测。这种方法假设不同源域的知识可以通过集成来泛化到目标域。
除了领域特定的学习器外,还可以引入一个或多个领域无关的学习器。这些学习器不依赖于特定的源域,而是尝试学习一种更通用的表示。在测试时,将领域特定的学习器和领域无关的学习器的预测结果进行融合,可以进一步提高模型的泛化能力。
Stacking是一种特殊的集成学习方法,它使用一个元学习器来融合多个基学习器的预测结果。在领域泛化问题中,可以将多个领域特定的学习器和/或领域无关的学习器的预测结果作为输入,训练一个元学习器来进行最终的预测。这种方法可以进一步挖掘不同学习器之间的互补性,提高预测的准确性和稳定性。
基于元学习的领域泛化方法旨在通过元学习的机制来提高模型在不同领域上的泛化能力。以下是一些主要的方法和思路:
通过元学习,模型可以学习到一个领域不变性的特征表示。这种特征表示能够最小化目标域与多个源域之间的差距,从而提供一个独立于域的表示,使其在新的目标域上也表现良好。
假定任何域都由一个底层的全局共享因子和一个特定于域的组件组成。通过在源域训练期间分解这些组件,可以提取域无关的组件作为一个模型。这个模型可能包含源域与目标域的共同信息,因此有可能在新的目标域上工作。
在元学习框架下,可以设计特定的算法来优化领域泛化性能。例如,MLDG(Meta-Learning Domain Generalization)是一种通过元学习来解决领域泛化问题的算法。它既可以用于监督学习也可以用于强化学习,通过元训练阶段和元测试阶段的交替进行,来优化模型在不同领域上的性能。
元学习使得模型能够在面对新领域时,通过少量的样本或数据快速适应并表现出良好的性能。这种能力对于实际应用中的领域泛化问题尤为重要。
书籍参考:王晋栋《迁移学习导论》