NVlabs/noise2noise代码(三)网络训练代码解析

目的:看懂网络训练代码,方便后续更改。

目录

一、更改迭代次数

1.1定义位置

train_config初始化位置

EasyDict定义位置

1.2 更改迭代次数的方法

二、网络结构

2.1 原始网络结构与代码解析

2.2 训练时autoencoder的调用

三、训练函数嵌套关系

3.1 config.py 到submit_run

3.2 submit.py 中的run_wrapper

3.3 调用train.py

3.4打印出信息

四、训练函数

4.1 输入参数及类型

4.2 初始化

4.3 创建网络

4.4 运用梯度更新及迭代



一、更改迭代次数

原始迭代次数较大,所以运行一次耗时较久,不易调试,我们将迭代次数改小以便调试。直接在main下面的def train(args)中更改就可以。

1.1定义位置

config.py之中,这里将其值改小就可以取得减少程序运行时间的效果。

if __name__ == "__main__":
    def train(args):
        if args:
            n2n = args.noise2noise if 'noise2noise' in args else True
            train_config.noise2noise = n2n
            if 'long_train' in args and args.long_train:
                #train_config.iteration_count = 500000
                train_config.iteration_count = 500
                #train_config.eval_interval = 5000
                train_config.eval_interval = 50
                train_config.ramp_down_perc = 0.5

train_config初始化位置

train_config = dnnlib.EasyDict(
    iteration_count=300000,
    eval_interval=1000,
    minibatch_size=4,
    run_func_name="train.train",
    learning_rate=0.0003,
    ramp_down_perc=0.3,
    noise=gaussian_noise_config,
#    noise=poisson_noise_config,
    noise2noise=True,
    train_tfrecords='datasets/imagenet_val_raw.tfrecords',
    validation_config=default_validation_config
)

可以看出将其定义为dnnlib.EasyDict格式的数据。类似于字典,但比字典更易使用。

EasyDict定义位置

train_config与validate_config都运用了dnnlib.EasyDict作为初始化,此class的位置为dnnlib/util.py

class EasyDict(dict):
    """Convenience class that behaves like a dict but allows access with the attribute syntax."""

    def __getattr__(self, name: str) -> Any:
        try:
            return self[name]
        except KeyError:
            raise AttributeError(name)

    def __setattr__(self, name: str, value: Any) -> None:
        self[name] = value

    def __delattr__(self, name: str) -> None:
        del self[name]

1.2 更改迭代次数的方法

迭代次数分别改为,评估与输出的interval,和输出的斜率。

                train_config.iteration_count = 100
                train_config.eval_interval = 10
                train_config.ramp_down_perc = 0.5

迭代次数设为500时,最终PSNR=27.34,时间为2m28s

迭代次数为100时,最终PSNR=23.42,时间为1m48s

二、网络结构

2.1 原始网络结构与代码解析

原始网络可以看作一个unet

    skips = [x]

    n = x
    n = conv_lr('enc_conv0', n, 48)
    n = conv_lr('enc_conv1', n, 48)
    n = maxpool2d(n)
    skips.append(n)

    n = conv_lr('enc_conv2', n, 48)
    n = maxpool2d(n)
    skips.append(n)

    n = conv_lr('enc_conv3', n, 48)
    n = maxpool2d(n)
    skips.append(n)

    n = conv_lr('enc_conv4', n, 48)
    n = maxpool2d(n)
    skips.append(n)

    n = conv_lr('enc_conv5', n, 48)
    n = maxpool2d(n)
    n = conv_lr('enc_conv6', n, 48)

    #-----------------------------------------------
    n = upscale2d(n)
    n = tf.concat([n, skips.pop()], axis=1)
    n = conv_lr('dec_conv5', n, 96)
    n = conv_lr('dec_conv5b', n, 96)

    n = upscale2d(n)
    n = tf.concat([n, skips.pop()], axis=1)
    n = conv_lr('dec_conv4', n, 96)
    n = conv_lr('dec_conv4b', n, 96)

    n = upscale2d(n)
    n = tf.concat([n, skips.pop()], axis=1)
    n = conv_lr('dec_conv3', n, 96)
    n = conv_lr('dec_conv3b', n, 96)

    n = upscale2d(n)
    n = tf.concat([n, skips.pop()], axis=1)
    n = conv_lr('dec_conv2', n, 96)
    n = conv_lr('dec_conv2b', n, 96)

    n = upscale2d(n)
    n = tf.concat([n, skips.pop()], axis=1)
    n = conv_lr('dec_conv1a', n, 64)
    n = conv_lr('dec_conv1b', n, 32)

    n = conv('dec_conv1', n, 3, gain=1.0)

