StackedGAN详解与实现(采用tensorflow2.x实现)

StackedGAN详解与实现(采用tensorflow2.3实现)

    • StackedGAN原理
    • StackedGAN实现
      • 编码器
      • 对抗网络
        • 鉴别器
        • 生成器
        • 模型构建
      • 模型训练
      • 效果展示

StackedGAN原理

StackedGAN提出了一种用于分解潜在表示以调节生成器输出的方法。与InfoGAN学习如何调节噪声以产生所需的输出,StackedGAN将GAN分解为GAN堆栈。每个GAN均以通常的区分生成器生成图片的方式进行独立训练,并带有自己的潜在编码。
StackedGAN详解与实现(采用tensorflow2.x实现)_第1张图片编码器网络由一堆简单的编码器组成,即 E n c o d e r i Encoder_i Encoderi,其中 i = 0 , . . . , n − 1 i = 0,...,n-1 i=0,...,n1对应于 n n n个特征。每个编码器都提取某些面部特征。例如, E n c o d e r 0 Encoder_0 Encoder0可以是发型特征 F e a t u r e 1 Feature_1 Feature1的编码器。所有简单的编码器都有助于使整个编码器执行正确的预测。
StackedGAN背后的想法是,如果想构建一个可以生成假名人面孔的GAN,应该简单地反转编码器。 StackedGAN由一堆更简单的GAN组成, G A N i GAN_i GANi,其中 i = 0 , . . . , n − 1 i = 0,...,n-1 i=0,...,n1对应于 n n n个特征。每个 G A N i GAN_i GANi都会学习反转其相应编码器 E n c o d e r i Encoder_i Encoderi的过程。例如, G A N 0 GAN_0 GAN0从伪造的发型特征生成伪造的名人面孔,这与 E n c o d e r 0 Encoder_0 Encoder0的过程相反。
每个 G A N i GAN_i GANi使用一个潜编码 z i z_i zi,以调节其生成器输出。例如,潜编码 z 0 z_0 z0可以修改发型。GAN的堆栈也可以用作合成假名人面孔的对象,从而完成整个编码器的逆过程。每个 G A N i GAN_i GANi的潜编码 z i z_i zi可以用来更改假名人面孔的特定属性。

StackedGAN实现

StackedGAN的详细网络模型。以2个encoder-GAN堆栈为例。

StackedGAN详解与实现(采用tensorflow2.x实现)_第2张图片StackedGAN包括编码器和GAN的堆栈。 对编码器进行预训练以执行分类。 G e n e r a t o r 1 Generator_1 Generator1学习合成基于伪标签 y f y_{f} yf和潜编码 z 1 f z_{1f} z1f的特征 f 1 f f_{1f} f1f G e n e r a t o r 0 Generator_0 Generator0使用伪特征 f 1 f f_{1f} f1f和潜码 z 0 f z_{0f} z0f产生伪图像。
StackedGAN从编码器开始。它可能是训练后的分类器,可以预测正确的标签。中间特征向量 f 1 r f_{1r} f1r可用于GAN训练。对于MNIST,可以使用基于CNN的分类器。
StackedGAN详解与实现(采用tensorflow2.x实现)_第3张图片使用Dense层提取256-dim特征。 有两种输出模型, E n c o d e r 0 Encoder_0 Encoder0 E n c o d e r 1 Encoder_1 Encoder1。 两者都将用于训练StackedGAN。

编码器

def build_encoder(inputs,num_labels=10,feature1_dim=256):
    """the Encoder Model sub networks
    Two sub networks:
    Encoder0: Image to feature1
    Encoder1: feature1 to labels

    #arguments
        inputs (layers): x - images, feature1 - feature1 layer output
        num_labels (int): number of class labels
        feature1_dim (int): feature1 dimenstionality
    #returns
        enc0,enc1 (models):Description below
    """
    kernel_size = 3
    filters = 64

    x,feature1 = inputs
    # Encoder0 or enc0
    y = keras.layers.Conv2D(filters=filters,
            kernel_size=kernel_size,
            padding='same',
            activation='relu')(x)
    y = keras.layers.MaxPool2D()(y)
    y = keras.layers.Conv2D(filters=filters,
            kernel_size=kernel_size,
            padding='same',
            activation='relu')(y)
    y = keras.layers.MaxPooling2D()(y)
    y = keras.layers.Flatten()(y)
    feature1_output = keras.layers.Dense(feature1_dim,activation='relu')(y)
    #Encoder0 or enc0: image (x or feature0) to feature1
    enc0 = keras.Model(inputs=x,outputs=feature1_output,name='encoder0')

    #Encoder1 or enc1
    y = keras.layers.Dense(num_labels)(feature1)
    labels = keras.layers.Activation('softmax')(y)
    #Encoder1 or enc1: feature1 to class labels (feature2)
    enc1 = keras.Model(inputs=feature1,outputs=labels,name='encoder1')
    #return both enc0,enc1
    return enc0,enc1

