StarGAN多数据集训练

我们进行多数据集训练时,引入了掩码向量mask vector,我们引入一个掩码向量m,允许StarGAN忽略未知的标签。
在这里插入图片描述
在用多个数据集训练时把mask向量添加到生成器中,生成器G忽略未指定的标签(零向量),并关注给定的标签。除了输入标签的维度外,生成器的结构与单个数据集的训练完全相同。
因为训练涉及两个数据集,所以在训练时,判别器每次只针对当前已知的标签,来最小化分类误差。例如,当训练是基于CelebA时,判别器最小化的目标只是与CelebA属性相关的分类误差。通过在CelebA和Fer2013之间交替变换,判别器学到了两个数据集上的所有特征。
StarGAN多数据集训练_第1张图片

def label2onehot(self, labels, dim):
        """将标签索引转换为一个one-hot量。"""
        batch_size = labels.size(0)
        out = torch.zeros(batch_size, dim)
        out[np.arange(batch_size), labels.long()] = 1
        return out
mask_celeba = self.label2onehot(torch.zeros(x_fixed.size(0)), 2).to(self.device)  # 对于celeba我们用torch.zeros
mask_fer = self.label2onehot(torch.ones(x_fixed.size(0)), 2).to(self.device)  # 对于Fer2013我们用torch.ones

if dataset == 'CelebA':
mask = self.label2onehot(torch.zeros(x_real.size(0)), 2)
c_org = torch.cat([c_org, zero, mask], dim=1)
c_trg = torch.cat([c_trg, zero, mask], dim=1)

elif dataset == 'Fer':
mask = self.label2onehot(torch.ones(x_real.size(0)), 2)
c_org = torch.cat([zero, c_org, mask], dim=1)
c_trg = torch.cat([zero, c_trg, mask], dim=1)

训练:
同时使用CelebA和Fer2013训练
bash
$ python main.py–mode=‘train’ --dataset=‘Both’ --cdim=5 --c2dim=8 --imagesize=256–numiters=200000 --numitersdecay=100000
测试:
在celeba上进行表情合成
bash
$ python main.py–mode=‘test’ --dataset=‘Both’ --cdim=5 --c2dim=8 --imagesize=256–testmodel=200000

你可能感兴趣的:(StarGAN多数据集训练)