Pytorch38,DCGAN,实现生成MNIST数字图片

直接上程序:

main.py 主程序:

import numpy as np
import random
import os
import math
import torch
import torchvision
from torchvision import datasets,transforms
import network
from PIL import Image
from torch.autograd import Variable
import torch.nn as nn
import time

def data_load(dir,batch):#图片标准化
    normalize = transforms.Normalize(mean=[0.5], std=[0.5])
    transform = transforms.Compose([
        transforms.Grayscale(1),
        transforms.Resize(size=(32,32)),
        transforms.RandomRotation(a),
        transforms.RandomHorizontalFlip(b),
        transforms.RandomVerticalFlip(c),
        transforms.ToTensor(),
        normalize
    ])
    train_dataset = torchvision.datasets.ImageFolder(root=dir,transform=transform)
    train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch,shuffle=True,drop_last=True)
    return train_loader



def train(optimizer_net,optimizer_gan):
    loss_function=nn.BCELoss().cuda()
    time1=time.time()
    epco=200
    batch=5000    
    real_label=Variable(torch.ones(batch))
    real_label=real_label.cuda()
    fake_label=Variable(torch.zeros(batch))
    fake_label=fake_label.cuda()
    train_loader=data_load(dir="train",batch=5000)
    for e in range(epco):  
        print(str(e+1)+"/"+str(epco))
        x=[]
        fake_x=[]
        loss_all=0
    	fake_loss_all=0
    	make_loss_all=0
    	count=0

        for i,(images,labels) in enumerate(train_loader):
            count+=1
            
            network.net.zero_grad()
            #在分辨网络(NET)将真实图片训练成1
            x=Variable(images.cuda())
            out=network.net.forward(x)      
            loss=loss_function(out,real_label)
            loss_all+=loss.item()          
            loss.backward()
            
            #在分辨网络(NET)将假图片训练成0
            noise=torch.randn(batch,512,1,1)
            noise=noise.cuda()
            fake_x=network.gan.forward(noise)
            fake_out=network.net.forward(fake_x.detach())#一定要加.detach(),防止对gan网络的训练
            fake_loss=loss_function(fake_out,fake_label)
            fake_loss_all+=fake_loss.item()
            fake_loss.backward()

            optimizer_net.step()
            
			#将假图片在分辨网络(NET)训练成1的LOSS,应用在生成网络(GAN)上
            network.gan.zero_grad()
            make_fake_out=network.net.forward(fake_x)
            make_loss=loss_function(make_fake_out,real_label)
            make_loss_all+=make_loss.item()
            make_loss.backward()
            optimizer_gan.step()


        print("avg_loss="+str(loss_all/count))
        print("avg_fake_loss="+str(fake_loss_all/count))
        print("avg_make_loss="+str(make_loss_all/count))
        print("-----------------------------------------------------")
        if e%50==0 and e!=0:
            state_net = {"net":network.net.state_dict(), "optimizer":optimizer_net.state_dict()}
            torch.save(state_net,"net.pth")
            state_gan = {"net":network.gan.state_dict(), "optimizer":optimizer_gan.state_dict()}
            torch.save(state_gan,"gan.pth")


    print(str(time.time()-time1)+"秒")
    state_net = {"net":network.net.state_dict(), "optimizer":optimizer_net.state_dict()}
    torch.save(state_net,"net.pth")
    state_gan = {"net":network.gan.state_dict(), "optimizer":optimizer_gan.state_dict()}
    torch.save(state_gan,"gan.pth")
    batch_make(filename="batch_make.png")


def make():#生成1张图片
    network.gan.eval()
    make_noise=torch.randn(1,512,1,1)
    make_noise=make_noise.cuda()
    make_fake_x=network.gan.forward(make_noise)
    make_fake_x=make_fake_x.view(1,32,32)
    torchvision.utils.save_image(make_fake_x.cpu(),"test1.png")
    img=Image.open("test1.png")
    img.show()

def batch_make(batch=120,filename="batch_make.png"):#生成120张图片的网格
    network.gan.eval()
    make_noise=torch.randn(batch,512,1,1)
    make_noise=make_noise.cuda()
    make_fake_x=network.gan.forward(make_noise)
    make_fake_x=make_fake_x.view(batch,1,32,32)
    torchvision.utils.save_image(make_fake_x.cpu(),filename,int(math.sqrt(batch)))
    img=Image.open(filename)
    img.show()


