不了解Domain adaptive可以先看看简介:https://blog.csdn.net/qq_33278461/article/details/90480525
以下几篇文章主要是用判别方法对域适应,有很多共性,主要讲一下第一篇,后面的作为对比参考,最后一篇ADDA是对利用判别方法作域适应的一个总结文章,它把这个过程抽象成了一个统一框架(推荐看一看)。
注:
输入D的分别是真实的图像和生成的图像,其中真实的图像需要归一化到0-1之间。
分割:只针对source数据。
分类:real和fake,逐像素判别。
训练步骤:
先更新D网络, (4个逐像素的判别器loss+1个辅助分割loss);
再更新G网络, (2个L1 loss,1个分割loss,2个判别loss);
最后更新C和F网络, (2个分割loss,2个判别loss)。
网络输入输出,以及更新步骤的核心代码
代码来自:https://github.com/swamiviv/LSD-seg
# Source domain
score, fc7, pool4, pool3 = self.model(data_source)
outG_src = self.netG(fc7, pool4, pool3)
outD_src_fake_s, outD_src_fake_c = self.netD(outG_src) #4class segmentation 都是每个像素
outD_src_real_s, outD_src_real_c = self.netD(data_source_forD)
# target domain
tscore, tfc7, tpool4, tpool3= self.model(data_target)
outG_tgt = self.netG(tfc7, tpool4, tpool3)
outD_tgt_real_s, outD_tgt_real_c = self.netD(data_target_forD)
outD_tgt_fake_s, outD_tgt_fake_c = self.netD(outG_tgt)
# Updates.
# There are three sets of updates - (1) Discriminator, (2) Generator and (3) F network
# (1) Discriminator updates
lossD_src_real_s = cross_entropy2d(outD_src_real_s, domain_labels_src_real, size_average=self.size_average)#4类
lossD_src_fake_s = cross_entropy2d(outD_src_fake_s, domain_labels_src_fake, size_average=self.size_average)#4类
lossD_src_real_c = cross_entropy2d(outD_src_real_c, label_forD, size_average=self.size_average) #分割
lossD_tgt_real = cross_entropy2d(outD_tgt_real_s, domain_labels_tgt_real, size_average=self.size_average)#4类
lossD_tgt_fake = cross_entropy2d(outD_tgt_fake_s, domain_labels_tgt_fake, size_average=self.size_average) #4类
self.optimD.zero_grad()
lossD = lossD_src_real_s + lossD_src_fake_s + lossD_src_real_c + lossD_tgt_real + lossD_tgt_fake
lossD /= len(data_source)
lossD.backward(retain_graph=True)
self.optimD.step()
# (2) Generator updates
self.optimG.zero_grad()
lossG_src_adv_s = cross_entropy2d(outD_src_fake_s, domain_labels_src_real,size_average=self.size_average) #4分类?
lossG_src_adv_c = cross_entropy2d(outD_src_fake_c, label_forD,size_average=self.size_average) #分割?
lossG_tgt_adv_s = cross_entropy2d(outD_tgt_fake_s, domain_labels_tgt_real,size_average=self.size_average)#4分类?
lossG_src_mse = F.l1_loss(outG_src,data_source_forD)#看生成质量
lossG_tgt_mse = F.l1_loss(outG_tgt,data_target_forD)#看生成质量
lossG = lossG_src_adv_c + 0.1*(lossG_src_adv_s+ lossG_tgt_adv_s) + self.l1_weight * (lossG_src_mse + lossG_tgt_mse
lossG /= len(data_source)
lossG.backward(retain_graph=True)
self.optimG.step()
# (3) F network updates
self.optim.zero_grad()
lossC = cross_entropy2d(score, labels_source,size_average=self.size_average) #分割 全图
lossF_src_adv_s = cross_entropy2d(outD_src_fake_s, domain_labels_tgt_real,size_average=self.size_average)#4分类
lossF_tgt_adv_s = cross_entropy2d(outD_tgt_fake_s, domain_labels_src_real,size_average=self.size_average)#4分类
lossF_src_adv_c = cross_entropy2d(outD_src_fake_c, label_forD,size_average=self.size_average)#分割 小图
lossF = lossC + self.adv_weight*(lossF_src_adv_s + lossF_tgt_adv_s) + self.c_weight*lossF_src_adv_c
lossF /= len(data_source)
lossF.backward()
self.optim.step()