目录
1--CVAE模型
2--代码实例
简单介绍:
与VAE类似,只不过模型的输入需要考虑图片和条件(condition)的融合,融合结果通过一个 encoder 映射到标准分布(均值和方差),从映射的标准分布中随机采样一个样本,样本也需要和条件进行融合,最后通过 decoder 重构图片;
由于模型的输入是图片和条件的融合,因此模型学习了基于条件的图片生成;
计算源图片和重构图片之间的损失,具体损失函数的推导可以参考:变分自编码器(VAE)
下面的 CVAE 中,用了最简单的融合方式(concat)将条件 Y 与输入 X 融合形成X_given_Y,同理条件 Y 与 X_given_Y 融合形成 z_given_Y;
import torch
import torch.nn as nn
class VAE(nn.Module):
def __init__(self, in_features, latent_size, y_size=0):
super(VAE, self).__init__()
self.latent_size = latent_size
self.encoder_forward = nn.Sequential( # encoder
nn.Linear(in_features + y_size, in_features),
nn.LeakyReLU(),
nn.Linear(in_features, in_features),
nn.LeakyReLU(),
nn.Linear(in_features, self.latent_size * 2)
)
self.decoder_forward = nn.Sequential( # decoder
nn.Linear(self.latent_size + y_size, in_features),
nn.LeakyReLU(),
nn.Linear(in_features, in_features),
nn.LeakyReLU(),
nn.Linear(in_features, in_features),
nn.Sigmoid()
)
def encoder(self, X): # encode
out = self.encoder_forward(X) # 这里通过一个encoder生成均值和标准差
mu = out[:, :self.latent_size] # 输出的前半部分作为均值
log_var = out[:, self.latent_size:] # 后半部分作为标准差
return mu, log_var
def decoder(self, z): # decode
mu_prime = self.decoder_forward(z)
return mu_prime
def reparameterization(self, mu, log_var): # reparameterization
epsilon = torch.randn_like(log_var)
z = mu + epsilon * torch.sqrt(log_var.exp())
return z
def loss(self, X, mu_prime, mu, log_var): # cal loss
reconstruction_loss = torch.mean(torch.square(X - mu_prime).sum(dim=1))
latent_loss = torch.mean(0.5 * (log_var.exp() + torch.square(mu) - log_var).sum(dim=1))
return reconstruction_loss + latent_loss
def forward(self, X, *args, **kwargs):
mu, log_var = self.encoder(X) # encode
z = self.reparameterization(mu, log_var) # generate z by reparameterization
mu_prime = self.decoder(z) # decode
return mu_prime, mu, log_var
class CVAE(VAE):
def __init__(self, in_features, latent_size, y_size):
super(CVAE, self).__init__(in_features, latent_size, y_size)
def forward(self, X, y = None, *args, **kwargs):
y = y.to(next(self.parameters()).device)
X_given_Y = torch.cat((X, y.unsqueeze(1)), dim = 1)
mu, log_var = self.encoder(X_given_Y)
z = self.reparameterization(mu, log_var)
z_given_Y = torch.cat((z, y.unsqueeze(1)), dim = 1)
mu_prime_given_Y = self.decoder(z_given_Y)
return mu_prime_given_Y, mu, log_var
简单的损失计算代码:
def loss(self, X, mu_prime, mu, log_var): # cal loss
reconstruction_loss = torch.mean(torch.square(X - mu_prime).sum(dim=1))
latent_loss = torch.mean(0.5 * (log_var.exp() + torch.square(mu) - log_var).sum(dim=1))
return reconstruction_loss + latent_loss
完整代码参考:liujf69/VAE