以下链接是个人关于stylegan所有见解,如有错误欢迎大家指出,我会第一时间纠正,如有兴趣可以加微信:a944284742相互讨论技术。若是帮助到了你什么,一定要记得点赞奥!因为这是对我最大的鼓励。
风格迁移0-00:stylegan-目录-史上最全:https://blog.csdn.net/weixin_43013761/article/details/100895333
其实,之前做项目的时候,我都不是很在意网络的结构,因为给我的感觉,搞来搞去,无非就是卷积,池化 ,反卷积,跳跃链接等等组合而已。但是这次的网络结构吗,可能是相对来说比较复杂,也考虑到源码的难阅读性,打算好好的深入理解一番,那么我们就开始吧。
根据:风格迁移0-05:stylegan-源码无死角解读(1)-框架总览我们可以知道,其网络实现是从:
def training_loop(
G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args)
D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args)
开始搭建网络的,对于生成网络,其最终会调用到training/networks_stylegan.py文件中的def G_style函数,对于该函数的注释如下:
#----------------------------------------------------------------------------
# Style-based generator used in the StyleGAN paper.
# Composed of two sub-networks (G_mapping and G_synthesis) that are defined below.
def G_style(
latents_in, # First input: Latent vectors (Z) [minibatch, latent_size].
labels_in, # Second input: Conditioning labels [minibatch, label_size].
truncation_psi = 0.7, # Style strength multiplier for the truncation trick. None = disable.
truncation_cutoff = 8, # Number of layers for which to apply the truncation trick. None = disable.
truncation_psi_val = None, # Value for truncation_psi to use during validation.
truncation_cutoff_val = None, # Value for truncation_cutoff to use during validation.
dlatent_avg_beta = 0.995, # Decay for tracking the moving average of W during training. None = disable.
style_mixing_prob = 0.9, # Probability of mixing styles during training. None = disable.
is_training = False, # Network is under training? Enables and disables specific features.
is_validation = False, # Network is under validation? Chooses which value to use for truncation_psi.
is_template_graph = False, # True = template graph constructed by the Network class, False = actual evaluation.
components = dnnlib.EasyDict(), # Container for sub-networks. Retained between calls.
**kwargs): # Arguments for sub-networks (G_mapping and G_synthesis).
# Validate arguments.
assert not is_training or not is_validation
assert isinstance(components, dnnlib.EasyDict)
if is_validation:
truncation_psi = truncation_psi_val
truncation_cutoff = truncation_cutoff_val
if is_training or (truncation_psi is not None and not tflib.is_tf_expression(truncation_psi) and truncation_psi == 1):
truncation_psi = None
if is_training or (truncation_cutoff is not None and not tflib.is_tf_expression(truncation_cutoff) and truncation_cutoff <= 0):
truncation_cutoff = None
if not is_training or (dlatent_avg_beta is not None and not tflib.is_tf_expression(dlatent_avg_beta) and dlatent_avg_beta == 1):
dlatent_avg_beta = None
if not is_training or (style_mixing_prob is not None and not tflib.is_tf_expression(style_mixing_prob) and style_mixing_prob <= 0):
style_mixing_prob = None
# Setup components.
# 会调用G_synthesis函数
if 'synthesis' not in components:
components.synthesis = tflib.Network('G_synthesis', func_name=G_synthesis, **kwargs)
num_layers = components.synthesis.input_shape[1]
dlatent_size = components.synthesis.input_shape[2]
if 'mapping' not in components:
components.mapping = tflib.Network('G_mapping', func_name=G_mapping, dlatent_broadcast=num_layers, **kwargs)
# Setup variables.
lod_in = tf.get_variable('lod', initializer=np.float32(0), trainable=False)
# (512,)
dlatent_avg = tf.get_variable('dlatent_avg', shape=[dlatent_size], initializer=tf.initializers.zeros(), trainable=False)
# Evaluate mapping network.
# 获得通过广播得到的(?, 18, 512)
dlatents = components.mapping.get_output_for(latents_in, labels_in, **kwargs)
# Update moving average of W.
# 因为W的形成与训练数据的分布是一样的,为了生成器生成的图片不至于都生成类似于
# 数据分布密度较高的图片,通过dlatent_avg在训练的时候不断的更新,得到数据分布相应W的平均向量(平均脸)
if dlatent_avg_beta is not None:
with tf.variable_scope('DlatentAvg'):
batch_avg = tf.reduce_mean(dlatents[:, 0], axis=0)
update_op = tf.assign(dlatent_avg, tflib.lerp(batch_avg, dlatent_avg, dlatent_avg_beta))
with tf.control_dependencies([update_op]):
dlatents = tf.identity(dlatents)
# Perform style mixing regularization.
if style_mixing_prob is not None:
with tf.name_scope('StyleMix'):
# 根据输入的latents_in,即Z,随机生成一个与其形状一样的latents2
latents2 = tf.random_normal(tf.shape(latents_in))
# 把latents2与labels_in即论文中的z,转换为论文中的W向量
dlatents2 = components.mapping.get_output_for(latents2, labels_in, **kwargs)
# 获得层数的idx
layer_idx = np.arange(num_layers)[np.newaxis, :, np.newaxis]
# 计算打当前的层数
cur_layers = num_layers - tf.cast(lod_in, tf.int32) * 2
# 使用随机混合截断,随机一个数字,这个数值小于style_mixing_prob则进行截断,截断也是随机选定一个层
mixing_cutoff = tf.cond(
tf.random_uniform([], 0.0, 1.0) < style_mixing_prob,
lambda: tf.random_uniform([], 1, cur_layers, dtype=tf.int32),
lambda: cur_layers)
# 进行随机截断混合,这里的dlatents不是之前的的dlatents了,是通过mixing_cutoff,随机由dlatents,dlatents2混合而成
dlatents = tf.where(tf.broadcast_to(layer_idx < mixing_cutoff, tf.shape(dlatents)), dlatents, dlatents2)
# Apply truncation trick.,其truncation_cutoff默认为8
if truncation_psi is not None and truncation_cutoff is not None:
with tf.variable_scope('Truncation'):
# layer_idx[[1],[2],[3],[4],[5],[6].....[18]]
layer_idx = np.arange(num_layers)[np.newaxis, :, np.newaxis]
ones = np.ones(layer_idx.shape, dtype=np.float32)
# [1,18,1]
coefs = tf.where(layer_idx < truncation_cutoff, truncation_psi * ones, ones)
# 在平均脸的基础上,进行插值,
dlatents = tflib.lerp(dlatent_avg, dlatents, coefs)
# Evaluate synthesis network.
with tf.control_dependencies([tf.assign(components.synthesis.find_var('lod'), lod_in)]):
images_out = components.synthesis.get_output_for(dlatents, force_clean_graph=is_template_graph, **kwargs)
return tf.identity(images_out, name='images_out')
从函数可以看到,其流程也比较简单,首先通过
if 'synthesis' not in components:
components.synthesis = tflib.Network('G_synthesis', func_name=G_synthesis, **kwargs)
调用G_synthesis函数,创建论文中的Synthesis network g网络,如图中的红框部分:
创建完成之后,紧接着:
if 'mapping' not in components:
components.mapping = tflib.Network('G_mapping', func_name=G_mapping, dlatent_broadcast=num_layers, **kwargs)
调用G_mapping网络,其搭建论文中如下网络结构:
即会把输入的Z向量转化为W向量。
以上两个网络的总体细节下小节进行详细解说,根据上面的代码我们可以知道,其把上述的两个网络搭建完成之后,有下面代码:
dlatent_avg = tf.get_variable('dlatent_avg', shape=[dlatent_size], initializer=tf.initializers.zeros(), trainable=False)
其是获得平均脸,主要解决因为数据分布不同,导致密度低的人脸生成几率较小的问题,平均脸是不停在更新的,每次网络训练都会进行更新,通过如下代码:
# Update moving average of W.
# 因为W的形成与训练数据的分布是一样的,为了生成器生成的图片不至于都生成类似于
# 数据分布密度较高的图片,通过dlatent_avg在训练的时候不断的更新,得到数据分布相应W的平均向量(平均脸)
if dlatent_avg_beta is not None:
with tf.variable_scope('DlatentAvg'):
batch_avg = tf.reduce_mean(dlatents[:, 0], axis=0)
update_op = tf.assign(dlatent_avg, tflib.lerp(batch_avg, dlatent_avg, dlatent_avg_beta))
with tf.control_dependencies([update_op]):
dlatents = tf.identity(dlatents)
其目的是为了后面truncation trick(截断技巧),需要平均脸参与计算。
接下来其会根据style_mixing_prob参数,进行风格混合,其原理也很简单,我们输入的是latents_in,其会再随机创建一个latents_in,为latents2,然后通过mapping网络生成两个W,即dlatents2与dlatents,然后dlatents与dlatents2每个都选择一定的层数,进行随机组合,得到一个新的dlatents。
再接着就是truncation trick,翻译过来为截断技巧,其主要是利用前面的平均脸脸,然后与通过style_mixing_prob样式组合得到dlatents,进行插值算法,达到截断的目的。
这样generate网络总体的框架就讲解完成了,下小节我将详细的讲解G_synthesis与G_mapping,即上面截图红框中的两个网络