Pytorch版UNIT(Coupled GAN algorithm for Unsupervised Image-to-Image Translation)(一)入口

UNIT与Coupled GAN (简称coGAN)的第一作者都是劉洺堉(Liu Mingyu),二者分别为ICCV和NIPS录用,可见作者在GAN方面成绩卓著。
文章的原理另写一篇文章介绍。这里只介绍代码实现的细节。
源代码
这份代码可用鸿篇巨制形容。作者使用yaml文件将参数设置和网络定义分离,类似于caffe中的prototxt的风格,值得我们学习。代码的模块清晰,定义网络、训练等文件分开,井然有序。因为我对于Pytorch并不熟练,看懂这份代码并不容易,所以将其解析如下。

入口文件的参数

训练的例子:

python cocogan_train.py --config ../exps/unit/blondhair.yaml --log ../logs --resume 1 --gpu 0

cocogan_train.py是训练的接口文件,跟着4个参数:

config: 指定训练超参
log: 指定日志文件夹
resume: 指定是否从前一次训练开始
gpu: 指定运行的GPU编号,默认为0
测试的例子:

python cocogan_translate_one_image.py --config ../exps/unit/corgi2husky.yaml --a2b 1 --weights ../outputs/unit/corgi2husky/corgi2husky_gen_00500000.pkl --image_name ../images/corgi001.jpg --output_image_name ../results/corgi2husky_corgi001.jpg

cocogan_translate_one_image.py 是测试的接口文件,跟着4个参数:

a2b: 1表示a到b,其他数表示b到a
trans_alone: 只显示迁移后的图像,默认为0
image_name: 输入的待迁移图像
output_image_name: 输出图像名
weights: 训练完的产生器权重参数值
config: 指定测试的超参
gpu: 指定运行的GPU编号,默认为0

YAML文件

config参数后指定超参的文件。下面是将头发转换颜色的例子。

# Copyright (C) 2017 NVIDIA Corporation.  All rights reserved.
# Licensed under the CC BY-NC-ND 4.0 license (https://creativecommons.org/licenses/by-nc-nd/4.0/legalcode).
train:
  snapshot_save_iterations: 5000 # How often do you want to save trained models
  image_save_iterations: 2500 # How often do you want to save output images during training
  image_display_iterations: 100
  display: 1 # How often do you want to log the training stats
  snapshot_prefix: ../outputs/unit/celeba/blondhair/blondhair # Where do you want to save the outputs
  hyperparameters:
    trainer: COCOGANTrainer
    lr: 0.0001             # learning rate
    ll_direct_link_w: 100  # weight on the self L1 reconstruction loss
    kl_direct_link_w: 0.1 # weight on VAE encoding loss
    ll_cycle_link_w: 100   # weight on the cycle L1 reconstruction loss
    kl_cycle_link_w: 0.1  # weight on the cycle L1 reconstruction loss
    gan_w: 10              # weight on the adversarial loss
    batch_size: 1          # image batch size per domain
    max_iterations: 2000000 # maximum number of training epochs
    gen:
      name: COCOResGen2
      ch: 64               # base channel number per layer
      input_dim_a: 3
      input_dim_b: 3
      n_enc_front_blk: 3
      n_enc_res_blk: 3
      n_enc_shared_blk: 1
      n_gen_shared_blk: 1
      n_gen_res_blk: 3
      n_gen_front_blk: 3
    dis:
      name: COCOSharedDis
      ch: 64
      input_dim_a: 3
      input_dim_b: 3
      n_front_layer: 2
      n_shared_layer: 4
  datasets:
    train_a: # Domain 1 dataset
      channels: 3       # image channel number
      scale: 1.0        # scaling factor for scaling image before processing
      crop_image_size: 128 # crop image size
      class_name: dataset_celeba           # dataset class name
      root: ../datasets/celeba/    # dataset folder location
      folder: img_align_crop_resize_celeba/
      list_name: lists/Blond_Hair_ON.txt
    train_b: # Domain 2 dataset
      channels: 3       # image channel number
      scale: 1.0        # scaling factor for scaling image before processing
      crop_image_size: 128 # crop image size
      class_name: dataset_celeba           # dataset class name
      root: ../datasets/celeba/    # dataset folder location
      folder: img_align_crop_resize_celeba/
      list_name: lists/Blond_Hair_OFF.txt

