GANmnist训练集图片生成

参考博客:https://github.com/lpty/tensorflow_tutorial

import os
import pickle
import numpy as np
import tensorflow as tf
from matplotlib import pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
from datetime import datetime


class MnistModel:

    def __init__(self):
        # mnist测试集
        self.mnist = input_data.read_data_sets('MNSIT_data', one_hot=True)
        # 图片大小
        self.img_size = self.mnist.train.images[0].shape[0]
        # 每步训练使用图片数量
        self.batch_size = 64
        # 图片分块数量
        self.chunk_size = self.mnist.train.num_examples // self.batch_size
        # 训练循环次数
        self.epoch_size = 300
        # 抽取样本数
        self.sample_size = 25
        # 生成器判别器隐含层数量
        self.units_size = 128
        # 学习率
        self.learning_rate = 0.001
        # 平滑参数
        self.smooth = 0.1

    @staticmethod
    def generator_graph(fake_imgs, units_size, out_size, alpha=0.01):
        # 生成器与判别器属于两个网络 定义不同scope
        with tf.variable_scope('generator'):
            # 构建一个全连接层
            layer = tf.layers.dense(fake_imgs, units_size)
            # leaky ReLU 激活函数
            relu = tf.maximum(alpha * layer, layer)
            # dropout 防止过拟合
            drop = tf.layers.dropout(relu, rate=0.2)
            # logits
            # out_size应为真实图片size大小
            logits = tf.layers.dense(drop, out_size)
            # 激活函数 将向量值限定在某个区间 与 真实图片向量类似
            # 这里tanh的效果比sigmoid好一些
            # 输出范围(-1, 1) 采用sigmoid则为[0, 1]
            outputs = tf.tanh(logits)
            return logits, outputs

    @staticmethod
    def discriminator_graph(imgs, units_size, alpha=0.01, reuse=False):
        with tf.variable_scope('discriminator', reuse=reuse):
            # 构建全连接层
            layer = tf.layers.dense(imgs, units_size)
            # leaky ReLU 激活函数
            relu = tf.maximum(alpha * layer, layer)
            # logits
            # 判断图片真假 out_size直接限定为1
            logits = tf.layers.dense(relu, 1)
            # 激活函数
            outputs = tf.sigmoid(logits)
            return logits, outputs

    @staticmethod
    def loss_graph(real_logits, fake_logits, smooth):
        # 生成器图片loss
        # 生成器希望判别器判断出来的标签为1
        gen_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_logits, labels=tf.ones_like(fake_logits) * (1 - smooth)))
        # 判别器识别生成器图片loss
        # 判别器希望识别出来的标签为0
        fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_logits, labels=tf.zeros_like(fake_logits)))
        # 判别器识别真实图片loss
        # 判别器希望识别出来的标签为1
        real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=real_logits, labels=tf.ones_like(real_logits) * (1 - smooth)))
        # 判别器总loss
        dis_loss = tf.add(fake_loss, real_loss)
        return gen_loss, fake_loss, real_loss, dis_loss

    @staticmethod
    def optimizer_graph(gen_loss, dis_loss, learning_rate):
        # 所有定义变量
        train_vars = tf.trainable_variables()
        # 生成器变量
        gen_vars = [var for var in train_vars if var.name.startswith('generator')]
        # 判别器变量
        dis_vars = [var for var in train_vars if var.name.startswith('discriminator')]
        # optimizer
        # 生成器与判别器作为两个网络需要分别优化
        gen_optimizer = tf.train.AdamOptimizer(learning_rate).minimize(gen_loss, var_list=gen_vars)
        dis_optimizer = tf.train.AdamOptimizer(learning_rate).minimize(dis_loss, var_list=dis_vars)
        return gen_optimizer, dis_optimizer

    def train(self):
        # 真实图片与混淆图片
        # 不确定输入图片数量 用None
        real_imgs = tf.placeholder(tf.float32, [None, self.img_size], name='real_imgs')
        fake_imgs = tf.placeholder(tf.float32, [None, self.img_size], name='fake_imgs')

        # 生成器
        gen_logits, gen_outputs = self.generator_graph(fake_imgs, self.units_size, self.img_size)
        # 判别器对真实图片
        real_logits, real_outputs = self.discriminator_graph(real_imgs, self.units_size)
        # 判别器对生成器图片
        # 公用参数所以要reuse
        fake_logits, fake_outputs = self.discriminator_graph(gen_outputs, self.units_size, reuse=True)

        # 损失
        gen_loss, fake_loss, real_loss, dis_loss = self.loss_graph(real_logits, fake_logits, self.smooth)
        # 优化
        gen_optimizer, dis_optimizer = self.optimizer_graph(gen_loss, dis_loss, self.learning_rate)

        # 开始训练
        saver = tf.train.Saver()
        step = 0
        # 指定占用GPU比例
        # tensorflow默认占用全部GPU显存 防止在机器显存被其他程序占用过多时可能在启动时报错
        gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.8)
        with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
            sess.run(tf.global_variables_initializer())
            for epoch in range(self.epoch_size):
                for _ in range(self.chunk_size):
                    batch_imgs, _ = self.mnist.train.next_batch(self.batch_size)
                    batch_imgs = batch_imgs * 2 - 1
                    # generator的输入噪声
                    noise_imgs = np.random.uniform(-1, 1, size=(self.batch_size, self.img_size))
                    # 优化
                    _ = sess.run(gen_optimizer, feed_dict={fake_imgs: noise_imgs})
                    _ = sess.run(dis_optimizer, feed_dict={real_imgs: batch_imgs, fake_imgs: noise_imgs})
                    step += 1
                # 每一轮结束计算loss
                # 判别器损失
                loss_dis = sess.run(dis_loss, feed_dict={real_imgs: batch_imgs, fake_imgs: noise_imgs})
                # 判别器对真实图片
                loss_real = sess.run(real_loss, feed_dict={real_imgs: batch_imgs, fake_imgs: noise_imgs})
                # 判别器对生成器图片
                loss_fake = sess.run(fake_loss, feed_dict={real_imgs: batch_imgs, fake_imgs: noise_imgs})
                # 生成器损失
                loss_gen = sess.run(gen_loss, feed_dict={fake_imgs: noise_imgs})

                print(datetime.now().strftime('%c'), ' epoch:', epoch, ' step:', step, ' loss_dis:', loss_dis,
                      ' loss_real:', loss_real, ' loss_fake:', loss_fake, ' loss_gen:', loss_gen)
            model_path = os.getcwd() + os.sep + "mnist.model"
            saver.save(sess, model_path, global_step=step)

    def gen(self):
        # 生成图片
        sample_imgs = tf.placeholder(tf.float32, [None, self.img_size], name='sample_imgs')
        gen_logits, gen_outputs = self.generator_graph(sample_imgs, self.units_size, self.img_size)
        saver = tf.train.Saver()
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            saver.restore(sess, tf.train.latest_checkpoint('.'))
            sample_noise = np.random.uniform(-1, 1, size=(self.sample_size, self.img_size))
            samples = sess.run(gen_outputs, feed_dict={sample_imgs: sample_noise})
        with open('samples.pkl', 'wb') as f:
            pickle.dump(samples, f)

    @staticmethod
    def show():
        # 展示图片
        with open('samples.pkl', 'rb') as f:
            samples = pickle.load(f)
        fig, axes = plt.subplots(figsize=(7, 7), nrows=5, ncols=5, sharey=True, sharex=True)
        for ax, img in zip(axes.flatten(), samples):
            ax.xaxis.set_visible(False)
            ax.yaxis.set_visible(False)
            ax.imshow(img.reshape((28, 28)), cmap='Greys_r')
        plt.show()

if __name__ == '__main__':
    mnist = MnistModel()
    mnist.train()
    minst.gen()
    mnist.show()

 

你可能感兴趣的:(Gan,深度学习)