自2014年Ian Goodfellow提出生成对抗网络(GAN)的概念后,生成对抗网络变成为了学术界的一个火热的研究热点,Yann LeCun更是称之为”过去十年间机器学习领域最让人激动的点子”.
生成对抗网络包括一个生成器(Generator,简称G)生成数据,一个鉴别器(Discriminator,简称D)来鉴别真实数据和生成数据,两者同时训练,直到达到一个纳什均衡,生成器生成的数据与真实样本无差别,鉴别器也无法正确的区分生成数据和真实数据.
生成模型故名思议就是已知模型,用来生成适应该模型的数据,那么它有什么应用场景呢?
生成模型主要的应用场景有两种:
当我们拥有大量的数据,例如图像、语音、文本等,如果生成模型可以帮助我们模拟这些高维数据的分布,那么对很多应用将大有裨益。
针对数据量缺乏的场景,生成模型则可以帮助生成数据,提高数据数量,从而利用半监督学习提升学习效率。语言模型(language model)是生成模型被广泛使用的例子之一,通过合理建模,语言模型不仅可以帮助生成语言通顺的句子,还在机器翻译、聊天对话等研究领域有着广泛的辅助应用。
那么,如果有数据集S={x1,…xn},如何建立一个关于这个类型数据的生成模型呢?
最简单的方法就是:假设这些数据的分布P{X}服从g(x;θ),在观测数据上通过最大化似然函数得到θ的值,即最大似然法。
就是说我们在已知x的情况下,哪个参数(θ)才能更好的模拟出我们现在的分布,我们找到最能模拟我们源数据的参数
似然函数相关知识请戳
例如,我们知道一一些文本中有若干单词,我们就可以用单词出现的频率作为这些数据的分布(如单词“text”的概率0.3,“today”的概率为0.1),以这些概率来生成新的文档。
单词“text”的概率0.3,“today”的概率为0.1 最能模拟我们现在的文本
GAN也是一种生成模型,不过是一种以于半监督学习方式训练的模型,基于神经网络,经常被用在图像处理和半监督学习领域。
GAN有一个生成器(Generator,简称G)生成数据,一个鉴别器(Discriminator,简称D)鉴别数据是否与真实数据相似。
鉴别器的作用G:尽最大努力区分生成器生成的数据和真实数据
生成器作用D:生成和真实数据几乎没有差距的数据
上述的博弈过程就基本上是GAN的原理了。
那么GAN的数学形式是怎样的呢?
假设我们的生成模型是g(z),其中z是一个随机噪声,而g将这个随机噪声转化为数据类型x,仍拿图片问题举例,这里g的输出就是一张图片。
D是一个判别模型,对任何输入x,D(x)的输出是0-1范围内的一个实数,用来判断这个图片是一个真实图片的概率是多大。
令Pr和Pg分别代表真实图像的分布与生成图像的分布.
鉴别器的作用效果(尽最大努力区分):
上式数值越大说明区分效果越好
生成器的作用效果:
要最小化鉴别器的区分效果
整体效果大概是下面这样:
图中黑色虚线是真实数据的高斯分布,绿色的线是生成网络学习到的伪造分布,蓝色的线是判别网络判定为真实图片的概率,标x的横线代表服从高斯分布x的采样空间,标z的横线代表服从均匀分布z的采样空间。可以看出G就是学习了从z的空间到x的空间的映射关系。
直观上来看,绿线(生成数据)慢慢趋近于黑线(真实数据),蓝线(是真实数据的概率)慢慢趋于稳定(是真实数据的概率等于1),说明我们已经骗过了判别器
优势
劣势
以上是GAN的发明者者回答网友问
常见的改进深度卷积的对抗生成网络(DC-GAN),在图像中有着很重要的应用
在图像生成过程中,如何设计生成模型和判别模型呢?深度学习里,对图像分类建模,刻画图像不同层次,抽象信息表达的最有效的模型是:CNN (convolutional neural network,卷积神经网络)。
在CSDN上看到一个例子,会发现DC-GAN的优化效果会好很多 :
我们需要准备以下的东西:
# ##### DATA: Target data and generator input data
def get_distribution_sampler(mu, sigma):
return lambda n: torch.Tensor(np.random.normal(mu, sigma, (1, n))) # Gaussian
def get_generator_input_sampler():
return lambda m, n: torch.rand(m, n) # Uniform-dist data into generator, _NOT_ Gaussian
# ##### MODELS: Generator model and discriminator model
class Generator(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(Generator, self).__init__()
self.map1 = nn.Linear(input_size, hidden_size)
self.map2 = nn.Linear(hidden_size, hidden_size)
self.map3 = nn.Linear(hidden_size, output_size)
def forward(self, x):
x = F.elu(self.map1(x))
x = F.sigmoid(self.map2(x))
return self.map3(x)
class Discriminator(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(Discriminator, self).__init__()
self.map1 = nn.Linear(input_size, hidden_size)
self.map2 = nn.Linear(hidden_size, hidden_size)
self.map3 = nn.Linear(hidden_size, output_size)
def forward(self, x):
x = F.elu(self.map1(x))
x = F.elu(self.map2(x))
return F.sigmoid(self.map3(x))
for d_index in range(d_steps):
# 1. Train D on real+fake
D.zero_grad()
# 1A: Train D on real
d_real_data = Variable(d_sampler(d_input_size))
d_real_decision = D(preprocess(d_real_data))
d_real_error = criterion(d_real_decision, Variable(torch.ones(1))) # ones = true
d_real_error.backward() # compute/store gradients, but don't change params
# 1B: Train D on fake
d_gen_input = Variable(gi_sampler(minibatch_size, g_input_size))
d_fake_data = G(d_gen_input).detach() # detach to avoid training G on these labels
d_fake_decision = D(preprocess(d_fake_data.t()))
d_fake_error = criterion(d_fake_decision, Variable(torch.zeros(1))) # zeros = fake
d_fake_error.backward()
d_optimizer.step() # Only optimizes D's parameters; changes based on stored gradients from backward()
for g_index in range(g_steps):
# 2. Train G on D's response (but DO NOT train D on these labels)
G.zero_grad()
gen_input = Variable(gi_sampler(minibatch_size, g_input_size))
g_fake_data = G(gen_input)
dg_fake_decision = D(preprocess(g_fake_data.t()))
g_error = criterion(dg_fake_decision, Variable(torch.ones(1))) # we want to fool, so pretend it's all genuine
g_error.backward()
g_optimizer.step() # Only optimizes G's parameters
代码链接
https://github.com/devnag/pytorch-generative-adversarial-networks/blob/master/gan_pytorch.py
下面是训练的结果:
经过2万训练回合, 平均 G 的输出过度 4.0, 但然后回来在一个相当稳定, 正确的范围 (左)。同样, 标准偏差最初下落在错误方向, 但然后上升到期望1.25 范围 (正确), 匹配 R。
由 G自动生成的最终分配结果如上图。
参考链接
[到底什么是生成式对抗网络](https://zhuanlan.zhihu.com/p/26994666?utm_source=com.tencent.tim&utm_medium=social
)
火热的生成对抗网络(GAN),你究竟好在哪里
Generative Adversarial Networks (GANs) in 50 lines of code (PyTorch)