风格迁移0-07:stylegan-源码无死角解读(3)-generate网络框架总览

以下链接是个人关于stylegan所有见解,如有错误欢迎大家指出,我会第一时间纠正,如有兴趣可以加微信:a944284742相互讨论技术。若是帮助到了你什么,一定要记得点赞奥!因为这是对我最大的鼓励。
风格迁移0-00:stylegan-目录-史上最全:https://blog.csdn.net/weixin_43013761/article/details/100895333

generate net流程

其实,之前做项目的时候,我都不是很在意网络的结构,因为给我的感觉,搞来搞去,无非就是卷积,池化 ,反卷积,跳跃链接等等组合而已。但是这次的网络结构吗,可能是相对来说比较复杂,也考虑到源码的难阅读性,打算好好的深入理解一番,那么我们就开始吧。

根据:风格迁移0-05:stylegan-源码无死角解读(1)-框架总览我们可以知道,其网络实现是从:

def training_loop(
	 G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args)
	 D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args)

开始搭建网络的,对于生成网络,其最终会调用到training/networks_stylegan.py文件中的def G_style函数,对于该函数的注释如下:

#----------------------------------------------------------------------------
# Style-based generator used in the StyleGAN paper.
# Composed of two sub-networks (G_mapping and G_synthesis) that are defined below.

def G_style(
    latents_in,                                     # First input: Latent vectors (Z) [minibatch, latent_size].
    labels_in,                                      # Second input: Conditioning labels [minibatch, label_size].
    truncation_psi          = 0.7,                  # Style strength multiplier for the truncation trick. None = disable.
    truncation_cutoff       = 8,                    # Number of layers for which to apply the truncation trick. None = disable.
    truncation_psi_val      = None,                 # Value for truncation_psi to use during validation.
    truncation_cutoff_val   = None,                 # Value for truncation_cutoff to use during validation.
    dlatent_avg_beta        = 0.995,                # Decay for tracking the moving average of W during training. None = disable.
    style_mixing_prob       = 0.9,                  # Probability of mixing styles during training. None = disable.
    is_training             = False,                # Network is under training? Enables and disables specific features.
    is_validation           = False,                # Network is under validation? Chooses which value to use for truncation_psi.
    is_template_graph       = False,                # True = template graph constructed by the Network class, False = actual evaluation.
    components              = dnnlib.EasyDict(),    # Container for sub-networks. Retained between calls.
    **kwargs):                                      # Arguments for sub-networks (G_mapping and G_synthesis).

    # Validate arguments.
    assert not is_training or not is_validation
    assert isinstance(components, dnnlib.EasyDict)
    if is_validation:
        truncation_psi = truncation_psi_val
        truncation_cutoff = truncation_cutoff_val
    if is_training or (truncation_psi is not None and not tflib.is_tf_expression(truncation_psi) and truncation_psi == 1):
        truncation_psi = None
    if is_training or (truncation_cutoff is not None and not tflib.is_tf_expression(truncation_cutoff) and truncation_cutoff <= 0):
        truncation_cutoff = None
    if not is_training or (dlatent_avg_beta is not None and not tflib.is_tf_expression(dlatent_avg_beta) and dlatent_avg_beta == 1):
        dlatent_avg_beta = None
    if not is_training or (style_mixing_prob is not None and not tflib.is_tf_expression(style_mixing_prob) and style_mixing_prob <= 0):
        style_mixing_prob = None

    # Setup components.
    # 会调用G_synthesis函数
    if 'synthesis' not in components:
        components.synthesis = tflib.Network('G_synthesis', func_name=G_synthesis, **kwargs)
    num_layers = components.synthesis.input_shape[1]
    dlatent_size = components.synthesis.input_shape[2]
    if 'mapping' not in components:
        components.mapping = tflib.Network('G_mapping', func_name=G_mapping, dlatent_broadcast=num_layers, **kwargs)

    # Setup variables.
    lod_in = tf.get_variable('lod', initializer=np.float32(0), trainable=False)
    # (512,)
    dlatent_avg = tf.get_variable('dlatent_avg', shape=[dlatent_size], initializer=tf.initializers.zeros(), trainable=False)

    # Evaluate mapping network.
    # 获得通过广播得到的(?, 18, 512)
    dlatents = components.mapping.get_output_for(latents_in, labels_in, **kwargs)

    # Update moving average of W.
    # 因为W的形成与训练数据的分布是一样的,为了生成器生成的图片不至于都生成类似于
    # 数据分布密度较高的图片,通过dlatent_avg在训练的时候不断的更新,得到数据分布相应W的平均向量(平均脸)
    if dlatent_avg_beta is not None:
        with tf.variable_scope('DlatentAvg'):
            batch_avg = tf.reduce_mean(dlatents[:, 0], axis=0)
            update_op = tf.assign(dlatent_avg, tflib.lerp(batch_avg, dlatent_avg, dlatent_avg_beta))
            with tf.control_dependencies([update_op]):
                dlatents = tf.identity(dlatents)

    # Perform style mixing regularization.
    if style_mixing_prob is not None:
        with tf.name_scope('StyleMix'):
            # 根据输入的latents_in,即Z,随机生成一个与其形状一样的latents2
            latents2 = tf.random_normal(tf.shape(latents_in))

            # 把latents2与labels_in即论文中的z,转换为论文中的W向量
            dlatents2 = components.mapping.get_output_for(latents2, labels_in, **kwargs)

            # 获得层数的idx
            layer_idx = np.arange(num_layers)[np.newaxis, :, np.newaxis]

            # 计算打当前的层数
            cur_layers = num_layers - tf.cast(lod_in, tf.int32) * 2

            # 使用随机混合截断,随机一个数字,这个数值小于style_mixing_prob则进行截断,截断也是随机选定一个层
            mixing_cutoff = tf.cond(
                tf.random_uniform([], 0.0, 1.0) < style_mixing_prob,
                lambda: tf.random_uniform([], 1, cur_layers, dtype=tf.int32),
                lambda: cur_layers)

            # 进行随机截断混合,这里的dlatents不是之前的的dlatents了,是通过mixing_cutoff,随机由dlatents,dlatents2混合而成
            dlatents = tf.where(tf.broadcast_to(layer_idx < mixing_cutoff, tf.shape(dlatents)), dlatents, dlatents2)

    # Apply truncation trick.,其truncation_cutoff默认为8
    if truncation_psi is not None and truncation_cutoff is not None:
        with tf.variable_scope('Truncation'):
            # layer_idx[[1],[2],[3],[4],[5],[6].....[18]]
            layer_idx = np.arange(num_layers)[np.newaxis, :, np.newaxis]
            ones = np.ones(layer_idx.shape, dtype=np.float32)
            # [1,18,1]
            coefs = tf.where(layer_idx < truncation_cutoff, truncation_psi * ones, ones)
            # 在平均脸的基础上,进行插值,
            dlatents = tflib.lerp(dlatent_avg, dlatents, coefs)

    # Evaluate synthesis network.
    with tf.control_dependencies([tf.assign(components.synthesis.find_var('lod'), lod_in)]):
        images_out = components.synthesis.get_output_for(dlatents, force_clean_graph=is_template_graph, **kwargs)
    return tf.identity(images_out, name='images_out')

