GAN(生成对抗神经网络)生成MNIST 基于pytorch实现

运行环境 pytorch1.3.0  需支持GPU

生成对抗神经网络分为两部分: 生成器

                                                    鉴别器

生成器作用是利用随机数生成以假乱真的数据

鉴别器的作用的判定数据真假

鉴别器的训练很简单:真数据打标签1   生成器生成数据打标签0    进行训练   就像二元分类问题一样

生成器的训练方法:  生成随机数、 随机数调用module生成数据、此数据再用鉴别器得到鉴别结果、鉴别结果与全1求误差、反向传播更新参数

笔者水平有限,讲解的比较粗糙。但是下面的代码是经过测试可以运行出结果的。

补充:从0.4起, Variable 正式合并入Tensor类,通过Variable嵌套实现的自动微分功能已经整合进入了Tensor类中。虽然为了代码的兼容性还是可以使用Variable(tensor)这种方式进行嵌套,但是这个操作其实什么都没做。谁用Variable谁就是臭弟弟。

  

import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import save_image
from torch.autograd import Variable
import os

if not os.path.exists('./img'):
    os.mkdir('./img')
def to_img(x):
    out=0.5*(x+1)
    out=out.clamp(0,1)
    out=out.view(-1,1,28,28)
    return out
batch_size=128
num_epoch=100
z_dimention=100

#image processing
img_transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
#MNIST dataset
mnist=datasets.MNIST(
    root='./data/',train=True,transform=img_transform,download=True
)
#data loader
dataloader=torch.utils.data.DataLoader(
    dataset=mnist,batch_size=batch_size,shuffle=True
)

class discriminator(nn.Module): #鉴别器 784图像数据映射至0 1   0代表鉴别器判定此图为假  1真
    def __init__(self):
        super(discriminator,self).__init__() 
        self.dis=nn.Sequential(  
            nn.Linear(784,256),
            nn.LeakyReLU(0.2),
            nn.Linear(256,256),
            nn.LeakyReLU(0.2),
            nn.Linear(256,1),
            nn.Sigmoid()
        )
    def forward(self,x):
        x=self.dis(x)
        return x
class generator(nn.Module):    #生成器 100长度随机数映射至784图像数据
    def __init__(self):
        super(generator,self).__init__()
        self.gen=nn.Sequential(
            nn.Linear(100,256),
            nn.ReLU(True),
            nn.Linear(256,256),
            nn.ReLU(True),
            nn.Linear(256,784),
            nn.Tanh()
        )
    def forward(self,x):
        x=self.gen(x)
        return x
D=discriminator()
G=generator()
D=D.cuda() 
G=G.cuda()
criterion=nn.BCELoss() #二分类的交叉熵
d_optimizer=torch.optim.Adam(D.parameters(),lr=0.0001) #原作者的0.003训练不出来结果
g_optimizer=torch.optim.Adam(G.parameters(),lr=0.0001)
count1=0
count2=0
for epoch in range(num_epoch): #训练100次
    for i,(img,_) in enumerate(dataloader):
        num_img=img.size(0)
        # =================train discriminator
        img=img.view(num_img,-1)
        real_img=Variable(img).cuda()
        real_label=Variable(torch.ones(num_img)).cuda()
        fake_label=Variable(torch.zeros(num_img)).cuda()
        
        real_out=D(real_img)
        d_loss_real=criterion(real_out,real_label) #真实数据对应输出 1
        real_scores=real_out
        
        z=Variable(torch.randn(num_img,z_dimention)).cuda()
        fake_img=G(z)
        fake_out=D(fake_img)
        d_loss_fake=criterion(fake_out,fake_label)
        fake_scores=fake_out
        
        count1+=1 #这个没用
        if count1%1==0:
            d_loss=d_loss_real+d_loss_fake
            d_optimizer.zero_grad()
            d_loss.backward()
            d_optimizer.step()
        
        # ===============train generator
        count2+=1
        if count2%1==0:
            z=Variable(torch.randn(num_img,z_dimention)).cuda()
            fake_img=G(z)
            output=D(fake_img)
            g_loss=criterion(output,real_label)

            g_optimizer.zero_grad()
            g_loss.backward()
            g_optimizer.step()
        
        if (i+1)%100==0:
            print('Epoch[{}/{}],d_loss:{:.6f},g_loss:{:.6f} D real:{:.6f},D fake:{:.6f}'.format(
            epoch,num_epoch,d_loss.item(),g_loss.item(),real_scores.data.mean(),fake_scores.data.mean()
            ))
    if epoch==0:
        real_images=to_img(real_img.cpu().data)
        save_image(real_images,'./img/real_images.png')
    fake_images=to_img(fake_img.cpu().data)
    save_image(fake_images,'./img/fake_images-{}.png'.format(epoch+1))
            

 GAN(生成对抗神经网络)生成MNIST 基于pytorch实现_第1张图片 第一次训练的结果

GAN(生成对抗神经网络)生成MNIST 基于pytorch实现_第2张图片第30次

GAN(生成对抗神经网络)生成MNIST 基于pytorch实现_第3张图片第90次

参考自: https://blog.csdn.net/weixin_41278720/article/details/80861284

你可能感兴趣的:(AI)