结构如下,相应数值和尺寸需要按照上面代码进行变化。

NVlabs/noise2noise代码(三)网络训练代码解析_第1张图片

基本可以确定autoencoder之中确定的网络结构。

2.2 训练时autoencoder的调用

直接在autoencoder之中加输出的信息

(n2n) jcx@smart-dsp:~/Desktop/NVlabs_noise2noise$ CUDA_VISIBLE_DEVICES=0 python config.py --desc='-test' train --train-tfrecords=datasets/part_bsd300.tfrecords --long-train=false --noise=gaussian
----------train in config.py
----------Iteration count is 100 and eval_interval is 10
{'iteration_count': 100, 'eval_interval': 10, 'minibatch_size': 4, 'run_func_name': 'train.train', 'learning_rate': 0.0003, 'ramp_down_perc': 0.5, 'noise': {'func_name': 'train.AugmentGaussian', 'train_stddev_rng_range': (0.0, 50.0), 'validation_stddev': 25.0}, 'noise2noise': True, 'train_tfrecords': 'datasets/part_bsd300.tfrecords', 'validation_config': {'dataset_dir': 'datasets/kodak'}}
----------submit_run in submit.py
----------submit_config.submit_target in {SubmitTarget.LOCAL},create new dir to run
Creating the run dir: results/00010-autoencoder-test-n2n
----------_populate_run_dir function in submit.py. Copying files to the run dir
----------run_wrapper function in submit.py
dnnlib: Running train.train() on localhost...
----------train in train.py
----------Setting up dataset source from datasets/part_bsd300.tfrecords
----------net = tflib.Network(**config.net_config) in train.py
-------------autoencoder in network.py

autoencoder                 Params      OutputShape             WeightShape
---                         ---         ---                     ---
x                           -           (?, 3, 256, 256)        -
enc_conv0                   1344        (?, 48, 256, 256)       (3, 3, 3, 48)
enc_conv1                   20784       (?, 48, 256, 256)       (3, 3, 48, 48)
MaxPool                     -           (?, 48, 128, 128)       -
enc_conv2                   20784       (?, 48, 128, 128)       (3, 3, 48, 48)
MaxPool_1                   -           (?, 48, 64, 64)         -
enc_conv3                   20784       (?, 48, 64, 64)         (3, 3, 48, 48)
MaxPool_2                   -           (?, 48, 32, 32)         -
enc_conv4                   20784       (?, 48, 32, 32)         (3, 3, 48, 48)
MaxPool_3                   -           (?, 48, 16, 16)         -
enc_conv5                   20784       (?, 48, 16, 16)         (3, 3, 48, 48)
MaxPool_4                   -           (?, 48, 8, 8)           -
enc_conv6                   20784       (?, 48, 8, 8)           (3, 3, 48, 48)
Upscale2D                   -           (?, 48, 16, 16)         -
dec_conv5                   83040       (?, 96, 16, 16)         (3, 3, 96, 96)
dec_conv5b                  83040       (?, 96, 16, 16)         (3, 3, 96, 96)
Upscale2D_1                 -           (?, 96, 32, 32)         -
dec_conv4                   124512      (?, 96, 32, 32)         (3, 3, 144, 96)
dec_conv4b                  83040       (?, 96, 32, 32)         (3, 3, 96, 96)
Upscale2D_2                 -           (?, 96, 64, 64)         -
dec_conv3                   124512      (?, 96, 64, 64)         (3, 3, 144, 96)
dec_conv3b                  83040       (?, 96, 64, 64)         (3, 3, 96, 96)
Upscale2D_3                 -           (?, 96, 128, 128)       -
dec_conv2                   124512      (?, 96, 128, 128)       (3, 3, 144, 96)
dec_conv2b                  83040       (?, 96, 128, 128)       (3, 3, 96, 96)
Upscale2D_4                 -           (?, 96, 256, 256)       -
dec_conv1a                  57088       (?, 64, 256, 256)       (3, 3, 99, 64)
dec_conv1b                  18464       (?, 32, 256, 256)       (3, 3, 64, 32)
dec_conv1                   867         (?, 3, 256, 256)        (3, 3, 32, 3)
---                         ---         ---                     ---
Total                       991203

