GAN图像生成-pyotrch

生成对抗网络

生成对抗网络(GAN)是一种算法体系结构,它使用两个神经网络,使一个神经网络与另一个神经网络(因此称为“对抗性”)相互对立,以便生成可以传递给真实数据的新的合成数据实例。 它们广泛用于图像生成,视频生成和语音生成。
GAN图像生成-pyotrch_第1张图片
虽然大多数深度生成模型是通过最大化对数可能性或对数可能性的下限来训练的,但是GAN采取了根本不同的方法,不需要推理或明确计算数据可能性。 取而代之的是,有两个模型用于求解极小极大博弈:一个对数据进行采样的生成器和一个将数据分类为真实或生成的鉴别器。理论上,这些模型能够对任意复杂的概率分布进行建模。
简介
生成器:生成新的数据实例
鉴别器:尝试从真实数据集中区分生成的数据或伪造的数据。
判别算法尝试对输入数据进行分类; 也就是说,给定数据实例的特征,它们可以预测该数据所属的标签或类别。 因此,判别算法将特征映射到标签。 他们只关心这种相关性。另一种松散地说,生成算法则相反。 他们尝试预测给定特定标签的特征,而不是预测给定特定特征的标签。 在训练过程中,它们都从头开始,并且生成器通过训练时期学习塑造随机分布。

工作原理
生成网络被馈入的噪声可能以随机分布的形式出现,并从噪声中生成伪造数据。 来自生成器的伪数据被输入到鉴别器。 一旦训练完成,生成器应该能够从噪声中生成真实的数据。 这里有趣的事实是,生成器学会了生成有意义的图像,甚至没有实际看图像。
GAN图像生成-pyotrch_第2张图片
鉴别器或对抗网络充当生成器的对手。 它基本上是分类器或区分器,其功能是区分两个不同类别的数据。 在这里,这些类是真实数据(标记为1),而生成器生成的伪数据(标记为0)。

训练网络
关于训练GAN的重要一点是,永远不要一起训练这两个组件。而是在两个不同的阶段对网络进行训练,第一个阶段用于训练鉴别器并适当地更新权重,并且在下一步中,在禁用鉴别器训练的同时对生成器进行训练。阶段1在训练的第一阶段,将噪声作为随机数据(以分布的形式)发送给生成器。生成器创建一些随机图像,这些图像被提供给鉴别器。鉴别器还从真实图像的数据集中获取输入。鉴别器通过学习或评估输入数据的特征来学习将真实数据与伪数据区分开。鉴别器通过网络反向传播预测结果和实际结果之间的一些概率和差异,并更新鉴别器的权重。请记住,在此阶段中,反向传播会在鉴别符的结尾处停止,并且不会对生成器进行训练或更新。阶段2在此阶段中,直接将生成器生成的图像批处理作为鉴别器的输入。这次没有将真实图像提供给鉴别器。生成器通过欺骗鉴别器来学习,从而输出误报。鉴别器输出的概率根据实际结果进行评估,并且发生器的权重通过反向传播进行更新。请记住,此处在反向传播期间,不应该像以前一样更新和保持鉴别器的权重。

GAN图像生成-pyotrch_第3张图片
简单GAN的损失函数

生成器试图将以下函数最小化,而鉴别器试图将其最大化:
在这里插入图片描述
GAN的应用:

  1. 生成图像数据集的示例
  2. 生成人脸的照片
  3. 生成逼真的照片
  4. 图像到图像翻译
  5. 文本到图像翻译
  6. 语义到图像到照片翻译
  7. 照片到表情符号
  8. 面部老化
  9. 超分辨率
  10. 3D对象生成

简单GAN的问题
在实践中,GAN会遇到很多问题,尤其是在训练期间。 一种常见的故障模式涉及发生器折叠以仅产生单个样本或一小组非常相似的样本。 在这种情况下,生成器将学习使用单个图像或几个图像来欺骗鉴别器,以使其视为真实图像。 另一个问题是训练过程中的生成器和鉴别器振荡,而不是收敛到固定点。 另外,如果一个代理变得比另一个代理功能强大得多,则发送给另一代理的学习信号将变得无用,并且系统将无法学习。 要训练GAN,必须采用许多技巧,一种方法是使用深度卷积生成对抗网络

案例

  1. 导入相应的依赖包
import torch
import torch.nn as nn
import pandas as pd
import numpy as np 
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torch import autograd
from torch.autograd import Variable
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

  1. 定义常量
num_epochs = 10
n_critic = 5
display_step = 300
  1. 处理数据集
