风格迁移0-05:stylegan-源码无死角解读(1)-框架总览

以下链接是个人关于stylegan所有见解,如有错误欢迎大家指出,我会第一时间纠正,如有兴趣可以加微信:a944284742相互讨论技术。若是帮助到了你什么,一定要记得点赞奥!因为这是对我最大的鼓励。
风格迁移0-05:stylegan-目录-史上最全:https://blog.csdn.net/weixin_43013761/article/details/100895333

代码总览

根据源码中的README.md 我们可以知道,训练的开始,是要再源码的根目录下执行:train.py ,配置好config文件即可运行。我们再改文件中看到一个:

    dnnlib.submit_run(**kwargs)

函数,其参数简单明了,一个kwargs,作者是简单明了了,但是害苦了我们啊,其主要参数如下:
风格迁移0-05:stylegan-源码无死角解读(1)-框架总览_第1张图片
其中submit_run的实现如下:

def submit_run(submit_config: SubmitConfig, run_func_name: str, **run_func_kwargs) -> None:
    """Create a run dir, gather files related to the run, copy files to the run dir, and launch the run in appropriate place."""
    submit_config = copy.copy(submit_config)

    if submit_config.user_name is None:
        submit_config.user_name = get_user_name()

    submit_config.run_func_name = run_func_name
    submit_config.run_func_kwargs = run_func_kwargs

    assert submit_config.submit_target == SubmitTarget.LOCAL
    if submit_config.submit_target in {SubmitTarget.LOCAL}:
        run_dir = _create_run_dir_local(submit_config)

        submit_config.task_name = "{0}-{1:05d}-{2}".format(submit_config.user_name, submit_config.run_id, submit_config.run_desc)
        submit_config.run_dir = run_dir
        _populate_run_dir(run_dir, submit_config)

    if submit_config.print_info:
        print("\nSubmit config:\n")
        pprint.pprint(submit_config, indent=4, width=200, compact=False)
        print()

    if submit_config.ask_confirmation:
        if not util.ask_yes_no("Continue submitting the job?"):
            return

    run_wrapper(submit_config)

简单的来说,再run_wrapper函数之前的内容我们都不需要去理会,其主要的功能就是在results目录(根据配置文件)下生成一个子项目,该子项目保存了你当前的训练配置信息,如果在训练过程中发生了中断,我们可以运行子项目,继续训练。后续如果有时间,会为大家分析

其核心要点函数在于run_wrapper函数,我们看看其参数配置把:

submit_config  
  #运行的根目录, 
 'run_dir_root'  = {str} 'results'
 
 # 其生在'run_dir_root'目录下生成子项目的名称
 'run_desc'  = {str} 'sgan-result-1gpu'
 
 # 应该是忽略拷贝的文件
 'run_dir_ignore'  = {list} <class 'list'>: ['__pycache__', '*.pyproj', '*.sln', '*.suo', '.cache', '.idea', '.vs', '.vscode', 'results', 'datasets', 'cache']
 
 'run_dir_extra_files'  = {NoneType} None
 
 'submit_target'  = {SubmitTarget} SubmitTarget.LOCAL
 
 # 使用GPU的数目
 'num_gpus'  = {int} 1

 # 是否打印信息
 'print_info'  = {bool} False
 
 'ask_confirmation'  = {bool} False
 
 # 其生成子项目名称的前缀,如00002-sgan-result-1gpu前面的00002
 'run_id'  = {int} 11
 
 # 子项目的名称
 'run_name'  = {str} '00011-sgan-result-1gpu'
 
 # 子项目运行,的目录
 'run_dir' = {str} 'results\\00011-sgan-result-1gpu'

 # 训练运行的函数,在training.training_loop.training_loop
 'run_func_name'  = {str} 'training.training_loop.training_loop'
 
 # 运行函数的参数,后续着重分析,就是training.training_loop.training_loop函数的参数
 'run_func_kwargs'  = {dict} <class 'dict'>: {'mirror_augment': True, 'total_kimg': 25000, 'G_args': {'func_name': 'training.networks_stylegan.G_style'}, 'D_args': {'func_name': 'training.networks_stylegan.D_basic'}, 'G_opt_args': {'beta1': 0.0, 'beta2': 0.99, 'epsilon': 1e-08}, 'D_opt_args': {'beta1': 0.0, 'beta2': 0.99, 'epsilon': 1e-08}, 'G_loss_args': {'func_name': 'training.loss.G_logistic_nonsaturating'}, 'D_loss_args': {'func_name': 'training.loss.D_logistic_simplegp', 'r1_gamma': 10.0}, 'dataset_args': {'tfrecord_dir': 'result'}, 'sched_args': {'minibatch_base': 1, 'minibatch_dict': {4: 32, 8: 32, 16: 32, 32: 16, 64: 8, 128: 4, 256: 2, 512: 1}, 'lod_initial_resolution': 8, 'G_lrate_dict': {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003}, 'D_lrate_dict': {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003}}, 'grid_args': {'size': '4k', 'layout': 'random'}, 'metric_arg_list': [{'func_name': 'metrics.frechet_inception_distance.FID', 'name': 'fid50k', 'num_images': 50000, 'minibatch_per_gpu': 8}], 'tf_config':...
 
 # 电脑用户名
 'user_name'  = {str} 'zwh'
 
 # 该次程序运行吗,名称(胡乱理解就行,无关紧要)
 'task_name' ) = {str} 'zwh-00011-sgan-result-1gpu'
 
 # 电脑主机名
 'host_name'  = {str} 'localhost'

