生成对抗网络就是两个网络进行相互对抗,相互进行的过程。GAN的主要部分为生成器Generator和判别器Discriminator。
生成器:输入一个随机向量,生成一个图片。希望生成的图片越像真的越好。
判别器:输出一个图片,判别图片的真伪。希望生成的图片判别为假,数据集中的图片判别为真。
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.fc1=nn.Linear(28*28,256)
self.lr1=nn.LeakyReLU(0.2)
self.fc2=nn.Linear(256,256)
self.lr2=nn.LeakyReLU(0.2)
self.fc3=nn.Linear(256,1)
self.sigmoid=nn.Sigmoid()
def forward(self,input):
x=input.view(-1,28*28)
x=self.fc1(x)
x=self.lr1(x)
x=self.fc2(x)
x=self.lr2(x)
x=self.fc3(x)
y=self.sigmoid(x)
return y
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.fc1=nn.Linear(100,256)
self.relu1=nn.ReLU()
self.fc2=nn.Linear(256,256)
self.relu2=nn.ReLU()
self.fc3=nn.Linear(256,28*28)
self.tanh=nn.Tanh()
def forward(self,input):
x=self.fc1(input)
x=self.relu1(x)
x=self.fc2(x)
x=self.relu2(x)
x=self.fc3(x)
y=self.tanh(x)
return y
real_out=D(real_img)
d_loss_real=criterion(real_out,real_label)
fake_img=G(z).detach()
fake_out=D(fake_img)
d_loss_fake=criterion(fake_out,fake_label)
d_loss=d_loss_real+d_loss_fake
d_optimizer.zero_grad()
d_loss.backward()
d_optimizer.step()
fake_img=G(z)
output=D(fake_img)
g_loss=criterion(output,real_label)
g_optimizer.zero_grad()
g_loss.backward()
g_optimizer.step()
import torch
from torch import nn
import torchvision
from torch.utils.data import DataLoader
from torch.autograd import Variable
import matplotlib.pyplot as plt
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.fc1=nn.Linear(28*28,256)
self.lr1=nn.LeakyReLU(0.2)
self.fc2=nn.Linear(256,256)
self.lr2=nn.LeakyReLU(0.2)
self.fc3=nn.Linear(256,1)
self.sigmoid=nn.Sigmoid()
def forward(self,input):
x=input.view(-1,28*28)
x=self.fc1(x)
x=self.lr1(x)
x=self.fc2(x)
x=self.lr2(x)
x=self.fc3(x)
y=self.sigmoid(x)
return y
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.fc1=nn.Linear(100,256)
self.relu1=nn.ReLU()
self.fc2=nn.Linear(256,256)
self.relu2=nn.ReLU()
self.fc3=nn.Linear(256,28*28)
self.tanh=nn.Tanh()
def forward(self,input):
x=self.fc1(input)
x=self.relu1(x)
x=self.fc2(x)
x=self.relu2(x)
x=self.fc3(x)
y=self.tanh(x)
return y
def to_img(x):
out=0.5*(x+1)
out=out.clamp(0,1)
out=out.view(-1,28,28,1)
return out
def save_images(fake_image,epoch):
r, c = 5, 5
fig, axs = plt.subplots(r, c)
cnt = 0
for i in range(r):
for j in range(c):
axs[i,j].imshow(fake_image[cnt, :,:,0], cmap='gray')
axs[i,j].axis('off')
cnt += 1
fig.savefig("images/epoch_{}.png".format(epoch+1))
plt.close()
if __name__=='__main__':
epochs=100
batch_size=128
z_dimension=100
transformer=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.5,),(0.5,))
])
mnist=torchvision.datasets.MNIST(root='data',train=True,transform=transformer,download=True)
dataloader=DataLoader(dataset=mnist,batch_size=batch_size,shuffle=True)
D=Discriminator()
G=Generator()
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
D = D.to(device)
G = G.to(device)
criterion=nn.BCELoss()
d_optimizer=torch.optim.Adam(D.parameters(),lr=0.0003)
g_optimizer=torch.optim.Adam(G.parameters(),lr=0.0003)
for epoch in range(epochs):
for idx,(img,_) in enumerate(dataloader):
num_img=img.size(0)
img=img.view(num_img,-1)
real_img=Variable(img)
real_label=Variable(torch.ones(num_img,1))
z=Variable(torch.randn(num_img,z_dimension))
fake_label=Variable(torch.zeros(num_img,1))
real_img=real_img.to(device)
real_label=real_label.to(device)
fake_label=fake_label.to(device)
z=z.to(device)
'''
训练判别器
'''
real_out=D(real_img)
d_loss_real=criterion(real_out,real_label)
fake_img=G(z).detach()
fake_out=D(fake_img)
d_loss_fake=criterion(fake_out,fake_label)
d_loss=d_loss_real+d_loss_fake
d_optimizer.zero_grad()
d_loss.backward()
d_optimizer.step()
'''
训练生成器
'''
z=Variable(torch.randn(num_img,z_dimension))
z=z.to(device)
fake_img=G(z)
output=D(fake_img)
g_loss=criterion(output,real_label)
g_optimizer.zero_grad()
g_loss.backward()
g_optimizer.step()
if idx %100==0:
print('epoch:{} ,d_loss:{} ,g_loss:{}'.format(epoch,d_loss.item(),g_loss.item()))
fake_image=to_img(fake_img.cpu().data)
save_images(fake_image,epoch)
print('--------------------------')
torch.save(G.state_dict(),'generator.pth')
torch.save(D.state_dict(),'discriminator.pth')