AIGC笔记--CVAE模型的搭建

目录

1--CVAE模型

2--代码实例


1--CVAE模型

简单介绍:

        与VAE类似,只不过模型的输入需要考虑图片和条件(condition)的融合,融合结果通过一个 encoder 映射到标准分布(均值和方差),从映射的标准分布中随机采样一个样本,样本也需要和条件进行融合,最后通过 decoder 重构图片;

        由于模型的输入是图片和条件的融合,因此模型学习了基于条件的图片生成;

        计算源图片和重构图片之间的损失,具体损失函数的推导可以参考:变分自编码器(VAE)

2--代码实例

        下面的 CVAE 中,用了最简单的融合方式(concat)将条件 Y 与输入 X 融合形成X_given_Y,同理条件 Y 与 X_given_Y 融合形成 z_given_Y;

import torch
import torch.nn as nn

class VAE(nn.Module):
    def __init__(self, in_features, latent_size, y_size=0):
        super(VAE, self).__init__()

        self.latent_size = latent_size

        self.encoder_forward = nn.Sequential( # encoder
            nn.Linear(in_features + y_size, in_features),
            nn.LeakyReLU(),
            nn.Linear(in_features, in_features),
            nn.LeakyReLU(),
            nn.Linear(in_features, self.latent_size * 2)
        )

        self.decoder_forward = nn.Sequential( # decoder
            nn.Linear(self.latent_size + y_size, in_features),
            nn.LeakyReLU(),
            nn.Linear(in_features, in_features),
            nn.LeakyReLU(),
            nn.Linear(in_features, in_features),
            nn.Sigmoid()
        )

    def encoder(self, X): # encode
        out = self.encoder_forward(X) # 这里通过一个encoder生成均值和标准差
        mu = out[:, :self.latent_size] # 输出的前半部分作为均值
        log_var = out[:, self.latent_size:] # 后半部分作为标准差
        return mu, log_var

    def decoder(self, z): # decode
        mu_prime = self.decoder_forward(z)
        return mu_prime

    def reparameterization(self, mu, log_var): # reparameterization
        epsilon = torch.randn_like(log_var)
        z = mu + epsilon * torch.sqrt(log_var.exp())
        return z

    def loss(self, X, mu_prime, mu, log_var): # cal loss
        reconstruction_loss = torch.mean(torch.square(X - mu_prime).sum(dim=1))
        latent_loss = torch.mean(0.5 * (log_var.exp() + torch.square(mu) - log_var).sum(dim=1))
        return reconstruction_loss + latent_loss

    def forward(self, X, *args, **kwargs):
        mu, log_var = self.encoder(X) # encode
        z = self.reparameterization(mu, log_var) # generate z by reparameterization
        mu_prime = self.decoder(z) # decode
        return mu_prime, mu, log_var

class CVAE(VAE):
    def __init__(self, in_features, latent_size, y_size):
        super(CVAE, self).__init__(in_features, latent_size, y_size)

    def forward(self, X, y = None, *args, **kwargs):
        y = y.to(next(self.parameters()).device)
        X_given_Y = torch.cat((X, y.unsqueeze(1)), dim = 1)

        mu, log_var = self.encoder(X_given_Y)
        z = self.reparameterization(mu, log_var)
        z_given_Y = torch.cat((z, y.unsqueeze(1)), dim = 1)

        mu_prime_given_Y = self.decoder(z_given_Y)
        return mu_prime_given_Y, mu, log_var

简单的损失计算代码:

def loss(self, X, mu_prime, mu, log_var): # cal loss
        reconstruction_loss = torch.mean(torch.square(X - mu_prime).sum(dim=1))
        latent_loss = torch.mean(0.5 * (log_var.exp() + torch.square(mu) - log_var).sum(dim=1))
        return reconstruction_loss + latent_loss

完整代码参考:liujf69/VAE

你可能感兴趣的:(AIGC学习笔记,AIGC,笔记)