基于 GAN的MNIST 手写字体生成

(参考 b站大神 日月光华 教程复现)

原理:

这里通过一个简单的手写字体生产网络了解GAN的基本原理,主要包含generator 和 discriminator两部分,其中generator 的输入 是正太分布噪声,输出是28x28的图像, discriminator 的输入是28x28的图像,分别是真是图像和generator生成的图像,输出是概率值。对抗的含义体现在优化目标上,generator 的目标是使输出的图像尽量被discriminator判别为真,而discriminator的目标是尽量将噪声生成的图像判别为假,真实图像判别为真。感兴趣的小伙伴可以在此网络上进行修改:

  • 添加可变的学习率

  • 添加卷积层

  • 增加网络深度

话不多说,上代码,可运行:


import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import numpy as np
import matplotlib.pyplot as plt
from   torchvision import transforms


#draw , pred
def draw_genImg(model, input):
    pred = np.squeeze(model(input).detach().cpu().numpy())
    size =  input.shape[0]
    for i in range(size):
        plt.subplot(4, int(size/4), i+1)
        plt.imshow((pred[i]+1)/2) #[0,1]
    plt.show()

#generator
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(100,256),
            nn.LeakyReLU(),
            nn.Linear(256,512),
            nn.LeakyReLU(),
            nn.Linear(512,28*28),
            nn.Tanh()
        )
    def forward(self,x):
        x= self.main(x)
        img = x.view(-1,28,28,1)
        return img

#discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(28*28,512),
            nn.LeakyReLU(),
            nn.Linear(512,256),
            nn.LeakyReLU(),
            nn.Linear(256,1),
            nn.Sigmoid()
        )
    def forward(self,x):
        x = x.view(-1,28*28)
        conf = self.main(x)
        return conf

if __name__=="__main__":

    batch_size = 64
    epoch_size = 200
    pred_size  = 16
    device  = 'cuda' if torch.cuda.is_available() else 'cpu'
    test_input = torch.randn(pred_size,100, device =device)
    # data
    transform = transforms.Compose([
        transforms.ToTensor(),   #0-1
        transforms.Normalize(0.5,0.5), #(mean-var:0.5,0.5)->-1,1
    ])
   
    train_ds   = torchvision.datasets.MNIST('data', train = True, transform = transform, download=True)  #data folder
    dataloader = torch.utils.data.DataLoader(train_ds, batch_size= batch_size, shuffle = True)

    
    gen  = Generator().to(device)
    dis  = Discriminator().to(device)
   
    g_optim = torch.optim.Adam(gen.parameters(), lr = 0.0001)
    d_optim = torch.optim.Adam(dis.parameters(), lr = 0.0001)
    loss_fn = nn.BCELoss()

    D_loss =  []
    G_loss =  []

    for epoch in range(epoch_size):
        d_epoch_loss = 0
        g_epoch_loss = 0
        count = len(dataloader)
        
        for step,(img,_) in enumerate(dataloader):
            img = img.to(device)
            size= img.size(0)
            random_noise  = torch.randn(size, 100, device = device)
            
            #optim the generator, gen->1
            g_optim.zero_grad()

            fake_out    =  dis(gen(random_noise))
            g_loss      = loss_fn(fake_out, torch.ones_like(fake_out))
            g_loss.backward()
            g_optim.step()


            #optim the discriminator, img->1, gen->0
            d_optim.zero_grad()

            real_out    = dis(img)
            d_real_loss = loss_fn(real_out, torch.ones_like(real_out))
            d_real_loss.backward()

            fake_out    =  dis(gen(random_noise).detach())
            d_fake_loss = loss_fn(fake_out, torch.zeros_like(fake_out))
            d_fake_loss.backward()

            d_loss = d_real_loss + d_fake_loss
            d_optim.step()  #over
            
            #统计
            with  torch.no_grad():
                g_epoch_loss +=  g_loss
                d_epoch_loss +=  d_loss
    
        #统计每次迭代后的loss 和生成结果        
    with torch.no_grad():
        g_epoch_loss /= count
        d_epoch_loss /= count
        G_loss.append(g_epoch_loss)
        D_loss.append(d_epoch_loss)
        print('epoch: ', epoch, 'g_epoch_loss:', g_epoch_loss, 'd_epoch_loss:', d_epoch_loss)

        if epoch > epoch_size - 5:
            draw_genImg(gen, test_input)
        


                


基于 GAN的MNIST 手写字体生成_第1张图片
基于 GAN的MNIST 手写字体生成_第2张图片
基于 GAN的MNIST 手写字体生成_第3张图片

上面分别为第1代、第30代和第200代的结果。

你可能感兴趣的:(CNN,图像处理,python,numpy,vscode)