原创 | 清华开源迁移学习算法库

作者:清华大数据软件团队机器学习组

本文长度为1700字,建议阅读6分钟

本文为你介绍 Trans-Learn 算法库。

Trans-Learn是基于PyTorch实现的一个高效、简洁的迁移学习算法库,目前发布了第一个子库——深度域自适应算法库(DALIB),支持的算法包括:

  • Domain Adversarial Neural Networks (DANN)

  • Deep Adaptation Network (DAN)

  • Joint Adaptation Networks (JAN)

  • Conditional Adversarial Domain Adaptation (CDAN)

  • Maximum Classifier Discrepancy (MCD)

  • Margin Disparity Discrepancy (MDD)

项目地址:

https://github.com/thuml/Transfer-Learning-Library

域自适应背景介绍

目前深度学习模型在部分计算机视觉、自然语言处理任务中已经超过了人类的表现,但是它们的成功依赖于大规模的数据标注。但是实际场景中,标注数据往往是稀缺的。解决标注数据稀缺问题的一个方法是通过计算机模拟生成训练数据,例如用计算机图形学的技术合成训练数据。

 

图表 1 VisDA2017竞赛任务

但是由于训练数据和测试数据不再服从独立同分布,训练得到的深度网络的准确率大打折扣。为了解决上述数据漂移造成的问题,域自适应(Domain Adaptation) 的概念被提出。域自适应的目标是将模型在源域(Source) 学到的知识迁移到目标域(Target)。例如计算机模拟生成训练数据的例子中,合成数据是源域,真实场景的数据是目标域。

 

域自适应有效地缓解了深度学习对于人工标注数据的依赖,受到了学术界和工业界广泛的关注。目前已经被引入到图片分类、图像分割(Segmentation)、目标检测(Object Detection)、机器翻译(Machine Translation) 等众多任务上。吴恩达曾说过:“在监督学习之后,迁移学习将引领下一波机器学习技术商业化浪潮。”随着产品级的机器学习应用进入数据稀缺的领域,监督学习得到的尖端模型性能大打折扣,域自适应变得至关重要。

 

研究现状

深度域自适应方法主要包括以下两大类:

1. 矩匹配。通过最小化分布差异来对齐不同域的特征分布。例如深度适配网络DAN,联合适配网络JAN。

2. 对抗训练域对抗网络DANN是最早的工作,它引入一个领域判别器,鼓励特征提取器学到领域无关的特征。 在DANN的基础上,衍生出了一系列方法,例如条件域对抗网络CDAN,间隔差异散度MDD等。

图表 2 DANN网络架构图

图表 3 MDD网络架构图

上述方法在实验数据上体现了良好的性能。然而目前学术界域自适应方法的开源实现中存在下述问题:

  • 复用性差。域自适应方法和模型架构、数据集耦合在一起,不利于域自适应方法在新的模型、数据集上复用。

  • 稳定性差。部分对抗训练方法随着训练进行,准确率会大幅度下降。

DALIB设计的初衷就是让用户通过少数几行代码,就可以将域自适应算法用在实际项目中,而无需考虑域自适应模块的实现细节。

易用性

DALIB将现有域自适应训练代码中的域自适应损失函数分离出来,按照PyTorch交叉熵损失函数的形式进行封装,方便用户的使用。域自适应损失函数也和模型架构进行了解耦,因此不依赖于具体的分类任务,所以算法库很容易扩展到图片分类以外的分类任务。

 

如下,使用两行代码即可定义一个与任务无关的域对抗损失函数。

 

不同域自适应损失函数中有一些公用的模块,例如所有算法中都用到的分类器模块,对抗训练中用到的梯度翻转模块、域判别器模块,核方法中的核函数模块等。这些公用模块和提供的域自适应损失函数是分离的。因此,在DALIB中,用户可以像搭积木一样,重新定制自己需要的域自适应损失函数。

 

例如,核方法中,用户可以自己定义不同参数的高斯核或者其他核函数,然后传入到多核最大均值差异(MK-MMD)的计算中。

 

目前,所有的模块和损失函数均已提供详细的API说明文档。

https://dalib.readthedocs.io/en/latest/

稳定性

域自适应算法研究领域往往关注方法的创新程度或者理论层面的价值,而忽视了工程实现中的稳定性和可复现性。在复现现有的算法的过程中,出现了部分算法准确率不稳定的问题。通过对数值方面的改进,这些问题都已经得到解决。(具体实现就不在此处展开了。)

 

此外,DALIB几乎在所有任务上,准确率都比原论文汇报准确率高,部分数据集上甚至能高14%。下图分别是Office-31和VisDA-2017上的测试结果。

 

图表 4 Office-31上不同算法的准确率

 

图表 5 VisDA2017上不同算法的准确率

算法库提供了各个算法在Office-31、Office-Home和VisDA-2017上的测试结果,以及所有的测试脚本。我们认为开源该算法库对于这个领域未来的研究工作是具有巨大价值的。

 

未来的工作

域自适应算法子库DALIB下一个版本会支持域自适应算法的不同设定,包括部分域自适应任务(Partial Domain Adaptation)、开放集域自适应任务(Open-set Domain Adaptation)、通用域自适应任务(Universal Domain Adaptation)等。

 

迁移学习算法库Trans-Learn目前还处于初期开发阶段,难免有不完善的地方,欢迎其他研究者提意见。同时迁移学习这个方向也还在不断发展,今后会不断跟进新工作中比较好的算法。

当前版本由龙明盛老师课题组的江俊广和付博同学开发,如果有任何意见和建议,欢迎联系[email protected]

[email protected]

编辑:于腾凯

校对:林亦霖

你可能感兴趣的:(算法,机器学习,人工智能,深度学习,python)