看了上面的参数之后,我相信大家可以注意到,其中大部分配置应该都是与子项目相关的,但是要注意的一个参数是’run_func_kwargs’ ,该参数是再后面的调用中传递给training.training_loop.training_loop函数的,我们先进入run_wrapper函数(一些无关要紧的代码我就不注释了):

def run_wrapper(submit_config: SubmitConfig) -> None:
	......
	#这一段都是都是log信息的收集
	......
    import dnnlib
    dnnlib.submit_config = submit_config

    try:
        print("dnnlib: Running {0}() on {1}...".format(submit_config.run_func_name, submit_config.host_name))
        start_time = time.time()
        print('=1' * 50)
        print(submit_config.run_func_name)
        # 通过字符串training.training_loop.training_loop,调用该函数
        util.call_func_by_name(func_name=submit_config.run_func_name, submit_config=submit_config, **submit_config.run_func_kwargs)
        print("dnnlib: Finished {0}() in {1}.".format(submit_config.run_func_name, util.format_time(time.time() - start_time)))
    except:
 		......
 		# 一些打印的处理
 		......

可以明确的知道,其主要核心点在于函数util.call_func_by_name,其通过字符串加载模块,并调用函数,training_loop,现在我们来看看在前面没有注释的:

 'run_func_kwargs'  = {dict} <class 'dict'>: {
	'mirror_augment': True,  
	
	'total_kimg': 25000, 
	
	'G_args': {'func_name': 'training.networks_stylegan.G_style'}, 
	
	'D_args': {'func_name': 'training.networks_stylegan.D_basic'}, 
	
	'G_opt_args': {'beta1': 0.0, 'beta2': 0.99, 'epsilon': 1e-08}, 
	
	'D_opt_args': {'beta1': 0.0, 'beta2': 0.99, 'epsilon': 1e-08}, 
	
	'G_loss_args': {'func_name': 'training.loss.G_logistic_nonsaturating'}, 
	
	'D_loss_args': {'func_name': 'training.loss.D_logistic_simplegp', 'r1_gamma': 10.0}, 
	
	'dataset_args': {'tfrecord_dir': 'result'}, 
	
	'sched_args': {'minibatch_base': 1, 'minibatch_dict': {4: 32, 8: 32, 16: 32, 32: 16, 64: 8, 128: 4, 256: 2, 512: 1}, 
	
	'lod_initial_resolution': 8, 
	
	'G_lrate_dict': {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003}, 
	
	'D_lrate_dict': {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003}}, 
	
	'grid_args': {'size': '4k', 'layout': 'random'},
	
	'metric_arg_list': [{
		'func_name': 'metrics.frechet_inception_distance.FID', 
		'name': 'fid50k', 
		'num_images': 50000, 
		'minibatch_per_gpu': 8}], 

	'tf_config':...
	total_kimg :2500

这样为大家展开,应该还是很清晰明了的,注意,上面的配置,是我本人的配置,并不代表的配置会和我的一样。我们还是来看看training_loop函数,在这之前我们看看其函数参数的意,当作结合上面一起注释了

# discriminators 网络框架
D_args =  {'func_name': 'training.networks_stylegan.D_basic'}

# discriminators 网络的损失函数
D_loss_args = {'func_name': 'training.loss.D_logistic_simplegp', 'r1_gamma': 10.0}

# 应该是 discriminators网络求损失是相关的参数,后续详细了解
D_opt_args = {'beta1': 0.0, 'beta2': 0.99, 'epsilon': 1e-08}

# 重复次数,可能是进行多次鉴别
D_repeats =  1

# 生成网络框架
G_args =  {'func_name': 'training.networks_stylegan.G_style'}

# 生成网络损失
G_loss_args =  {'func_name': 'training.loss.G_logistic_nonsaturating'}

# 生成网络相关的超参数
G_opt_args =  {'beta1': 0.0, 'beta2': 0.99, 'epsilon': 1e-08}

# 平滑,不是很理解,后续在了解
G_smoothing_kimg = 10.0

# 生成训练数据的目录
dataset_args =  {'tfrecord_dir': 'result'}

# 这个也是不是了解
drange_net = [-1, 1]

# 网格,输出图片,4K可能代表4*1024
grid_args =  {'size': '4k', 'layout': 'random'}

# 图片快照,也不是很理解
image_snapshot_ticks =  1

# 对模型进行指标衡量的参数
metric_arg_list =   [{
    'func_name': 'metrics.frechet_inception_distance.FID', 
	'name': 'fid50k', 
	'num_images': 50000, 
	'minibatch_per_gpu': 8}]
# minibatch重复次数
minibatch_repeats =  4

# 镜像翻转
mirror_augment =  True

# 大概是训练10ticks打印一次图片
network_snapshot_ticks =  10

# 不是很理解
reset_opt_for_new_lod =  True

# 这个参数比较重要,大家因为某种原因断了,训练子项目的时候,可以设置为上次断了的时候,训练的图片张数,如4000
resume_kimg = 0.0

# 加载预模型的id,00002-sgan-result-1gpu的前缀,如这里的00002
resume_run_id = } None

