DCGAN基于深度卷积生成对抗网络的实例 ——CIFAR-10图像的生成

代码摘要:

1.读取数据

2.搭建鉴别网络和生成网络

3.初始化网络的权重值

4.训练鉴别网络和生成网络(时间大概需要几个小时)

root路径可以自己设置 

"分批读取CIFAR-10图片并将部分批次保存为图片文件"

from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
import torchvision.transforms as transforms
from torchvision.utils import save_image

dataset = CIFAR10(root='./data', download=True,
        transform=transforms.ToTensor())
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)#喂入大小是把原来数据集中的多少图片组合成一张图片
batch_size=64
for batch_idx, data in enumerate(dataloader):
    if batch_idx==len(dataloader)-1:
        continue   
    real_images, _ = data

    print ('#{} has {} images.'.format(batch_idx, batch_size))
    if batch_idx % 100 == 0:
        path = './data/CIFAR10_shuffled_batch{:03d}.png'.format(batch_idx)
        save_image(real_images, path, normalize=True)

运行结果部分展示:

DCGAN基于深度卷积生成对抗网络的实例 ——CIFAR-10图像的生成_第1张图片

"搭建生成网络和鉴别网络"
"隐藏的卷积层(即除了最后的输出卷积层外)的输出都需要经过规范化操作"
import torch.nn as nn

# 搭建生成网络
latent_size = 64 # 潜在大小
n_channel = 3 # 输出通道数
n_g_feature = 64 # 生成网络隐藏层大小
"生成网络采用了四层转置卷积操作"
gnet = nn.Sequential(
        # 输入大小 = (64, 1, 1)
        #有点像互相关的反操作,(x-4)/1=1-->x=4
        nn.ConvTranspose2d(latent_size, 4 * n_g_feature, kernel_size=4,
        bias=False),
        nn.BatchNorm2d(4 * n_g_feature),
        nn.ReLU(),
        # 大小 = (256, 4, 4)
        #{x+2(填充)-4(核尺寸)+2(步长)}/2=4-->x=8
        nn.ConvTranspose2d(4 * n_g_feature, 2 * n_g_feature, kernel_size=4,
                stride=2, padding=1, bias=False),
        nn.BatchNorm2d(2 * n_g_feature),
        nn.ReLU(),
        # 大小 = (128, 8, 8)
        nn.ConvTranspose2d(2 * n_g_feature, n_g_feature, kernel_size=4,
                stride=2, padding=1, bias=False),
        nn.BatchNorm2d(n_g_feature),
        nn.ReLU(),
        # 大小 = (64, 16, 16)
        nn.ConvTranspose2d(n_g_feature, n_channel, kernel_size=4,
                stride=2, padding=1),
        nn.Sigmoid(),
        # 图片大小 = (3, 32, 32)
        )
print (gnet)

# 搭建鉴别网络
n_d_feature = 64 # 鉴别网络隐藏层大小
"鉴别网络采用了4层互相关操作"
dnet = nn.Sequential( 
        # 图片大小 = (3, 32, 32)
        nn.Conv2d(n_channel, n_d_feature, kernel_size=4,
                stride=2, padding=1),
        nn.LeakyReLU(0.2),
        # 大小 = (64, 16, 16)
        nn.Conv2d(n_d_feature, 2 * n_d_feature, kernel_size=4,
                stride=2, padding=1, bias=False),
        nn.BatchNorm2d(2 * n_d_feature),
        nn.LeakyReLU(0.2),
        # 大小 = (128, 8, 8)
        nn.Conv2d(2 * n_d_feature, 4 * n_d_feature, kernel_size=4,
                stride=2, padding=1, bias=False),
        nn.BatchNorm2d(4 * n_d_feature),
        nn.LeakyReLU(0.2),
        # 大小 = (256, 4, 4)
        nn.Conv2d(4 * n_d_feature, 1, kernel_size=4),
        # 对数赔率张量大小 = (1, 1, 1)
        )
print(dnet)

代码运行的部分结果:

DCGAN基于深度卷积生成对抗网络的实例 ——CIFAR-10图像的生成_第2张图片

"初始化权重值"
import torch.nn.init as init

def weights_init(m): # 用于初始化权重值的函数
    if type(m) in [nn.ConvTranspose2d, nn.Conv2d]:
        init.xavier_normal_(m.weight)
    elif type(m) == nn.BatchNorm2d:
        init.normal_(m.weight, 1.0, 0.02)
        init.constant_(m.bias, 0)
