Pytorch之经典神经网络Generative Model(四) —— DCGAN (MNIST)

2015年提出

 

DCGAN —— 深度卷积生成对抗网络

深度卷积生成对抗网络特别简单, 就是将生成网络和对抗网络都改成了卷积网络的形式

DCGAN属于比较基本的模型。在一定程度上提高了训练的结果,但是这仅仅是一个治标不治本的架构

 

Discriminator

卷积判别网络Discriminator 就是一个一般的卷积网络,结构如下

  • 32 Filters, 5x5, Stride 1, Leaky ReLU(alpha=0.01)
  • Max Pool 2x2, Stride 2
  • 64 Filters, 5x5, Stride 1, Leaky ReLU(alpha=0.01)
  • Max Pool 2x2, Stride 2
  • Fully Connected size 4 x 4 x 64, Leaky ReLU(alpha=0.01)
  • Fully Connected size 1

 

Generator

卷积生成网络Generator需要将一个低维的噪声向量变成一个图片数据,结构如下

  • Fully connected of size 1024, ReLU
  • BatchNorm
  • Fully connected of size 7 x 7 x 128, ReLU
  • BatchNorm
  • Reshape into Image Tensor
  • 64 conv2d^T filters of 4x4, stride 2, padding 1, ReLU
  • BatchNorm
  • 1 conv2d^T filter of 4x4, stride 2, padding 1, TanH

 

import torch
from torch import nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from visdom import Visdom
 
NOISE_DIM = 96
batch_size = 128
 
def show_images(images): # 定义画图工具
    images = np.reshape(images, [images.shape[0], -1])
    sqrtn = int(np.ceil(np.sqrt(images.shape[0])))
    sqrtimg = int(np.ceil(np.sqrt(images.shape[1])))
 
    fig = plt.figure(figsize=(sqrtn, sqrtn))
    gs = gridspec.GridSpec(sqrtn, sqrtn)
    gs.update(wspace=0.05, hspace=0.05)
 
    for i, img in enumerate(images):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(img.reshape([sqrtimg,sqrtimg]))
    # plt.show()
    return
 
 
class generator(nn.Module): 
    def __init__(self, noise_dim=NOISE_DIM):
        super(generator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(noise_dim, 1024),
            nn.ReLU(True),
            nn.BatchNorm1d(1024),
            nn.Linear(1024, 7 * 7 * 128),
            nn.ReLU(True),
            nn.BatchNorm1d(7 * 7 * 128)
        )
        
        self.conv = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, 2, padding=1),
            nn.ReLU(True),
            nn.BatchNorm2d(64),
            nn.ConvTranspose2d(64, 1, 4, 2, padding=1),
            nn.Tanh()
        )
        
    def forward(self, x):
        x = self.fc(x)
        x = x.view(x.shape[0], 128, 7, 7) # reshape 通道是 128,大小是 7x7
        x = self.conv(x)
        return x
 
class discriminator(nn.Module):
    def __init__(self):
        super(discriminator, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, 5, 1),
            nn.LeakyReLU(0.01),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 64, 5, 1),
            nn.LeakyReLU(0.01),
            nn.MaxPool2d(2, 2)
        )
        self.fc = nn.Sequential(
            nn.Linear(1024, 1024),
            nn.LeakyReLU(0.01),
            nn.Linear(1024, 1)
        )
        
    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.shape[0], -1)
        x = self.fc(x)
        return x
 
 
def discriminator_loss(logits_real, logits_fake): # 判别器的 loss
    size = logits_real.shape[0]
 
    true_labels = torch.tensor(torch.ones(size, 1)).float().cuda() #全1
    false_labels = torch.tensor(torch.zeros(size, 1)).float().cuda() #全0
 
    loss = bce_loss(logits_real, true_labels) + bce_loss(logits_fake, false_labels)
    #表示logits_real和全1还差多少,logits_fake和全0还差多少
    return loss
 
def generator_loss(logits_fake): # 生成器的 loss
    size = logits_fake.shape[0]
    true_labels = torch.tensor(torch.ones(size, 1)).float().cuda()
    #true_label就全是1
 
    loss = bce_loss(logits_fake, true_labels)
    return loss
 
 