E n c o d e r 0 Encoder_0 Encoder0的输出 f 1 r f_{1r} f1r是希望 G e n e r a t o r 1 Generator_1 Generator1学习进行合成的256维特征向量。可用作 E n c o d e r 0 Encoder_0 Encoder0的辅助输出。训练整个编码器以对MNIST数字 x r x_r xr进行分类。 正确的标签 y r y_r yr E n c o d e r 1 Encoder_1 Encoder1预测。 在此过程中,将学习中间特征集 f 1 r f_1r f1r并将其用于 G e n e r a t o r 0 Generator_0 Generator0训练。 当GAN针对此编码器进行训练时,下标 r r r用于强调和区分真实数据与伪数据。
假设编码器输入 x r x_r xr,输出为中间特征 f 1 r f_{1r} f1r和标签 y r y_r yr,则每个GAN都会以通常的鉴别网络-对抗网络方式进行训练。

对抗网络

损失函数:
鉴别器
L i ( D ) = − E f i ∼ p d a t a l o g D ( f i ) − E f i + 1 ∼ p d a t a , z i l o g [ 1 − D ( G ( f i + 1 , z i ) ) ] \mathcal L_i^{(D)} = -\mathbb E_{f_i\sim p_{data}}logD(f_i)-\mathbb E_{f_{i+1}\sim p_{data},z_i}log[1 − D(G(f_{i+1},z_i))] Li(D)=EfipdatalogD(fi)Efi+1pdata,zilog[1D(G(fi+1,zi))]
生成器
L i ( G ) a d v = − E f i ∼ p d a t a , z i l o g D ( G ( f i + 1 , z i ) ) \mathcal L_i^{(G)adv} = -\mathbb E_{f_i\sim p_{data},z_i}logD(G(f_{i+1},z_i)) Li(G)adv=Efipdata,zilogD(G(fi+1,zi))
L i ( D ) c o n d = ∥ E i ( G ( f i + 1 , z i ) ) , f i ∥ 2 \mathcal L_i^{(D)cond} = \| \mathbb E_i(G(f_{i+1},z_i)),f_i \|_2 Li(D)cond=Ei(G(fi+1,zi)),fi2
L i ( D ) e n t = ∥ Q i ( G ( f i + 1 , z i ) ) , z i ∥ 2 \mathcal L_i^{(D)ent} = \| Q_i(G(f_{i+1},z_i)),z_i \|_2 Li(D)ent=Qi(G(fi+1,zi)),zi2
L i ( G ) = λ 1 L i ( G ) a d v + λ 2 L i ( D ) c o n d + λ 3 L i ( D ) e n t \mathcal L_i^{(G)} = \lambda_1 \mathcal L_i^{(G)adv}+\lambda_2 \mathcal L_i^{(D)cond} +\lambda_3 \mathcal L_i^{(D)ent} Li(G)=λ1Li(G)adv+λ2Li(D)cond+λ3Li(D)ent
条件损失函数 L i ( D ) c o n d \mathcal L_i^{(D)cond} Li(D)cond确保了在从输入噪声编码 z i z_i zi合成输出 f i f_i fi时,生成器不会忽略输入 f i + 1 f_{i+1} fi+1。 编码器 E n c o d e r i Encoder_i Encoderi必须能够通过反转 G e n e r a t o r i Generator_i Generatori的过程来恢复生成器输入。生成器输入和使用编码器恢复的输入之间的差通过欧几里德距离(均方误差(MSE))测量。
StackedGAN详解与实现(采用tensorflow2.x实现)_第4张图片但是,条件损失函数引入了新问题。生成器忽略输入噪声编码 z i z_i zi,仅依赖于 f i + 1 f_{i+1} fi+1。 熵损失函数 L i ( D ) e n t \mathcal L_i^{(D)ent} Li(D)ent确保生成器不会忽略噪声编码 z i z_i zi。 Q网络从生成器的输出中恢复噪声矢量。恢复的噪声与输入噪声之间的差异也可以通过欧几里德距离(MSE)进行测量。
StackedGAN详解与实现(采用tensorflow2.x实现)_第5张图片

