DCGAN生成彩色图片

我们以cifar数据集作为训练数据

import tensorflow as tf
import matplotlib.pyplot as plt


import numpy as np

mnist = np.load('images.npy')

print('---------load data successful-----------')
def input_placeholder(img_size,noise_size):
    img=tf.placeholder(dtype=tf.float32,shape=(None,img_size[1],img_size[2],img_size[3]),name='input_image')
    noise=tf.placeholder(dtype=tf.float32,shape=(None,noise_size),name='input_noise')
    return img,noise

def generator(noise_img, output_dim, is_train=True, alpha=0.01):

    with tf.variable_scope("generator") as scope0:
        if not is_train:
            scope0.reuse_variables()

        # 100 x 1 to 4 x 4 x 512
        # 全连接层
        layer1 = tf.layers.dense(noise_img, 4*4*512)
        layer1 = tf.reshape(layer1, [-1, 4, 4, 512])
        layer1 = tf.layers.batch_normalization(layer1, training=is_train)
        # Leaky ReLU
        layer1 = tf.maximum(alpha * layer1, layer1)
        # dropout
        layer1=tf.nn.dropout(layer1,keep_prob=0.9)


        # 4 x 4 x 512 to 8 x 8 x 256
        layer2 = tf.layers.conv2d_transpose(layer1, 256, 4, strides=2, padding='same')
        layer2 = tf.layers.batch_normalization(layer2, training=is_train)
        layer2 = tf.maximum(alpha * layer2, layer2)
        layer2=tf.nn.dropout(layer2,keep_prob=0.9)

        # 8 x 8 256 to 16 x 16 x 128
        layer3 = tf.layers.conv2d_transpose(layer2, 128, 3, strides=2, padding='same')
        layer3 = tf.layers.batch_normalization(layer3, training=is_train)
        layer3 = tf.maximum(alpha * layer3, layer3)
        layer3=tf.nn.dropout(layer3,keep_prob=0.9)

        # 16 x 16 x 128 to 32 x 32 x 3
        logits = tf.layers.conv2d_transpose(layer3, output_dim, 3, strides=2, padding='same')

        outputs = tf.tanh(logits)

        tf.summary.image('input',outputs,10)

        return outputs



def discriminator(img_or_noise, reuse=False, alpha=0.01):
    with tf.variable_scope('discriminator') as scope1:

        if reuse:
            scope1.reuse_variables()

        layer1=tf.layers.conv2d(img_or_noise,128,3,strides=2,padding='same')
        layer1=tf.maximum(alpha*layer1,layer1)
        layer1=tf.nn.dropout(layer1,keep_prob=0.9)


        layer2=tf.layers.conv2d(layer1,256,3,strides=2,padding='same')
        layer2=tf.layers.batch_normalization(layer2,training=True)
        layer2=tf.maximum(alpha*layer2,layer2)
        layer2=tf.nn.dropout(layer2,keep_prob=0.9)

        layer3=tf.layers.conv2d(layer2,512,3,strides=2,padding='same')
        layer3=tf.layers.batch_normalization(layer3,training=True)
        layer3=tf.maximum(alpha*layer3,layer3)
        layer3=tf.nn.dropout(layer3,keep_prob=0.9)

        flatten=tf.reshape(layer3,(-1,4*4*512))
        logits=tf.layers.dense(flatten,1)
        outputs=tf.sigmoid(logits)
        return logits,outputs


def inference(real_img,fake_noise,image_depth=3,smooth=0.1):
    g_outputs=generator(fake_noise,image_depth,is_train=True)
    d_logits_real,d_outputs_real=discriminator(real_img)
    d_logits_fake,d_outputs_fake=discriminator(g_outputs,reuse=True)

    g_loss=tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,
                                        labels=tf.ones_like(d_outputs_fake)*(1-smooth)))

    d_loss_real=tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_real,
                                        labels=tf.ones_like(d_outputs_real)*(1-smooth)))

    d_loss_fake=tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,
                                        labels=tf.zeros_like(d_outputs_fake)))

    d_loss=tf.add(d_loss_real,d_loss_fake)

    tf.summary.scalar('d_loss_real', d_loss_real)
    tf.summary.scalar('d_loss_fake', d_loss_fake)



    return g_loss,d_loss

