StarGAN v2: Diverse Image Synthesis for Multiple Domainsp
github:https://github.com/clovaai/stargan-v2
模型四个:
模型主要架构由四部分组成
①Generator、②Mapping network、③Style encoder、④Discriminator
Generator:G网络
生成模型G将输入图片x转换成 输出图片G(x,s),反映了一个领域的独有风格编码s。s是有Maping network F或者风格编码E生成。s被设计表征为领域y的风格。
Mapping network: F网络
给定一个隐变量z和一个领域y,mapping network F生成一个风格 s=F(z),F由一层MLP和多个输出分支组成,分支代表了领域的所有风格。通过随机采样不同的隐变量z,F能高效的学习各领域的风格表征。
Style encoder:E网络
给定图片X和相应的领域y,encoder E挖掘风格编码 s=E(x),E和上面的F类似。使用不同的参考图片,E可以产生不同风格的编码s。
Discriminator:D网络
判别器D是多任务判别器,由多个输出分支组成,每个分支Dy学习一个二分类,判断图片X是否是真实的y领域,或者由G生成的假图 G(x,s).
对抗损失:
训练期间,随机采样隐变量z和领域y,通过F函数,生成风格编码s ,
风格编码: s ˉ = F y ˉ ( z ) 风格编码: \bar s=F_{\bar y}(z) 风格编码:sˉ=Fyˉ(z)
生成网络G,将图片X和上面的风格编码S作为输入,生成图片:
生成图片: G ( x , s ˉ ) 生成图片: G(x,\bar s) 生成图片:G(x,sˉ)
对抗损失函数为:
L a d v = E x , y [ l o g D y ( x ) ] + E x , y ˉ , z [ l o g ( 1 − D y ˉ ( G ( x , s ˉ ) ) ) ] ( 1 ) L_{adv}=E_{x,y}[logD_y(x)]+E_{x,\bar y,z}[log(1-D_{\bar y}(G(x,\bar s)))] \qquad (1) Ladv=Ex,y[logDy(x)]+Ex,yˉ,z[log(1−Dyˉ(G(x,sˉ)))](1)
D_y:是y领域的判别器
F: 是提供y领域的风格编码s
G:输入图片和风格编码s,生成新图片
风格重构损失
使得前后的风格距离最小
L s t y = E x , y ˉ , z [ ∣ ∣ s ˉ − E y ˉ ( G ( x , s ˉ ) ) ∣ ∣ 1 ] ( 2 ) L_{sty}=E_{x,\bar y ,z}[||\bar s-E_{\bar y}(G(x,\bar s))||_1] \qquad (2) Lsty=Ex,yˉ,z[∣∣sˉ−Eyˉ(G(x,sˉ))∣∣1](2)
E网络用来生成风格,上面有提到
(前面的E是求均值,后面的 E y ˉ E_{\bar y} Eyˉ是网络)
为了使生成器G产生更多风格图片,使得不同风格图片的距离尽可能大
L d s = E x , y ˉ , z 1 , z 2 [ ∣ ∣ G ( x , s ˉ 1 ) − G ( x , s ˉ 2 ) ∣ ∣ 1 ] ( 3 ) L_{ds}=E_{x,\bar y,z_1,z_2}[||G(x,\bar s_1)-G(x,\bar s_2)||_1]\qquad (3) Lds=Ex,yˉ,z1,z2[∣∣G(x,sˉ1)−G(x,sˉ2)∣∣1](3)
s ˉ 1 和 s ˉ 2 \bar s_1和\bar s_2 sˉ1和sˉ2是F在隐变量 z 1 和 z 2 条件下生成的 s ˉ i = F y ˉ ( z i ) f o r i ∈ 1 , 2 z_1和z_2条件下生成的 \bar s_i =F_{\bar y}(z_i) \quad for \quad i \in {1,2} z1和z2条件下生成的sˉi=Fyˉ(zi)fori∈1,2
循环一致损失
使得经过变换后的X与之前的X距离最小
L c y c = E x , y , y ˉ , z [ ∣ ∣ x − G ( G ( x , s ˉ ) , s ^ ) ∣ ∣ 1 ] ( 4 ) L_{cyc}=E_{x,y,\bar y,z}[||x-G(G(x,\bar s),\hat s)||_1] \qquad (4) Lcyc=Ex,y,yˉ,z[∣∣x−G(G(x,sˉ),s^)∣∣1](4)
s ^ = E y ( x ) \hat s=E_y(x) s^=Ey(x)是E网络估计的风格code,y是原始的X的领域,使生成器G学会去保留原始的X的特征
将上面的损失函数求和,其中DS是最大化距离(所有用减号),其他是最小化
有多个卷积层组成
输入:图片 和风格编码s,
输入图片x shape [batch_size,3,512,512] ;s shape [batch_size,64]
输出:x_fake shape: [4,3,512,512]
self.from_rgb = nn.Conv2d(3, dim_in, 3, 1, 1) #(in_channels,out_channels,kernel_size,stride,padding)
self.encode = nn.ModuleList()
self.decode = nn.ModuleList()
self.to_rgb = nn.Sequential(
nn.InstanceNorm2d(dim_in, affine=True),
nn.LeakyReLU(0.2),
nn.Conv2d(dim_in, 3, 1, 1, 0))
图片和风格编码会作如下融合,风格编码输出,gamma 和beta,然后和norm后的图片X进行如下计算 (1 + gamma) * self.norm(x) + beta
def forward(self, x, s):
h = self.fc(s)
h = h.view(h.size(0), h.size(1), 1, 1)
gamma, beta = torch.chunk(h, chunks=2, dim=1)
return (1 + gamma) * self.norm(x) + beta
x_fake = nets.generator(x_real, s_trg, masks=masks)
out = nets.discriminator(x_fake, y_trg)
loss_fake = adv_loss(out, 0)
由多层线性层组成
模型输入是隐变量z(随机生成)shape: [batch_size,dim],和风格领域y shape:[batch_size]
经过共享层后,对每个风格领域分别输出风格编码s, 输出 shpae [batch_size,64]
def forward(self, z, y):
h = self.shared(z)
out = []
for layer in self.unshared:
out += [layer(h)]
out = torch.stack(out, dim=1) # (batch, num_domains, style_dim)
idx = torch.LongTensor(range(y.size(0))).to(y.device)
s = out[idx, y] # (batch, style_dim)
return s
输入x,y:
图片x shape [batch_size,3,256,256],领域风格 y shape [batch_size]
输出:
风格编码s shape [batch_size]
def forward(self, x, y):
h = self.shared(x)
h = h.view(h.size(0), -1)
out = []
for layer in self.unshared:
out += [layer(h)]
out = torch.stack(out, dim=1) # (batch, num_domains, style_dim)
idx = torch.LongTensor(range(y.size(0))).to(y.device)
s = out[idx, y] # (batch, style_dim)
return s
1) 模型输入输出
输入图片X,及对应的风格y
x shape [4,3,256,256] :[batch_size,channel_num,WH]
y为风格编码[1,1,1,0 ]:batch_size长度大小
输出shap[batch_size]:batch_size长度大小
out = self.main(x) # x shape [4,3,256,256]
out = out.view(out.size(0), -1) # (batch, num_domains) # out shape ; [4,2]
idx = torch.LongTensor(range(y.size(0))).to(y.device)
out = out[idx, y] # (batch)
return out # out shape [4,]
使所有真实的图片风格预测为1,
使所有G网络生成的图片预测为0
out = nets.discriminator(x_real, y_org)
loss_real = adv_loss(out, 1)
二分类交叉熵损失
def adv_loss(logits, target):
assert target in [1, 0]
targets = torch.full_like(logits, fill_value=target)
loss = F.binary_cross_entropy_with_logits(logits, targets)
return loss
x_fake = nets.generator(x_real, s_trg, masks=masks)
out = nets.discriminator(x_fake, y_trg)
loss_fake = adv_loss(out, 0)
loss_reg = r1_reg(out, x_real)
def r1_reg(d_out, x_in):
# zero-centered gradient penalty for real images
batch_size = x_in.size(0)
grad_dout = torch.autograd.grad(
outputs=d_out.sum(), inputs=x_in,
create_graph=True, retain_graph=True, only_inputs=True
)[0]
grad_dout2 = grad_dout.pow(2)
assert(grad_dout2.size() == x_in.size())
reg = 0.5 * grad_dout2.view(batch_size, -1).sum(1).mean(0)
return reg
loss = loss_real + loss_fake + args.lambda_reg * loss_reg
x_fake = nets.generator(x_real, s_trg, masks=masks)
out = nets.discriminator(x_fake, y_trg)
loss_adv = adv_loss(out, 1)
# style reconstruction loss
s_trg = nets.style_encoder(x_ref, y_trg)
s_pred = nets.style_encoder(x_fake, y_trg)
loss_sty = torch.mean(torch.abs(s_pred - s_trg))
3)多样性损失函数
使得不同隐变量z 生成的风格编码s,最终生成的x_fake 距离尽可能大
x_fake = nets.generator(x_real, s_trg, masks=masks)
x_fake2 = nets.generator(x_real, s_trg2, masks=masks)
x_fake2 = x_fake2.detach()
loss_ds = torch.mean(torch.abs(x_fake - x_fake2))
s_org = nets.style_encoder(x_real, y_org)
x_rec = nets.generator(x_fake, s_org, masks=masks)
loss_cyc = torch.mean(torch.abs(x_rec - x_real))
loss = loss_adv + args.lambda_sty * loss_sty \
- args.lambda_ds * loss_ds + args.lambda_cyc * loss_cyc