GAN学习

基于MNIST的生成对抗样本(GAN)

一、导入MNIST数据集(训练模块中)

dataset=torchvision.datasets.MNIST("mnist_data",train=True,download=True)

参数:目录,是否是train模式,是否在线下载,传入transform(因为原来是utf8格式的,要转换成浮点型)

了解MNIST数据集
1、输出数据集长度

print(len(dataset))

2、查看数据集中具体数据(前五个)

for i in range(len(dataset)):
    if i<5:
        print(dataset[i])
    else:
        break

查看结果是5个image格式的对象信息
GAN学习_第1张图片shape()函数:用于计算数组的行数列数。torchvision.datasets.MNIST:产生的image对象前面是数据后面是标签,所以dataset[][]为二维。产生报错信息:image 对象没有shape模式,所以需要调用一下数据集中的transform
GAN学习_第2张图片
添加如下代码:

 transform=torchvision.transforms.Compose(
      [torchvision.transforms.Resize(28),
       torchvision.transforms.ToTensor(),
])

再次运行代码,会得到数据的具体参数:12828
GAN学习_第3张图片
二、生成器的大概框架

class Generator(nn.Module):
    def __init__(self,in_dim):
        super(Generator,self).__init__()
        self.model=nn.Sequential(
            nn.Linear(in_dim, 64),  
            torch.nn.ReLU(inplace=True),  
            nn.Linear(64, 128),
            torch.nn.ReLU(inplace=True),
            nn.Linear(128, 256),
            torch.nn.ReLU(inplace=True),
            nn.Linear(256, 512),
            torch.nn.ReLU(inplace=True),
            nn.Linear(512, 1024),
            torch.nn.ReLU(inplace=True),
            nn.Linear(1024,torch.prod(image_size,dtype=torch.int32)),
            nn.Tanh(),
        )
    def forward(self,z):
        output=self.model(z)
        image=output.reshape(z.shape[0],*image_size)
        return image

生成器框架分析:

  • 输入高斯噪声z,输出生成的图像
  • 整体由两部分函数构成,一部分构建网络,一部分进行连接使用

三、判别器的整体框架

class Discriminator(nn.Module):
    #输入一张照片输出概率
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(np.prod(image_size, dtype=torch.int32), 1024),
            torch.nn.ReLU(inplace=True),
            nn.Linear(1024, 512),
            torch.nn.ReLU(inplace=True),
            nn.Linear(512, 256),
            torch.nn.ReLU(inplace=True),
            nn.Linear(256, 128),
            torch.nn.ReLU(inplace=True),
            nn.Linear(128, 1),
            nn.Sigmoid(),
        )

    def forward(self, image):  
        prob = self.model(image.reshape(image.shape[0],-1))
        return prob

判别器框架分析:

  • 输入生成的图像,输出为判别图像概率
  • 整体由两部分函数构成,一部分构建网络,一部分进行连接使用

四、数据训练(全部代码)

'''
实验:基于MNIST(手写数字识别)实现生成对抗网络(GAN)
'''
import torch
import torch.nn as nn
#生成器类的实现
import torchvision.datasets
from torch.utils.data import DataLoader
import numpy as np

image_size=[1,28,28]#定义常量(图片大小)
num_epoch=100
latent_dim=32
batch_size=32
class Generator(nn.Module):
    #输入高斯噪声,输出图像
    # in_dim: 高斯噪声z的输入维度
    def __init__(self):#进行一些模块的定义
        #对父类实例化,集成nn.Module
        super(Generator,self).__init__()
        #定义一个model,这个model可以用很多DNN去做,我们使用一个容器nn.Sequential()将他们装起来
        #容器中需要传入一个个的model的,我们不需要用列表,只需要传入
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 64),
            torch.nn.ReLU(inplace=True),
            nn.Linear(64, 128),
            torch.nn.ReLU(inplace=True),
            nn.Linear(128, 256),
            torch.nn.ReLU(inplace=True),
            nn.Linear(256, 512),
            torch.nn.ReLU(inplace=True),
            nn.Linear(512, 1024),
            torch.nn.ReLU(inplace=True),
            nn.Linear(1024, np.prod(image_size, dtype=np.int32)),
            nn.Tanh(),
        )

    #由于生成器的输入是一个高斯噪声z(文中随机的高斯变量)
    def forward(self,z):#将定义函数进行连接
        #z的形状:[batchsize,1*28*28],定义为两维,一般将1*28*28定义为任意维度latent_dim
        #将z传到model中
        output=self.model(z)
        #将矩阵映射为图像
        # 参数:z的维度batchsize,图像大小即位image_size
        image=output.reshape(z.shape[0],*image_size)#由于image_size为列表,所以添加星号表示为将列表元素独立出来分别传入
        #得到生成器的输出
        return image