超参解析函数

对应文件net_config.py

import yaml


class NetConfig(object):
    def __init__(self, config):
        stream = open(config,'r')
        docs = yaml.load_all(stream)
        for doc in docs:
            for k, v in doc.items():
                if k == "train":
                    for k1, v1 in v.items():
                        cmd = "self." + k1 + "=" + repr(v1)
                        print(cmd)
                        exec(cmd)
        stream.close()


def dict_from_class(cls):
    return dict(
        (key, value)
        for (key, value) in cls.__dict__.items())


class SettingConfig(object):
    def __init__(self, config):
        stream = open(config,'r')
        docs = yaml.load_all(stream)
        for doc in docs:
            for k, v in doc.items():
                if k == "train":
                    for k1, v1 in v.items():
                        cmd = "self." + k1 + "=" + repr(v1)
                        print(cmd)
                        exec(cmd)
        stream.close()

入口文件

import sys
from tools import *
from trainers import *
from datasets import *
import torchvision
import itertools
from common import *
import tensorboard
from tensorboard import summary
from optparse import OptionParser
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
parser = OptionParser()
parser.add_option('--gpu', type=int, help="gpu id", default=0)
parser.add_option('--resume', type=int, help="resume training?", default=0)
parser.add_option('--config', type=str, help="net configuration")
parser.add_option('--log', type=str, help="log path")

MAX_EPOCHS = 100000

def main(argv):
  (opts, args) = parser.parse_args(argv)

  # Load experiment setting
  assert isinstance(opts, object)
  config = NetConfig(opts.config)

  batch_size = config.hyperparameters['batch_size']
  max_iterations = config.hyperparameters['max_iterations']

  train_loader_a = get_data_loader(config.datasets['train_a'], batch_size)
  train_loader_b = get_data_loader(config.datasets['train_b'], batch_size)

  trainer = []
  exec ("trainer=%s(config.hyperparameters)" % config.hyperparameters['trainer'])
  # Check if resume training
  iterations = 0
  if opts.resume == 1:
    iterations = trainer.resume(config.snapshot_prefix)
  trainer.cuda(opts.gpu)

  ######################################################################################################################
  # Setup logger and repare image outputs
  train_writer = tensorboard.FileWriter("%s/%s" % (opts.log,os.path.splitext(os.path.basename(opts.config))[0]))
  image_directory, snapshot_directory = prepare_snapshot_and_image_folder(config.snapshot_prefix, iterations, config.image_save_iterations)

  for ep in range(0, MAX_EPOCHS):
    for it, (images_a, images_b) in enumerate(itertools.izip(train_loader_a,train_loader_b)):
      if images_a.size(0) != batch_size or images_b.size(0) != batch_size:
        continue
      images_a = Variable(images_a.cuda(opts.gpu))
      images_b = Variable(images_b.cuda(opts.gpu))

      # Main training code 最主要的训练步骤
      trainer.dis_update(images_a, images_b, config.hyperparameters)
      image_outputs = trainer.gen_update(images_a, images_b, config.hyperparameters)
      assembled_images = trainer.assemble_outputs(images_a, images_b, image_outputs)

      # Dump training stats in log file 显示中间结果
      if (iterations+1) % config.display == 0:
        write_loss(iterations, max_iterations, trainer, train_writer)

      if (iterations+1) % config.image_save_iterations == 0:
        img_filename = '%s/gen_%08d.jpg' % (image_directory, iterations + 1)
        torchvision.utils.save_image(assembled_images.data / 2 + 0.5, img_filename, nrow=1)
        write_html(snapshot_directory + "/index.html", iterations + 1, config.image_save_iterations, image_directory)
      elif (iterations + 1) % config.image_display_iterations == 0:
        img_filename = '%s/gen.jpg' % (image_directory)
        torchvision.utils.save_image(assembled_images.data / 2 + 0.5, img_filename, nrow=1)

      # Save network weights  记录网络参数的中间结果
      if (iterations+1) % config.snapshot_save_iterations == 0:
        trainer.save(config.snapshot_prefix, iterations)

      iterations += 1
      if iterations >= max_iterations:
        return

