生成对抗网络GAN的keras实例

生成对抗网络GAN的keras实例

导入一些需要的包

from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam

import matplotlib.pyplot as plt

import sys

import numpy as np

定义优化器

optimizer = Adam(0.0002, 0.5)

构建鉴别器并编译

n_y_value = 20
D = Sequential()
D.add(Dense(512))
D.add(LeakyReLU(alpha=0.2))
D.add(Dense(256))
D.add(Dense(1, activation='sigmoid'))
# D.summary()
img = Input(shape=(n_y_value,))
validity = D(img)
Discriminator = Model(img,validity)
Discriminator.compile(loss='binary_crossentropy',
            optimizer=optimizer,
            metrics=['accuracy'])

构建生成器,并组合生成器和鉴别器成GAN

N_ideas = 5
G = Sequential()
G.add(Dense(512,input_dim=N_ideas))
G.add(LeakyReLU(alpha=0.2))
G.add(BatchNormalization(momentum=0.8))
G.add(Dense(512))
G.add(LeakyReLU(alpha=0.2))
G.add(BatchNormalization(momentum=0.8))
G.add(Dense(1024))
G.add(LeakyReLU(alpha=0.2))
G.add(BatchNormalization(momentum=0.8))
G.add(Dense(n_y_value, activation='tanh'))
# G.add(Reshape(n_y_value))
# G.summary()
noise = Input(shape=(N_ideas,))
G_img = G(noise)
Generator = Model(noise,G_img)

z = Input(shape=(N_ideas,))
G_img = Generator(z)
Discriminator.trainable = False
validity = Discriminator(G_img)
GAN = Model(z,validity)
GAN.compile(loss='binary_crossentropy', optimizer=optimizer)

训练过程

batch_size=64
x= np.vstack([np.linspace(-1,1,n_y_value) for _ in range(batch_size)])
true_imgs =np.power(x,2)



valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))
plt.ion()
for i in range(1600):
    noise = np.random.normal(0, 1, (batch_size, N_ideas))
    G_imgs = Generator.predict(noise)

    d_loss_fake = Discriminator.train_on_batch(G_imgs,fake)
    d_loss_real = Discriminator.train_on_batch(true_imgs,valid)
    d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

    g_loss = GAN.train_on_batch(noise,valid)
    print("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (i, d_loss[0], 100 * d_loss[1], g_loss))
    # print("G_imgs.shape:",G_imgs.shape)  #(64,20)
    plt.cla()

    plt.xlim((-1.2, 1.2))
    plt.ylim((-0.2, 1.2))
    plt.plot(x[0], true_imgs[0], lw=2, c='#11AAAA')
    plt.plot(x[0],G_imgs[0], lw=2, c='#B62A2A')
    plt.pause(0.01)

plt.ioff()
plt.show()

最终网络生成的虚假图片与真实图片如下
生成对抗网络GAN的keras实例_第1张图片

你可能感兴趣的:(tensorflow,深度学习)