鉴别器

构建 D i s c r i m i n a t o r 0 Discriminator_0 Discriminator0 D i s c r i m i n a t o r 1 Discriminator_1 Discriminator1的函数。 除特征向量输入 Z 0 Z_0 Z0和辅助网络 Q 0 Q_0 Q0之外,dis0鉴别器与GAN鉴别器类似。创建dis0:

def discriminator(inputs,activation='sigmoid',num_codes=None):
    """discriminator model
    Arguments:
        inputs (Layer): input layer of the discriminator
        activation (string): name of output activation layer
        num_labels (int): dimension of one-hot labels for ACGAN & InfoGAN
        num_codes (int): num_codes-dim Q network as output
                if StackedGAN or 2 Q netwoek if InfoGAN
    Returns:
        Model: Discriminator model
    """
    kernel_size = 5
    layer_filters = [32,64,128,256]
    x = inputs
    for filters in layer_filters:
        if filters == layer_filters[-1]:
            strides = 1
        else:
            strides = 2
        x = keras.layers.LeakyReLU(0.2)(x)
        x = keras.layers.Conv2D(filters=filters,
                kernel_size=kernel_size,
                strides=strides,
                padding='same')(x)
    x = keras.layers.Flatten()(x)
    outputs = keras.layers.Dense(1)(x)
    if activation is not None:
        print(activation)
        outputs = keras.layers.Activation(activation)(outputs)
    # StackedGAN Q0 output
    # z0_recon is reconstruction of z0 normal distribution
    z0_recon = keras.layers.Dense(num_codes)(x)
    z0_recon = keras.layers.Activation('tanh',name='z0')(z0_recon)
    outputs = [outputs,z0_recon]
    return keras.Model(inputs,outputs,name='discriminator')

dis1鉴别器由三层MLP组成。 最后一层区分真实和伪。网络共享dis1的前两层。其第三层重建 z 1 z_1 z1

def build_disciminator(inputs,z_dim=50):
    """Discriminator 1 model
    将feature1分类为真实/伪图像,并恢复输入噪声或潜编码
    #argumnets
        inputs (layer): feature1
        z_dim (int): noise dimensionality
    #Returns
        dis1 (Model): feature1 as real/fake and recovered latent code
    """
    #input is 256-dim feature1
    x = keras.layers.Dense(256,activation='relu')(inputs)
    x = keras.layers.Dense(256,activation='relu')(x)
    # first output is probality that feature1 is real
    f1_source = keras.layers.Dense(1)(x)
    f1_source = keras.layers.Activation('sigmoid',name='feature1_source')(f1_source)
    #z1 reonstruction (Q1 network)
    z1_recon = keras.layers.Dense(z_dim)(x)
    z1_recon = keras.layers.Activation('tanh',name='z1')(z1_recon)

    discriminator_outputs = [f1_source,z1_recon]
    dis1 = keras.Model(inputs,discriminator_outputs,name='dis1')
    return dis1

生成器

gen1生成器由带有标签和噪声编码 z 1 f z_{1f} z1f作为输入的三个密集层组成。 第三层生成伪造的特征 f 1 f f_{1f} f1f