if __name__ == '__main__':
  main(sys.argv)

最重要的函数已经标出,有以下几个,我们在另外一篇文章中解析。

trainer.dis_update
trainer.gen_update
trainer.assemble_outputs

对于不太重要的函数,其中的data_loader如下:

def get_data_loader(conf, batch_size):
  dataset = []
  print("dataset=%s(conf)" % conf['class_name'])
  exec ("dataset=%s(conf)" % conf['class_name'])
  return torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, num_workers=10)

还有一个工具函数,用来存放训练的中间值:

def prepare_snapshot_and_image_folder(snapshot_prefix, iterations, image_save_iterations, all_size=1536):
  snapshot_directory = prepare_snapshot_folder(snapshot_prefix)
  image_directory = prepare_image_folder(snapshot_directory)
  write_html(snapshot_directory + "/index.html", iterations + 1, image_save_iterations, image_directory, all_size)
  return image_directory, snapshot_directory

def prepare_snapshot_folder(snapshot_prefix):
  snapshot_directory = os.path.dirname(snapshot_prefix)
  if not os.path.exists(snapshot_directory):
    os.makedirs(snapshot_directory)
  return snapshot_directory

def prepare_image_folder(snapshot_directory):
  image_directory = os.path.join(snapshot_directory, 'images')
  if not os.path.exists(image_directory):
    os.makedirs(image_directory)
  return image_directory

def write_html(filename, iterations, image_save_iterations, image_directory, all_size=1536):
  html_file = open(filename, "w")
  html_file.write('''
  
  
  
    Experiment name = UnitNet
    
  
  
  ''')
  html_file.write("

current

") img_filename = '%s/gen.jpg' % (image_directory) html_file.write("""


""" % (img_filename, img_filename, all_size)) for j in range(iterations,image_save_iterations-1,-1): if j % image_save_iterations == 0: img_filename = '%s/gen_%08d.jpg' % (image_directory, j) html_file.write("

iteration [%d]

" % j) html_file.write("""


""" % (img_filename, img_filename, all_size)) html_file.write("") html_file.close()

记录中间值的步骤:

      # Dump training stats in log file
      if (iterations+1) % config.display == 0:
        write_loss(iterations, max_iterations, trainer, train_writer)

      if (iterations+1) % config.image_save_iterations == 0:
        img_filename = '%s/gen_%08d.jpg' % (image_directory, iterations + 1)
        torchvision.utils.save_image(assembled_images.data / 2 + 0.5, img_filename, nrow=1)
        write_html(snapshot_directory + "/index.html", iterations + 1, config.image_save_iterations, image_directory)
      elif (iterations + 1) % config.image_display_iterations == 0:
        img_filename = '%s/gen.jpg' % (image_directory)
        torchvision.utils.save_image(assembled_images.data / 2 + 0.5, img_filename, nrow=1)

      # Save network weights
      if (iterations+1) % config.snapshot_save_iterations == 0:
        trainer.save(config.snapshot_prefix, iterations)

其中write_loss为:

def write_loss(iterations, max_iterations, trainer, train_writer):
  print("Iteration: %08d/%08d" % (iterations + 1, max_iterations))
  members = [attr for attr in dir(trainer) \
             if not callable(getattr(trainer, attr)) and not attr.startswith("__") and 'loss' in attr]
  for m in members:
    train_writer.add_summary(summary.scalar(m, getattr(trainer, m)), iterations + 1)
  members = [attr for attr in dir(trainer) \
             if not callable(getattr(trainer, attr)) and not attr.startswith("__") and 'acc' in attr]
  for m in members:
    train_writer.add_summary(summary.scalar(m, getattr(trainer, m)), iterations + 1)

你可能感兴趣的:(Pytorch版UNIT(Coupled GAN algorithm for Unsupervised Image-to-Image Translation)(一)入口)