生成对抗网络GANs

一、什么是GANs网络

1. GANs概念

GANs当然不是真的“干”,而是一种新型网络结构,全称是生成对抗网络(Generative Adversarial Nets,GANs),第一个词是“生成”,第二个词是“对抗”,因此称作生成对抗网络。

2. GANs分类

GANs又进化为两个分支:一个是典型的GANs,一个是拓展的GANs。
lsgans(Loss-Sensitive Generative Adversarial Networks on Lipschitz Densities** )
http://wiseodd.github.io/techblog/2017/03/02/least-squares-gan/
今天提到的是最原始版本的GANs网络(原始论文请戳这里),

二、pytorch中实现GANs网络

英文参考资料:Generative Adversarial Networks (GANs) in 50 lines of code (PyTorch)

上述英文参考资料对应中文地址请戳这里

三、理论推导过程

四、代码分析过程

根据对于GANs网络的理解,加上上述参考资料的辅助,对于代码进行分析如下:

1.几个记号

R:real data,原始、真实数据集

I:作为熵的一项来源,进入生成器的随机噪音,实际操作中先从均匀分布中采样得到样本数据,然后将该样本数据输入到生成器模型G中生成假数据fake data

G:生成器模型,试图模仿原始数据

D:判别器,试图区别 G 的生成数据和 真实数据R

目标:我们教 G 糊弄 D、教 D 当心 G 的“训练”环(在gan网络的目标函数中体现为两部分损失函数的加和)

2. 真假数据来源

真实数据(real data)采样于均值为mu,标准差为sigma的正态分布;假数据(fake data)为提高点难度,使用均匀分布(uniform distribution )而非标准正太分布。这意味着, 生成器模型 G 不能简单地改变输入(放大/缩小、平移)来复制真实数据,而需要用非线性的方式来改造数据。

pytorch代码如下
# ##### DATA: Target data and generator input data
#从正太分布中进行抽样
def get_distribution_sampler(mu, sigma):
    return lambda n: torch.Tensor(np.random.normal(mu, sigma, (1, n)))  # Gaussian
#上述生成结果是(1,5)的Variable,是一个二维结构
#从均匀分布中进行采样
def get_generator_input_sampler():
    return lambda m, n: torch.rand(m, n)  # Uniform-dist data into generator, _NOT_ Gaussian

3.生成器G和判别器D的网络结构

# ##### MODELS: Generator model and discriminator model
#生成器G具有两层隐层,三个线性映射(linear maps)
class Generator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Generator, self).__init__()
        #Linear是一种仿射操作,也就是一种线性操作;an affine operation: y = Wx + b(x1->x2)
        self.map1 = nn.Linear(input_size, hidden_size) 
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.map3 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = F.elu(self.map1(x))
        x = F.sigmoid(self.map2(x))
        return self.map3(x)

#判别器D从 R 或 G 那里获得样本,然后通过sigmoid输出 0 或 1 的判别值,对应反例和正例
class Discriminator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Discriminator, self).__init__()
        self.map1 = nn.Linear(input_size, hidden_size)
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.map3 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = F.elu(self.map1(x))
        x = F.elu(self.map2(x))
        return F.sigmoid(self.map3(x))

4.训练过程

for epoch in range(num_epochs):
    for d_index in range(d_steps):
        # 1. Train D on real+fake
        D.zero_grad()

        #  1A: Train D on real
        d_real_data = Variable(d_sampler(d_input_size))
        d_real_decision = D(preprocess(d_real_data))
        d_real_error = criterion(d_real_decision, Variable(torch.ones(1)))  # ones = true
        d_real_error.backward() # compute/store gradients, but don't change params

        #  1B: Train D on fake
        d_gen_input = Variable(gi_sampler(minibatch_size, g_input_size))
        d_fake_data = G(d_gen_input).detach()  # detach to avoid training G on these labels
        d_fake_decision = D(preprocess(d_fake_data.t()))
        d_fake_error = criterion(d_fake_decision, Variable(torch.zeros(1)))  # zeros = fake
        d_fake_error.backward()
        d_optimizer.step()     # Only optimizes D's parameters; changes based on stored gradients from backward()

    for g_index in range(g_steps):
        # 2. Train G on D's response (but DO NOT train D on these labels)
        G.zero_grad()

        gen_input = Variable(gi_sampler(minibatch_size, g_input_size))
        g_fake_data = G(gen_input)
        dg_fake_decision = D(preprocess(g_fake_data.t()))
        g_error = criterion(dg_fake_decision, Variable(torch.ones(1)))  # we want to fool, so pretend it's all genuine

        g_error.backward()
        g_optimizer.step()  # Only optimizes G's parameters

    if epoch % print_interval == 0:
        print("%s: D: %s/%s G: %s (Real: %s, Fake: %s) " % (epoch,
                                                            extract(d_real_error)[0],
                                                            extract(d_fake_error)[0],
                                                            extract(g_error)[0],
                                                            stats(extract(d_real_data)),
                                                            stats(extract(d_fake_data))))

五、Gans网络的应用

你可能感兴趣的:(生成对抗网络GANs)