----------Building TensorFlow graph...
-------------autoencoder in network.py
----------train_step = opt.apply_updates() in train.py
----------Training...
-------------autoencoder in network.py
-------------autoencoder in network.py
-------------autoencoder in network.py
Average PSNR: 7.15
iter 0          time 6s           sec/eval 0.0     sec/iter 0.00    maintenance 6.5
Average PSNR: 18.83
iter 10         time 37s          sec/eval 1.7     sec/iter 0.17    maintenance 28.3
Average PSNR: 19.80
iter 20         time 45s          sec/eval 1.1     sec/iter 0.11    maintenance 7.4
Average PSNR: 21.02
iter 30         time 53s          sec/eval 1.0     sec/iter 0.10    maintenance 7.2
Average PSNR: 21.56
iter 40         time 1m 01s       sec/eval 1.1     sec/iter 0.11    maintenance 7.1
Average PSNR: 22.35
iter 50         time 1m 09s       sec/eval 1.1     sec/iter 0.11    maintenance 6.9
Average PSNR: 21.74
iter 60         time 1m 18s       sec/eval 1.1     sec/iter 0.11    maintenance 7.5
Average PSNR: 22.93
iter 70         time 1m 26s       sec/eval 1.1     sec/iter 0.11    maintenance 6.6
Average PSNR: 22.64
iter 80         time 1m 34s       sec/eval 1.0     sec/iter 0.10    maintenance 7.7
Average PSNR: 23.02
iter 90         time 1m 43s       sec/eval 1.0     sec/iter 0.10    maintenance 7.7
Elapsed time: 1m 52s
dnnlib: Finished train.train() in 1m 53s.
----------try in run wrapper finally handle _finished.txt

三、训练函数嵌套关系

NVlabs/noise2noise代码(三)网络训练代码解析_第2张图片

3.1 config.py 到submit_run

config.py之中是主函数,训练过程中,函数的作用是加入相应参数之后,运行train函数

    submit_config.run_desc = desc + args.desc
    if args.run_dir_root is not None:
        submit_config.run_dir_root = args.run_dir_root
    if args.command is not None:
        args.func(args)
    else:
        # Train if no subcommand was given
        train(args)

train函数中关键语句为

dnnlib.submission.submit.submit_run(submit_config, **train_config)

其中输入参数,train_config与submit_config就是上面设置的相应参数.其中,train_config是一系列与训练相关的代码,submit_config在submit.py之中,用于实现将运行目录及结果写入等等操作。

#train config
train_config = dnnlib.EasyDict(
    iteration_count=300000,
    eval_interval=1000,
    minibatch_size=4,
    run_func_name="train.train",
    learning_rate=0.0003,
    ramp_down_perc=0.3,
    noise=gaussian_noise_config,
#    noise=poisson_noise_config,
    noise2noise=True,
    train_tfrecords='datasets/imagenet_val_raw.tfrecords',
    validation_config=default_validation_config
)

#submit config
submit_config = dnnlib.SubmitConfig()
submit_config.run_dir_root = 'results'
submit_config.run_dir_ignore += ['datasets', 'results']

