from __future__ import print_function, division
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
import pandas as pd
import matplotlib.pyplot as plt
import sys
import numpy as np
import csv
class GAN():
def __init__(self):
self.data_rows = 1
self.data_cols = 200
self.channels = 1
self.data_shape = (self.data_rows, self.data_cols)
self.latent_dim = 100
self.sample_size = 200
optimizer = Adam(0.0002, 0.5)
self.discriminator = self.build_discriminator()
self.discriminator.compile(loss='binary_crossentropy',
optimizer=optimizer,
metrics=['accuracy'])
self.generator = self.build_generator()
z = Input(shape=(self.latent_dim,))
data = self.generator(z)
self.discriminator.trainable = False
validity = self.discriminator(data)
self.combined = Model(z, validity)
self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)
def build_generator(self):
model = Sequential()
model.add(Dense(256, input_dim=self.latent_dim))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(1024))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(np.prod(self.data_shape), activation='tanh'))
model.add(Reshape(self.data_shape))
model.summary()
noise = Input(shape=(self.latent_dim,))
data = model(noise)
return Model(noise, data)
def build_discriminator(self):
model = Sequential()
model.add(Flatten(inputshape=self.data_shape))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(256))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(1, activation='sigmoid'))
model.summary()
data = Input(shape=self.data_shape)
validity = model(data)
return Model(data, validity)
def train(self, epochs, batch_size=128, sample_interval=500):
data = pd.read_csv("gan_data.csv", header=None)
data = np.array(data.values.tolist()).reshape(3520, 1, 200)
data = data / 194 - 1
valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))
for epoch in range(epochs):
idx = np.random.randint(0, data.shape[0], batch_size)
x = data[idx]
noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
gen_x = self.generator.predict(noise)
d_loss_real = self.discriminator.train_on_batch(x, valid)
d_loss_fake = self.discriminator.train_on_batch(gen_x, fake)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
g_loss = self.combined.train_on_batch(noise, valid)
print("%d [D loss: %f, acc: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100 * d_loss[1], g_loss))
if epoch % sample_interval == 0:
self.sample_data(epoch)
if not os.path.exists("gen_model"):
os.makedirs("gen_model")
self.generator.save_weights("gen_model/G_model%d.hdf5" % epoch, True)
self.discriminator.save_weights("gen_model/D_model%d.hdf5" % epoch, True)
def data_write_csv(self, epoch, gen_datas, num):
if not os.path.exists("gen_data"):
os.makedirs("gen_data")
if epoch == 666:
file_name = "gen_test/test.csv"
else:
file_name = "gen_data/%d.csv" % epoch
print(file_name)
gen_datas = gen_datas.reshape(num, self.data_cols)
dt = pd.DataFrame(gen_datas)
dt.to_csv(file_name, index=0)
def sample_data(self, epoch):
noise = np.random.normal(0, 1, (self.sample_size, self.latent_dim))
gen_datas = self.generator.predict(noise)
gen_datas = (gen_datas + 1) * 192
self.data_write_csv(epoch, gen_datas, self.sample_size)
def test(self, gen_nums=200):
self.generator.load_weights("gen_model/G_model9000.hdf5", by_name=True)
self.discriminator.load_weights("gen_model/D_model9000.hdf5", by_name=True)
noise = np.random.normal(0, 1, (gen_nums, self.latent_dim))
gen_datas = self.generator.predict(noise)
gen_datas = (gen_datas + 1) * 192
print(gen_datas)
if not os.path.exists("gen_test"):
os.makedirs("gen_test")
self.data_write_csv(666, gen_datas, gen_nums)
if __name__ == '__main__':
gan = GAN()
gan.train(epochs=10000, batch_size=256, sample_interval=500)
gan.test()