def train_gan(D_net, G_net, D_optimizer, G_optimizer, discriminator_loss, generator_loss, show_every=10,
                noise_size=96, num_epochs=10):
    iter_count = 0
    min_D_loss = 1000.0
    min_G_loss = 1000.0
    min_D_iter = 0
    min_G_iter = 0
    for epoch in range(num_epochs):
        for input, _ in train_data:
            batchsz = input.shape[0]
 
            # 判别网络-----------------------------------
            #这里图片不再需要打平了
            real_img = torch.tensor(input).cuda()  # 真实数据
            logits_real = D_net(real_img)  # 判别网络得分
 
            sample_noise = (torch.rand(batchsz, noise_size) - 0.5) / 0.5  # -1 ~ 1 的均匀分布
            g_fake_seed = torch.tensor(sample_noise).cuda()
            fake_images = G_net(g_fake_seed)  # 生成的假的数据
            logits_fake = D_net(fake_images)  # 判别网络得分
 
            # 判别器的 loss
            d_total_loss = discriminator_loss(logits_real, logits_fake)
 
            # 优化判别网络
            D_optimizer.zero_grad()
            d_total_loss.backward()
            D_optimizer.step()
 
 
            # 生成网络----------------------------
            g_fake_seed = torch.tensor(sample_noise).cuda()
            fake_images = G_net(g_fake_seed)  # 生成的假的数据
 
            gen_logits_fake = D_net(fake_images)
            g_loss = generator_loss(gen_logits_fake)  # 生成网络的 loss
            G_optimizer.zero_grad()
            g_loss.backward()
            G_optimizer.step()  # 优化生成网络
 
            if (iter_count % show_every == 0):
                print('Epoch: {}, Iter: {}, D_loss: {:.4}, G_loss:{:.4}'.format(epoch, iter_count, d_total_loss.item(), g_loss.item()))
                imgs_numpy = deprocess_img(fake_images.data.cpu().numpy())
                show_images(imgs_numpy[0:16])
                plt.savefig('plt_img/%d.png'% iter_count)
                plt.close()
                print('       Min_D_loss: %f, iter %d.'%(min_D_loss,min_D_iter))
                print('       Min_G_loss: %f, iter %d.'%(min_G_loss,min_G_iter))
            viz.line([d_total_loss.item()], [iter_count], win='D_loss', update='append')
            viz.line([g_loss.item()], [iter_count], win='G_loss', update='append')
            if d_total_loss.item() < min_D_loss:
                min_D_loss = d_total_loss.item()
                min_D_iter = iter_count
            if g_loss.item() < min_G_loss:
                min_G_loss = g_loss.item()
                min_G_oter = iter_count
            iter_count += 1
        
        checkpoint = {
            "net_D": D.state_dict(),
            "net_G": G.state_dict(),
            'D_optim':D_optim.state_dict(),
            'G_optim':G_optim.state_dict(),
            "epoch": epoch
        }
        torch.save(checkpoint, 'checkpoints/ckpt_%s.pth' %(str(epoch)))
        print('checkpoint of epoch %d has been saved!'%epoch)
 
 
def preprocess_img(x):
    x = transforms.ToTensor()(x)
    return (x - 0.5) / 0.5
 
#把preprocess_img的操作逆回来
def deprocess_img(x):
    return (x + 1.0) / 2.0
 
 
train_set = MNIST(
    root='dataset/',
    train=True,
    download=True,
    transform=preprocess_img
)
 
train_data = DataLoader(
    dataset=train_set,
    batch_size=batch_size,
    # sampler=ChunkSampler(NUM_TRAIN, 0) #从第0个开始,采样NUM_TRAIN个
)
 
val_set = MNIST(
    root='dataset/',
    train=False,
    download=True,
    transform=preprocess_img
)
 
val_data = DataLoader(
    dataset=val_set,
    batch_size=batch_size,
    # sampler=ChunkSampler(NUM_VAL, NUM_TRAIN)
)
 
# print(len(train_set))# 是 391
# print(len(val_set))# 是 40
viz = Visdom()
viz.line([0.], [0.], win='G_loss', opts=dict(title='G_loss'))
viz.line([0.], [0.], win='D_loss', opts=dict(title='D_loss'))
 
bce_loss = nn.BCEWithLogitsLoss()
 

 
D = discriminator().cuda()
G = generator().cuda()

 
D_optim = torch.optim.Adam(D.parameters(), lr=3e-4, betas=(0.5, 0.999))
G_optim = torch.optim.Adam(G.parameters(), lr=3e-4, betas=(0.5, 0.999))


train_gan(D, G, D_optim, G_optim, discriminator_loss, generator_loss, num_epochs=100)

Pytorch之经典神经网络Generative Model(四) —— DCGAN (MNIST)_第1张图片

最终能train的差不多这样子

Pytorch之经典神经网络Generative Model(四) —— DCGAN (MNIST)_第2张图片

 

 

DCGAN的效果

 

你可能感兴趣的:(DCGAN)