领域自适应阅读笔记2

Progressive Feature Alignment for Unsupervised Domain Adaptation

来源:CVPR 2019,厦门大学信息科学与工程学院,腾讯AI实验室。

作者:Chaoqi Chen, Weiping Xie等

数据集:Office-31(31类样本,三个域,A,W,D,本文适配6次),ImageCLEF-DA(12类样本,三个域,I,P,C,本文适配6次),MNIST,SVHN,USPS(这三个数据集,MNIST的图片size是28*28,SVHN的图片size是16*16,USPS的图片size是32*32,并且一张图片上通常有多个数字,文中沿袭之前的工作,进行MNIST到SVHN的双向适配,MNIST到USPS的单向适配)

实现框架:Caffe

实验的backbone:AlexNet

损失函数:

整体的优化目标


对齐整体分布的损失函数


对齐原域和目标域每一类样本的损失函数

笔记里我记得比较杂,可能突出不了本文的重点。目前正在探索中,目标是,既突出重点,又能把从该篇论文中得到的关于已学知识的进一步认识给记录下来。

摘要部分

之前有人提出用加入伪标签的方法来进行原域和目标域的类分布(class-level distribution)的对齐,但是这种方法对错误累积(error calculation)非常敏感,因而不能保留跨域的种类一致性(cross-domain category consistency)。

本文提出用PFAN通过探索目标域的类内的多样性来对齐原域和目标域的有辨别力的特征。

 We propose the Progressive Feature Alignment Network (PFAN) to align the discriminative features across domains progressively and effectively, via exploiting the intra-class variation in the target domain.

特别地,本文用特别提出用Easy to-Hard Transfer Strategy (EHTS) 和 Adaptive Prototype Alignment (APA) 来训练模型。

 同时,为了减慢原域分类损失的收敛速度,本文把在softmax函数中加入了一个temperature variate。

Introduction




本文方法的动机

圆圈代表原域样本,三角代表目标域样本,绿色代表数字9,橘色代表数字7,蓝色代表数字1.

用原域样本训练出分类器,对目标域样本进行分类,可以看到,目标域样本(三角形)可分成三种情况:

第一种,作者称为easy samples。

这类目标域样本,由于和原域非常接近,被分类器分对的可能性很大,不需要适配就能给它们分配伪标签。

第二种称为hard samples。图中紫色圆圈圈住的样本。

它们是那些离原域很远的目标域样本,,它们位于分类边界附近,分类器不知道把它们分成哪一类。

第三种称为叫做false-easy samples。图中红色圆圈圈住的样本。

它们属于easy samples,但是分类器给它们分配的伪标签是错误的。虽然分类器把这类样本分错了,但是分类器对自己

迷之自信,认为自己分得是非常正确的,换句话说,分类器对自己的分类结果有很高的confidence。

作者认为,这些false-easy samples会给种类对齐(catagory alignment)带来错误信息,可能会造成错误累积。

作者提出的PFAN网络,主要采取EHTS和APA。EHTS的作用是渐进式地选择那些值得信赖的目标域样本(已经被分配了伪标签),APA的作用是对于原域和目标域中的每一类,对齐它们的原型(prototype).

EHTS和APA是相互作用的,EHTS可以促进APA,APA又可以反过来促进EHTS。

In this paper, we propose a Progressive Feature AlignmentNetwork (PFAN), which largely extends the abilityof prior discriminative representations-based approaches byexplicitly enforcing the category alignment in a progressivemanner. Firstly, an Easy-to-Hard Transfer Strategy(EHTS) progressively selects reliable pseudo-labeled targetsamples with cross-domain similarity measurements. However,the selected samples may include some misclassifiedtarget samples with high confidence. Then, to suppressthe negative influence of falsely-labeled samples, we proposean Adaptive Prototype Alignment (APA) to align thesource and target prototypes for each category. Rather thanbackpropagating the category loss for target samples basedon pseudo-labeled samples, our work statistically align thecross-domain class distributions based on the source samplesand the selected pseudo-labeled target samples。



