第一部分为z,代表噪声;
第二部分为c,代表隐含编码;
目标是希望在每个维度上都具备可解释型特征。
InfoGAN和前面介绍过的GAN区别在于,真实训练数据不有标签数据,二输入数据为隐含编码和随机噪声的组合,最后通过判别器一端和最大化互信息的方式还原隐含编码的信息。也就是说,判别器D最终需要同时具备还原隐含编码和辨别真伪的能力。前者为了生成图像能够很好具备编码中的特性,也就是说隐含编码可以对生网络产生相对显著地成果;后者是要求生成模型在还原信息的同时保证生成的数据与真实数据非常逼近。
互信息表示两个随机变量之间的依赖程度的度量。对于随机变量X和Y,互信息为I(X;Y),H(X)和H(Y)为边缘熵,H(X|Y)和H(Y|X)为条件熵。
当X和Y相互独立时候,互信息为0.给定任意的输入,希望生成器的 有一个相对较小的熵,即希望隐含编码c的信息在生成过程中不会流失。对此我们修改目标函数:
由于概率能以得到,导致互信息难以最大化,实际计算可以定义一个近似概率的辅助分布来获取互信息的下界,推导如下:
由此可以得到互信息的下界值:
我们可以重新改写之前不等式,并重新使蒙特卡洛方法逼近
得到我们最终的目标函数
我们发现通过控制隐含编码中的可以调节生成数字是几,其他参数可以调节生成字符的倾斜程度、字体宽度等。
from __future__ import print_function, division
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply, concatenate
from keras.layers import BatchNormalization, Activation, Embedding, ZeroPadding2D, Lambda
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
from keras.utils import to_categorical
import keras.backend as K
import matplotlib.pyplot as plt
import numpy as np
class INFOGAN():
def __init__(self):
self.img_rows = 28
self.img_cols = 28
self.channels = 1
self.num_classes = 10
self.img_shape = (self.img_rows, self.img_cols, self.channels)
self.latent_dim = 72
optimizer = Adam(0.0002, 0.5)
losses = ['binary_crossentropy', self.mutual_info_loss]
# Build and the discriminator and recognition network
self.discriminator, self.auxilliary = self.build_disk_and_q_net()
self.discriminator.compile(loss=['binary_crossentropy'],
optimizer=optimizer,
metrics=['accuracy'])
# Build and compile the recognition network Q
self.auxilliary.compile(loss=[self.mutual_info_loss],
optimizer=optimizer,
metrics=['accuracy'])
# Build the generator
self.generator = self.build_generator()
# The generator takes noise and the target label as input
# and generates the corresponding digit of that label
gen_input = Input(shape=(self.latent_dim,))
img = self.generator(gen_input)
# For the combined model we will only train the generator
self.discriminator.trainable = False
# The discriminator takes generated image as input and determines validity
valid = self.discriminator(img)
# The recognition network produces the label
target_label = self.auxilliary(img)
# The combined model (stacked generator and discriminator)
self.combined = Model(gen_input, [valid, target_label])
self.combined.compile(loss=losses,
optimizer=optimizer)
def build_generator(self):
model = Sequential()
model.add(Dense(128 * 7 * 7, activation="relu", input_dim=self.latent_dim))
model.add(Reshape((7, 7, 128)))
model.add(BatchNormalization(momentum=0.8))
model.add(UpSampling2D())
model.add(Conv2D(128, kernel_size=3, padding="same"))
model.add(Activation("relu"))
model.add(BatchNormalization(momentum=0.8))
model.add(UpSampling2D())
model.add(Conv2D(64, kernel_size=3, padding="same"))
model.add(Activation("relu"))
model.add(BatchNormalization(momentum=0.8))
model.add(Conv2D(self.channels, kernel_size=3, padding='same'))
model.add(Activation("tanh"))
gen_input = Input(shape=(self.latent_dim,))
img = model(gen_input)
model.summary()
return Model(gen_input, img)
def build_disk_and_q_net(self):
img = Input(shape=self.img_shape)
# Shared layers between discriminator and recognition network
model = Sequential()
model.add(Conv2D(64, kernel_size=3, strides=2, input_shape=self.img_shape, padding="same"))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))
model.add(ZeroPadding2D(padding=((0,1),(0,1))))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(BatchNormalization(momentum=0.8))
model.add(Conv2D(256, kernel_size=3, strides=2, padding="same"))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(BatchNormalization(momentum=0.8))
model.add(Conv2D(512, kernel_size=3, strides=2, padding="same"))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(BatchNormalization(momentum=0.8))
model.add(Flatten())
img_embedding = model(img)
# Discriminator
validity = Dense(1, activation='sigmoid')(img_embedding)
# Recognition
q_net = Dense(128, activation='relu')(img_embedding)
label = Dense(self.num_classes, activation='softmax')(q_net)
# Return discriminator and recognition network
return Model(img, validity), Model(img, label)
def mutual_info_loss(self, c, c_given_x):
"""The mutual information metric we aim to minimize"""
eps = 1e-8
conditional_entropy = K.mean(- K.sum(K.log(c_given_x + eps) * c, axis=1))
entropy = K.mean(- K.sum(K.log(c + eps) * c, axis=1))
return conditional_entropy + entropy
def sample_generator_input(self, batch_size):
# Generator inputs
sampled_noise = np.random.normal(0, 1, (batch_size, 62))
sampled_labels = np.random.randint(0, self.num_classes, batch_size).reshape(-1, 1)
sampled_labels = to_categorical(sampled_labels, num_classes=self.num_classes)
return sampled_noise, sampled_labels
def train(self, epochs, batch_size=128, sample_interval=50):
# Load the dataset
(X_train, y_train), (_, _) = mnist.load_data()
# Rescale -1 to 1
X_train = (X_train.astype(np.float32) - 127.5) / 127.5
X_train = np.expand_dims(X_train, axis=3)
y_train = y_train.reshape(-1, 1)
# Adversarial ground truths
valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))
for epoch in range(epochs):
# ---------------------
# Train Discriminator
# ---------------------
# Select a random half batch of images
idx = np.random.randint(0, X_train.shape[0], batch_size)
imgs = X_train[idx]
# Sample noise and categorical labels
sampled_noise, sampled_labels = self.sample_generator_input(batch_size)
gen_input = np.concatenate((sampled_noise, sampled_labels), axis=1)
# Generate a half batch of new images
gen_imgs = self.generator.predict(gen_input)
# Train on real and generated data
d_loss_real = self.discriminator.train_on_batch(imgs, valid)
d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
# Avg. loss
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
# ---------------------
# Train Generator and Q-network
# ---------------------
g_loss = self.combined.train_on_batch(gen_input, [valid, sampled_labels])
# Plot the progress
print ("%d [D loss: %.2f, acc.: %.2f%%] [Q loss: %.2f] [G loss: %.2f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss[1], g_loss[2]))
# If at save interval => save generated image samples
if epoch % sample_interval == 0:
self.sample_images(epoch)
def sample_images(self, epoch):
r, c = 10, 10
fig, axs = plt.subplots(r, c)
for i in range(c):
sampled_noise, _ = self.sample_generator_input(c)
label = to_categorical(np.full(fill_value=i, shape=(r,1)), num_classes=self.num_classes)
gen_input = np.concatenate((sampled_noise, label), axis=1)
gen_imgs = self.generator.predict(gen_input)
gen_imgs = 0.5 * gen_imgs + 0.5
for j in range(r):
axs[j,i].imshow(gen_imgs[j,:,:,0], cmap='gray')
axs[j,i].axis('off')
fig.savefig("images/%d.png" % epoch)
plt.close()
def save_model(self):
def save(model, model_name):
model_path = "saved_model/%s.json" % model_name
weights_path = "saved_model/%s_weights.hdf5" % model_name
options = {"file_arch": model_path,
"file_weight": weights_path}
json_string = model.to_json()
open(options['file_arch'], 'w').write(json_string)
model.save_weights(options['file_weight'])
save(self.generator, "generator")
save(self.discriminator, "discriminator")
if __name__ == '__main__':
infogan = INFOGAN()
infogan.train(epochs=50000, batch_size=128, sample_interval=50)
from __future__ import print_function, division
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply, concatenate
from keras.layers import BatchNormalization, Activation, Embedding, ZeroPadding2D, Lambda
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
from keras.utils import to_categorical
import keras.backend as K
import matplotlib.pyplot as plt
import numpy as np
class INFOGAN():
def __init__(self):
self.img_rows = 28
self.img_cols = 28
self.channels = 1
self.num_classes = 10
self.img_shape = (self.img_rows, self.img_cols, self.channels)
self.latent_dim = 72
optimizer = Adam(0.0002, 0.5)
losses = ['binary_crossentropy', self.mutual_info_loss]
# Build and the discriminator and recognition network
self.discriminator, self.auxilliary = self.build_disk_and_q_net()
self.discriminator.compile(loss=['binary_crossentropy'],
optimizer=optimizer,
metrics=['accuracy'])
# Build and compile the recognition network Q
self.auxilliary.compile(loss=[self.mutual_info_loss],
optimizer=optimizer,
metrics=['accuracy'])
# Build the generator
self.generator = self.build_generator()
# The generator takes noise and the target label as input
# and generates the corresponding digit of that label
gen_input = Input(shape=(self.latent_dim,))
img = self.generator(gen_input)
# For the combined model we will only train the generator
self.discriminator.trainable = False
# The discriminator takes generated image as input and determines validity
valid = self.discriminator(img)
# The recognition network produces the label
target_label = self.auxilliary(img)
# The combined model (stacked generator and discriminator)
self.combined = Model(gen_input, [valid, target_label])
self.combined.compile(loss=losses,
optimizer=optimizer)
def build_generator(self):
model = Sequential()
model.add(Dense(128 * 7 * 7, activation="relu", input_dim=self.latent_dim))
model.add(Reshape((7, 7, 128)))
model.add(BatchNormalization(momentum=0.8))
model.add(UpSampling2D())
model.add(Conv2D(128, kernel_size=3, padding="same"))
model.add(Activation("relu"))
model.add(BatchNormalization(momentum=0.8))
model.add(UpSampling2D())
model.add(Conv2D(64, kernel_size=3, padding="same"))
model.add(Activation("relu"))
model.add(BatchNormalization(momentum=0.8))
model.add(Conv2D(self.channels, kernel_size=3, padding='same'))
model.add(Activation("tanh"))
gen_input = Input(shape=(self.latent_dim,))
img = model(gen_input)
model.summary()
return Model(gen_input, img)
def build_disk_and_q_net(self):
img = Input(shape=self.img_shape)
# Shared layers between discriminator and recognition network
model = Sequential()
model.add(Conv2D(64, kernel_size=3, strides=2, input_shape=self.img_shape, padding="same"))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))
model.add(ZeroPadding2D(padding=((0,1),(0,1))))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(BatchNormalization(momentum=0.8))
model.add(Conv2D(256, kernel_size=3, strides=2, padding="same"))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(BatchNormalization(momentum=0.8))
model.add(Conv2D(512, kernel_size=3, strides=2, padding="same"))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(BatchNormalization(momentum=0.8))
model.add(Flatten())
img_embedding = model(img)
# Discriminator
validity = Dense(1, activation='sigmoid')(img_embedding)
# Recognition
q_net = Dense(128, activation='relu')(img_embedding)
label = Dense(self.num_classes, activation='softmax')(q_net)
# Return discriminator and recognition network
return Model(img, validity), Model(img, label)
def mutual_info_loss(self, c, c_given_x):
"""The mutual information metric we aim to minimize"""
eps = 1e-8
conditional_entropy = K.mean(- K.sum(K.log(c_given_x + eps) * c, axis=1))
entropy = K.mean(- K.sum(K.log(c + eps) * c, axis=1))
return conditional_entropy + entropy
def sample_generator_input(self, batch_size):
# Generator inputs
sampled_noise = np.random.normal(0, 1, (batch_size, 62))
sampled_labels = np.random.randint(0, self.num_classes, batch_size).reshape(-1, 1)
sampled_labels = to_categorical(sampled_labels, num_classes=self.num_classes)
return sampled_noise, sampled_labels
def train(self, epochs, batch_size=128, sample_interval=50):
# Load the dataset
(X_train, y_train), (_, _) = mnist.load_data()
# Rescale -1 to 1
X_train = (X_train.astype(np.float32) - 127.5) / 127.5
X_train = np.expand_dims(X_train, axis=3)
y_train = y_train.reshape(-1, 1)
# Adversarial ground truths
valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))
for epoch in range(epochs):
# ---------------------
# Train Discriminator
# ---------------------
# Select a random half batch of images
idx = np.random.randint(0, X_train.shape[0], batch_size)
imgs = X_train[idx]
# Sample noise and categorical labels
sampled_noise, sampled_labels = self.sample_generator_input(batch_size)
gen_input = np.concatenate((sampled_noise, sampled_labels), axis=1)
# Generate a half batch of new images
gen_imgs = self.generator.predict(gen_input)
# Train on real and generated data
d_loss_real = self.discriminator.train_on_batch(imgs, valid)
d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
# Avg. loss
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
# ---------------------
# Train Generator and Q-network
# ---------------------
g_loss = self.combined.train_on_batch(gen_input, [valid, sampled_labels])
# Plot the progress
print ("%d [D loss: %.2f, acc.: %.2f%%] [Q loss: %.2f] [G loss: %.2f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss[1], g_loss[2]))
# If at save interval => save generated image samples
if epoch % sample_interval == 0:
self.sample_images(epoch)
def sample_images(self, epoch):
r, c = 10, 10
fig, axs = plt.subplots(r, c)
for i in range(c):
sampled_noise, _ = self.sample_generator_input(c)
label = to_categorical(np.full(fill_value=i, shape=(r,1)), num_classes=self.num_classes)
gen_input = np.concatenate((sampled_noise, label), axis=1)
gen_imgs = self.generator.predict(gen_input)
gen_imgs = 0.5 * gen_imgs + 0.5
for j in range(r):
axs[j,i].imshow(gen_imgs[j,:,:,0], cmap='gray')
axs[j,i].axis('off')
fig.savefig("images/%d.png" % epoch)
plt.close()
def save_model(self):
def save(model, model_name):
model_path = "saved_model/%s.json" % model_name
weights_path = "saved_model/%s_weights.hdf5" % model_name
options = {"file_arch": model_path,
"file_weight": weights_path}
json_string = model.to_json()
open(options['file_arch'], 'w').write(json_string)
model.save_weights(options['file_weight'])
save(self.generator, "generator")
save(self.discriminator, "discriminator")
if __name__ == '__main__':
infogan = INFOGAN()
infogan.train(epochs=50000, batch_size=128, sample_interval=50)