部分内容参考:寻找领域不变量:从生成模型到因果表征
在迁移学习和域自适应中,我们常常需要寻找领域不变的表征(Domain-invariant Representation),这种表示可被认为是学习到各领域之间的共性,并基于此共性进行迁移。而获取这个表征的过程就与深度学习中的“表征学习”联系紧密。生成模型,对比学习和最近流行的因果表征学习都可以视为获取良好的领域不变表征的工具。
生成模型的视角是在模型中引入隐变量(Latent Variable),而学到的隐变量为数据提供了一个隐含表示(Latent Representation)。生成模型描述了隐变量 z z z生成观测数据 x x x的过程:
用概率可以表示为: p θ ( x ) = ∑ z p θ ( x , z ) = ∑ z p θ ( z ) p θ ( x ∣ z ) p_{\theta}(x)=\sum_{z}p_{\theta}(x,z)=\sum_{z}p_{\theta}(z)p_{\theta}(x|z) pθ(x)=z∑pθ(x,z)=z∑pθ(z)pθ(x∣z)其中, p θ ( x ) p_{\theta}(x) pθ(x)表示在生成器模型参数为 θ \theta θ时, x x x出现的概率, p θ ( x , z ) p_{\theta}(x,z) pθ(x,z)则是 z z z和 x x x联合出现的概率, p θ ( x ∣ z ) p_{\theta}(x|z) pθ(x∣z)代表输入数据为 z z z,生成 x x x的概率。积分(求和)项 ∑ z p θ ( z ) p θ ( x ∣ z ) \sum_{z}p_{\theta}(z)p_{\theta}(x|z) ∑zpθ(z)pθ(x∣z)很难计算。
VAE的思想是采用变分近似后验分布 q ϕ ( z ∣ x ) q_{\phi}(z|x) qϕ(z∣x)(对应编码器),数据的生成过程 p θ ( x ∣ z ) p_{\theta}(x|z) pθ(x∣z)视为解码器。如下图所示:
变分自编码器的优化目标为最大化与 x x x关联的变分下界: m a x L V A E ( x ; θ , ϕ ) = − K L ( q ϕ ( z ∣ x ) ∣ ∣ p θ ( z ) ) + E z ∼ q ϕ ( z ∣ x ) ( l o g p θ ( x ∣ z ) ) maxL_{VAE}(x;\theta,\phi)=-KL(q_{\phi}(z|x)||p_{\theta}(z))+E_{z\sim q_{\phi}(z|x)}(logp_{\theta}(x|z)) maxLVAE(x;θ,ϕ)=−KL(qϕ(z∣x)∣∣pθ(z))+Ez∼qϕ(z∣x)(logpθ(x∣z))第一项使近似的 z z z后验分布 q ( z ∣ x ) q(z|x) q(z∣x)和先验分布 p θ ( z ) p_{\theta}(z) pθ(z)(人为设为高斯分布)尽可能接近(这样的目的是使解码器的输入尽可能服从高斯分布,从而使解码器对随机输入也有很好的输出),第二项为解码器的对数似然。
接下来是如何从近似后验分布 q ( z ∣ x ) q(z|x) q(z∣x)采样 z z z,因为 z z z不是由一个函数产生,而是由一个随机采样过程产生(它的输出会随我们每次查询而发生变化),故直接用一个神经网络 z = g ( x ) z=g(x) z=g(x)表示是不合理的,需要一个重参数化技巧: z = g ϕ ( ϵ , x ) = μ + σ ⊙ ϵ z=g_{\phi}(\epsilon,x)=\mu+\sigma\odot\epsilon z=gϕ(ϵ,x)=μ+σ⊙ϵ μ , σ = E n c o d e r ϕ ( x ) \mu,\sigma=Encoder_{\phi}(x) μ,σ=Encoderϕ(x) ϵ ∼ N ( 0 , I ) \epsilon\sim N(\textbf{0},\textbf{I}) ϵ∼N(0,I)这样就能使得 z z z来自随机采样,并通过反向传播训练。
这是在VAE基础上增加了条件信息 c c c(数据 x x x的标签信息):
因此,优化目标表示为: m a x L C V A E ( x , c ; θ , ϕ ) = − K L ( q ϕ ( z ∣ x , c ) ∣ ∣ p θ ( z ∣ c ) ) + E z ∼ q ϕ ( z ∣ x , c ) ( l o g p θ ( x ∣ z , c ) ) maxL_{CVAE}(x,c;\theta,\phi)=-KL(q_{\phi}(z|x,c)||p_{\theta}(z|c))+E_{z\sim q_{\phi}(z|x,c)}(logp_{\theta}(x|z,c)) maxLCVAE(x,c;θ,ϕ)=−KL(qϕ(z∣x,c)∣∣pθ(z∣c))+Ez∼qϕ(z∣x,c)(logpθ(x∣z,c))对于在MNIST训练完的VAE和CVAE,可以可视化隐向量 z ∼ q ϕ ( z ∣ x ) z\sim q_{\phi}(z|x) z∼qϕ(z∣x)和 z ∼ q ϕ ( z ∣ x , c ) z\sim q_{\phi}(z|x,c) z∼qϕ(z∣x,c):
可以看到CVAE的隐空间相比VAE的隐空间没有编码标签信息,而是编码其他关于数据 x x x的分布信息,这可以视为一种解耦的表征学习(disentangled representation learning):将不同标签的分布解耦开,比如0到9的10个分布都归一化到标准正态分布。
就领域自适应任务而言,训练生成模型获得了隐向量之后就已经完成目标,之后可以将隐向量拿到其它领域的任务中去用了。不过有时训练生成模型的最终目的还是为了生成原始数据。接下来对比两者的图像生成效果。
注意观察重参数的式子: z = g ϕ ( ϵ , x ) = μ + σ ⊙ ϵ z=g_{\phi}(\epsilon,x)=\mu+\sigma\odot\epsilon z=gϕ(ϵ,x)=μ+σ⊙ϵ我们首先要明白, z z z并不是服从标准正态分布的,它服从 N ( μ , σ ) N(\mu,\sigma) N(μ,σ)。在引入 K L ( q ϕ ( z ∣ x ) ∣ ∣ p θ ( z ) ) KL(q_{\phi}(z|x)||p_{\theta}(z)) KL(qϕ(z∣x)∣∣pθ(z))之前,VAE的学习目标仅有对数似然,换言之,此时关于编码器是未知的,为了规范化(正则化)编码器,VAE强行将所有 z z z的分布 q ϕ ( z ∣ x ) q_{\phi}(z|x) qϕ(z∣x)对齐到一个固定的分布即 p θ ( z ) p_{\theta}(z) pθ(z)。这样做的好处在于,将距离较远的簇的分布拉近,最终可以表现为所有 z z z的分布都比较集中,这可以确保整体分布的连续性。
因此,对于VAE的隐变量,不仅在整体上呈现标准正态分布,具体的局部簇也符合某个高斯分布(具体的局部簇对应一个category)。
本节内容来自:Federated Learning with Domain Generalization
首先提及联邦学习Federated Learning,联邦学习分为四个步骤:
FL使一组客户端能够在集中服务器的帮助下联合训练机器学习模型。客户端在训练期间不需要向服务器提交本地数据,因此客户端的本地训练数据受到保护。在FL中,分布式客户端独立收集其本地数据,因此每个客户端的数据集可能自然形成不同的源域。实际上,在多个源域上训练的模型在不可见的目标域上可能具有较差的泛化性能。
在FL中, K K K个不同的客户端协同训练一个机器学习模型来完成对象分类任务。每个客户端的数据形成一个不同的源域。该工作的目标是开发一个解决方案来学习一个分类器(在多个源域上),该分类器可以在Unseen domain上也能表现良好。
提取领域不变的表示也可以通过简单的 “特征提取器 + GAN 对抗训练” 来得到。如下图Client1设置了一个生成器根据随机噪声和标签编码(one-hot向量形式)来生成 “伪” 特征,并训练判别器来区分特征提取器得到的特征和"伪" 特征。此外,还采用了随机投影层来使得判别器更难区分这两种特征,使得对抗网络更稳定。
总结上面3个生成模型,我们可以发现其之所以可以学习域不变表征,是因为我们将隐变量的后验分布人为对齐到固定的先验分布,限制了特征的表征空间。
在自监督预训练中,我们要求该过程能够学习出一些对建模 p ( y ∣ x ) p(y|x) p(y∣x)(对应于下游分类任务)同样有用的特征。因为如果 y y y与 x x x的成因相关,则 p ( x ) p(x) p(x)应当和 p ( y ∣ x ) p(y|x) p(y∣x)存在联系,故试图找到变化潜在因素的自监督表示学习会非常有用。自然语言处理中的经典模型 BERT 便是基于自监督学习的思想。
对比学习属于自监督,通过构造anchor样本,正样本,负样本之间的关系来学习表征。对于任意anchor样本 x x x,我们用 x + x^{+} x+和 x − x^{-} x−分别表示正样本和负样本, f ( ⋅ ) f(\cdot) f(⋅)表示要训练的特征提取器。此时,学习目标为限制anchor与负样本的距离远大于与正样本的距离(此处的距离为在表征空间的距离): d ( f ( x ) , f ( x + ) ) ≥ d ( f ( x ) , f ( x − ) ) d(f(x),f(x^{+}))\geq d(f(x),f(x^{-})) d(f(x),f(x+))≥d(f(x),f(x−))其中, d ( ⋅ , ⋅ ) d(\cdot,\cdot) d(⋅,⋅)为距离的度量,常用的是余弦相似度: c o s ( a , b ) = a ⋅ b ∣ ∣ a ∣ ∣ × ∣ ∣ b ∣ ∣ cos(a,b)=\frac{a\cdot b}{||a||\times||b||} cos(a,b)=∣∣a∣∣×∣∣b∣∣a⋅b当 a , b a,b a,b归一化后,余弦相似度等价于向量内积。
此外,互信息也可以作为相似度的度量。在经典的 SimCLR 架构按照如下图所示的图像增强(比如旋转裁剪等)方式产生正样本:
如上图所示,它对每张输入的图片进行两次随机数据增强(如旋转剪裁等)来得到 x i , x j x_{i},x_{j} xi,xj。对于 x i x_{i} xi来说, x j x_{j} xj为其正样本,数据集剩下 N − 1 N-1 N−1个样本为负样本。
对比学习一般也是用来获取 embeddings,然后用于下游的有监督任务中。
依靠正样本和负样本的对比刺激,对比学习可以运用数据增强来捕捉正确的域不变特征。在泛化视角,这相比 “单纯的数据增强+直接监督学习正样本” 显得更加明确。
但我个人认为,对比学习可以学习泛化表征的另一个原因在于,对比学习侧重于学习样本之间的信息,这样做可以将样本均匀地在超球面空间上拉开距离,但是注意,前提是需要包含足够数据,模型才能学会将样本投影到同一个全局的超球面空间,这或许是CLIP这种大模型在大数据支撑下表现出色的原因。
因此,对比学习的两个特点应该是:
本节内容来自:Causality Inspired Representation Learning for Domain Generalization
前面提到在对比学习中可以运用数据增强来捕捉域不变特征,然而这种数据增强的框架也可以从因果表征学习的视角来看。因果推断中的因果不变量同样也可以对应到领域不变的表征。
原始数据 X X X由因果因子 S S S(图像本身的语义)和非因果因子 U U U(图像的风格)混合决定,且只有 S S S能够影响原始数据的类别标签。
我们不能将原始数据量化为 X = f ( S , U ) X=f(S,U) X=f(S,U),因为因果和非因果因子一般不能观测到并且不能被形式化。
我们的任务应该是将因果因子 S S S从原始数据提取出来,而这可以在因果干预 P ( Y ∣ d o ( U ) , S ) P(Y|do(U),S) P(Y∣do(U),S)的帮助下完成。下面将对CIRL进行具体描述。
为了提高泛化能力,人们提出了很多DG方法(Domain generalization),但目前存在一个固有问题,这些努力只是试图弥补OOD数据造成的问题,并对数据和标签之间的统计相关性建模,而没有解释潜在的因果机制。例如,在图像分类任务中,很可能所有长颈鹿都在草地上,表现出高度的统计依赖性,当目标域中的背景变化时,这很容易误导模型做出错误的预测。毕竟,长颈鹿的头部、颈部等特征才构成了长颈鹿。
CIRL引入结构因果模型(SCM)来形式化DG问题,旨在挖掘数据和标签之间的内在因果机制,并获得更好的泛化能力。具体来说,假设category-related信息作为因果因子(causal factors),因果因子与label的关系独立于domain。比如数字识别中的"shape"。而独立于category的信息被认为是非因果因子(non-causal factors),这通常是与domain相关的信息,例如数字识别中的"handwriting style"(笔迹风格)。每个原始数据 X X X由因果因子 S S S和非因果因子 U U U混合而成,只有前者对类别标签 Y Y Y产生因果影响,如下图所示。
我们的目标是从原始输入 X X X中提取因果因子 S S S,然后重建不变的因果机制,其可以基于因果干预 P ( Y ∣ d o ( U ) , S ) P(Y|do(U),S) P(Y∣do(U),S)完成。 d o ( ⋅ ) do(⋅) do(⋅)表示对变量的干预。
CIRL认为,因果因子 S S S应满足三个属性:
如下图a所示,与 U U U的混合使 S S S包含潜在的(underlying)非因果信息,而联合依赖因子分解(dependent factorization)时会使 S S S冗余,进一步导致遗漏一些潜在的因果信息。相反,图b中的因果因子 S S S是满足所有要求的理想因子。受此启发,提出了一种因果激励表示学习(CIRL,Causality Inspired Representation Learning)算法,强制所学习的表示具有图b的性质,然后利用表示的每个维度模拟因果因子的因式分解,从而具有更强的泛化能力。
简而言之,对于每个输入,我们首先利用因果干预模块(causal intervention module),通过生成带有扰动领域相关信息的新数据,将因果因子 S S S与非因果因子 U U U分开。与原始数据相比,生成的数据具有不同的非因果因子 U U U,但具有相同的因果因子 S S S,因此表征被强制保持不变。此外,利用因子分解模块(factorization module),使表征的每个维度联合独立,然后可以用来近似因果因子。此外,为了便于分类,使用对抗掩码模块,该模块反复检测包含相对较少因果信息的维度,并通过掩码器和表征生成器之间的对抗学习,迫使它们包含更多新颖的因果信息。
在开始了解具体方法前,我们需要了解两个原则:
对于原则1,我们将SCM形式化以描述DG问题: X : = f ( S , U , V 1 ) , S ⊥ U ⊥ V 1 X:=f(S,U,V_{1}),S\perp U\perp V_{1} X:=f(S,U,V1),S⊥U⊥V1 Y : = h ( S , V 2 ) = h ( g ( X ) , V 2 ) , V 1 ⊥ V 2 Y:=h(S,V_{2})=h(g(X),V_{2}),V_{1}\perp V_{2} Y:=h(S,V2)=h(g(X),V2),V1⊥V2其中, X , Y X,Y X,Y代表输入图像和对应的label, S S S代表影响 X X X和 Y Y Y的因果因子,比如与类别相关的信息,如数字识别中的"shape",而 U U U表示仅影响 X X X的非因果因子,通常是与领域相关的信息比如"style"。 V 1 V_1 V1、 V 2 V_2 V2是共同独立的无法解释的噪声变量( ⊥ \perp ⊥代表相互独立)。对于 f f f、 h h h、 g g g,它们可以看作是未知的结构函数。因此,对于任何分布 P ( X , Y ) P(X,Y) P(X,Y),如果因果因子 S S S给定,就存在一个条件分布 P ( Y ∣ S ) P(Y|S) P(Y∣S),比如:不变的因果机制(invariant causal mechanism)。基于上述讨论,如果我们能够获得因果因子,那么通过优化 h h h: h ∗ = arg min h E P [ l ( h ( g ( X ) ) , Y ) ] = arg min h E P [ l ( h ( S ) , Y ) ] h^{*}=\arg\min_{h}E_{P}[l(h(g(X)),Y)]=\arg\min_{h}E_{P}[l(h(S),Y)] h∗=arghminEP[l(h(g(X)),Y)]=arghminEP[l(h(S),Y)]其中, l ( ⋅ , ⋅ ) l(\cdot,\cdot) l(⋅,⋅)代表交叉熵损失。
不幸的是,因果因子 S S S并不是事先提供给我们的,我们得到的是原始图像 X X X,而这些图像通常是非结构化的。直接重建因果因子和机制是不切实际的,因为它们是不可观测和不明确的。然而,显而易见的是,因果因子仍然需要遵守某些要求。以前的工作声明,因果因子应该是共同独立的,如原则2所示。
因为 S S S代表所有因果因子的集合 { s 1 , s 2 , . . . , s N } \left\{s_{1},s_{2},...,s_{N}\right\} {s1,s2,...,sN},原则2告诉我们:
因此,我们可以将因果因子的联合分布分解为下列条件分布: P ( s 1 , s 2 , . . . , s N ) = ∏ i = 1 N P ( s i ∣ P A i ) P(s_{1},s_{2},...,s_{N})=\prod_{i=1}^{N}P(s_{i}|PA_{i}) P(s1,s2,...,sN)=i=1∏NP(si∣PAi)因此,我们强调,基于共同原因原则(原则1)中因果变量的定义和ICM原则(原则2)中因果机制的性质,因果因子 S S S应满足三个基本属性:
基于这3个属性,我们建议通过强迫网络的表示具有与因果因子相同的属性来学习因果表征,而不是直接理想化地重构因果因子。
CIRL由3个模块组成:causal intervention module(因果干预模块),causal factorization module(因果因子分解模块)和adversarial mask module(对抗掩码模块),CIRL框架如下图:
我们首先旨在通过因果干预将因果因子 S S S从非因果因子 U U U的混合中分离出来。我们已经知道因果因子 S S S应该对 U U U的干预保持不变,即 P ( S ∣ d o ( U ) ) P(S|do(U)) P(S∣do(U))不变。在DG的相关工作中,我们确实知道一些领域相关信息无法确定输入的类别,可以将其视为非因果因素,并通过一些技术捕获。例如,傅里叶变换有一个众所周知的特性:傅里叶谱的相位分量保留了原始信号的高级语义,而振幅分量包含低级统计信息。因此,我们通过干扰振幅信息而对 U U U进行干预,同时保持相位信息不变。形式上,给定原始输入图像 x o x^{o} xo,其傅里叶变换可以表示为: F ( x o ) = A ( x o ) × e − j × P ( x o ) \mathscr{F}(x^{o})=A(x^{o})\times e^{-j\times P(x^{o})} F(xo)=A(xo)×e−j×P(xo)其中, A ( x o ) , P ( x o ) A(x^{o}),P(x^{o}) A(xo),P(xo)分别表示振幅(amplitude)和相位(phase)分量。傅里叶变换 F ( ⋅ ) \mathscr{F}(\cdot) F(⋅)和其逆变换 F − 1 ( ⋅ ) \mathscr{F}^{-1}(\cdot) F−1(⋅)可以用FFT进行有效计算。然后,我们通过在原始图像 x o x^{o} xo的振幅谱和从任意源域随机采样的图像 ( x ′ ) o (x')^{o} (x′)o之间进行线性插值扰动振幅信息: A ^ ( x o ) = ( 1 − λ ) A ( x o ) + λ A ( ( x ′ ) o ) \widehat{A}(x^{o})=(1-\lambda)A(x^{o})+\lambda A((x')^{o}) A (xo)=(1−λ)A(xo)+λA((x′)o)其中, λ ∼ U ( 0 , η ) \lambda\sim U(0,\eta) λ∼U(0,η), η \eta η控制扰动强度。然后,我们将扰动振幅谱与原始相位分量结合,用傅里叶逆变换生成增强后的图像 x a x^{a} xa: F ( x a ) = A ^ ( x o ) × e − j × P ( x o ) , x a = F − 1 ( F ( x a ) ) \mathscr{F}(x^{a})=\widehat{A}(x^{o})\times e^{-j\times P(x^{o})},x^{a}=\mathscr{F}^{-1}(\mathscr{F}(x^{a})) F(xa)=A (xo)×e−j×P(xo),xa=F−1(F(xa))将CNN模型实现的表示生成器记为 g ^ ( ⋅ ) \widehat{g}(\cdot) g (⋅),表示记为 r = g ^ ( x ) ∈ R 1 × N r=\widehat{g}(x)\in R^{1\times N} r=g (x)∈R1×N,其中 N N N是维数,为了模拟对 U U U的干预保持不变的因果因子,优化 g ^ \widehat{g} g : max g ^ 1 N ∑ i = 1 N C O R ( r ~ i o , r ~ i a ) \max_{\widehat{g}}\frac{1}{N}\sum_{i=1}^{N}COR(\tilde{r}_{i}^{o},\tilde{r}_{i}^{a}) g maxN1i=1∑NCOR(r~io,r~ia)其中, r ~ i o , r ~ i a \tilde{r}_{i}^{o},\tilde{r}_{i}^{a} r~io,r~ia分别代表 R o = [ ( r 1 o ) T , . . . , ( r B o ) T ] T ∈ R B × N R^{o}=[(r_{1}^{o})^{T},...,(r_{B}^{o})^{T}]^{T}\in R^{B\times N} Ro=[(r1o)T,...,(rBo)T]T∈RB×N和 R a = [ ( r 1 a ) T , . . . , ( r B a ) T ] T ∈ R B × N R^{a}=[(r_{1}^{a})^{T},...,(r_{B}^{a})^{T}]^{T}\in R^{B\times N} Ra=[(r1a)T,...,(rBa)T]T∈RB×N的第 i i i列Z-score标准化。 B B B代表batch-size。
r j o = g ^ ( x j o ) r^{o}_{j}=\widehat{g}(x^{o}_{j}) rjo=g (xjo)和 r j a = g ^ ( x j a ) r^{a}_{j}=\widehat{g}(x^{a}_{j}) rja=g (xja), j ∈ { 1 , . . . , B } j\in\left\{1,...,B\right\} j∈{1,...,B}。我们利用COR函数(计算相关性)来衡量干预前后表征的相关性。因此,我们可以通过使其独立于 U U U来实现用表征 R R R模拟因果因子 S S S的第一步。
如原则2,因果因子 s 1 , . . . , s N s_{1},...,s_{N} s1,...,sN的因子分解应该是联合独立的,不需要其他因子的信息,因此,我们应该使表征的任何两个维度相互独立: min g ^ 1 N ( N − 1 ) ∑ i ≠ j C O R ( r ~ i o , r ~ i a ) \min_{\widehat{g}}\frac{1}{N(N-1)}\sum_{i\neq j}COR(\tilde{r}_{i}^{o},\tilde{r}_{i}^{a}) g minN(N−1)1i=j∑COR(r~io,r~ia)为了节省计算成本,我们统一 g ^ \widehat{g} g 的两个优化目标,建立相关性矩阵 C C C(correlation matrix): C i j = < r ~ i o , r ~ i a > ∣ ∣ r ~ i o ∣ ∣ ⋅ ∣ ∣ r ~ i a ∣ ∣ , i , j ∈ 1 , 2 , . . . , N C_{ij}=\frac{<\tilde{r}_{i}^{o},\tilde{r}_{i}^{a}>}{||\tilde{r}_{i}^{o}||\cdot||\tilde{r}_{i}^{a}||},i,j\in1,2,...,N Cij=∣∣r~io∣∣⋅∣∣r~ia∣∣<r~io,r~ia>,i,j∈1,2,...,N其中, < ⋅ , ⋅ > <\cdot,\cdot> <⋅,⋅>表示内积操作。因此,因此, R o R^o Ro和 R a R^a Ra的同一维度可以被视为需要最大化相关性的正对(positive pairs),而不同维度可以被看作需要最小化相关性的负对(negative pairs)。基于此,得到了因子分解损失 L F a c L_{Fac} LFac: min g ^ L F a c = 1 2 ∣ ∣ C − I ∣ ∣ F 2 \min_{\widehat{g}}L_{Fac}=\frac{1}{2}||C-I||^{2}_{F} g minLFac=21∣∣C−I∣∣F2优化上面的目标可以使相关矩阵 C C C的对角元素近似于1,这意味着非因果因子干预前后的表示是不变的。这表明我们可以有效地将因果因子从非因果因子的混合中分离出来。此外,它还使 C C C的非对角线元素接近0,即强制让表征的维度联合独立。因此,通过最小化 L F a c L_{Fac} LFac,我们可以使表征满足理想因果因子的前两个特性。
为了成功完成分类任务 X → Y X→Y X→Y,表征应该是因果充分的,即包含所有支撑信息。最直接的方法是在多个源域中使用监督标签 y y y: L c l s = l ( h ^ ( g ^ ( x o ) ) , y ) + l ( h ^ ( g ^ ( x a ) ) , y ) L_{cls}=l(\widehat{h}(\widehat{g}(x^{o})),y)+l(\widehat{h}(\widehat{g}(x^{a})),y) Lcls=l(h (g (xo)),y)+l(h (g (xa)),y)其中, h ^ \widehat{h} h 是分类器。然而,这种直截了当的方法不能保证我们所学表征的每个维度都是重要的(不能包含足够的潜在因果信息用于分类)。具体来说,可能存在携带相对较少因果信息的劣质维度(inferior dimensions),对分类做出较小贡献。因此,我们应该检测这些维度并强制它们做出更多贡献。由于在因子分解模块的帮助下,维度也需要联合独立,因此,检测到的inferior dimensions被渲染为包含更多新的因果信息,而其他维度没有包含这些信息,这使得整个表示更具因果性。
因此,为了检测inferior dimensions,我们设计一个对抗性掩码模块。我们构建了一个基于神经网络的掩码,用 w ^ \widehat{w} w 表示,以学习每个维度的贡献(输出向量 m ∈ R N m\in R^{N} m∈RN),这些维度对应于较大的比率 κ ∈ ( 0 , 1 ) κ∈(0,1) κ∈(0,1)被视为superior dimensions(优质维度),其余被视为inferior dimensions。
我们将表征乘以获得的掩码 m m m和 1 − m 1−m 1−m,可以分别突出表征的superior dimensions和inferior dimensions。然后,我们将它们输入两个不同的分类器 h ^ 1 , h ^ 2 \widehat{h}_{1},\widehat{h}_{2} h 1,h 2, L c l s L_{cls} Lcls被写成: L c l s s u p = l ( h ^ 1 ( r o ⊙ m o ) , y ) + l ( h ^ 1 ( r a ⊙ m a ) , y ) L_{cls}^{sup}=l(\widehat{h}_{1}(r^{o}\odot m^{o}),y)+l(\widehat{h}_{1}(r^{a}\odot m^{a}),y) Lclssup=l(h 1(ro⊙mo),y)+l(h 1(ra⊙ma),y) L c l s i n f = l ( h ^ 2 ( r o ⊙ ( 1 − m o ) ) , y ) + l ( h ^ 2 ( r a ⊙ ( 1 − m a ) ) , y ) L_{cls}^{inf}=l(\widehat{h}_{2}(r^{o}\odot (1-m^{o})),y)+l(\widehat{h}_{2}(r^{a}\odot (1-m^{a})),y) Lclsinf=l(h 2(ro⊙(1−mo)),y)+l(h 2(ra⊙(1−ma)),y)从上面式子可以看到这是一个对抗的行为,CIRL的最终优化目标为: min g ^ , h ^ 1 , h ^ 2 L c l s s u p + L c l s i n f + L F a c \min_{\widehat{g},\widehat{h}_{1},\widehat{h}_{2}}L_{cls}^{sup}+L_{cls}^{inf}+L_{Fac} g ,h 1,h 2minLclssup+Lclsinf+LFac min w ^ L c l s s u p − L c l s i n f \min_{\widehat{w}}L_{cls}^{sup}-L_{cls}^{inf} w minLclssup−Lclsinf