李宏毅2022机器学习HW11解析

李宏毅2022机器学习HW11解析_第1张图片

准备工作

作业十一是域适应(Domain Adaptation),通过训练真实图片得到分类模型,并将其应用到涂鸦图片上进行分类,来获得更高的精准度。作业过程需要助教代码和数据集,关注本公众号,可获得代码和数据集(包括解析代码,文末有方法)。

提交地址

Kaggle:http://www.kaggle.com/competitions/ml2022-spring-hw11,有想讨论沟通的同学可进QQ群:156013866。以下为作业解析。

Simple Baseline (acc >0.44194)

方法:直接运行助教代码。注意在本地或kaggle上运行时候,需要调整相应的文件名称或者路径。代码提交kaggle得到score:0.52368 。

Medium Baseline (acc >0.64950)

方法:增加epoch+ 改变lamb值。epoch从200增加到800,lamb从0.1变为0.7,提升lamb意味着更注重domainclassifier的表现,让source domain和target domain的表现更一致,不过也不能一味的提升,太大会影响labelpredictor的能力,我试验了下发现0.7是一个不错的值。改动后提交kaggle得到socre:0.70924

Strong Baseline (acc >0.75470)

方法:增加epoch+ 动态调整lamb值。将epoch调整到1000。根据DANN论文,可使用动态调整的lamb值,从0.02动态的调整为1,这样前期可让labelpredictor更准确,后期更注重domainclassifier的表现,我们也使用这种方式。提交kaggle的socre是:0.76310

epochs = 1000
....
for epoch in range(epochs):
    lamb = np.log(1.02 + 1.7*epoch/epochs)
    ....

李宏毅2022机器学习HW11解析_第2张图片

我们看下第0,399,999 epoch的训练效果图。在这些epoch,将5000张source和target图片,输入到feature_extractor模型,每张图片的输出是一个512维度的图片,然后利用t-sne方法降维到2维,最后画出source图片的不同类别分布图和source target分布对比度。可以看出,在第0个epoch,图较乱,到第399个epoch,分类已经比较明显,target的分布和source的分布也基本上保持一致,不过边缘仍然比较模糊,最后第999个epoch,分类已经很明显,target和source的分布基本一致了,达到了DANN模型的目的。

李宏毅2022机器学习HW11解析_第3张图片

李宏毅2022机器学习HW11解析_第4张图片

李宏毅2022机器学习HW11解析_第5张图片

Boss Baseline (acc > 0.80394)

方法:利用DANN模型生成伪标签(pseudo-label)。我们借鉴DIRT的两步训练法,第一步用adversarial的方法训练一个模型,这里我们使用strong baseline得到的模型,所以第一步已经完成,第二步是利用第一步产生的模型,对target图片生成伪标签,有了标签,就可以对target做有监督学习,该方法能充分利用模型的潜在价值。在具体的实现环节,为了保证伪标签的可靠性,设计了一个超参数赋值0.95,所产生的伪标签概率高于该值才被使用,另外为保证训练的稳定性,使用了teacher网络,利用teacher网络生成伪标签,teacher网络的初始值也来自strong base模型,在训练过程中,teacher网络更新比较慢,做法是设计了一个超参数赋值0.9,teacher网络的更新中0.9的权重来自于自己,0.1的权重来自于主干网络,代码变动相对之前比较多,可用文末方式获得。提交kaggle的socre是:0.80948,相信通过调整参数能得到更好的效果。

模型被训练了400个epoch,这里画出了第0,199,399epoch的t-sne图。可以看出,第0个epoch和strong baseline最后的效果差不多,这跟我们网络的初始参数来自于strong baseline相呼应,第199个epoch,图形变动较大,source和target有了明显的区分,不过分类仍然清晰,第399个epoch,更进了一步,source和target的分布不同是因为二者来自不同的domain,即使同属一个类,他们的feature也会有所区分。

李宏毅2022机器学习HW11解析_第6张图片

李宏毅2022机器学习HW11解析_第7张图片

李宏毅2022机器学习HW11解析_第8张图片

作业十一答案获得方式:

  1. 关注微信公众号 “机器学习手艺人”
  2. 后台回复关键词:202211

你可能感兴趣的:(机器学习,人工智能,python,深度学习)