gan怎么输入一维数据_GAN网络,利用gan网络完成对一维数据点的生成

代码:

import argparse

import numpy as np

from scipy.stats import norm

import tensorflow as tf

import matplotlib.pyplot as plt

from matplotlib import animation

import seaborn as sns

sns.set(color_codes=True)

seed = 42

np.random.seed(seed)

tf.set_random_seed(seed)

class DataDistribution(object): # 真实数据

def __init__(self):

self.mu = 4

self.sigma = 0.5

def sample(self, N):

samples = np.random.normal(self.mu, self.sigma, N)

samples.sort()

return samples

class GeneratorDistibution(object): # 随机噪音点,初始化输入

def __init__(self, range):

self.range = range

def sample(self, N):

return np.linspace(-self.range, self.range, N) + np.random.normal(N) * 0.01

def linear(input, output_dim, scope=None, stddev=1.0): # 单网络层

norm = tf.random_normal_initializer(stddev=stddev)

const = tf.constant_initializer(0.0)

with tf.variable_scope(scope or 'linear'): # 初始化 w, b参数

w = tf.get_variable('w', [input.get_shape()[1], output_dim], initializer=norm)

b = tf.get_variable('b', [output_dim], initializer=const)

return tf.matmul(input, w) + b

def generator(input, h_dim):

h0 = tf.nn.softplus(linear(input, h_dim, 'g0'))

h1 = linear(h0, 1, 'g1')

return h1

def discriminator(input, h_dim): # 预训练判别D网络

h0 = tf.tanh(linear(input, h_dim * 2, 'd0'))

h1 = tf.tanh(linear

你可能感兴趣的:(gan怎么输入一维数据)