导入一些需要的包
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
import matplotlib.pyplot as plt
import sys
import numpy as np
定义优化器
optimizer = Adam(0.0002, 0.5)
构建鉴别器并编译
n_y_value = 20
D = Sequential()
D.add(Dense(512))
D.add(LeakyReLU(alpha=0.2))
D.add(Dense(256))
D.add(Dense(1, activation='sigmoid'))
# D.summary()
img = Input(shape=(n_y_value,))
validity = D(img)
Discriminator = Model(img,validity)
Discriminator.compile(loss='binary_crossentropy',
optimizer=optimizer,
metrics=['accuracy'])
构建生成器,并组合生成器和鉴别器成GAN
N_ideas = 5
G = Sequential()
G.add(Dense(512,input_dim=N_ideas))
G.add(LeakyReLU(alpha=0.2))
G.add(BatchNormalization(momentum=0.8))
G.add(Dense(512))
G.add(LeakyReLU(alpha=0.2))
G.add(BatchNormalization(momentum=0.8))
G.add(Dense(1024))
G.add(LeakyReLU(alpha=0.2))
G.add(BatchNormalization(momentum=0.8))
G.add(Dense(n_y_value, activation='tanh'))
# G.add(Reshape(n_y_value))
# G.summary()
noise = Input(shape=(N_ideas,))
G_img = G(noise)
Generator = Model(noise,G_img)
z = Input(shape=(N_ideas,))
G_img = Generator(z)
Discriminator.trainable = False
validity = Discriminator(G_img)
GAN = Model(z,validity)
GAN.compile(loss='binary_crossentropy', optimizer=optimizer)
训练过程
batch_size=64
x= np.vstack([np.linspace(-1,1,n_y_value) for _ in range(batch_size)])
true_imgs =np.power(x,2)
valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))
plt.ion()
for i in range(1600):
noise = np.random.normal(0, 1, (batch_size, N_ideas))
G_imgs = Generator.predict(noise)
d_loss_fake = Discriminator.train_on_batch(G_imgs,fake)
d_loss_real = Discriminator.train_on_batch(true_imgs,valid)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
g_loss = GAN.train_on_batch(noise,valid)
print("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (i, d_loss[0], 100 * d_loss[1], g_loss))
# print("G_imgs.shape:",G_imgs.shape) #(64,20)
plt.cla()
plt.xlim((-1.2, 1.2))
plt.ylim((-0.2, 1.2))
plt.plot(x[0], true_imgs[0], lw=2, c='#11AAAA')
plt.plot(x[0],G_imgs[0], lw=2, c='#B62A2A')
plt.pause(0.01)
plt.ioff()
plt.show()