从函数可以看到,其流程也比较简单,首先通过

    if 'synthesis' not in components:
        components.synthesis = tflib.Network('G_synthesis', func_name=G_synthesis, **kwargs)

调用G_synthesis函数,创建论文中的Synthesis network g网络,如图中的红框部分:
风格迁移0-07:stylegan-源码无死角解读(3)-generate网络框架总览_第1张图片
创建完成之后,紧接着:

if 'mapping' not in components:
        components.mapping = tflib.Network('G_mapping', func_name=G_mapping, dlatent_broadcast=num_layers, **kwargs)

调用G_mapping网络,其搭建论文中如下网络结构:
风格迁移0-07:stylegan-源码无死角解读(3)-generate网络框架总览_第2张图片
即会把输入的Z向量转化为W向量。
以上两个网络的总体细节下小节进行详细解说,根据上面的代码我们可以知道,其把上述的两个网络搭建完成之后,有下面代码:

dlatent_avg = tf.get_variable('dlatent_avg', shape=[dlatent_size], initializer=tf.initializers.zeros(), trainable=False)

其是获得平均脸,主要解决因为数据分布不同,导致密度低的人脸生成几率较小的问题,平均脸是不停在更新的,每次网络训练都会进行更新,通过如下代码:

    # Update moving average of W.
    # 因为W的形成与训练数据的分布是一样的,为了生成器生成的图片不至于都生成类似于
    # 数据分布密度较高的图片,通过dlatent_avg在训练的时候不断的更新,得到数据分布相应W的平均向量(平均脸)
    if dlatent_avg_beta is not None:
        with tf.variable_scope('DlatentAvg'):
            batch_avg = tf.reduce_mean(dlatents[:, 0], axis=0)
            update_op = tf.assign(dlatent_avg, tflib.lerp(batch_avg, dlatent_avg, dlatent_avg_beta))
            with tf.control_dependencies([update_op]):
                dlatents = tf.identity(dlatents)

其目的是为了后面truncation trick(截断技巧),需要平均脸参与计算。
接下来其会根据style_mixing_prob参数,进行风格混合,其原理也很简单,我们输入的是latents_in,其会再随机创建一个latents_in,为latents2,然后通过mapping网络生成两个W,即dlatents2与dlatents,然后dlatents与dlatents2每个都选择一定的层数,进行随机组合,得到一个新的dlatents。

再接着就是truncation trick,翻译过来为截断技巧,其主要是利用前面的平均脸脸,然后与通过style_mixing_prob样式组合得到dlatents,进行插值算法,达到截断的目的。

这样generate网络总体的框架就讲解完成了,下小节我将详细的讲解G_synthesis与G_mapping,即上面截图红框中的两个网络

你可能感兴趣的:(风格迁移)