迁移学习论文(五):Learning Semantic Representations for Unsupervised Domain Adaptation论文原理及复现工作

目录

  • 前言
  • 原理阐述
    • 文章介绍
    • 模型结构
      • 模型总述
    • 超参数设置
  • 总结

前言

  • 本文属于我迁移学习专栏里的一篇,该专栏用于记录本人研究生阶段相关迁移学习论文的原理阐述以及复现工作。
  • 本专栏的文章主要内容为解释原理,论文具体的翻译及复现代码在文章的github中。

原理阐述

文章介绍

  • 这篇文章于2018年发表在ICML会议,作者是Shaoan Xie、Zibin Zheng、Liang Chen、Chuan Chen。
  • 这篇文章解决的主要问题是如何利用伪标签来进行域适应。之前的方法都忽略了样本的语义信息,比如之前的算法可能将目标域的背包映射到源域的小汽车附近。 这篇文章最要的贡献就是提出了 moving semantic transfer network 这个网络,简称mstn,其主要是通过对齐源域(有标签)和 目标域(伪标签,网络预测一个标签)相同类别的中心,以学习到样本的语义信息。

模型结构

  • 模型是这样的:
    迁移学习论文(五):Learning Semantic Representations for Unsupervised Domain Adaptation论文原理及复现工作_第1张图片

模型总述

  • 上述模型的G特征提取器和F标签分类器以及D域分类器与DANN中的特征提取器、标签分类器和全局域分类器是一样的,这里不展开研究了。
  • 这个论文有价值的地方在于使用了伪标签,提出了semantic transfer loss,这个论文中的方法其实我也有考虑到过,我是受了DAAN的启发,但DAAN应该是受了该文的启发,因为DAAN是2019年发表的。DAAN中的局部域分类器也是将样本的每个类单独分开计算损失,但是DAAN计算的是域分类损失,而MSTN考虑的是MSE,因为相同类别经过特征提取之后的特征应当是相近的,这对应域适应中的条件概率损失。
  • 但是MSTN考虑到了两个问题,1.每次抽取样本可能会使得某些类别没有抽取到样本,那么就无从计算MSE。2.伪标签可能是不准确的,这样可能导致相反的效果,比如使一个书包的特征和一个汽车的特征进行对齐。
  • MSTN的解决办法非常有意思:
    迁移学习论文(五):Learning Semantic Representations for Unsupervised Domain Adaptation论文原理及复现工作_第2张图片
    对每个类维护一个全局特征 C T k 或 者 C S k C^k_{T}或者C^k_{S} CTkCSk,每次使用 C T k 或 者 C S k C^k_{T}或者C^k_{S} CTkCSk来计算损失, C T k 或 者 C S k C^k_{T}或者C^k_{S} CTkCSk的计算同时考虑当前的 C T k 或 者 C S k C^k_{T}或者C^k_{S} CTkCSk和本次根据样本生成的平均特征。所以就算本次抽取样本中没有某一类的样本,也可以根据该类上一次的 C T k 或 者 C S k C^k_{T}或者C^k_{S} CTkCSk来计算,同时假如有错误的伪标签也因为占比不大所以影响不大。
  • 其实MSTN这种解决办法也是尽可能的削弱错误影响,并没有根本上解决这些问题。

超参数设置

  • 学习率采用衰减,
    迁移学习论文(五):Learning Semantic Representations for Unsupervised Domain Adaptation论文原理及复现工作_第3张图片
    p是迭代次数占总的比例,学习率每次迭代更新一次,
def train(epoch, model, sourceDataLoader, targetDataLoader,DEVICE,args):
    learningRate=args.lr/math.pow((1+10*(epoch-1)/args.epoch),0.75)
  • 损失函数迁移学习论文(五):Learning Semantic Representations for Unsupervised Domain Adaptation论文原理及复现工作_第4张图片
    三项分别是标签分类损失,域分类损失,semantic transfer loss,其中 γ = λ γ=λ γ=λ,λ遵循下面的公式:
    迁移学习论文(五):Learning Semantic Representations for Unsupervised Domain Adaptation论文原理及复现工作_第5张图片
    里面的上图的γ可不是损失函数中的γ,上图的p设置为当前batchid占总的比例,如下代码所示:

    lenSourceDataLoader = len(sourceDataLoader)
    for batch_idx, (sourceData, sourceLabel) in tqdm.tqdm(enumerate(sourceDataLoader),total=lenSourceDataLoader,desc='Train epoch {}'.format(epoch),ncols=80,leave=False):
        p = float(batch_idx + 1 + epoch * lenSourceDataLoader) / args.epoch / lenSourceDataLoader
        alpha = 2. / (1. + np.exp(-10 * p)) - 1
  • CNN 采用的是AlexNet作为基本结构,fc7后面接了一个bottleneck layer(瓶颈层,主要作用是降维)。
  • 鉴别器,我们采用的是RevGard相同的结构:x-》1024-》1024-》2
  • 超参数的设置:θ = 0.7。

总结

  • 该文总体来说提供了一种思路,但是我觉得伪标签的问题其实并没有办法真正解决,会限制该类模型的上限并不会很高。

你可能感兴趣的:(迁移学习论文复现,人工智能,python,深度学习)