class FashionMNIST(Dataset):
    def __init__(self, transform=None):
        self.transform = transform
        fashion_df = pd.read_csv('./fashion-mnist_train.csv')
        print(fashion_df.shape)
        self.labels = fashion_df.label.values
        self.images = fashion_df.iloc[:, 1:].values.astype('uint8').reshape(-1, 28, 28)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        label = self.labels[idx]
        img = Image.fromarray(self.images[idx])
        if self.transform:
            img = self.transform(img)
        return img, label
transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
dataset = FashionMNIST(transform=transform)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)
  1. 定义模型
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.label_emb = nn.Embedding(10, 10)
        self.model = nn.Sequential(
            nn.Linear(794, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x, labels):
        x = x.view(x.size(0), 784)
        c = self.label_emb(labels)
        x = torch.cat([x, c], 1)
        out = self.model(x)
        return out.squeeze()
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.label_emb = nn.Embedding(10, 10)
        self.model = nn.Sequential(
            nn.Linear(110, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 784),
            nn.Tanh()
        )
    def forward(self, z, labels):
        z = z.view(z.size(0), 100)
        c = self.label_emb(labels)
        x = torch.cat([z, c], 1)
        out = self.model(x)
        return out.view(x.size(0), 28, 28)
generator = Generator().cuda()
discriminator = Discriminator().cuda()
  1. 定义损失函数
criterion = nn.BCELoss()
def generator_train_step(batch_size, discriminator, generator, g_optimizer, criterion):
    g_optimizer.zero_grad()
    z = Variable(torch.randn(batch_size, 100)).cuda()
    fake_labels = Variable(torch.LongTensor(np.random.randint(0, 10, batch_size))).cuda()
    fake_images = generator(z, fake_labels)
    validity = discriminator(fake_images, fake_labels)
    g_loss = criterion(validity, Variable(torch.ones(batch_size)).cuda())
    g_loss.backward()
    g_optimizer.step()
    return g_loss.item()


def discriminator_train_step(batch_size, discriminator, generator, d_optimizer, criterion, real_images, labels):
    d_optimizer.zero_grad()

    # train with real images
    real_validity = discriminator(real_images, labels)
    real_loss = criterion(real_validity, Variable(torch.ones(batch_size)).cuda())
    
    # train with fake images
    z = Variable(torch.randn(batch_size, 100)).cuda()
    fake_labels = Variable(torch.LongTensor(np.random.randint(0, 10, batch_size))).cuda()
    fake_images = generator(z, fake_labels)
    fake_validity = discriminator(fake_images, fake_labels)
    fake_loss = criterion(fake_validity, Variable(torch.zeros(batch_size)).cuda())
    
    d_loss = real_loss + fake_loss
    d_loss.backward()
    d_optimizer.step()
    return d_loss.item()
  1. 定义优化器
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=1e-4)
g_optimizer = torch.optim.Adam(generator.parameters(), lr=1e-4)

  1. 开始运行
for epoch in range(num_epochs):
    print('Starting epoch {}...'.format(epoch))
    for i, (images, labels) in enumerate(tqdm(data_loader)):
        real_images = Variable(images).cuda()
        labels = Variable(labels).cuda()
        generator.train()
        batch_size = real_images.size(0)
        d_loss = discriminator_train_step(len(real_images), discriminator,
                                          generator, d_optimizer, criterion,
                                          real_images, labels)
        

        g_loss = generator_train_step(batch_size, discriminator, generator, g_optimizer, criterion)

    generator.eval()
    print('g_loss: {}, d_loss: {}'.format(g_loss, d_loss))
    z = Variable(torch.randn(9, 100)).cuda()
    labels = Variable(torch.LongTensor(np.arange(9))).cuda()
    sample_images = generator(z, labels).unsqueeze(1).data.cpu()
    grid = make_grid(sample_images, nrow=3, normalize=True).permute(1,2,0).numpy()
    plt.imshow(grid)
    plt.show()
  1. 可视化
z = Variable(torch.randn(100, 100)).cuda()
labels = Variable(torch.LongTensor([i for _ in range(10) for i in range(10)])).cuda()
sample_images = generator(z, labels).unsqueeze(1).data.cpu()
grid = make_grid(sample_images, nrow=10, normalize=True).permute(1,2,0).numpy()
fig, ax = plt.subplots(figsize=(15,15))
ax.imshow(grid)
_ = plt.yticks([])
_ = plt.xticks(np.arange(15, 300, 30), ['T-Shirt', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'], rotation=45, fontsize=20)

经过30轮后:
GAN图像生成-pyotrch_第4张图片
GAN图像生成-pyotrch_第5张图片
GAN图像生成-pyotrch_第6张图片

你可能感兴趣的:(深度学习,深度学习,pytorch)