def test(fake_placeholder,output_dim=3,num_images=25):
    with tf.Session() as sess:
        saver.restore(sess,tf.train.latest_checkpoint('checkpoints'))
        fake_shape=fake_placeholder.shape.as_list()[-1]

        fake_images=np.random.uniform(-1,1,size=[num_images,fake_shape])

        samples=sess.run(generator(fake_placeholder,output_dim,is_train=False),
                         feed_dict={fake_placeholder:fake_images})




        plot_image(samples)

def plot_image(samples):

    samples=(samples+1)/2.
    fig,axes=plt.subplots(nrows=5,ncols=5,figsize=(7,7))

    for img,ax in zip(samples,axes.flatten()):        
        ax.imshow(img, cmap='Greys_r')        
        ax.axis('off')


def get_optimizer(g_loss,d_loss,beta=0.4,learning_rate=0.001):
    train_vars=tf.trainable_variables()
    g_vars=[var for var in train_vars if var.name.startswith('generator')]
    d_vars=[var for var in train_vars if var.name.startswith('discriminator')]
    with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
        g_opt=tf.train.AdamOptimizer(learning_rate,beta1=beta).minimize(g_loss,var_list=g_vars)
        d_opt=tf.train.AdamOptimizer(learning_rate,beta1=beta).minimize(d_loss,var_list=d_vars)

    return g_opt,d_opt

def train(real_placeholder,fake_placeholder,g_train_opt,d_train_opt,epoches,noise_size=100,batch_size=64,n_samples=25):
    global_step_=tf.Variable(0,trainable=False)
    with tf.Session() as sess:

        sess.run(tf.global_variables_initializer())

        saver.restore(sess,tf.train.latest_checkpoint('checkpoints'))
        summary_writer = tf.summary.FileWriter('log/', sess.graph)

        for e in range(1,epoches):

            for step in range(len(mnist)//batch_size):
                global_step_=global_step_+1
                images=mnist[step*batch_size:(step+1)*batch_size]


                batch_image=images*2 -1

                batch_noise=np.random.uniform(-1,1,size=(batch_size,noise_size))

                sess.run(g_train_opt,feed_dict={real_placeholder:batch_image,fake_placeholder:batch_noise})


                sess.run(d_train_opt,feed_dict={real_placeholder:batch_image,fake_placeholder:batch_noise})

                summary_str=sess.run(summary,feed_dict={real_placeholder:batch_image,fake_placeholder:batch_noise})

                summary_writer.add_summary(summary_str, step)
                summary_writer.flush()

                if step %50 ==0:

                    train_loss_d=d_loss.eval({real_placeholder:batch_image,fake_placeholder:batch_noise})

                    train_loss_g=g_loss.eval(feed_dict={fake_placeholder:batch_noise})

                    print('step:{}/Epoch:{}/total Epoch:{}'.format(step,e,epoches),
                          'Discriminator Loss:{:.4f}..'.format(train_loss_d),'Generator Loss:{:.4f}..'.format(train_loss_g))

            saver.save(sess,'./checkpoints/generator.ckpt',global_step=global_step_)



with tf.Graph().as_default():

    real_img,fake_img=input_placeholder([-1,32,32,3],noise_size=100)

    g_loss,d_loss=inference(real_img,fake_img)
    summary = tf.summary.merge_all()
    g_train_opt,d_train_opt=get_optimizer(g_loss,d_loss)

    saver=tf.train.Saver()
    train(real_img,fake_img,g_train_opt,d_train_opt,epoches=5)
    test(fake_img,num_images=25)

你可能感兴趣的:(我的Python学习,tensorflow学习)