GAN使用谱归一(spectral-normalization-gan)稳定训练——tensorflow应用

参考代码:https://github.com/christiancosgrove/pytorch-spectral-normalization-gan

参考代码:https://github.com/heykeetae/Self-Attention-GAN

参考代码:https://github.com/taki0112/Self-Attention-GAN-Tensorflow

 

 

谱归一就是限制W,使他趋于一个分布

谱归一代码部分,可以直接复制上去,调用见下个code:


weight_init = tf.random_normal_initializer(mean=0.0, stddev=0.02)
weight_regularizer = None


def spectral_norm(w, iteration=1):
    w_shape = w.shape.as_list()
    w = tf.reshape(w, [-1, w_shape[-1]])
    #print("w:",w.shape)#w: (48, 64)   #w: (1024, 128)   w: (2048, 256)
    u = tf.get_variable("u", [1, w_shape[-1]], initializer=tf.truncated_normal_initializer(), trainable=False)

    u_hat = u
    v_hat = None
    for i in range(iteration):
        """
        power iteration
        Usually iteration = 1 will be enough
        """
        #print("u_hat:",i,u_hat.shape)#u_hat: 0 (1, 64)   u_hat: 0 (1, 128)   u_hat: 0 (1, 256)
        v_ = tf.matmul(u_hat, tf.transpose(w))
        #print("v_",v_.shape)#v_ (1, 48)   #v_ (1, 1024)   v_ (1, 2048)
        v_hat = l2_norm(v_)
        #print("v_hat:",v_hat.shape)#v_hat: (1, 48)  v_hat: (1, 1024)   v_hat: (1, 2048)
        u_ = tf.matmul(v_hat, w)
        u_hat = l2_norm(u_)

    sigma = tf.matmul(tf.matmul(v_hat, w), tf.transpose(u_hat))
    #print("sigma",sigma.shape)#sigma (1, 1)
    w_norm = w / sigma

    with tf.control_dependencies([u.assign(u_hat)]):
        w_norm = tf.reshape(w_norm, w_shape)

    return w_norm

def l2_norm(v, eps=1e-12):
    return v / (tf.reduce_sum(v ** 2) ** 0.5 + eps)

 

调用谱归一:在你的w范围使用

w = tf.get_default_graph().get_tensor_by_name(self.core.name)
                #print("w1:",w)
w = spectral_norm(w)

这是我的范围,我的name_scope在前面

elif self.layer_type == 'transconv2d':
                self.core = tf.layers.conv2d_transpose(
                    self.conditioned, filters, kernel_size, strides, padding,
                    kernel_initializer=kernel_initializer, name='transconv2d')
    w = tf.get_default_graph().get_tensor_by_name(self.core.name)
    #print("w1:",w)
    w = spectral_norm(w)

 

你可能感兴趣的:(tensorflow,机器学习)