图像生成—使用GANs给出代码示例

文章目录

  • 图像生成简单介绍—使用GANs给出代码示例
    • 1. 什么是生成对抗网络(GANs)
    • 2. 准备数据集
    • 3. 构建生成器和判别器
    • 4. 训练GAN模型
    • 5. 生成新图像
    • 6. 总结

图像生成简单介绍—使用GANs给出代码示例

图像生成是指使用计算机算法生成图像的过程。这些图像可以是真实的照片、绘画、3D渲染或者是完全想象的图像。图像生成技术涵盖了一系列算法,包括基于规则的方法、基于统计学的方法、深度学习等。

基于规则的方法通常是通过手动设计规则来生成图像。例如,计算机图形学中的几何建模就是一种基于规则的方法,通过定义几何形状、光照、材质等参数来生成图像。

基于统计学的方法则是通过对大量图像数据进行分析,学习数据中的规律,然后使用这些规律来生成新的图像。这些方法包括基于纹理的方法、基于样式的方法等。

深度学习方法则是最近几年兴起的一种生成图像的方法,它利用神经网络模型进行训练,以学习输入图像和输出图像之间的映射关系。这些模型包括生成对抗网络(GAN)、变分自编码器(VAE)等,能够生成高质量、逼真的图像。

图像生成技术在许多领域都有应用,例如计算机游戏、电影制作、虚拟现实、视觉特效等。同时,它也在艺术创作、产品设计、医学图像处理等领域得到广泛应用。

图像生成是一种涉及生成新图像样本的技术,通常基于深度学习模型。在这份教程中,我们将介绍如何使用生成对抗网络(GANs)生成图像。

1. 什么是生成对抗网络(GANs)

生成对抗网络(GANs)是一种深度学习技术,由两个独立的神经网络组成:生成器(Generator)和判别器(Discriminator)。生成器的任务是生成与真实图像类似的图像,而判别器的任务是区分生成的图像是否为真实图像。这两个网络相互竞争,生成器试图生成越来越真实的图像,而判别器试图越来越准确地识别生成的图像。

2. 准备数据集

首先,我们需要一个用于训练的图像数据集。这里,我们以CIFAR-10数据集为例进行说明。CIFAR-10包含10个类别的60000张32x32彩色图像。我们将使用PyTorch框架,首先需要安装并导入相应的库:

!pip install torch torchvision

import torch
import torchvision
import torchvision.transforms as transforms

接下来,加载和预处理数据:

transform = transforms.Compose([
    transforms.Resize(64),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

3. 构建生成器和判别器

接下来,我们需要构建生成器和判别器网络。这里,我们使用卷积层和反卷积层构建网络。生成器的输入是随机噪声,输出是生成的图像;判别器的输入是图像,输出是它判断图像是否为真实图像的概率。

import torch.nn as nn

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(100, 64 * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(64 * 8),
            nn.ReLU(True),
            nn.ConvTranspose2d(64 * 8, 64 * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(64 * 4, 64 * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(64 * 2, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.main(x)
        return x

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 64 * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64 * 2, 64 * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64 * 4, 64 * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64 * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.main(x)
        return x

generator = Generator()
discriminator = Discriminator()

4. 训练GAN模型

为了训练模型,我们需要定义损失函数和优化器。这里我们使用二元交叉熵损失(Binary Cross Entropy Loss)和Adam优化器。

criterion = nn.BCELoss()
optimizer_g = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

现在我们可以开始训练模型。在每个训练循环中,我们首先训练判别器,然后训练生成器。

save_every = 10  # 保存模型的频率,每训练十次保存一次
start_time = time.time()  # 获取开始时间

log_file = open("training_log.txt", "a")

num_epochs = 100
for epoch in range(num_epochs):
    for i, (images, _) in enumerate(trainloader):
        # Train discriminator
        discriminator.zero_grad()
        real_images = images.to(device)
        batch_size = real_images.size(0)
        label = torch.full((batch_size,), 1, device=device, dtype=torch.float)
        output = discriminator(real_images).view(-1)
        errD_real = criterion(output, label)
        errD_real.backward()

        noise = torch.randn(batch_size, 100, 1, 1, device=device)
        fake_images = generator(noise).to(device)
        label.fill_(0.0)
        output = discriminator(fake_images.detach()).view(-1)
        errD_fake = criterion(output, label)
        errD_fake.backward()
        errD = errD_real + errD_fake
        optimizer_d.step()

        # Train generator
        generator.zero_grad()
        label.fill_(1)
        output = discriminator(fake_images).view(-1)
        errG = criterion(output, label)
        errG.backward()
        optimizer_g.step()

    # 保存模型
    if (epoch+1) % save_every == 0:
        generator_name = f'generator_{epoch+1}.pth'
        discriminator_name = f'discriminator_{epoch+1}.pth'
        torch.save(generator, generator_name)
        torch.save(discriminator, discriminator_name)


    # 将输出打印到终端并保存到log文件中
    log_str = f'Epoch [{epoch+1}/{num_epochs}] Loss_D: {errD.item():.4f} Loss_G: {errG.item():.4f} Time: {time.time()-start_time:.2f}s\n'
    print(log_str)
    log_file.write(log_str)
    log_file.flush()  # 刷新缓冲区

log_file.close()

代码开始运行

图像生成—使用GANs给出代码示例_第1张图片

5. 生成新图像

训练完成后,我们可以使用生成器来生成新的图像。

import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(100, 64 * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(64 * 8),
            nn.ReLU(True),
            nn.ConvTranspose2d(64 * 8, 64 * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(64 * 4, 64 * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(64 * 2, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.main(x)
        return x

# generator = Generator()
# 加载保存的模型文件,'model/generator_82.pth'填写自己的生成模型文件路径
generator = torch.load('model/generator_82.pth', map_location=device)
generator.to(device)

with torch.no_grad():
   noise = torch.randn(1, 100, 1, 1, device=device)
   generated_image = generator(noise)


def imshow(img):
    img = img / 2 + 0.5  # unnormalize
    np_img = img.cpu().numpy()
    plt.imshow(np.transpose(np_img, (1, 2, 0)))
    plt.show()

imshow(torchvision.utils.make_grid(generated_image.cpu()))

通过上述代码通过加载生成器模型,可以生成图片,这个模型训练的次数一般越多越好,我训练的82次,也就图一乐,虽说图片啥也不是,但是勉强有图片轮廓。
图像生成—使用GANs给出代码示例_第2张图片

如果不调用生成模型,将上述代码修改为如下

generator = Generator()

# 加载保存的模型文件,'model/generator_82.pth'填写自己的生成模型文件路径
# generator = torch.load('model/generator_82.pth', map_location=device)
generator.to(device)

重新运行代码,可得到下面随机噪声的图片,说明我们生成器模型是有点作用的,刚才的图片并不是随机噪声,随机噪声是下面这种图片。

图像生成—使用GANs给出代码示例_第3张图片

6. 总结

在本教程中,我们介绍了如何使用生成对抗网络(GANs)生成图像。我们以 CIFAR-10 数据集为例,构建了生成器和判别器网络,并进行了训练。最后,我们使用训练好的生成器生成了新的图像。

GANs 是一种非常强大的图像生成技术,但训练过程可能具有挑战性。为了获得高质量的生成图像,可能需要调整网络结构、损失函数和训练参数。此外,还有许多 GANs 的变体可供尝试,如 Deep Convolutional GANs(DCGANs)、Wasserstein GANs(WGANs)等。

你可能感兴趣的:(图像重建,计算机视觉,人工智能,图像生成,GANs,图像处理)