模型架构
代码
数据准备
import os
import time
import matplotlib.pyplot as plt
import numpy as np
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets
import torch.nn as nn
import torch
os.makedirs("data", exist_ok=True)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(0.5,0.5),
])
train_dataset = datasets.MNIST('data',
train=True,
transform=transform,
download=True)
dataloader = torch.utils.data.DataLoader(train_dataset,batch_size=64,shuffle=True)
定义生成器
'''
输入:正态分布随机数噪声(长度为100)
输出:生成的图片,(1,28,28)
中间过程:
linear1: 100 -> 256
linear2: 256 -> 512
linear3: 512 -> 28*28
reshape: 28x28 -> (1,28,28)
'''
class Generator(nn.Module):
def __init__(self):
super(Generator,self).__init__()
self.model = nn.Sequential(nn.Linear(100,256),nn.ReLU(),
nn.Linear(256,512),nn.ReLU(),
nn.Linear(512,28*28),nn.Tanh())
def forward(self,x):
img = self.model(x)
img = img.view(-1,28,28,1)
return img
定义判别器
'''
判别器
输入:(1,28,28)的图片
输出:二分类的概率值 用sigmoid压缩到0-1之间
内容:
判别器 推荐使用LeakyRelu,因为生成器难以训练,Relu的负值直接变成0没有梯度了
'''
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator,self).__init__()
self.model = nn.Sequential(
nn.Linear(28*28,512),nn.LeakyReLU(),
nn.Linear(512,256),nn.LeakyReLU(),
nn.Linear(256,1),nn.Sigmoid(),
)
def forward(self,x):
x = x.view(-1,28*28)
x = self.model(x)
return x
初始化模型,优化器及损失计算函数
device = 'cuda' if torch.cuda.is_available() else 'cpu'
gen = Generator().to(device)
dis = Discriminator().to(device)
dis_optim = torch.optim.Adam(dis.parameters(),lr=0.0001)
gen_optim = torch.optim.Adam(gen.parameters(),lr=0.0001)
bce_loss = torch.nn.BCELoss()
画生成器生成的图的绘图函数
def gen_img_plot(model,epoch,test_input):
prediction = model(test_input).detach().cpu().numpy()
prediction = np.squeeze(prediction)
fig = plt.figure(figsize=(4,4))
for i in range(16):
plt.subplot(4,4,i+1)
plt.imshow((prediction[i] + 1) / 2)
plt.axis('off')
plt.show()
显示图片的函数
def img_plot(img):
img = np.squeeze(img)
fig = plt.figure(figsize=(4,4))
for i in range(16):
plt.subplot(4,4,i+1)
plt.imshow((img[i] + 1) / 2)
plt.axis('off')
plt.show()
定义训练函数
def train(num_epoch,test_input):
D_loss = []
G_loss = []
for epoch in range(num_epoch):
d_epoch_loss = 0
g_epoch_loss = 0
count = len(dataloader)
for step,(img,_) in enumerate(dataloader):
img = img.to(device)
size = img.size(0)
random_noise = torch.randn(size,100,device=device)
'''一. 训练判别器'''
'''用真实图片训练判别器'''
dis_optim.zero_grad()
real_output = dis(img)
d_real_loss = bce_loss(real_output,
torch.ones_like(real_output))
d_real_loss.backward()
'''用生成的图片训练判别器'''
gen_img = gen(random_noise)
fake_output = dis(gen_img.detach())
d_fake_loss = bce_loss(fake_output,
torch.zeros_like(fake_output))
d_fake_loss.backward()
d_loss = d_real_loss+d_fake_loss
dis_optim.step()
'''二.训练生成器'''
gen_optim.zero_grad()
fake_output = dis(gen_img)
g_loss = bce_loss(fake_output,
torch.ones_like(fake_output))
g_loss.backward()
gen_optim.step()
with torch.no_grad():
d_epoch_loss +=d_loss
g_epoch_loss +=g_loss
with torch.no_grad():
d_epoch_loss /= count
g_epoch_loss /= count
D_loss.append(d_epoch_loss)
G_loss.append(g_epoch_loss)
print('Epoch:', epoch+1)
print(f'd_epoch_loss={d_epoch_loss}')
print(f'g_epoch_loss={g_epoch_loss}')
gen_img_plot(gen,test_input)
开始训练
'''开始计时'''
start_time = time.time()
'''开始训练'''
test_input = torch.randn(16,100,device=device)
print(test_input)
num_epoch = 50
train(num_epoch,test_input)
torch.save(gen.state_dict(),'gen_weights.pth')
torch.save(dis.state_dict(),'dis_weights.pth')
'''计时结束'''
end_time = time.time()
run_time = end_time - start_time
if int(run_time)<60:
print(f'{round(run_time,2)}s')
else:
print(f'{round(run_time/60,2)}minutes')
结果可视化
加载训练好的参数
gen.load_state_dict(torch.load('/opt/software/computer_vision/codes/My_codes/paper_codes/GAN/weights/gen_weights.pth'))
用训练好的生成器生成图片并画图
test_new_input = torch.randn(16,100,device=device)
gen_img_plot(gen,test_new_input)
GAN的生成是随机的,不同的噪声,生成不同的数字