PFAN的整体结构


接下来是过于详细的解说:

        每一个原域样本通过嵌入函数G(即图2中的特征提取器)后,会得到一个D维的特征向量,本文假定原域和目标域中的样本均有C类,对原域中的每一类样本,计算其经过潜入函数后得到向量的均值,该均值就是文中提到的source prototype,这个均值也是一个D维的向量。故原域中共有C个prototype.。


source prototype的计算公式

对于一个不带标签的目标域样本,文中通过以下方式为其分配伪标签。

          首先,文中定义一个相似性度量函数psi,该函数是一个cosine相似度函数,用于衡量经过特征提取器提取后的目标域向量和原域的某一类prototype之间相似性,原域有C类,共有C个prototype,这样,就计算出了C个结果,最后,在这C个结果中选取psi最大的值对应的类别数作为该目标域样本的标签。


相似性度量函数


       接下来开始选择目标域样本的easy samples,选择的标准是设置一个阈值tau,刚才,我们为每一个目标域样本都计算出了C个psi值,选择最大的psi值和tau比较,若大于tau,则该目标域样本被选中成easy samples。

    由于随着训练的进行,psi的值是逐渐增大的,所以为了控制easy samples的生成速率,文中对于tau值的设定是随着训练的进行而不断变化的,文中给出了tau的计算公式。


tau的计算公式


目标域的easy samples的选择方式

APA通过对齐原域的prototype和被选择出来的目标域样本的prototype来减弱false-easy samples的负面影响以及促进原域和目标域的种类一致性。对齐是通过最小化原域和目标域的prototype之间的欧氏距离来实现的。


原域和目标域的prototype的度量

等式5里面提到的prototype是全局的,也就是说要找出原域中某一类的全部样本,然后再计算prototype,但是实际训练中,是用mini-batch方法训练的,每次只处理batchsize个样本,因而,可以用原域样本中batchsize个样本去算出原域的一个local prototype,再用目标域里通过EHTS选取出的样本来算出目标域的local prototype,然后去对齐这两个prototype。

但是这种方法有缺点,那就是当每一个mini-batch中包含的样本类别不到C时,目标域里的一个false-sample样本就会让计算出的prototype和真实的prototype之间产生很大的差异。

基于上面的问题,本文采用如下方式解决:首先APA根据最初选取的目标域的easy-samples计算出目标域的一个global prototype,然后,在每次迭代过程中,都要去计算C个目标域的local prototype,迭代到当前,总共迭代了I次,一共计算出了I*C个目标域的local prototype,那么每一种类的样本都计算出了I个local prototype,把这I个local prototype取平均值,然后利用这个平均值,通过公式8所表达的那样来得到当前迭代次数下的某一类的global prototype。公式8里首先用前面提到的psi函数来度量当前上一迭代步的global prototype和本次迭代步计算出的平均值之间的相似度,然后利用这个相似度的平方和1减去这个相似度的平方分别做系数,根据这个平均数以及上一迭代步的global prototype来计算新的global prototype。原域样本某一类的global prototype也通过这种方式来计算。


平均值计算公式


第I次迭代时,global prototype的计算方式



整体的算法流程


在这篇论文里,提到了其他运用伪标签的论文,我读过的有Learning Semantic Representations for Unsupervised Domain Adaptation以及Collaborative and Adversarial Network for Unsupervised domain adaptation。这两篇都是用对抗方法的。







关于对之前方法的总结,我认为作者总结得不错的部分,摘抄下来:

Many approaches utilize a distance metric to measurethe domain discrepancy between the source and target domains,such as maximum mean discrepancy (MMD), KLdivergenceor Wasserstein distance [12, 22, 37, 24, 42, 6].Most prior efforts intend to achieve domain alignment bymatching P(Xs) and P(Xt). However, an exact domainlevelalignment does not imply a fine-grained class-to-classoverlap.

你可能感兴趣的:(领域自适应阅读笔记2)