resume_snapshot =  None
# 继续训练的时间点
resume_time =  0.0

# 保存模型
save_tf_graph =  False

save_weight_histograms =  False

# 最小的minibatch的基数,以及各个分辨率在生成和鉴别网络的学习率
sched_args =  {'minibatch_base': 1, 'minibatch_dict': {4: 32, 8: 32, 16: 32, 32: 16, 64: 8, 128: 4, 256: 2, 512: 1}, 'lod_initial_resolution': 8, 'G_lrate_dict': {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003}, 'D_lrate_dict': {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003}}

# 子项目的配置
submit_config =  {'run_dir_root': 'results', 'run_desc': 'sgan-result-1gpu', 'run_dir_ignore': ['__pycache__', '*.pyproj', '*.sln', '*.suo', '.cache', '.idea', '.vs', '.vscode', 'results', 'datasets', 'cache'], 'run_dir_extra_files': None, 'submit_target': <SubmitTarget.LOCAL: 1>, 'num_gpus': 1, 'print_info': False, 'ask_confirmation': False, 'run_id': 6, 'run_name': '00006-sgan-result-1gpu', 'run_dir': 'results\\00006-sgan-result-1gpu', 'run_func_name': 'training.training_loop.training_loop', 'run_func_kwargs': {'mirror_augment': True, 'total_kimg': 25000, 'G_args': {'func_name': 'training.networks_stylegan.G_style'}, 'D_args': {'func_name': 'training.networks_stylegan.D_basic'}, 'G_opt_args': {'beta1': 0.0, 'beta2': 0.99, 'epsilon': 1e-08}, 'D_opt_args': {'beta1': 0.0, 'beta2': 0.99, 'epsilon': 1e-08}, 'G_loss_args': {'func_name': 'training.loss.G_logistic_nonsaturating'}, 'D_loss_args': {'func_name': 'training.loss.D_logistic_simplegp', 'r1_gamma': 10.0}, 'dataset_args': {'tfrecord_dir...

# 随机种子
tf_config =  {'rnd.np_random_seed': 1000}

# 训练数据总的图片张数
total_kimg =  25000

在源码中,也有对函数参数的详细注解,大家可以结合一起参考,有错误的地方欢迎提出,本人好进行修改。下面是函数总体框架的注解:

def training_loop(
    #print('sched_args: ',sched_args)
    # Initialize dnnlib and TensorFlow.
    # 根据子项目对初始化一些基本配置
    ctx = dnnlib.RunContext(submit_config, train)
    tflib.init_tf(tf_config)

    # Load training set.
    # 加载训练数据,其会把所有分辨率的数据都加载进来
    training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **dataset_args)

    # Construct networks.,
    # 如果指定了resume_run_id,则加其中的预训练模型,如果没有则从零开始训练。该处为核心重点,后续仔细分析
    with tf.device('/gpu:0'):
        if resume_run_id is not None:
            network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot)
            print('Loading networks from "%s"...' % network_pkl)
            G, D, Gs = misc.load_pkl(network_pkl)
        else:
            print('Constructing networks...')
            G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args)
            D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args)

            # 如果有多个GPU存在,其会其多个GPU权重的平均值。可以理解为,专门用来保存权重的
            Gs = G.clone('Gs')
    G.print_layers(); D.print_layers()

    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'), tf.device('/cpu:0'):
        # 图片分辨的,以2的多少次方进行输入,就是我们训练数据的2,3,4,5,6,7,......
        lod_in          = tf.placeholder(tf.float32, name='lod_in', shape=[])
        # 学习率
        lrate_in        = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        # 输入minibatch数目
        minibatch_in    = tf.placeholder(tf.int32, name='minibatch_in', shape=[])
        # 每个GPU训练的批次大小
        minibatch_split = minibatch_in // submit_config.num_gpus
        # 这个参数也比较奇怪,后续分析内部代码时讲解
        Gs_beta         = 0.5 ** tf.div(tf.cast(minibatch_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0

    # 对网络进行优化,应该包含了损失函数在里面
    G_opt = tflib.Optimizer(name='TrainG', learning_rate=lrate_in, **G_opt_args)
    D_opt = tflib.Optimizer(name='TrainD', learning_rate=lrate_in, **D_opt_args)
    for gpu in range(submit_config.num_gpus):
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):
            # 为每个个GPU拷贝一份
            G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow')
            D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')
            lod_assign_ops = [tf.assign(G_gpu.find_var('lod'), lod_in), tf.assign(D_gpu.find_var('lod'), lod_in)]
            # 获得训练数据图片和标签
            reals, labels = training_set.get_minibatch_tf()
            # 对训练数据的真实图片进行处理,主要把图片分成多个区域进行平滑,注意这里的reals包含多张图片,分别对应不同的分辨率,
            # 其实这里说是分辨率不太合适,总的来说,他们分辨率都是1024,但是平滑插值不一样.其不是用来训练的数据,是用来求损失用的,具体细节后面分析,也属于一个比较重要的地方
            reals = process_reals(reals, lod_in, mirror_augment, training_set.dynamic_range, drange_net)
            with tf.name_scope('G_loss'), tf.control_dependencies(lod_assign_ops):
                G_loss = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_split, **G_loss_args)
            with tf.name_scope('D_loss'), tf.control_dependencies(lod_assign_ops):
                D_loss = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_split, reals=reals, labels=labels, **D_loss_args)

            # 注册梯度下降求损失的方法
            G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables)
            D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables)
    # 反向传播需要的op
    G_train_op = G_opt.apply_updates()
    D_train_op = D_opt.apply_updates()

    # 计算权重平均值
    Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta)
    with tf.device('/gpu:0'):
        try:
            peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse()
        except tf.errors.NotFoundError:
            peak_gpu_mem_op = tf.constant(0)

    print('Setting up snapshot image grid...')
    # 在保存图片的时候,其一次会保存很多张,这个是其相关的设置,每一张图片看成一个网格
    grid_size, grid_reals, grid_labels, grid_latents = misc.setup_snapshot_image_grid(G, training_set, **grid_args)
    # 训练安排,如配置目前是训练了多少张,还有使用几个gpu等等,该函数会在随着训练图片的张数,被多次调用,其中会改变sched.lod_in参数。
    sched = training_schedule(cur_nimg=total_kimg*1000, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args)

    # 进行一次训练,输出的是(28,3,1024,1024),从保存的图片的结果可以知道
    grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch//submit_config.num_gpus)

    print('Setting up run dir...')
    # 把图片保存到子项目根目录
    misc.save_image_grid(grid_reals, os.path.join(submit_config.run_dir, 'reals.png'), drange=training_set.dynamic_range, grid_size=grid_size)
    misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % resume_kimg), drange=drange_net, grid_size=grid_size)

    # log的收集
    summary_log = tf.summary.FileWriter(submit_config.run_dir)
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        G.setup_weight_histograms(); D.setup_weight_histograms()
    metrics = metric_base.MetricGroup(metric_arg_list)


    print('Training...\n')
    # 更改训练图片的张数
    ctx.update('', cur_epoch=resume_kimg, max_epoch=total_kimg)
    maintenance_time = ctx.get_last_update_interval()
    cur_nimg = int(resume_kimg * 1000)
    # 当前训练的tick数从零开始
    cur_tick = 0
    tick_start_nimg = cur_nimg
    prev_lod = -1.0
    while cur_nimg < total_kimg * 1000:
        if ctx.should_stop(): break

        # Choose training parameters and configure training ops.
        # 主要根据cur_nimg,更改sched.lod参数
        sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args)
        training_set.configure(sched.minibatch // submit_config.num_gpus, sched.lod)

        # 如果设置了该参数,sched.lod会变成2把,总的来说,生成图片的样子会很平滑,很模糊。
        if reset_opt_for_new_lod:
            if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(sched.lod) != np.ceil(prev_lod):
                G_opt.reset_optimizer_state(); D_opt.reset_optimizer_state()
        prev_lod = sched.lod

        # Run training ops.
        # 经过上面一次生成器的迭代,对生成器进行多次迭代优化
        for _mb_repeat in range(minibatch_repeats):
            for _D_repeat in range(D_repeats):
                tflib.run([D_train_op, Gs_update_op], {lod_in: sched.lod, lrate_in: sched.D_lrate, minibatch_in: sched.minibatch})
                cur_nimg += sched.minibatch
            tflib.run([G_train_op], {lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_in: sched.minibatch})

        # 下面都是信息的打印,以及快照图片保存,就不进行介绍了
        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = ctx.get_time_since_last_update()
            total_time = ctx.get_time_since_start() + resume_time

            # Report progress.
            print('tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f' % (
                autosummary('Progress/tick', cur_tick),
                autosummary('Progress/kimg', cur_nimg / 1000.0),
                autosummary('Progress/lod', sched.lod),
                autosummary('Progress/minibatch', sched.minibatch),
                dnnlib.util.format_time(autosummary('Timing/total_sec', total_time)),
                autosummary('Timing/sec_per_tick', tick_time),
                autosummary('Timing/sec_per_kimg', tick_time / tick_kimg),
                autosummary('Timing/maintenance_sec', maintenance_time),
                autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30)))
            autosummary('Timing/total_hours', total_time / (60.0 * 60.0))
            autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0))

            # Save snapshots.
            if cur_tick % image_snapshot_ticks == 0 or done:
                grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch//submit_config.num_gpus)
                misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size)
            if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1:
                pkl = os.path.join(submit_config.run_dir, 'network-snapshot-%06d.pkl' % (cur_nimg // 1000))
                misc.save_pkl((G, D, Gs), pkl)
                metrics.run(pkl, run_dir=submit_config.run_dir, num_gpus=submit_config.num_gpus, tf_config=tf_config)

            # Update summaries and RunContext.
            metrics.update_autosummaries()
            tflib.autosummary.save_summaries(summary_log, cur_nimg)
            ctx.update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg)
            maintenance_time = ctx.get_last_update_interval() - tick_time

    # Write final results.
    # 保存最后的模型
    misc.save_pkl((G, D, Gs), os.path.join(submit_config.run_dir, 'network-final.pkl'))
    summary_log.close()

    ctx.close()

上面是训练框架的总体预览,因为注解勉强还过的去,就不做总结了,后面会详细分析其内部的每一个细节。

你可能感兴趣的:(风格迁移,style-gan,图片生成,图片融合)