import torch
import torchvision
from torch.utils.data import DataLoader
from torch import nn
import matplotlib.pyplot as plt
def generate_noise(num):
return torch.randn(size=(num, 1, 28, 28))
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.Sequential(nn.Flatten(),
nn.Linear(784, 1568),
nn.ReLU(),
nn.Linear(in_features=1568, out_features=1200),
nn.ReLU(),
nn.Linear(in_features=1200, out_features=784),
nn.Sigmoid(),
nn.Unflatten(dim=1, unflattened_size=torch.Size([1, 28, 28])))
def forward(self, X):
Y = self.layers(X)
return Y
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.Sequential(nn.Flatten(),
nn.Linear(in_features=784, out_features=392),
nn.ReLU(),
nn.Linear(in_features=392, out_features=196),
nn.ReLU(),
nn.Linear(in_features=196, out_features=1),
nn.Sigmoid())
def forward(self, X):
Y = self.layers(X)
return Y
def train_discriminator_one_times(discriminator, generator, train_loader, batch_size,
criterion_discriminator, optimizer_discriminator):
for datas, _ in train_loader:
noises = generate_noise(batch_size)
fakes = generator(noises)
images = torch.cat([datas, fakes], dim=0)
discriminator.train()
dis_result = discriminator(images).reshape(batch_size*2)
labels = torch.tensor([1-i for i in range(2) for j in range(batch_size)], dtype=torch.float32)
loss = criterion_discriminator(dis_result, labels)
optimizer_discriminator.zero_grad()
loss.sum().backward()
optimizer_discriminator.step()
return discriminator
def train_generator_one_times(discriminator, generator, batch_size, criterion_generator,
optimizer_generator):
noises = generate_noise(batch_size)
generator.train()
fakes = generator(noises)
# with torch.no_grad():
# discriminator.eval()
# dis_result = discriminator(fakes).reshape(batch_size)
dis_result = discriminator(fakes).reshape(batch_size)
labels = torch.tensor([1 - i for i in range(1) for j in range(batch_size)], dtype=torch.float32)
loss = criterion_generator(dis_result, labels)
optimizer_generator.zero_grad()
loss.sum().backward()
optimizer_generator.step()
return generator
def init_weights(m):
if type(m) == nn.Linear:
torch.nn.init.xavier_uniform_(m.weight)
if __name__ == '__main__':
batch_size = 400
epochs = 100
K = 2
G = 25
lr = 0.03
noise = generate_noise(1)
MNIST_train_dataset = torchvision.datasets.MNIST(root='MNIST', train=False,
transform=torchvision.transforms.ToTensor(),
download=True)
train_loader = DataLoader(dataset=MNIST_train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
torch.manual_seed(5)
discriminator = Discriminator()
discriminator.apply(init_weights)
generator = Generator()
generator.apply(init_weights)
criterion_discriminator = torch.nn.BCELoss()
criterion_generator = torch.nn.BCELoss()
optimizer_discriminator = torch.optim.SGD(discriminator.parameters(), lr=lr)
optimizer_generator = torch.optim.SGD(generator.parameters(), lr=lr)
generator_state_dict = torch.load('generator_state_dict.pt')
generator.load_state_dict(generator_state_dict)
discriminator_state_dict = torch.load('discriminator_state_dict.pt')
discriminator.load_state_dict(discriminator_state_dict)
for epoch in range(epochs):
for k in range(K):
discriminator = train_discriminator_one_times(discriminator=discriminator, generator=generator,
train_loader=train_loader, batch_size=batch_size,
criterion_discriminator=criterion_discriminator,
optimizer_discriminator=optimizer_discriminator)
torch.save(discriminator, 'discriminator.pt')
torch.save(discriminator.state_dict(), 'discriminator_state_dict.pt')
print(f'Epoch.k {epoch}.{k}, training discriminator has been finished.')
for g in range(G): # 一共训练 batch_size * G 个噪声样本
generator = train_generator_one_times(discriminator=discriminator, generator=generator,
batch_size=batch_size, criterion_generator=criterion_generator,
optimizer_generator=optimizer_generator)
torch.save(generator.state_dict(), 'generator_state_dict.pt')
torch.save(generator, 'generator.pt')
print(f'Epoch {epoch}, training generator has been finished.')
with torch.no_grad():
generator.eval()
fake = generator(noise)
plt.imshow(fake[0][0].detach().numpy(), cmap='gray')
plt.show()
import torch
import torchvision
from torch.utils.data import DataLoader
from torch import nn
import matplotlib.pyplot as plt
def discriminator_loss(dis_result, batch_size):
loss = 0
for i in range(batch_size):
a = torch.log(dis_result[i])
b = torch.log(1 - dis_result[i + batch_size])
if (a < torch.tensor(-100)).detach().numpy():
a = torch.tensor([-100])
if (b < torch.tensor(-100)).detach().numpy():
b = torch.tensor([-100])
loss -= (a + b)
return loss / (2 * batch_size)
def generator_loss(dis_result, batch_size):
loss = 0
for i in range(batch_size):
a = torch.log(dis_result[i])
if (a < torch.tensor(-100)).detach().numpy():
a = torch.tensor([-100])
loss -= a
return loss / batch_size
def generate_noise(num):
return torch.randn(size=(num, 1, 28, 28))
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.Sequential(nn.Flatten(),
nn.Linear(784, 1200),
# nn.BatchNorm1d(1200),
nn.GELU(),
nn.Linear(in_features=1200, out_features=1600),
# nn.BatchNorm1d(1600),
nn.GELU(),
nn.Linear(in_features=1600, out_features=784),
nn.Sigmoid(),
nn.Unflatten(dim=1, unflattened_size=torch.Size([1, 28, 28])))
def forward(self, X):
Y = self.layers(X)
return Y
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.Sequential(nn.Flatten(),
nn.Linear(in_features=784, out_features=392),
# nn.BatchNorm1d(392),
nn.ReLU(),
nn.Linear(in_features=392, out_features=196),
# nn.BatchNorm1d(196),
nn.ReLU(),
nn.Linear(in_features=196, out_features=1),
nn.Sigmoid())
def forward(self, X):
Y = self.layers(X)
return Y
def train_discriminator_one_times(discriminator, generator, train_loader, batch_size,
criterion_discriminator, optimizer_discriminator):
for datas, _ in train_loader:
noises = generate_noise(batch_size)
fakes = generator(noises)
images = torch.cat([datas, fakes], dim=0)
discriminator.train()
dis_result = discriminator(images).reshape(batch_size*2)
labels = torch.tensor([1-i for i in range(2) for j in range(batch_size)], dtype=torch.float32)
loss = criterion_discriminator(dis_result, batch_size)
optimizer_discriminator.zero_grad()
loss.sum().backward()
optimizer_discriminator.step()
return discriminator
def train_generator_one_times(discriminator, generator, batch_size, criterion_generator,
optimizer_generator):
noises = generate_noise(batch_size)
generator.train()
fakes = generator(noises)
# with torch.no_grad():
# discriminator.eval()
# dis_result = discriminator(fakes).reshape(batch_size)
dis_result = discriminator(fakes).reshape(batch_size)
labels = torch.tensor([1 - i for i in range(1) for j in range(batch_size)], dtype=torch.float32)
loss = criterion_generator(dis_result, batch_size)
optimizer_generator.zero_grad()
loss.sum().backward()
optimizer_generator.step()
return generator
if __name__ == '__main__':
batch_size = 400
epochs = 100
K = 1
G = 25
lr = 0.1
MNIST_train_dataset = torchvision.datasets.MNIST(root='MNIST', train=False,
transform=torchvision.transforms.ToTensor(),
download=True)
train_loader = DataLoader(dataset=MNIST_train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
torch.manual_seed(5)
discriminator = Discriminator()
generator = Generator()
criterion_discriminator = discriminator_loss
criterion_generator = generator_loss
optimizer_discriminator = torch.optim.SGD(discriminator.parameters(), lr=lr)
optimizer_generator = torch.optim.SGD(generator.parameters(), lr=lr)
# generator_state_dict = torch.load('generator_state_dict.pt')
# generator.load_state_dict(generator_state_dict)
# discriminator_state_dict = torch.load('discriminator_state_dict.pt')
# discriminator.load_state_dict(discriminator_state_dict)
for epoch in range(epochs):
for k in range(K):
discriminator = train_discriminator_one_times(discriminator=discriminator, generator=generator,
train_loader=train_loader, batch_size=batch_size,
criterion_discriminator=criterion_discriminator,
optimizer_discriminator=optimizer_discriminator)
# torch.save(discriminator, 'discriminator.pt')
# torch.save(discriminator.state_dict(), 'discriminator_state_dict.pt')
print(f'Epoch.k {epoch}.{k}, training discriminator has been finished.')
for g in range(G): # 一共训练 batch_size * G 个噪声样本
generator = train_generator_one_times(discriminator=discriminator, generator=generator,
batch_size=batch_size, criterion_generator=criterion_generator,
optimizer_generator=optimizer_generator)
# torch.save(generator.state_dict(), 'generator_state_dict.pt')
# torch.save(generator, 'generator.pt')
print(f'Epoch {epoch}, training generator has been finished.')
with torch.no_grad():
generator.eval()
noise = generate_noise(1)
fake = generator(noise)
plt.imshow(fake[0][0].detach().numpy(), cmap='gray')
plt.show()
经常我们觉得挺合理的Loss设置其实是不对的,不一定满足网络实际需要。在开始的时候我认为把discriminator的BCELoss换成正样本的 1 − x n 1-x_n 1−xn的和加上负样本的 x n x_n xn是挺有道理的,但是事实是网络因此而无法训练。当我把Loss换为自定义的BCELoss,就可以正常训练了。
另外,自定义的Loss中也是可以使用If语句的。
在生成对抗网络中,训练判别器discriminator时我们需要“固定”生成器generator,将随机噪声通过生成器得到伪造样本作为负样本,和真实数据的正样本一起联合放到discriminator中训练。在这个过程中我们需要用到generator,但不训练generator,因此generator就是“固定的”。同理,在训练generator的时候我们需要用到discriminator但不训练discriminator,因此这时discriminator就是“固定的”。但是这种固定应该怎么体现呢? 开始的时候我觉得,比如固定generator,就应该是在使用generator的时候用 with no_grad()去设置一个无自动梯度的环境并且设置generator.eval()调整到推理模式。这些设置当然对于加速优化是很有好处的,但如果设置不当就可能会有一系列麻烦。在比较不那么在意加速优化的情况下,其实我们可以不设置无自动梯度环境及将generator调整到推理模式,因为虽然在训练discriminator的过程中generator的参数也会有梯度,但是我们在更新参数的时候使用的是optimizer_discriminator.step(),而在设置optimizer_discriminator时是这样设置的
optimizer_discriminator = torch.optim.SGD(discriminator.parameters(), lr=lr)
这也就意味着在调用optimizer_discriminator.step()的时候只有discriminator中的参数受到了反向传播算法的影响而被更新,generator中的参数并未更新。
刚开始按照习惯我将学习率设置为 0.3 0.3 0.3,但是因为结果一直出不来,我又将学习率设置为了 1 1 1,后面就一直忘改了。但是就是因为这个 1 1 1的学习率出现了很多问题,后来还是调整回了 0.1 0.1 0.1才得以训练成功。就是说学习率是影响训练的一个很重要的因素,容易被忽视。