def build_generator(latent_codes,image_size,feature1_dim=256):
    """build generator model sub networks
    Two sub networks:
        class and noise to feature1
        feature1 to image

    #Argument
        latent_codes (layers): dicrete code (labels), noise and feature1 features
        image_size (int): target size of one side
        feature1_dim (int): feature1 dimensionality
    #Return
        gen0,gen1 (models)
    """
    #latent codes and network parameters
    labels,z0,z1,feature1 = latent_codes
    #image_resize = image_size // 4
    #kernel_size = 5
    #layer_filters = [128,64,32,1]
    
    #gen1 inputs
    inputs = [labels,z1] #10+50=60-dim
    x = keras.layers.concatenate(inputs,axis=1)
    x = keras.layers.Dense(512,activation='relu')(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.Dense(512,activation='relu')(x)
    x = keras.layers.BatchNormalization()(x)
    fake_feature1 = keras.layers.Dense(feature1_dim,activation='relu')(x)
    #gen1: classes and noise (feature2 + z1) to feature1
    gen1 = keras.Model(inputs,fake_feature1,name='gen1')
    #gen0: feature1 + z0 to feature0 (image)
    gen0 = generator(feature1,image_size,codes=z0)
    return gen0,gen1

gen0生成器类似于其他GAN生成器.

def generator(inputs,image_size,activation='sigmoid',codes=None):
    """generator model
    Arguments:
        inputs (layer): input layer of generator
        image_size (int): Target size of one side
        activation (string): name of output activation layer
        labels (tensor): input labels
        codes (list): 2-dim disentangled codes for infoGAN
    returns:
        model: generator model
    """
    image_resize = image_size // 4
    kernel_size = 5
    layer_filters = [128,64,32,1]
    ## generator 0 of StackedGAN
    inputs = [inputs,codes]
    x = keras.layers.concatenate(inputs,axis=1)
    x = keras.layers.Dense(image_resize*image_resize*layer_filters[0])(x)
    x = keras.layers.Reshape((image_resize,image_resize,layer_filters[0]))(x)
    for filters in layer_filters:
        if filters > layer_filters[-2]:
            strides = 2
        else:
            strides = 1
        x = keras.layers.BatchNormalization()(x)
        x = keras.layers.Activation('relu')(x)
        x = keras.layers.Conv2DTranspose(filters=filters,
                kernel_size=kernel_size,
                strides=strides,
                padding='same')(x)
    if activation is not None:
        x = keras.layers.Activation(activation)(x)
    return keras.Model(inputs,x,name='generator')

模型构建

def build_and_train_models():
    #build StackedGAN
    #数据加载
    (x_train,y_train),(x_test,y_test) = keras.datasets.mnist.load_data()
    image_size = x_train.shape[1]
    x_train = np.reshape(x_train,[-1,image_size,image_size,1])
    x_train = x_train.astype('float32') / 255.

    x_test = np.reshape(x_test,[-1,image_size,image_size,1])
    x_test = x_test.astype('float32') / 255.

    num_labels = len(np.unique(y_train))
    y_train = keras.utils.to_categorical(y_train)
    y_test = keras.utils.to_categorical(y_test)

    #超参数
    model_name = 'stackedGAN_mnist'
    batch_size = 64
    train_steps = 40000
    lr = 2e-4
    decay = 6e-8
    input_shape = (image_size,image_size,1)
    label_shape = (num_labels,)
    z_dim = 50
    z_shape = (z_dim,)
    feature1_dim = 256
    feature1_shape = (feature1_dim,)

    #discriminator 0 and Q network 0 models
    inputs = keras.layers.Input(shape=input_shape,name='discriminator0_input')
    dis0 = discriminator(inputs,num_codes=z_dim)
    optimizer = keras.optimizers.RMSprop(lr=lr,decay=decay)
    # 损失函数:1)图像是真实的概率
    # 2)MSE z0重建损失
    loss = ['binary_crossentropy','mse']
    loss_weights = [1.0,10.0]
    dis0.compile(loss=loss,loss_weights=loss_weights,
            optimizer=optimizer,
            metrics=['accuracy'])
    dis0.summary()

    #discriminator 1 and Q network 1 models
    input_shape = (feature1_dim,)
    inputs = keras.layers.Input(shape=input_shape,name='discriminator1_input')
    dis1 = build_disciminator(inputs,z_dim=z_dim)
    # 损失函数: 1) feature1是真实的概率 (adversarial1 loss)
    # 2) MSE z1 重建损失 (Q1 network loss or entropy1 loss)
    loss = ['binary_crossentropy','mse']
    loss_weights = [1.0,1.0]
    dis1.compile(loss=loss,loss_weights=loss_weights,
            optimizer=optimizer,
            metrics=['acc'])
    dis1.summary()

    #generator models
    feature1 = keras.layers.Input(shape=feature1_shape,name='featue1_input')
    labels = keras.layers.Input(shape=label_shape,name='labels')
    z1 = keras.layers.Input(shape=z_shape,name='z1_input')
    z0 = keras.layers.Input(shape=z_shape,name='z0_input')
    latent_codes = (labels,z0,z1,feature1)
    gen0,gen1 = build_generator(latent_codes,image_size)
    gen0.summary()
    gen1.summary()

    #encoder models
    input_shape = (image_size,image_size,1)
    inputs = keras.layers.Input(shape=input_shape,name='encoder_input')
    enc0,enc1 = build_encoder((inputs,feature1),num_labels)
    enc0.summary()
    enc1.summary()
    encoder = keras.Model(inputs,enc1(enc0(inputs)))
    encoder.summary()

    data = (x_train,y_train),(x_test,y_test)
    #训练对抗网路前,需要已经训练完成的编码器网络
    train_encoder(encoder,data,model_name=model_name)

    #adversarial0 model = generator0 + discrimnator0 + encoder0
    optimizer = keras.optimizers.RMSprop(lr=lr*0.5,decay=decay*0.5)
    enc0.trainable = False
    dis0.trainable = False
    gen0_inputs = [feature1,z0]
    gen0_outputs = gen0(gen0_inputs)
    adv0_outputs = dis0(gen0_outputs) + [enc0(gen0_outputs)]
    adv0 = keras.Model(gen0_inputs,adv0_outputs,name='adv0')
    # 损失函数:1)feature1是真实的概率
    # 2)Q network 0 损失
    # 3)condition0 损失
    loss = ['binary_crossentropy','mse','mse']
    loss_weights = [1.0,10.0,1.0]
    adv0.compile(loss=loss,
            loss_weights=loss_weights,
            optimizer=optimizer,
            metrics=['acc'])
    adv0.summary()

    #adversarial1 model = generator1 + discrimnator1 + encoder1
    enc1.trainable = False
    dis1.trainable = False
    gen1_inputs = [labels,z1]
    gen1_outputs = gen1(gen1_inputs)
    adv1_outputs = dis1(gen1_outputs) + [enc1(gen1_outputs)]
    adv1 = keras.Model(gen1_inputs,adv1_outputs,name='adv1')
    #损失函数:1)标签是真实的概率
    #2)Q network 1 损失
    #3)conditional1 损失
    loss_weights = [1.0,1.0,1.0]
    loss = ['binary_crossentropy','mse','categorical_crossentropy']
    adv1.compile(loss=loss,
            loss_weights=loss_weights,
            optimizer=optimizer,
            metrics=['acc'])
    adv1.summary()

    models = (enc0,enc1,gen0,gen1,dis0,dis1,adv0,adv1)
    params = (batch_size,train_steps,num_labels,z_dim,model_name)
    train(models,data,params)

模型训练

#训练对抗网路前,需要已经训练完成的编码器网络
def train_encoder(model,data,model_name='stackedgan_mnist',batch_size=64):
    """Train Encoder model
    # Arguments
        model (model): Encoder
        data (tensor): train and test data
        model_name (string): model name
        batch_size (int): train batch size
    """
    (x_train,y_train),(x_test,y_test) = data
    model.compile(loss='categorical_crossentropy',optimizer='adam',metrics=['acc'])
    model.fit(x_train,y_train,validation_data=(x_test,y_test),
            epochs=20,batch_size=batch_size)
    model.save(model_name + '-encoder.h5')
    score = model.evaluate(x_test,y_test,batch_size=batch_size,verbose=0)
    print("\nTest accuracy: %.1f%%" % (100.0 * score[1]))

训练顺序为:
1. D i s c r i m i n a t o r 1 Discriminator_1 Discriminator1 Q 1 Q_1 Q1
2. D i s c r i m i n a t o r 0 Discriminator_0 Discriminator0 Q 0 Q_0 Q0
3. A d v e r s a r i a l 1 Adversarial_1 Adversarial1
4. A d v e r s a r i a l 0 Adversarial_0 Adversarial0

def train(models,data,params):
    """train networks
    Arguments
        models (models): encoder,generator,discriminator,adversarial
        data (tuple): x_train,y_train
        params (tuple): parameters
    """
    enc0,enc1,gen0,gen1,dis0,dis1,adv0,adv1 = models
    batch_size,train_steps,num_labels,z_dim,model_name = params
    (x_train,y_train),_ = data
    save_interval = 500

    z0 = np.random.normal(scale=0.5,size=[16,z_dim])
    z1 = np.random.normal(scale=0.5,size=[16,z_dim])
    noise_class = np.eye(num_labels)[np.arange(0,16) % num_labels]
    noise_params = [noise_class,z0,z1]
    train_size = x_train.shape[0]
    print(model_name,'labels for generated images: ',np.argmax(noise_class,axis=1))

    for i in range(train_steps):
        rand_indexes = np.random.randint(0,train_size,size=batch_size)
        real_images = x_train[rand_indexes]
        # real feature1 from encoder0 output
        real_feature1 = enc0.predict(real_images)
        # generate random 50-dim z1 latent code
        real_z1 = np.random.normal(scale=0.5,size=[batch_size,z_dim])
        #real labels
        real_labels = y_train[rand_indexes]
        #generate fake feature1 using generator1 from real labels and 50-dim z1 latent code
        fake_z1 = np.random.normal(scale=0.5,size=[batch_size,z_dim])
        fake_feature1 = gen1.predict([real_labels,fake_z1])
        #real + fake data
        feature1 = np.concatenate((real_feature1,fake_feature1))
        z1 = np.concatenate((real_z1,fake_z1))
        #label 1st half as real and 2nd half as fake
        y = np.ones([2*batch_size,1])
        y[batch_size:,:] = 0

        #train discriminator1 to classify feature1 as real/fake and recover
        metrics = dis1.train_on_batch(feature1,[y,z1])
        log = "%d: [dis1_loss: %f]" % (i, metrics[0])

        #train the discriminator0 for 1 batch
        #1 batch of reanl and fake images
        real_z0 = np.random.normal(scale=0.5,size=[batch_size,z_dim])
        fake_z0 = np.random.normal(scale=0.5,size=[batch_size,z_dim])
        fake_images = gen0.predict([real_feature1,fake_z0])

        #real + fake data
        x = np.concatenate((real_images,fake_images))
        z0 = np.concatenate((real_z0,fake_z0))
        #train discriminator0 to classify image as real/fake and recover latent code (z0)
        metrics = dis0.train_on_batch(x,[y,z0])
        log = "%s [dis0_loss: %f]" % (log, metrics[0])
        
        # 对抗训练
        # 生成fake z1,labels
        fake_z1 = np.random.normal(scale=0.5,size=[batch_size,z_dim])
        #input to generator1 is sampling fr real labels and 50-dim z1 latent code
        gen1_inputs = [real_labels,fake_z1]
        y = np.ones([batch_size,1])
        #train generator1
        metrics = adv1.train_on_batch(gen1_inputs,[y,fake_z1,real_labels])
        fmt = "%s [adv1_loss: %f, enc1_acc: %f]"
        log = fmt % (log, metrics[0], metrics[6])
        # input to generator0 is real feature1 and 50-dim z0 latent code
        fake_z0 = np.random.normal(scale=0.5,size=[batch_size,z_dim])
        gen0_inputs = [real_feature1,fake_z0]
        #train generator0
        metrics = adv0.train_on_batch(gen0_inputs,[y,fake_z0,real_feature1])
        log = "%s [adv0_loss: %f]" % (log, metrics[0])
        print(log)
        if (i + 1) % save_interval == 0:
            genenators = (gen0,gen1)
            plot_images(genenators,noise_params=noise_params,
                    show=False,
                    step=(i+1),
                    model_name=model_name)
    gen1.save(model_name + '-gen1.h5')
    gen0.save(model_name + '-gen0.h5')

效果展示

#绘制生成图片
def plot_images(generators,noise_params,show=False,step=0,model_name='gan'):
    """generator fake images and plot
    Arguments
        generators (model): gen0 and gen1 models for fake images generation
        noise_params (list): noise parameters (label,z0 and z1 codes)
        show (bool): whether to show plot or not
        step (int): Appended tor filename of the save images
        model_name (string): model name
    """
    gen0,gen1 = generators
    noise_class,z0,z1 = noise_params
    os.makedirs(model_name,exist_ok=True)
    filename = os.path.join(model_name,'%05d.png' % step)
    feature1 = gen1.predict([noise_class,z1])
    images = gen0.predict([feature1,z0])
    print(model_name,'labels for generated images: ',np.argmax(noise_class,axis=1))
    plt.figure(figsize=(2.2,2.2))
    num_images = images.shape[0]
    image_size = images.shape[1]
    rows = int(math.sqrt(noise_class.shape[0]))
    for i in range(num_images):
        plt.subplot(rows,rows,i + 1)
        image = np.reshape(images[i],[image_size,image_size])
        plt.imshow(image,cmap='gray')
        plt.axis('off')
    plt.savefig(filename)
    if show:
        plt.show()
    else:
        plt.close('all')
if __name__ == '__main__:
    build_and_train_models()
step=10000

StackedGAN详解与实现(采用tensorflow2.x实现)_第6张图片

修改书写角度的分离编码

StackedGAN详解与实现(采用tensorflow2.x实现)_第7张图片

你可能感兴趣的:(深度学习,#,tensorflow,#,GAN,tensorflow,深度学习,python,神经网络,生成器)