一、代码
'''
使用随机噪声生成手写数字
1. 数据准备
2. 网络搭建:生成器、判别器
3. 损失函数、优化器
4. 训练、测试、保存中间结果
5. 保存模型
'''
import os
import argparse
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from torchvision import transforms,datasets
from torch.utils.tensorboard import SummaryWriter
import warnings
warnings.filterwarnings('ignore')
parser=argparse.ArgumentParser("parameters configuring")
parser.add_argument("--batch_size",default=128,help="the batch size of dataset")
parser.add_argument("--epochs",default=100,help="the epochs of training")
parser.add_argument("--logs_dir",default='./logs',help="the path of logs")
parser.add_argument("--models_dir",default='./models',help="the path of saved models")
parser.add_argument("--lr",default=1e-3,help="the init learning rate")
parser.add_argument("--images_dir",default='./images',help="the path of saved images")
args=parser.parse_args()
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE=args.batch_size
train_dataset=datasets.MNIST(
root="./dataset",
train=True,
transform=transforms.ToTensor(),
download=True
)
train_dataloader=DataLoader(
train_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=0,
drop_last=False
)
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.model1=nn.Sequential(
nn.Linear(in_features=100,out_features=256),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Linear(in_features=256,out_features=512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Linear(in_features=512,out_features=1024),
nn.BatchNorm1d(1024),
nn.ReLU(),
nn.Linear(in_features=1024,out_features=28*28),
nn.BatchNorm1d(28*28),
nn.Sigmoid())
def forward(self,x):
x=self.model1(x)
return x.view((-1,1,28,28))
generator=Generator().to(device)
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model1=nn.Sequential(
nn.Flatten(),
nn.Linear(in_features=28*28,out_features=512),
nn.ReLU(),
nn.Linear(in_features=512, out_features=256),
nn.ReLU(),
nn.Linear(in_features=256, out_features=128),
nn.ReLU(),
nn.Linear(in_features=128, out_features=32),
nn.ReLU(),
nn.Linear(in_features=32, out_features=1),
nn.Sigmoid(),
)
def forward(self,x):
x=self.model1(x)
return x
discriminator=Discriminator().to(device)
generator_optimizer=torch.optim.Adam(generator.parameters(),lr=args.lr)
discriminator_optimizer=torch.optim.Adam(discriminator.parameters(),lr=args.lr)
loss_object2=nn.BCELoss().to(device)
def generator_creation(fake_out):
target=torch.ones_like(fake_out).to(device)
fake_loss=loss_object2(fake_out,target)
return fake_loss
def discriminator_creation(fake_out,real_out):
fake_target=torch.zeros_like(fake_out).to(device)
real_target=torch.ones_like(real_out).to(device)
fake_loss=loss_object2(fake_out,fake_target)
real_loss=loss_object2(real_out,real_target)
return fake_loss+real_loss
models_save_path=args.models_dir
images_save_path=args.images_dir
os.makedirs(models_save_path,exist_ok=True)
os.makedirs(images_save_path,exist_ok=True)
EPOCHS=args.epochs
total_train_steps=0
total_val_steps=0
train_info="the {} times of train,g_train_loss is {}."
for epoch in range(EPOCHS):
print("Epoch:{}".format(epoch))
generator.train()
discriminator.train()
total_train_loss=0.0
for batch,(imgs,_) in enumerate(train_dataloader):
imgs=imgs.to(device)
random_vector=torch.randn(size=(imgs.shape[0],100),dtype=torch.float32).to(device)
prediction=generator(random_vector)
fake_out=discriminator(prediction)
g_train_loss=generator_creation(fake_out)
generator_optimizer.zero_grad()
g_train_loss.backward()
generator_optimizer.step()
total_train_loss+=g_train_loss.item()
fake_out=discriminator(prediction.detach())
real_out=discriminator(imgs)
d_train_loss=discriminator_creation(fake_out,real_out)
discriminator_optimizer.zero_grad()
d_train_loss.backward()
discriminator_optimizer.step()
if (total_train_steps+1)%100==0:
print(train_info.format(total_train_steps,g_train_loss.item()))
total_train_steps+=1
print("the {} epoch of train,total_g_train_loss is {}.".format(epoch,total_train_loss))
generator.eval()
with torch.no_grad():
random_vector=torch.randn(size=(imgs.shape[0],100),dtype=torch.float32).to(device)
output=generator(random_vector)
total_val_steps+=1
if (total_val_steps+1)%5==0:
save_image(output[:20],images_save_path+"/epoch{}.jpg".format((total_val_steps+1)/5),nrows=5,normalize=True)
torch.save(generator.state_dict(),models_save_path+"/generator{}.pth".format(epoch))
print("model saved!!")
二、中间结果
训练5轮:

训练100轮:
