文章目录
- 1、序言
- 2、网络结构
- 2.1、生 成 器
- 2.2、审 判 者
- 2.3、欺 诈 者
- 3、代码(直接复制可用)
- 4、伪 造 图 像 展 示
1、序言
- GAN升级版:辅助分类器对抗生成式网络
(Auxiliary Classifier Generative Adversarial Network)
- 本文用Keras实现极简的ACGAN,利用面向对象的思想将模型封装成3部分:
生 成 器
、审 判 者
和欺 诈 者
- https://gitee.com/arye/dl/tree/master/Keras/ACGAN
2、网络结构
2.1、生 成 器
data:image/s3,"s3://crabby-images/e3055/e3055bf2c68e80bfa1be44e61bcc00a5d93318e5" alt="Keras【极简】ACGAN_第1张图片"
2.2、审 判 者
data:image/s3,"s3://crabby-images/e0d20/e0d2017fa24900e407be5b8cf6f4c78c9a1d34d7" alt="Keras【极简】ACGAN_第2张图片"
2.3、欺 诈 者
data:image/s3,"s3://crabby-images/a6b81/a6b81d98732d5e3c4fa336f96a7c1708c5121245" alt="Keras【极简】ACGAN_第3张图片"
3、代码(直接复制可用)
import numpy as np, matplotlib.pyplot as mp, os
from keras.datasets import mnist
from keras.models import Sequential, Model
from keras.layers import (
Input, Dense, Reshape, Flatten, Embedding,
Conv2DTranspose, Conv2D, LeakyReLU,
BatchNormalization, Dropout,
multiply)
from keras.optimizers import Adam
from keras.utils import plot_model
"""配置"""
dir_imgs = 'images/'
path_imgs = dir_imgs + '%02d.png'
dir_model_imgs = 'model_images/'
path_cnnt_imgs = dir_model_imgs + 'cnnt.png'
path_generator_imgs = dir_model_imgs + 'generator.png'
path_cnn_imgs = dir_model_imgs + 'cnn.png'
path_judge_imgs = dir_model_imgs + 'judge.png'
path_tricker_imgs = dir_model_imgs + 'tricker.png'
for dir_i in [dir_imgs, dir_model_imgs]:
if not os.path.exists(dir_i):
os.mkdir(dir_i)
shape = (28, 28, 1)
num_classes = 10
noise_dim = 100
adam = Adam(lr=2e-4, beta_1=.5)
n_samples = 40960
batch_size = 512
times = int(n_samples / batch_size)
epochs = 33
mix = np.array([1.] * batch_size + [0.] * batch_size)
trick = np.ones(batch_size * 2)
weight = [np.ones(batch_size * 2),
np.concatenate((np.ones(batch_size) * 2, np.zeros(batch_size)))]
def load_data():
(x, y), _ = mnist.load_data()
x = x[:n_samples].reshape(-1, *shape) / 127.5 - 1
return x, y[:n_samples].reshape(-1, 1)
class GAN:
def __init__(self):
self.generator = None
self.judge = None
self.tricker = None
def modeling(self):
self.build_generator()
self.build_judge()
self.build_tricker()
def build_generator(self):
"""输入【噪音】和【类别】,输出【赝品】图像"""
cnnt = Sequential()
cnnt.add(Dense(3 * 3 * 384, input_dim=noise_dim, activation='relu'))
cnnt.add(Reshape((3, 3, 384)))
cnnt.add(Conv2DTranspose(192, 5, strides=1, activation='relu',
kernel_initializer='glorot_normal'))
cnnt.add(BatchNormalization())
cnnt.add(Conv2DTranspose(96, 5, strides=2, padding='same', activation='relu',
kernel_initializer='glorot_normal'))
cnnt.add(BatchNormalization())
cnnt.add(Conv2DTranspose(1, 5, strides=2, padding='same', activation='tanh',
kernel_initializer='glorot_normal'))
noise = Input(shape=(noise_dim,))
num = Input(shape=(1,), dtype='int32')
num_emb = Embedding(num_classes, noise_dim,
embeddings_initializer='glorot_normal')(num)
h = multiply([noise, num_emb])
x_fake = cnnt(h)
self.generator = Model([noise, num], x_fake)
plot_model(cnnt, path_cnnt_imgs,
show_shapes=True, show_layer_names=False)
plot_model(self.generator, path_generator_imgs,
show_shapes=True, show_layer_names=False)
def build_judge(self):
"""输入图像,输出【真伪】和【类别】"""
cnn = Sequential()
cnn.add(Conv2D(32, 3, strides=2, padding='same', activation=LeakyReLU(.2),
input_shape=shape))
cnn.add(Dropout(.3))
cnn.add(Conv2D(64, 3, strides=1, padding='same', activation=LeakyReLU(.2)))
cnn.add(Dropout(.3))
cnn.add(Conv2D(128, 3, strides=2, padding='same', activation=LeakyReLU(.2)))
cnn.add(Dropout(.3))
cnn.add(Conv2D(256, 3, strides=1, padding='same', activation=LeakyReLU(.2)))
cnn.add(Dropout(.3))
cnn.add(Flatten())
image = Input(shape)
flatten = cnn(image)
judgement = Dense(1, activation='sigmoid')(flatten)
num = Dense(num_classes, activation='softmax')(flatten)
self.judge = Model(image, [judgement, num])
self.judge.compile(
adam, ['binary_crossentropy', 'sparse_categorical_crossentropy'])
plot_model(cnn, path_cnn_imgs,
show_shapes=True, show_layer_names=False)
plot_model(self.judge, path_judge_imgs,
show_shapes=True, show_layer_names=False)
def build_tricker(self):
"""伪造【赝品】交给【审判者】,返回审判结果,据此提升【伪造技术】"""
noise = Input(shape=(noise_dim,))
num_noise = Input(shape=(1,), dtype='int32')
x_fake = self.generator([noise, num_noise])
self.judge.trainable = False
judgement, num_judgement = self.judge(x_fake)
self.tricker = Model([noise, num_noise], [judgement, num_judgement])
self.tricker.compile(
adam, ['binary_crossentropy', 'sparse_categorical_crossentropy'])
plot_model(self.tricker, path_tricker_imgs,
show_shapes=True, show_layer_names=False)
def train_judge(self, x, y):
noise = np.random.uniform(-1, 1, (batch_size, noise_dim))
num_noise = np.random.randint(0, num_classes, (batch_size, 1))
x_fake = self.generator.predict([noise, num_noise], verbose=0)
x = np.concatenate((x, x_fake))
num = np.concatenate((y, num_noise), axis=0)
return self.judge.train_on_batch(x, [mix, num], sample_weight=weight)
def train_tricker(self):
noise = np.random.uniform(-1, 1, (batch_size * 2, noise_dim))
num = np.random.randint(0, num_classes, (batch_size * 2, 1))
return self.tricker.train_on_batch([noise, num], [trick, num])
def train(self, x, y):
loss_d = self.train_judge(x, y)
loss_t = self.train_tricker()
return loss_d, loss_t
def save_fig(self, epoch):
nrows, ncols = 10, 10
noises = np.random.normal(size=(nrows * ncols, noise_dim))
nums = np.array(list(range(nrows)) * ncols)
imgs = self.generator.predict([noises, nums])
imgs = .5 * imgs + .5
for i in range(nrows):
for j in range(ncols):
mp.subplot(nrows, ncols, i * ncols + j + 1)
mp.imshow(imgs[i * ncols + j].reshape(28, 28), cmap='gray')
mp.axis('off')
mp.savefig(path_imgs % epoch)
mp.close()
if __name__ == '__main__':
gan = GAN()
gan.modeling()
x, y = load_data()
for e in range(epochs):
print('\033[033m{} %s\033[0m'.format(e) %
'[loss_d_total loss_d_2 loss_d_10] [loss_t_total loss_t_2 loss_t_10]')
for i in range(times):
loss_d, loss_t = gan.train(x[i * batch_size: (i + 1) * batch_size],
y[i * batch_size: (i + 1) * batch_size])
print(i, loss_d, loss_t)
gan.save_fig(e)
4、伪 造 图 像 展 示
训练30次后有较好结果
data:image/s3,"s3://crabby-images/9cc44/9cc441845b14ca6bad0d5cd399b252d563d317a3" alt="Keras【极简】ACGAN_第4张图片"