https://arxiv.org/abs/1705.07215v5
DRAGAN,是在WGAN-gp和WGAN-div之间的一个WGAN-family的模型。和CT-GAN一样,都是在WGAN-gp的损失上做文章。
DRAGAN在分析了不同的GAN训练过程之后,发现,梯度消失的最大原因是D在x附近的邻域内出现梯度迅速上升的情况,从而导致了模型训练不稳定。
因为很自然的想法就是添加一个gradient penalty。但是不同于WGAN-gp中设计的是在real data和bogus data之间取一个中间的随机值,DRAGAN只考虑在real data领域内的D的导数不能太大。
也就是这样的思想下,DRAGAN提出了新的gradient penalty,即:
与WGAN-gp中使用的gp其实差不多。中间的那个范数也是用的L2。
在WGAN-gp中,gp的k取1,这里的DRAGAN虽然泛化到了K上,但是default值还是1。emmmm
比较有趣的是,关于距离x多远,这里给出的设计是,这个距离需要符合一个N(0, cI)分布,I就是1-向量。c就是一个常数,论文中给的是10。(这样就在距离x较为近的区域上计算梯度了)
在距离x足够近的一个邻域上计算梯度,保证这个梯度的范数和1足够近。(我寻思着,梯度逐渐靠近1,这不会是学着传统的GAN吧,毕竟logx在,x趋近0的时候导数也接近1。)
相关阅读
WGAN-gp 模型理论以及Python实现
WGAN模型理论以及Python实现
CT-GAN模型理论以及Python实现
WGAN-div模型理论以及Python实现
恰饭
实验
实验效果看起来还行。(有必要试下MLP下的效果可能会更好,因为比较之前有做过类似的,大家可以根据之前的文章做下)
main.py
import osimport torchfrom torch.utils.data import Dataset, DataLoaderfrom model import Generator, Discriminator, gp_lossimport torchvisionimport matplotlib.pyplot as pltif __name__ == '__main__': LR = 0.0002 EPOCH = 20 # 50 BATCH_SIZE = 100 N_IDEAS = 100 nc = 2 TRAINED = False lam = 10 DOWNLOAD_MNIST = False mnist_root = '../Conditional-GAN/mnist/' if not (os.path.exists(mnist_root)) or not os.listdir(mnist_root): # not mnist dir or mnist is empyt dir DOWNLOAD_MNIST = True train_data = torchvision.datasets.MNIST( root=mnist_root, train=True, # this is training data transform=torchvision.transforms.ToTensor(), # Converts a PIL.Image or numpy.ndarray to # torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0] download=DOWNLOAD_MNIST, ) train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True) torch.cuda.empty_cache() if TRAINED: G = torch.load('G.pkl').cuda() D = torch.load('D.pkl').cuda() else: G = Generator(N_IDEAS).cuda() D = Discriminator().cuda() optimizerG = torch.optim.Adam(G.parameters(), lr=LR) optimizerD = torch.optim.Adam(D.parameters(), lr=LR) for epoch in range(EPOCH): tmpD, tmpG = 0, 0 for step, (x, y) in enumerate(train_loader): x = x.cuda() rand_noise = torch.randn((x.shape[0], N_IDEAS, 1, 1)).cuda() G_imgs = G(rand_noise) D_fake = torch.squeeze(D(G_imgs)) D_real = torch.squeeze(D(x)) D_real_loss = -torch.mean(D_real) D_fake_loss = torch.mean(D_fake) D_loss = D_real_loss + D_fake_loss + lam * gp_loss(D, x, cuda=True) optimizerD.zero_grad() D_loss.backward(retain_graph=True) optimizerD.step() if (step + 1) % nc == 0: G_loss = -torch.mean(D_fake) optimizerG.zero_grad() G_loss.backward(retain_graph=True) optimizerG.step() tmpG_ = G_loss.cpu().detach().data tmpG += tmpG_ tmpD_ = D_loss.cpu().detach().data tmpD += tmpD_ tmpD /= (step + 1) tmpG /= (step + 1) print( 'epoch %d avg of loss: D: %.6f, G: %.6f' % (epoch, tmpD * nc, tmpG * nc) ) if epoch % 2 == 0: plt.imshow(torch.squeeze(G_imgs[0].cpu().detach()).numpy()) plt.show() torch.save(G, 'G.pkl') torch.save(D, 'D.pkl')
model.py
import osimport torch.autograd as autogradimport torchimport torch.nn as nnimport torch.utils.data as Dataimport torchvisionfrom torch.utils.data import DataLoaderclass Generator(nn.Module): def __init__(self, input_size): super(Generator, self).__init__() strides = [1, 2, 2, 2] padding = [0, 1, 1, 1] channels = [input_size, 256, 128, 64, 1] # 1表示一维 kernels = [4, 3, 4, 4] model = [] for i, stride in enumerate(strides): model.append( nn.ConvTranspose2d( in_channels=channels[i], out_channels=channels[i + 1], stride=stride, kernel_size=kernels[i], padding=padding[i] ) ) if i != len(strides) - 1: model.append( nn.BatchNorm2d(channels[i + 1], 0.8) ) model.append( nn.LeakyReLU(.2) ) else: model.append( nn.Tanh() ) self.main = nn.Sequential(*model) def forward(self, x): x = self.main(x) return xclass Discriminator(nn.Module): def __init__(self, input_size=1): super(Discriminator, self).__init__() strides = [2, 2, 2, 1] padding = [1, 1, 1, 0] channels = [input_size, 64, 128, 256, 1] # 1表示一维 kernels = [4, 4, 4, 3] model = [] for i, stride in enumerate(strides): model.append( nn.Conv2d( in_channels=channels[i], out_channels=channels[i + 1], stride=stride, kernel_size=kernels[i], padding=padding[i] ) ) model.append( nn.LeakyReLU(0.2) ) # if i != len(strides) - 1: # model.append( # nn.LeakyReLU(0.2) # ) # else: # model.append( # nn.Sigmoid() # ) self.main = nn.Sequential(*model) def forward(self, x): x = self.main(x) return xdef gp_loss(D, real_x, c=10, k=1, cuda=False): if cuda: theta = torch.normal(0, torch.ones((real_x.shape[0], 1, 1, 1)) * c).cuda() else: theta = torch.normal(0, torch.ones((real_x.shape[0], 1, 1, 1)) * c) x_ = (real_x + theta).requires_grad_(True) y_ = D(x_) # cal f'(x) grad = autograd.grad( outputs=y_, inputs=x_, grad_outputs=torch.ones_like(y_), create_graph=True, retain_graph=True, only_inputs=True, )[0] grad = grad.view(x_.shape[0], -1) gp = ((grad.norm(2, dim=1) - k) ** 2).mean() return gpif __name__ == '__main__': N_IDEAS = 100 G = Generator(N_IDEAS, ) rand_noise = torch.randn((10, N_IDEAS, 1, 1)) print(G(rand_noise).shape) DOWNLOAD_MNIST = False mnist_root = '../Conditional-GAN/mnist/' if not (os.path.exists(mnist_root)) or not os.listdir(mnist_root): # not mnist dir or mnist is empyt dir DOWNLOAD_MNIST = True train_data = torchvision.datasets.MNIST( root=mnist_root, train=True, # this is training data transform=torchvision.transforms.ToTensor(), # Converts a PIL.Image or numpy.ndarray to # torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0] download=DOWNLOAD_MNIST, ) D = Discriminator(1) print(len(train_data)) train_loader = Data.DataLoader(dataset=train_data, batch_size=2, shuffle=True) for step, (x, y) in enumerate(train_loader): print(x.shape) print(D(x).shape) print(gp_loss(D, x, x)) break
judge.py
import numpy as npimport torchimport matplotlib.pyplot as pltfrom model import Generator, Discriminatorimport torchvision.utils as vutilsif __name__ == '__main__': BATCH_SIZE = 100 N_IDEAS = 100 img_shape = (1, 28, 28) TIME = 5 G = torch.load("G.pkl").cuda() for t in range(TIME): rand_noise = torch.randn((BATCH_SIZE, N_IDEAS, 1, 1)).cuda() G_imgs = G(rand_noise).cpu().detach() fig = plt.figure(figsize=(10, 10)) plt.axis("off") plt.imshow(np.transpose(vutils.make_grid(G_imgs, nrow=10, padding=0, normalize=True, scale_each=True), (1, 2, 0))) plt.savefig(str(t) + '.png') plt.show()