while True:
    rate_net=0.0002
    rate_gan=0.0002
    make_batch=120
    optimizer_net=torch.optim.Adam(network.net.parameters())
    optimizer_g=torch.optim.Adam(network.gan.parameters())
    if not os.path.exists("net.pth"):
        state = {"net":network.net.state_dict(), "optimizer":optimizer_net.state_dict()}
        torch.save(state,"net.pth")     
    else:
        checkpoint = torch.load("net.pth")
        network.net.load_state_dict(checkpoint["net"])
        optimizer_net.load_state_dict(checkpoint["optimizer"])
    if not os.path.exists("gan.pth"):
        state = {"net":network.gan.state_dict(), "optimizer":optimizer_g.state_dict()}
        torch.save(state,"gan.pth")     
    else:
        checkpoint = torch.load("gan.pth")
        network.gan.load_state_dict(checkpoint["net"])
        optimizer_g.load_state_dict(checkpoint["optimizer"])
    optimizer_net=torch.optim.Adam(network.net.parameters(),rate_net,betas=(0.5, 0.999))#这里和下面的betas=(0.5, 0.999)一定要写,不然训练很慢
    optimizer_g=torch.optim.Adam(network.gan.parameters(),rate_gan,betas=(0.5, 0.999))
    s=input("1:训练   2:批量生成图片("+str(make_batch)+")   3:生成单张图片   0:退出:")
    while s!="1" and s!="2" and s!="3" and s!="4" and  s!="0":
        s=input("输入错误,重新输入:1:训练   2:批量生成图片("+str(make_batch)+")   3:生成单张图片   0:退出:")
    if s=="1":
        train(optimizer_net,optimizer_g)
    elif s=="2":
        batch_make(make_batch)
    elif s=="3":
        make()       
    elif s=="0":
        break

network.py 神经网络:

import torch
import torch.nn as nn
from torch.autograd import Variable

def weights_init(m):#这个最好加上,是给网络设置初始权重的
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

class Net(nn.Module):#分辨网络
    def __init__(self):
        super(Net,self).__init__()
        self.conv=nn.Sequential(
            nn.Conv2d(1,32,4,2,1,bias=False),#32 to 16
            nn.LeakyReLU(0.2,True),
            nn.Conv2d(32,64,4,2,1,bias=False),#16 to 8
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2,True),
            nn.Conv2d(64,128,4,2,1,bias=False),#8 to 4
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2,True),
            nn.Conv2d(128,256,4,2,1,bias=False),#4 to 2
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2,True),
            nn.Conv2d(256,1,4,2,1,bias=False),#2 to 1
            nn.Sigmoid()
            )
    def forward(self,input):
        return self.conv(input).view(-1)

net=Net().cuda()
net.apply(weights_init)

class Gan(nn.Module):#生成网络
    def __init__(self):
        super(Gan,self).__init__()
        self.conv=nn.Sequential(
            nn.ConvTranspose2d(512,256,4,2,1,bias=False),#1 to 2   h_out=(h_in-1)*stride-2*padding+kernel_size
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256,128,4,2,1,bias=False),#2 to 4
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128,64,4,2,1,bias=False),#4 to 8
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64,32,4,2,1,bias=False),#8 to 16
            nn.BatchNorm2d(32),
            nn.ReLU(True),
            nn.ConvTranspose2d(32,1,4,2,1,bias=False),#16 to 32
            nn.Tanh()
            )

    def forward(self,input):
        return self.conv(input)

gan=Gan().cuda()
gan.apply(weights_init)

注意点:
1.net.apply(weights_init)和gan.apply(weights_init),最好写上,不然训练会慢,但是写上也会有问题,有可能会出现无法训练的问题。 补充:实验证明,还是不要加了,加了反而问题更多

2.optimizer_net=torch.optim.Adam(network.net.parameters(),rate_net,betas=(0.5, 0.999))
optimizer_g=torch.optim.Adam(network.gan.parameters(),rate_gan,betas=(0.5, 0.999))
这里的betas=(0.5, 0.999)一定要写,不然训练会慢

3.DCGAN非常依赖参数的调整,上面的1.2点就是例子,参数调不好,一切都白搭

4.通过实验,发现假图片在NET中的表现很难向真图片的LABEL 1靠拢,所以可以一直训练下去

5.用于训练的真实文件全部保存在“train\mnist_train”文件夹下,0-9共55000个

下面是训练后的效果(都是随机采样,如果选取最好的一组,相信效果会更好,比如采样1000个,选NET中靠近1的100个):

100轮:
Pytorch38,DCGAN,实现生成MNIST数字图片_第1张图片

500轮:
Pytorch38,DCGAN,实现生成MNIST数字图片_第2张图片

1600轮:

Pytorch38,DCGAN,实现生成MNIST数字图片_第3张图片

你可能感兴趣的:(机器学习,pytorch,神经网络,深度学习)