GAN-cls:具有匹配感知的判别器

  GAN-cls是一种GAN网络增强技术——具有匹配感知的判别器。在InforGAN中,使用了ACGAN的方式进行指导模拟数据与生成数据的对应关系。在GAN-cls中该效果会以更简单的方式来实现,即增强判别器的功能,令其不仅能判断图片的真伪,还能判断匹配真伪。

一、GAN-cls的具体实现

  GAN-cls具体做法是,在原有GAN网络上,将判别器的输入变为图片与对应标签的连接数据。这样判别器的输入特征中就会有生成图像的特征与对应标签的特征。然后用这样的判别器分别对真实标签与真实图片、假标签与真实图片、真实标签与假图片进行判断,预期结果为真、假、假,在训练过程中沿这个方向收敛即可。对于生成器不需要作出改变。

二、实例92:使用GNA-cls技术实现生成标签匹配的模拟数据

实例描述

  使用GAN-cls技术对判别器进行改造,并通过输入错误的样本标签让判别器学习样本与标签的匹配,从而优化生成器,是生成器最终生成与标签一样的样本,实现与ACGAN同等的效果。

1.修改判别器

  将判别器的输入改成x与y,新增加的y代表输入的标签;在内部处理中,先通过全连接网络将y变成与图片一样维度的映射,并调整为图片相同的形状,使用concat将二者连接到一起统一处理。后续的处理过程是一样的,两个卷积后再连接连接两个全连接,最后输出disc。

def discriminator(x,y):     #x是正图片,y是标签
    reuse = len([t for t in tf.global_variables() if t.name.startswith('discriminator')]) > 0
    with tf.variable_scope('discriminator', reuse=reuse):
        # 通过一个全连接层将y标签转换成与x一样的维度
        y = slim.fully_connected(y, num_outputs=n_input, activation_fn = leaky_relu)
        y = tf.reshape(y, shape=[-1, 28, 28, 1])
        x = tf.reshape(x, shape=[-1, 28, 28, 1])
        
        x= tf.concat(axis=3, values=[x,y])
        # 两个卷积层+两个全连接层
        x = slim.conv2d(x, num_outputs = 64, kernel_size=[4,4], stride=2, activation_fn=leaky_relu)
        x = slim.conv2d(x, num_outputs=128, kernel_size=[4,4], stride=2, activation_fn=leaky_relu)
        x = slim.flatten(x)
        shared_tensor = slim.fully_connected(x, num_outputs=1024, activation_fn = leaky_relu)
        disc = slim.fully_connected(shared_tensor, num_outputs=1, activation_fn=tf.nn.sigmoid)
        disc = tf.squeeze(disc, -1)

    return disc

2.添加错误而标签输入符,构建网络结构

  添加错误标签misy,同时在判别器中分别将真实样本与真实标签、生成的图片gen与真实标签、真实样本与错误标签组成输入传入到判别器中。

这里将3中输入的x与y分别按照batch_size维度连接为判别器的输入。生成结果后再使用split函数将其裁为3个结果disc_real、disc_fake和disc_mis,分别代表真实样本与真实标签、生成的图像gen与真实标签、真实样本与错误标签所对应的判别值。

x = tf.placeholder(tf.float32, [None, n_input])
y = tf.placeholder(tf.int32, [None])    #正确的标签
misy = tf.placeholder(tf.int32, [None]) #错误的标签


z_rand = tf.random_normal((batch_size, rand_dim))#38列
z = tf.concat(axis=1, values=[tf.one_hot(y, depth = classes_dim), z_rand])#50列
gen = generator(z)
genout= tf.squeeze(gen, -1)



# 判别器
xin=tf.concat([x, tf.reshape(gen, shape=[-1,784]),x],0)
yin=tf.concat([tf.one_hot(y, depth = classes_dim),tf.one_hot(y, depth = classes_dim),tf.one_hot(misy, depth = classes_dim)],0)
disc_all = discriminator(xin,yin)
disc_real,disc_fake,disc_mis =tf.split(disc_all,3)


loss_d = tf.reduce_sum(tf.square(disc_real-1) + ( tf.square(disc_fake)+tf.square(disc_mis))/2 )/2
loss_g = tf.reduce_sum(tf.square(disc_fake-1))/2

  在计算判别器的loss时,同时使用LSGAN方式,并且将错误部分的loss变成disc_fake与disc_mis的和,然后除以2。因为对于生成器生成的样本与错误的标签输入,判别器都应该判别错误。

3.使用MonitoredTrainingSession创建session开始训练

  定义global_step,使用MonitoredTrainingSession创建session,来管理检查点文件,在session中构建错误标签数据,训练模型。

gen_global_step = tf.Variable(0, trainable=False)

global_step = tf.train.get_or_create_global_step()#使用MonitoredTrainingSession,必须有

train_disc = tf.train.AdamOptimizer(0.0001).minimize(loss_d , var_list = d_vars, global_step = global_step)
train_gen = tf.train.AdamOptimizer(0.001).minimize(loss_g , var_list = g_vars, global_step = gen_global_step)



training_epochs = 3
display_step = 1


with tf.train.MonitoredTrainingSession(checkpoint_dir='log/checkpointsnew',save_checkpoint_secs  =60) as sess:

    total_batch = int(mnist.train.num_examples/batch_size)
    print("global_step.eval(session=sess)",global_step.eval(session=sess),int(global_step.eval(session=sess)/total_batch))
    for epoch in range( int(global_step.eval(session=sess)/total_batch),training_epochs):
        avg_cost = 0.

        # 遍历全部数据集
        for i in range(total_batch):
            batch_xs, batch_ys = mnist.train.next_batch(batch_size)#取数据
            _, mis_batch_ys = mnist.train.next_batch(batch_size)#取数据
            feeds = {x: batch_xs, y: batch_ys,misy:mis_batch_ys}

            # Fit training using batch data
            l_disc, _, l_d_step = sess.run([loss_d, train_disc, global_step],feeds)
            l_gen, _, l_g_step = sess.run([loss_g, train_gen, gen_global_step],feeds)

        # 显示训练中的详细信息
        if epoch % display_step == 0:
            print("Epoch:", '%04d' % (epoch + 1), "cost=", "{:.9f} ".format(l_disc),l_gen)

    print("完成!")
    
    # 测试
    _, mis_batch_ys = mnist.train.next_batch(batch_size)
    print ("result:", loss_d.eval({x: mnist.test.images[:batch_size],y:mnist.test.labels[:batch_size],misy:mis_batch_ys},session = sess)
                        , loss_g.eval({x: mnist.test.images[:batch_size],y:mnist.test.labels[:batch_size],misy:mis_batch_ys},session = sess))
    
    # 根据图片模拟生成图片
    show_num = 10
    gensimple,inputx,inputy = sess.run(
        [genout,x,y], feed_dict={x: mnist.test.images[:batch_size],y: mnist.test.labels[:batch_size]})

    f, a = plt.subplots(2, 10, figsize=(10, 2))
    for i in range(show_num):
        a[0][i].imshow(np.reshape(inputx[i], (28, 28)))
        a[1][i].imshow(np.reshape(gensimple[i], (28, 28)))
        
        
    plt.draw()
    plt.show()  
    

你可能感兴趣的:(GAN,计算机视觉,深度学习,tensorflow)