不服就GAN:keras --- GAN 网络生成手写数字实例,生成数字 6

代码

1. 导包

import keras,os
from keras.models import *
from keras.layers import *
from keras.optimizers import *
from keras.losses import *
from keras import backend as K
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from PIL import Image
from keras.preprocessing import image


from keras.datasets import fashion_mnist,cifar10,cifar100,mnist
from keras.utils import to_categorical

os.environ["CUDA_VISIBLE_DEVICES"] = " 2"

2. 鉴别器和生成器的定义、创建

def generator(input_shape):
    inputs = Input(input_shape)
    # 先全连接到64*7*7的维度上
    x = Dense(128 * 14 * 14)(inputs)
    x = LeakyReLU(0.2)(x)
    x = Reshape((14, 14, 128))(x)

    x = Conv2D(256, 5, padding = 'same')(x)
    x = LeakyReLU(0.2)(x)
    

    x = Conv2DTranspose(256, 4, strides = 2, padding = 'same')(x)
    x = LeakyReLU(0.2)(x)
    

    x = Conv2D(256, 5, padding = 'same')(x)
    x = LeakyReLU(0.2)(x)
    
    x = Conv2D(256, 5, padding = 'same')(x)
    x = LeakyReLU(0.2)(x)
    

    x = Conv2D(1, 7, activation='tanh', padding = 'same')(x)
    return Model(inputs,x)
                  


def discriminator(input_shape):
                  
    inputs = Input(input_shape)
    # 28, 28, 1 -> 14, 14, 32
    x = Conv2D(128, 3)(inputs)
    x = LeakyReLU(0.2)(x)
    
    x = Conv2D(128,4,strides = 2)(x)
    x = LeakyReLU(0.2)(x)
    
    x = Conv2D(128,4,strides = 2)(x)
    x = LeakyReLU(0.2)(x)
    
    x = Conv2D(128, 4,strides = 2)(x)
    x = LeakyReLU(0.2)(x)

    x = Flatten()(x)

    x = Dropout(0.4)(x)
    x = Dense(1, activation='sigmoid')(x) #分类层

    return Model(inputs,x)


gen = generator((100,))

dis = discriminator((28,28,1))

dis.compile(loss=keras.losses.binary_crossentropy,optimizer= keras.optimizers.RMSprop(lr = 0.0008,clipvalue = 1.0,decay=1e-8))

3. 联合生成器和鉴别器创建 GAN 网络

def GAN():
    gan_input = Input((100,))
    fake_image = gen(gan_input)
    dis.trainable=False
    score = dis(fake_image)
    return Model(gan_input,score)

gan = GAN()
gan.compile(loss=keras.losses.binary_crossentropy,optimizer=keras.optimizers.RMSprop(lr=0.0004, clipvalue=1.0, decay=1e-8))

4. 数据导入+规范化

(x_train,y_train),(x_test,y_test)= mnist.load_data()

x_train = x_train[y_train.flatten() == 6] 

x_train = x_train.reshape(x_train.shape[0],28,28,1).astype('float32')/255.

5. 训练

epochs = 10000
batch_size = 64


generated_img = []
discriminator_loss = []
generator_loss = []
save_dir = './A-GAN-PHOTO'

start = 0

for epoch in range(epochs):
    
    noise = np.random.normal(size=(batch_size,100))
    stop = start + batch_size
    
    real_img = x_train[start:stop]
    
    fake_img = gen.predict(noise)
    
    data = np.concatenate([fake_img, real_img])
    
    valid = np.ones((batch_size,1))
    fake = np.zeros((batch_size,1))
    
    label = np.concatenate([fake,valid])
    label += 0.05 * np.random.random(label.shape)  ## 训练时加入噪声
    
    d_loss = dis.train_on_batch(data,label)

    # ---------------------
    #  训练生成模型
    # ---------------------
    noise_ = np.random.normal(size=(batch_size,100))
    g_loss = gan.train_on_batch(noise_, valid)
    
#     dis.trainable=True
#     dis.compile(loss=keras.losses.binary_crossentropy,optimizer= keras.optimizers.RMSprop(lr = 0.0008,clipvalue = 1.0,decay=1e-8))
    
    start += batch_size
    if start > len(x_train) - batch_size:
        start = 0
        

    if epoch%100 == 0:
#         im = fake_img[0].reshape((28,28))
        im = fake_img[0]
        
#         im = fake_img[0].reshape(32,32,3)
        generated_img.append(im)
        img = image.array_to_img(im * 255, scale=False)
        img.save(os.path.join(save_dir, 'fake_six' + str(epoch) + '.png'))	#保存一张生成图像

        img = image.array_to_img(real_img[0] * 255, scale=False)
        img.save(os.path.join(save_dir, 'real_six' + str(epoch) +'.png'))   #保存一张真实图像用于对比

        print('discriminator_loss:',d_loss)
        print('adversal_loss:',g_loss)
        discriminator_loss.append(d_loss)
        generator_loss.append(g_loss)
        print("epoch:%d" % epoch + "========")

6. 可视化

fig, axes = plt.subplots(nrows=5, ncols=20, sharex=True, sharey=True, figsize=(80,12))
imgs = [i.reshape(28,28) for i in generated_img]
# imgs = generated_img

for image, row in zip([imgs[:20], imgs[20:40],imgs[40:60],imgs[60:80],imgs[80:100]], axes):
# for image, row in zip([imgs[0:10],imgs[5:10]], axes):
    
    for img, ax in zip(image, row):
        ax.imshow(img)
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)

fig.tight_layout(pad=0.1)

不服就GAN:keras --- GAN 网络生成手写数字实例,生成数字 6_第1张图片

plt.plot(discriminator_loss,label='discriminator_loss')
plt.plot(generator_loss,label='generator_loss')
plt.legend()

不服就GAN:keras --- GAN 网络生成手写数字实例,生成数字 6_第2张图片

fig, axes = plt.subplots(nrows=5, ncols=20, sharex=True, sharey=True, figsize=(80,12))
imgs = [i.reshape(28,28) for i in generated_img]
# imgs = generated_img

for image, row in zip([imgs[:20], imgs[20:40],imgs[40:60],imgs[60:80],imgs[80:100]], axes):
# for image, row in zip([imgs[0:10],imgs[5:10]], axes):
    
    for img, ax in zip(image, row):
        ax.imshow(img)
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)

fig.tight_layout(pad=0.1)

不服就GAN:keras --- GAN 网络生成手写数字实例,生成数字 6_第3张图片

你可能感兴趣的:(不服就GAN)