#调用apply()函数,torch.nn.Module类实例会递归地让自己成为weights_init()里面函数的m
gnet.apply(weights_init)
dnet.apply(weights_init)
"训练生成网络和鉴别网络并输出图片"
import torch
import torch.optim

# 损失
criterion = nn.BCEWithLogitsLoss()

# 优化器
#Adam优化器的默认学习率n=0.01,过高,应减小为0.002,动量参数默认0.9,会造成震荡,减小为0.5
goptimizer = torch.optim.Adam(gnet.parameters(),
        lr=0.0002, betas=(0.5, 0.999))
doptimizer = torch.optim.Adam(dnet.parameters(), 
        lr=0.0002, betas=(0.5, 0.999))

# 用于测试的固定噪声,用来查看相同的潜在张量在训练过程中生成图片的变换
batch_size = 64
fixed_noises = torch.randn(batch_size, latent_size, 1, 1)

# 训练过程
epoch_num = 10
for epoch in range(epoch_num):
    for batch_idx, data in enumerate(dataloader):
        if batch_idx==len(dataloader)-1: #剔除最后一张是(16,3,32,32)
            continue   
        # 载入本批次数据
        real_images, _ = data#real_images(64,3,32,32)
        
        # 训练鉴别网络
        labels = torch.ones(batch_size) # 真实数据对应标签为1(64,)
        preds = dnet(real_images) # 对真实数据进行判别(64,1,1,1)
      
        outputs = preds.reshape(-1)#(64,)
        dloss_real = criterion(outputs, labels) # 真实数据的鉴别器损失
        dmean_real = outputs.sigmoid().mean() # 计算鉴别器将多少比例的真数据判定为真,仅用于输出显示
        
        noises = torch.randn(batch_size, latent_size, 1, 1) # 潜在噪声(64,64,1,1)
        fake_images = gnet(noises) # 生成假数据(64,3,32,32)
        labels = torch.zeros(batch_size) # 假数据对应标签为0
        fake = fake_images.detach()# 使得梯度的计算不回溯到生成网络,可用于加快训练速度.删去此步结果不变
        preds = dnet(fake) # 对假数据进行鉴别
        outputs = preds.view(-1)
        dloss_fake = criterion(outputs, labels) # 假数据的鉴别器损失
        dmean_fake = outputs.sigmoid().mean()
                # 计算鉴别器将多少比例的假数据判定为真,仅用于输出显示
        
        dloss = dloss_real + dloss_fake # 总的鉴别器损失
        dnet.zero_grad()
        dloss.backward()
        doptimizer.step()
        
        # 训练生成网络
        labels = torch.ones(batch_size)
                # 生成网络希望所有生成的数据都被认为是真数据
        preds = dnet(fake_images) # 把假数据通过鉴别网络
        outputs = preds.view(-1)
        gloss = criterion(outputs, labels) # 真数据看到的损失
        gmean_fake = outputs.sigmoid().mean()
                # 计算鉴别器将多少比例的假数据判定为真,仅用于输出显示
        gnet.zero_grad()
        gloss.backward()
        goptimizer.step()
        
        # 输出本步训练结果
        if batch_idx % 100 == 0:
            print('[{}/{}]'.format(epoch, epoch_num) +
                    '[{}/{}]'.format(batch_idx, len(dataloader)) +
                    '鉴别网络损失:{:g} 生成网络损失:{:g}'.format(dloss, gloss) +
                    '真数据判真比例:{:g} 假数据判真比例:{:g}/{:g}'.format(
                    dmean_real, dmean_fake, gmean_fake))
            fake = gnet(fixed_noises) # 由固定潜在张量生成假数据
            save_image(fake, # 保存假数据
                    './data/images_epoch{:02d}_batch{:03d}.png'.format(
                    epoch, batch_idx))

代码运行结果:人工停止的结果

DCGAN基于深度卷积生成对抗网络的实例 ——CIFAR-10图像的生成_第3张图片

 

生成的假图像:

依次是第0论的0,300,600,第1轮的100

DCGAN基于深度卷积生成对抗网络的实例 ——CIFAR-10图像的生成_第4张图片DCGAN基于深度卷积生成对抗网络的实例 ——CIFAR-10图像的生成_第5张图片DCGAN基于深度卷积生成对抗网络的实例 ——CIFAR-10图像的生成_第6张图片

DCGAN基于深度卷积生成对抗网络的实例 ——CIFAR-10图像的生成_第7张图片

 

你可能感兴趣的:(pytorch,机器学习)