从0到1,采用全连接层实现普通GAN,利用Pytorch实现随机向量生成手写数字Mnist(GAN:一种深度学习网络)

实例操作

  • 一、代码
  • 二、中间结果

一、代码

'''
使用随机噪声生成手写数字
1. 数据准备
2. 网络搭建:生成器、判别器
3. 损失函数、优化器
4. 训练、测试、保存中间结果
5. 保存模型
'''
import os
import argparse
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from torchvision import transforms,datasets
from torch.utils.tensorboard import SummaryWriter
import warnings

warnings.filterwarnings('ignore')

parser=argparse.ArgumentParser("parameters configuring")
parser.add_argument("--batch_size",default=128,help="the batch size of dataset")
parser.add_argument("--epochs",default=100,help="the epochs of training")
parser.add_argument("--logs_dir",default='./logs',help="the path of logs")
parser.add_argument("--models_dir",default='./models',help="the path of saved models")
parser.add_argument("--lr",default=1e-3,help="the init learning rate")
parser.add_argument("--images_dir",default='./images',help="the path of saved images")
args=parser.parse_args()


device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 1. 数据准备[0,1]
BATCH_SIZE=args.batch_size
train_dataset=datasets.MNIST(
    root="./dataset",
    train=True,
    transform=transforms.ToTensor(),
    download=True
)

train_dataloader=DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0,
    drop_last=False
)


# 2. 网络搭建
class Generator(nn.Module):  # 输入:随机向量 输出28*28*1的手写数字
    def __init__(self):
        super(Generator, self).__init__()
        self.model1=nn.Sequential(
            nn.Linear(in_features=100,out_features=256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Linear(in_features=256,out_features=512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(in_features=512,out_features=1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Linear(in_features=1024,out_features=28*28),
            nn.BatchNorm1d(28*28),
            nn.Sigmoid())

    def forward(self,x):
        x=self.model1(x)
        return x.view((-1,1,28,28))


generator=Generator().to(device)


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model1=nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features=28*28,out_features=512),
            nn.ReLU(),
            nn.Linear(in_features=512, out_features=256),
            nn.ReLU(),
            nn.Linear(in_features=256, out_features=128),
            nn.ReLU(),
            nn.Linear(in_features=128, out_features=32),
            nn.ReLU(),
            nn.Linear(in_features=32, out_features=1),
            nn.Sigmoid(),
        )

    def forward(self,x):
        x=self.model1(x)
        return x


discriminator=Discriminator().to(device)

# 3. 损失函数、优化器
generator_optimizer=torch.optim.Adam(generator.parameters(),lr=args.lr)
discriminator_optimizer=torch.optim.Adam(discriminator.parameters(),lr=args.lr)

loss_object2=nn.BCELoss().to(device)


def generator_creation(fake_out):
    target=torch.ones_like(fake_out).to(device)
    fake_loss=loss_object2(fake_out,target)
    return fake_loss


def discriminator_creation(fake_out,real_out):
    fake_target=torch.zeros_like(fake_out).to(device)
    real_target=torch.ones_like(real_out).to(device)
    fake_loss=loss_object2(fake_out,fake_target)
    real_loss=loss_object2(real_out,real_target)
    return fake_loss+real_loss


# 4. 训练、测试、保存中间结果
models_save_path=args.models_dir
images_save_path=args.images_dir
os.makedirs(models_save_path,exist_ok=True)
os.makedirs(images_save_path,exist_ok=True)

EPOCHS=args.epochs
# writer=SummaryWriter(args.logs_dir)
total_train_steps=0
total_val_steps=0
train_info="the {} times of train,g_train_loss is {}."
for epoch in range(EPOCHS):
    print("Epoch:{}".format(epoch))
    # training
    generator.train()
    discriminator.train()
    total_train_loss=0.0
    for batch,(imgs,_) in enumerate(train_dataloader):
        imgs=imgs.to(device)
        # 1)更新生成器
        # 生成器输入:随机向量
        random_vector=torch.randn(size=(imgs.shape[0],100),dtype=torch.float32).to(device)
        prediction=generator(random_vector)

        fake_out=discriminator(prediction)
        g_train_loss=generator_creation(fake_out)
        # r_loss=reconstruction_loss(prediction,imgs)
        # g_loss=g_train_loss+r_loss
        # generator_optimizer.zero_grad()
        # g_loss.backward()
        # generator_optimizer.step()
        # total_train_loss+=g_loss.item()
        generator_optimizer.zero_grad()
        g_train_loss.backward()
        generator_optimizer.step()
        total_train_loss+=g_train_loss.item()

        # 2) 更新判别器
        fake_out=discriminator(prediction.detach())
        real_out=discriminator(imgs)
        d_train_loss=discriminator_creation(fake_out,real_out)
        discriminator_optimizer.zero_grad()
        d_train_loss.backward()
        discriminator_optimizer.step()

        if (total_train_steps+1)%100==0:
            print(train_info.format(total_train_steps,g_train_loss.item()))
        total_train_steps+=1
    print("the {} epoch of train,total_g_train_loss is {}.".format(epoch,total_train_loss))
    # writer.add_scalar("train_loss_g",total_train_loss,epoch)

    # validating
    generator.eval()
    with torch.no_grad():
        random_vector=torch.randn(size=(imgs.shape[0],100),dtype=torch.float32).to(device)
        output=generator(random_vector)
        total_val_steps+=1

    if (total_val_steps+1)%5==0:
        # writer.add_images("prediction",output,(total_val_steps+1)/5)
        save_image(output[:20],images_save_path+"/epoch{}.jpg".format((total_val_steps+1)/5),nrows=5,normalize=True)

    torch.save(generator.state_dict(),models_save_path+"/generator{}.pth".format(epoch))
    print("model saved!!")
# writer.close()

二、中间结果

训练5轮:
在这里插入图片描述

训练100轮:
在这里插入图片描述

你可能感兴趣的:(Pytorch,深度学习,深度学习,pytorch,生成对抗网络,人工智能,卷积神经网络)