#判别器类的实现
class Discriminator(nn.Module):
    #输入一张照片输出概率
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(np.prod(image_size, dtype=np.int32), 1024),
            torch.nn.ReLU(inplace=True),
            nn.Linear(1024, 512),
            torch.nn.ReLU(inplace=True),
            nn.Linear(512, 256),
            torch.nn.ReLU(inplace=True),
            nn.Linear(256, 128),
            torch.nn.ReLU(inplace=True),
            nn.Linear(128, 1),
            nn.Sigmoid(),
        )

    def forward(self, image):  # 将定义函数进行连接
        # image图像的形状:[batchsize,1,28,28]
        # 将图像传到model中
        prob = self.model(image.reshape(image.shape[0],-1))
        # 得到生成器的输出,返回概率
        return prob
#训练Training

#对MNIST数据集API进行导入(直接下载)
    #参数介绍:目录,是否是train模式,是否在线下载,还要传入transform(因为原来是utf8格式的,要转换成浮点型)
dataset=torchvision.datasets.MNIST("mnist_data",train=True,download=True,
                                   transform=torchvision.transforms.Compose(#可以包含很多操作
                                       [
                                       torchvision.transforms.Resize(28),#调整图像大小
                                       torchvision.transforms.ToTensor(),#将utf-8转换为浮点型
                                       # torchvision.transforms.Normalize(mean=[0.5],std=[0.5])#归一化,需要提前计算均值和方差
                                        ]))
#将导入的数据引入downloder中,需要传入参数dataset,batch_size(32或64均可)和shuffle
#dataloader的作用就是把dataset中的每个样本构成一个mini_barch,后面进行批训练
dataloader=DataLoader(dataset,batch_size=batch_size,shuffle=True)
#构建优化器2个,分别对生成器和判别起的参数进行优化

#实例化一个generator和discriminator
generator = Generator()
discriminator = Discriminator()
# 参数params可迭代的参数,其实是对函数调用即可得到.第二个为学习率
g_optimizer=torch.optim.Adam(generator.parameters(),lr=0.0001,betas=(0.4, 0.8), weight_decay=0.0001)
d_optimizer=torch.optim.Adam(discriminator.parameters(),lr=0.0001,betas=(0.4, 0.8), weight_decay=0.0001)
# 损失函数定义:BCE二项的交叉熵函数
loss_fn=nn.BCELoss()


for epoch in range(num_epoch):#对epoch进行循环
    # 对dataloader进行遍历,使用enumerate()函数的返回参数是index(索引)和sample(一个元祖,图片和标签)
    for i,mini_batch in enumerate(dataloader):
        # 对mini_batch进行解析
        gt_images,_=mini_batch#我们只需要mini_batch中的图片(ground_true真实图片),不需要标签
        #随机生成变量z服从正态分布,形状为barch_size,z的维度1*28*28,此处定义为latent_dim
        z=torch.randn(batch_size,latent_dim)#z的大小为batch_size*latent_dim
        #将z喂入生成器得到预期图像
        pred_images=generator(z)
        #把预期图像送入判别器中进行概率预测
        # discriminator(pred_images)

        #对生成器进行优化

        #梯度置0
        g_optimizer.zero_grad()
        #计算损失函数,通过图像概率和目标计算损失
        # target=torch.ones(batch_size,1)
        g_loss=loss_fn(discriminator(pred_images),torch.ones(batch_size,1))
        #计算梯度
        g_loss.backward()
        #参数优化(更新)
        g_optimizer.step()

        #对判别器进行优化
        d_optimizer.zero_grad()
        #调用detach()函数对预测图像的梯度进行隔离,不需要计算它的梯度
        # d_loss = 0.5*(loss_fn(discriminator(gt_images), torch.ones(batch_size,1))+loss_fn(discriminator(pred_images.detach()),torch.zeros(batch_size,1)))
        real_loss = loss_fn(discriminator(gt_images), torch.ones(batch_size,1))
        fales_loss =loss_fn(discriminator(pred_images.detach()),torch.zeros(batch_size,1))
        d_loss=0.5*real_loss+fales_loss#当两个loss都在不断下降并相差不大,则训练成功
        #计算梯度
        d_loss.backward()
        #参数优化(更新)
        d_optimizer.step()

          # #保存照片结果(每隔1000步对所生成的照片进行保存)
        # if i % 1000==0:
        #     # 参数:存储的图片数据,存储的文件名称,
        #     for index,image  in enumerate(pred_images):
        #         torchvision.utils.save_image(image,f"image_{index}.png")
        #
        #
        # if i % 50 == 0:
        #     print(f"step:{len(dataloader)*epoch+i}, g_loss:{g_loss.item()}, d_loss:{d_loss.item()}, real_loss:{real_loss.item()}, fake_loss:{fales_loss.item()}")
        #

五、存在问题
仅学习了解了GAN的整体框架,没有对生成器和判别器使用的网络进行设计
可以自行百度进行设计获得更好的效果
GAN学习_第4张图片

GAN学习_第5张图片

你可能感兴趣的:(生成对抗网络,学习,深度学习)