StarGAN: Unified Generative Adversarial Networks for Multi-Domain Image-to-Image Translation

#Problem

  • existing models are both inefficient and ineffective in such multi-domain image translation tasks
  • incapable of jointly training domains from different datasets

New method

  • Stargan, a novel and scalable approch that can perform image-to-image translations for multiple domains using only a single model
  • A mask vector to domain label enables joint training between domains of different datasets

##Star Generative Adversarial Networks

###1. Multi-Domain Image-to-Image Translation

notation meaning
x x x input image
y y y output image
c c c target domain label
c ′ c' c original domain label
Dsrc(x) a probability distribution over sources given by D
Dcls(c|x) a probability distribution over domain labels computed by D
λcls hyper-parameters that control the relative importance of domain classification and reconstruction losses
λrec hyper-parameters control the relative importance of reconstruction losses
m a mask vector
[ ⋅ ] [\cdot] [] concatenation
c i c_i ci a vector for the labels of the i i i-th dataset
x ^ \hat{x} x^ sampled uniformly along a straight line between a pair of a real and a generated images
λ g p λ_{gp} λgp hyper-parameters control the gradient penalty
  • Goals:To train a single generator G that learns mappings among multiple domains
  • train G to translate an input image x into an output image y conditioned on the target domain label c, G(x, c) → y
  • Discriminator produces probability distributions over both sources and domain labels, D : x → {Dsrc(x), Dcls(x)}, in order to allows a single discriminator to control multiple domains.
    ####Adversarial Loss
    (1) L a d v = E x [ l o g D s r c ( x ) ] + E x , c [ l o g ( 1 − D s r c ( G ( x , c ) ) ] \mathcal{L}_{adv} = \mathbb{E}_x [log D_{src}(x)] + \mathbb{E}_{x,c}[log (1- D_{src}(G(x, c))]\tag{1} Ladv=Ex[logDsrc(x)]+Ex,c[log(1Dsrc(G(x,c))](1)

Dsrc(x) as a probability distribution over sources given by D. The generator G tries to minimize this objective, while the discriminator D tries to maximize it
####Domain Classification Loss

  • add an auxiliary classifier on top of D and impose the domain classification loss when optimizing both D and G
  • decompose the objective into two terms: a domain classification loss of
    real images used to optimize D, and a domain classification loss of fake images used to optimize G
    (2) L c l s r = E x , c ′ [ − l o g D c l s ( c ′ ∣ x ) ] \mathcal{L}_{cls}^r = \mathbb{E}_{x,c'}[-log D_{cls}(c'|x)]\tag{2} Lclsr=Ex,c[logDcls(cx)](2)
    (3) L c l s f = E x , c [ − l o g D c l s ( c ∣ G ( x , c ) ) ] \mathcal{L}_{cls}^f = \mathbb{E}_{x,c}[-log D_{cls}(c|G(x,c))]\tag{3} Lclsf=Ex,c[logDcls(cG(x,c))](3)
    ####Reconstruction Loss
  • problem: minimizing the losses(Eqs. (1) and (3)) does not guarantee that translated images preserve the content of its input images while changing only the domain-related part of the inputs
  • method: apply a cycle consistency loss to the generator
    L r e c = E x , c , c ′ [ ∣ ∣ x − G ( G ( x , c ) , c ′ ) ∣ ∣ 1 ] \mathcal{L}_{rec} = \mathbb{E}_{x,c,c'}[||x-G(G(x,c), c')||_1] Lrec=Ex,c,c[xG(G(x,c),c)1]
    G takes in the translated image G(x, c) and the original domain label c as input and tries to reconstruct the original image x. We adopt the L1 norm as our reconstruction loss.
    Note that we use a single generator twice, first to translate an original image into an image in the target domain and then to reconstruct the original image from the translated image.
    ####Full Objective
    L D = − L a d v + λ c l s L c l s r \mathcal{L}_D = -\mathcal{L}_{adv} + \lambda_{cls}\mathcal{L}_{cls}^r LD=Ladv+λclsLclsr
    L G = L a d v + λ c l s L c l s f + λ r e c L r e c \mathcal{L}_G = \mathcal{L}_{adv}+\lambda_{cls}\mathcal{L}_{cls}^f+\lambda_{rec}\mathcal{L}_{rec} LG=Ladv+λclsLclsf+λrecLrec

We use λ c l s λ_{cls} λcls = 1 and λ r e c λ_{rec} λrec = 10 in all of our experiments
###2. Training with Multiple Datasets

  • Problem:the complete information on the label vector c ′ c' c is required when reconstructing the input image x x x from the translated image G ( x , c ) G(x, c) G(x,c)
    ####Mask Vector
  • introduce a mask vector m m m that allows StarGAN to ignore unspecified
    labels and focus on the explicitly known label provided by a particular dataset.
  • use an n-dimensional one-hot vector to represent m m m, with n n n being the number of datasets. In addition, we define a unified version of the label as a vector

c ~ = [ c 1 , c 2 . . . c n , m ] \tilde{c} = [c_1,c_2...c_n,m] c~=[c1,c2...cn,m]
For the remaining n n n-1 unknown labels we simply assign zero values

####Training Strategy

  • use the domain label c ~ \tilde{c} c~ as input to the generator
  • the generator learns to ignore the unspecified labels, which are zero vectors, and focus on the explicitly given label
  • extend the auxiliary classifier of the discriminator to generate probability distributions over labels for all datasets
  • train the model in a multi-task learning setting, where the discriminator tries to minimize only the classification error associated to the known label
  • Under these settings, by alternating between CelebA and RaFD the discriminator learns all of the discriminative features for both datasets, and the generator learns to control all the labels in both datasets.

##Implementation
###Improved GAN Training

  • replace Eq. (1) with Wasserstein GAN objective with gradient penalty defined as

KaTeX parse error: Got function '\hat' with no arguments as subscript at position 102: …_{gp}\mathbb{E}_̲\hat{x}[||\nabl…

where x ^ \hat{x} x^ is sampled uniformly along a straight line between a pair of a real and a generated images. We use λ g p λ_{gp} λgp = 10 for all experiments

###Network Architecture

  • generator network composed of two convolutional layers with the stride size of two for downsampling, six residual blocks, and two transposed convolutional layers with the stride size of two for upsampling.
  • use instance normalization for the generator but no normalization for
    the discriminator.
  • leverage PatchGANs for the discriminator network, which classifies whether local image patches are real or fake.

你可能感兴趣的:(GAN论文学习,论文学习,GAN)