desc = "autoencoder"

3.2 submit.py 中的run_wrapper

submit_run之中,创建相应文件夹,在文件夹之中运行该运行的函数,核心语句为run_wrapper函数

run_wrapper函数中用于运行封装好的函数

核心语句:

    try:
        print("dnnlib: Running {0}() on {1}...".format(submit_config.run_func_name, submit_config.host_name))
        start_time = time.time()
        util.call_func_by_name(func_name=submit_config.run_func_name, submit_config=submit_config, **submit_config.run_func_kwargs)
        print("dnnlib: Finished {0}() in {1}.".format(submit_config.run_func_name, util.format_time(time.time() - start_time)))
    except:
        if is_local:
            raise
        else:
            traceback.print_exc()

            log_src = os.path.join(submit_config.run_dir, "log.txt")
            log_dst = os.path.join(get_path_from_template(submit_config.run_dir_root), "{0}-error.txt".format(submit_config.run_name))
            shutil.copyfile(log_src, log_dst)
    finally:
        open(os.path.join(submit_config.run_dir, "_finished.txt"), "w").close()

其中,就是运用call_func_by_name函数进行函数的调用与实现

util.call_func_by_name(func_name=submit_config.run_func_name, submit_config=submit_config, **submit_config.run_func_kwargs)

3.3 调用train.py

util.call_func_by_name(func_name=submit_config.run_func_name, submit_config=submit_config, **submit_config.run_func_kwargs)

通过call_func_by_name这个函数调用相应的函数,

def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
    """Finds the python object with the given name and calls it as a function."""
    assert func_name is not None
    func_obj = get_obj_by_name(func_name)
    assert callable(func_obj)
    return func_obj(*args, **kwargs)

3.4打印出信息

通过打印出信息验证了嵌套及调用关系

----------train in config.py
----------Iteration count is 100 and eval_interval is 10
{'iteration_count': 100, 'eval_interval': 10, 'minibatch_size': 4, 'run_func_name': 'train.train', 'learning_rate': 0.0003, 'ramp_down_perc': 0.5, 'noise': {'func_name': 'train.AugmentGaussian', 'train_stddev_rng_range': (0.0, 50.0), 'validation_stddev': 25.0}, 'noise2noise': True, 'train_tfrecords': 'datasets/part_bsd300.tfrecords', 'validation_config': {'dataset_dir': 'datasets/kodak'}}
----------submit_run in submit.py
----------submit_config.submit_target in {SubmitTarget.LOCAL},create new dir to run
Creating the run dir: results/00503-autoencoder-n2n
----------_populate_run_dir function in submit.py. Copying files to the run dir
----------run_wrapper function in submit.py
dnnlib: Running train.train() on localhost...
---------------train function in train.py

四、训练函数

train.py用于进行函数的训练

NVlabs/noise2noise代码(三)网络训练代码解析_第3张图片

4.1 输入参数及类型

def train(
    submit_config: dnnlib.SubmitConfig,
    iteration_count: int,
    eval_interval: int,
    minibatch_size: int,
    learning_rate: float,
    ramp_down_perc: float,
    noise: dict,
    validation_config: dict,
    train_tfrecords: str,
    noise2noise: bool):

输入网络训练,验证以及batch size与iteration_count等等参数,后面此参数会传入各种train中的函数,以用于定义网络训练。

4.2 初始化

    #create validation set
    noise_augmenter = dnnlib.util.call_func_by_name(**noise)
    validation_set = ValidationSet(submit_config)
    validation_set.load(**validation_config)
    print("----------train in train.py")

    # Create a run context (hides low level details, exposes simple API to manage the run)
    ctx = dnnlib.RunContext(submit_config, config)

    # Initialize TensorFlow graph and session using good default settings
    tfutil.init_tf(config.tf_config)

    #create dataset
    dataset_iter = create_dataset(train_tfrecords, minibatch_size, noise_augmenter.add_train_noise_tf)

