在实现了GAN之后,下面就来看一下WGAN-GP实战
看一下WGAN如何解决training不稳定的问题
加了个1.3 gradientr penalty函数
wgan.py
import torch from torch import nn, optim, autograd import numpy as np import visdom import random from matplotlib import pyplot as plt h_dim = 400 batchsz = 512 viz = visdom.Visdom() class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() #z:[b.2] => [b,2] self.net = nn.Sequential( #一共4层 nn.Linear(2, h_dim), nn.ReLU(True), nn.Linear(h_dim, h_dim), nn.ReLU(True), nn.Linear(h_dim, h_dim), nn.ReLU(True), nn.Linear(h_dim, 2), ) def forward(self, z): output = self.net(z) return output class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() self.net = nn.Sequential( #4层 nn.Linear(2, h_dim), nn.ReLU(True), nn.Linear(h_dim, h_dim), nn.ReLU(True), nn.Linear(h_dim, h_dim), nn.ReLU(True), nn.Linear(h_dim, 1), nn.Sigmoid() #因为discriminator的输入是probability,用sigmoid函数把它弄到(0,1)的范围 ) def forward(self, x): output = self.net(x) return output.view(-1) def data_generator(): #8-gaussian mixture models scale = 2. centers = [ (1, 0), (-1, 0), (0, 1), (0, -1), (1. / np.sqrt(2), 1. / np.sqrt(2)), (1. / np.sqrt(2), -1. / np.sqrt(2)), (-1. / np.sqrt(2), 1. / np.sqrt(2)), (-1. / np.sqrt(2), -1. / np.sqrt(2)) ] centers = [(scale * x, scale * y) for x,y in centers] while True: dataset = [] for i in range(batchsz): point = np.random.randn(2) * 0.02 center = random.choice(centers) point[0] += center[0] point[1] += center[1] dataset.append(point) dataset = np.array(dataset).astype(np.float32) dataset /= 1.414 yield dataset #实现可视化的 def generate_image(D, G, xr, epoch): #xr就是真实的sample出来的x """ Generates and saves a plot of the true distribution, the generator, and the critic. """ N_POINTS = 128 RANGE = 3 plt.clf() points = np.zeros((N_POINTS, N_POINTS, 2), dtype='float32') points[:, :, 0] = np.linspace(-RANGE, RANGE, N_POINTS)[:, None] points[:, :, 1] = np.linspace(-RANGE, RANGE, N_POINTS)[None, :] points = points.reshape((-1, 2)) # (16384, 2) # print('p:', points.shape) # draw contour with torch.no_grad(): points = torch.Tensor(points).cuda() # [16384, 2] disc_map = D(points).cpu().numpy() # [16384] x = y = np.linspace(-RANGE, RANGE, N_POINTS) cs = plt.contour(x, y, disc_map.reshape((len(x), len(y))).transpose()) plt.clabel(cs, inline=1, fontsize=10) # plt.colorbar() # draw samples with torch.no_grad(): z = torch.randn(batchsz, 2).cuda() # [b, 2] samples = G(z).cpu().numpy() # [b, 2] plt.scatter(xr[:, 0], xr[:, 1], c='orange', marker='.') plt.scatter(samples[:, 0], samples[:, 1], c='green', marker='+') viz.matplot(plt, win='contour', opts=dict(title='p(x):%d'%epoch)) def gradient_penalty(D, xr, xf): #[b,1] t = torch.rand(batchsz, 1).cuda() #[b,1]=>[b,2] t = t.expand_as(xr) #在真实数据和fake数据之间做一个线性插值 mid = t * xr + (1-t) * xf #设置它需要导数信息 mid.requires_grad_() pred = D(mid) grads = autograd.grad(outputs=pred, inputs=mid, grad_outputs=torch.ones_like(pred), create_graph=True, retain_graph=True, only_inputs=True)[0] gp = torch.pow(grads.norm(2, dim=1)-1,2).mean() return gp def main(): torch.manual_seed(23) np.random.seed(23) data_iter = data_generator() x = next(data_iter) #[b,2] # print(x.shape) G = Generator().cuda() D = Discriminator().cuda() # print(G) # print(D) optim_G = optim.Adam(G.parameters(), lr=5e-4, betas=(0.5, 0.9)) optim_D = optim.Adam(D.parameters(), lr=5e-4, betas=(0.5, 0.9)) viz.line([[0,0]], [0], win='loss', opts=dict(title='loss',legend=['D','G'])) for epoch in range(50000): #1、首先训练Discriminator for _ in range(5): #1、train on real data x = next(data_iter) xr = torch.from_numpy(x).cuda() #把真实数据送入discriminator #[b,2]=>[b,1] predr = D(xr) #max predr lossr = -predr.mean() #1.2 train on fake data z = torch.randn(batchsz,2).cuda() #xf是xfake xf = G(z).detach() #类似于tf.stop_gradient() predf = D(xf) #min loss lossf = predf.mean() #1.3 gradient penalty gp = gradient_penalty(D, xr, xf.detach()) #aggregate all loss_D = lossr + lossf + 0.2 * gp #optimize optim_D.zero_grad() loss_D.backward() optim_D.step() #2、训练Generator #从数据中sample一个z出来 z = torch.randn(batchsz,2).cuda() xf = G(z) predf = D(xf) #max predf.mean() loss_G = -predf.mean() #optimize optim_G.zero_grad() loss_G.backward() optim_G.step() if epoch % 100 == 0: viz.line([[loss_D.item(), loss_G.item()]], [epoch], win='loss', update='append') print(loss_D.item(), loss_G.item()) generate_image(D,G, xr.cpu(), epoch) if __name__=='__main__': main()