4.3 创建网络

    # Construct the network using the Network helper class and a function defined in config.net_config
    with tf.device("/gpu:0"):
        print("----------net = tflib.Network(**config.net_config) in train.py")
        net = tflib.Network(**config.net_config)

    # Optionally print layer information
    net.print_layers()

    print('----------Building TensorFlow graph...')
    with tf.name_scope('Inputs'), tf.device("/cpu:0"):
        lrate_in        = tf.placeholder(tf.float32, name='lrate_in', shape=[])

        noisy_input, noisy_target, clean_target = dataset_iter.get_next()
        noisy_input_split = tf.split(noisy_input, submit_config.num_gpus)
        noisy_target_split = tf.split(noisy_target, submit_config.num_gpus)
        clean_target_split = tf.split(clean_target, submit_config.num_gpus)

    # Define the loss function using the Optimizer helper class, this will take care of multi GPU
    opt = tflib.Optimizer(learning_rate=lrate_in, **config.optimizer_config)

    #compute meansq_error
    for gpu in range(submit_config.num_gpus):
        with tf.device("/gpu:%d" % gpu):
            net_gpu = net if gpu == 0 else net.clone()

            denoised = net_gpu.get_output_for(noisy_input_split[gpu])

            if noise2noise:
                meansq_error = tf.reduce_mean(tf.square(noisy_target_split[gpu] - denoised))
            else:
                meansq_error = tf.reduce_mean(tf.square(clean_target_split[gpu] - denoised))
            # Create an autosummary that will average over all GPUs
            with tf.control_dependencies([autosummary("Loss", meansq_error)]):
                opt.register_gradients(meansq_error, net_gpu.trainables)

4.4 运用梯度更新及迭代

NVlabs/noise2noise代码(三)网络训练代码解析_第4张图片

    #apply updates
    print("----------train_step = opt.apply_updates() in train.py")
    train_step = opt.apply_updates()

    # Create a log file for Tensorboard
    summary_log = tf.summary.FileWriter(submit_config.run_dir)
    summary_log.add_graph(tf.get_default_graph())

    print('----------Training...')
    time_maintenance = ctx.get_time_since_last_update()
    ctx.update(loss='run %d' % submit_config.run_id, cur_epoch=0, max_epoch=iteration_count)

    # The actual training loop
    for i in range(iteration_count):
        # Whether to stop the training or not should be asked from the context
        if ctx.should_stop():
            break

        # Dump training status
        if i % eval_interval == 0:

            time_train = ctx.get_time_since_last_update()
            time_total = ctx.get_time_since_start()

            # Evaluate 'x' to draw a batch of inputs
            [source_mb, target_mb] = tfutil.run([noisy_input, clean_target])
            denoised = net.run(source_mb)
            #save_image(submit_config, denoised[0], "img_{0}_y_pred.png".format(i))
            #save_image(submit_config, target_mb[0], "img_{0}_y.png".format(i))
            #save_image(submit_config, source_mb[0], "img_{0}_x_aug.png".format(i))

            validation_set.evaluate(net, i, noise_augmenter.add_validation_noise_np)

            print('iter %-10d time %-12s sec/eval %-7.1f sec/iter %-7.2f maintenance %-6.1f' % (
                autosummary('Timing/iter', i),
                dnnlib.util.format_time(autosummary('Timing/total_sec', time_total)),
                autosummary('Timing/sec_per_eval', time_train),
                autosummary('Timing/sec_per_iter', time_train / eval_interval),
                autosummary('Timing/maintenance_sec', time_maintenance)))

            dnnlib.tflib.autosummary.save_summaries(summary_log, i)
            ctx.update(loss='run %d' % submit_config.run_id, cur_epoch=i, max_epoch=iteration_count)
            time_maintenance = ctx.get_last_update_interval() - time_train

        lrate =  compute_ramped_down_lrate(i, iteration_count, ramp_down_perc, learning_rate)
        tfutil.run([train_step], {lrate_in: lrate})

你可能感兴趣的:(图像去噪,python)