UNIT与Coupled GAN (简称coGAN)的第一作者都是劉洺堉(Liu Mingyu),二者分别为ICCV和NIPS录用,可见作者在GAN方面成绩卓著。
文章的原理另写一篇文章介绍。这里只介绍代码实现的细节。
源代码
这份代码可用鸿篇巨制形容。作者使用yaml文件将参数设置和网络定义分离,类似于caffe中的prototxt的风格,值得我们学习。代码的模块清晰,定义网络、训练等文件分开,井然有序。因为我对于Pytorch并不熟练,看懂这份代码并不容易,所以将其解析如下。
入口文件的参数
训练的例子:
python cocogan_train.py --config ../exps/unit/blondhair.yaml --log ../logs --resume 1 --gpu 0
cocogan_train.py是训练的接口文件,跟着4个参数:
config: 指定训练超参
log: 指定日志文件夹
resume: 指定是否从前一次训练开始
gpu: 指定运行的GPU编号,默认为0
测试的例子:
python cocogan_translate_one_image.py --config ../exps/unit/corgi2husky.yaml --a2b 1 --weights ../outputs/unit/corgi2husky/corgi2husky_gen_00500000.pkl --image_name ../images/corgi001.jpg --output_image_name ../results/corgi2husky_corgi001.jpg
cocogan_translate_one_image.py 是测试的接口文件,跟着4个参数:
a2b: 1表示a到b,其他数表示b到a
trans_alone: 只显示迁移后的图像,默认为0
image_name: 输入的待迁移图像
output_image_name: 输出图像名
weights: 训练完的产生器权重参数值
config: 指定测试的超参
gpu: 指定运行的GPU编号,默认为0
YAML文件
config参数后指定超参的文件。下面是将头发转换颜色的例子。
# Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
# Licensed under the CC BY-NC-ND 4.0 license (https://creativecommons.org/licenses/by-nc-nd/4.0/legalcode).
train:
snapshot_save_iterations: 5000 # How often do you want to save trained models
image_save_iterations: 2500 # How often do you want to save output images during training
image_display_iterations: 100
display: 1 # How often do you want to log the training stats
snapshot_prefix: ../outputs/unit/celeba/blondhair/blondhair # Where do you want to save the outputs
hyperparameters:
trainer: COCOGANTrainer
lr: 0.0001 # learning rate
ll_direct_link_w: 100 # weight on the self L1 reconstruction loss
kl_direct_link_w: 0.1 # weight on VAE encoding loss
ll_cycle_link_w: 100 # weight on the cycle L1 reconstruction loss
kl_cycle_link_w: 0.1 # weight on the cycle L1 reconstruction loss
gan_w: 10 # weight on the adversarial loss
batch_size: 1 # image batch size per domain
max_iterations: 2000000 # maximum number of training epochs
gen:
name: COCOResGen2
ch: 64 # base channel number per layer
input_dim_a: 3
input_dim_b: 3
n_enc_front_blk: 3
n_enc_res_blk: 3
n_enc_shared_blk: 1
n_gen_shared_blk: 1
n_gen_res_blk: 3
n_gen_front_blk: 3
dis:
name: COCOSharedDis
ch: 64
input_dim_a: 3
input_dim_b: 3
n_front_layer: 2
n_shared_layer: 4
datasets:
train_a: # Domain 1 dataset
channels: 3 # image channel number
scale: 1.0 # scaling factor for scaling image before processing
crop_image_size: 128 # crop image size
class_name: dataset_celeba # dataset class name
root: ../datasets/celeba/ # dataset folder location
folder: img_align_crop_resize_celeba/
list_name: lists/Blond_Hair_ON.txt
train_b: # Domain 2 dataset
channels: 3 # image channel number
scale: 1.0 # scaling factor for scaling image before processing
crop_image_size: 128 # crop image size
class_name: dataset_celeba # dataset class name
root: ../datasets/celeba/ # dataset folder location
folder: img_align_crop_resize_celeba/
list_name: lists/Blond_Hair_OFF.txt
超参解析函数
对应文件net_config.py
import yaml
class NetConfig(object):
def __init__(self, config):
stream = open(config,'r')
docs = yaml.load_all(stream)
for doc in docs:
for k, v in doc.items():
if k == "train":
for k1, v1 in v.items():
cmd = "self." + k1 + "=" + repr(v1)
print(cmd)
exec(cmd)
stream.close()
def dict_from_class(cls):
return dict(
(key, value)
for (key, value) in cls.__dict__.items())
class SettingConfig(object):
def __init__(self, config):
stream = open(config,'r')
docs = yaml.load_all(stream)
for doc in docs:
for k, v in doc.items():
if k == "train":
for k1, v1 in v.items():
cmd = "self." + k1 + "=" + repr(v1)
print(cmd)
exec(cmd)
stream.close()
入口文件
import sys
from tools import *
from trainers import *
from datasets import *
import torchvision
import itertools
from common import *
import tensorboard
from tensorboard import summary
from optparse import OptionParser
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
parser = OptionParser()
parser.add_option('--gpu', type=int, help="gpu id", default=0)
parser.add_option('--resume', type=int, help="resume training?", default=0)
parser.add_option('--config', type=str, help="net configuration")
parser.add_option('--log', type=str, help="log path")
MAX_EPOCHS = 100000
def main(argv):
(opts, args) = parser.parse_args(argv)
# Load experiment setting
assert isinstance(opts, object)
config = NetConfig(opts.config)
batch_size = config.hyperparameters['batch_size']
max_iterations = config.hyperparameters['max_iterations']
train_loader_a = get_data_loader(config.datasets['train_a'], batch_size)
train_loader_b = get_data_loader(config.datasets['train_b'], batch_size)
trainer = []
exec ("trainer=%s(config.hyperparameters)" % config.hyperparameters['trainer'])
# Check if resume training
iterations = 0
if opts.resume == 1:
iterations = trainer.resume(config.snapshot_prefix)
trainer.cuda(opts.gpu)
######################################################################################################################
# Setup logger and repare image outputs
train_writer = tensorboard.FileWriter("%s/%s" % (opts.log,os.path.splitext(os.path.basename(opts.config))[0]))
image_directory, snapshot_directory = prepare_snapshot_and_image_folder(config.snapshot_prefix, iterations, config.image_save_iterations)
for ep in range(0, MAX_EPOCHS):
for it, (images_a, images_b) in enumerate(itertools.izip(train_loader_a,train_loader_b)):
if images_a.size(0) != batch_size or images_b.size(0) != batch_size:
continue
images_a = Variable(images_a.cuda(opts.gpu))
images_b = Variable(images_b.cuda(opts.gpu))
# Main training code 最主要的训练步骤
trainer.dis_update(images_a, images_b, config.hyperparameters)
image_outputs = trainer.gen_update(images_a, images_b, config.hyperparameters)
assembled_images = trainer.assemble_outputs(images_a, images_b, image_outputs)
# Dump training stats in log file 显示中间结果
if (iterations+1) % config.display == 0:
write_loss(iterations, max_iterations, trainer, train_writer)
if (iterations+1) % config.image_save_iterations == 0:
img_filename = '%s/gen_%08d.jpg' % (image_directory, iterations + 1)
torchvision.utils.save_image(assembled_images.data / 2 + 0.5, img_filename, nrow=1)
write_html(snapshot_directory + "/index.html", iterations + 1, config.image_save_iterations, image_directory)
elif (iterations + 1) % config.image_display_iterations == 0:
img_filename = '%s/gen.jpg' % (image_directory)
torchvision.utils.save_image(assembled_images.data / 2 + 0.5, img_filename, nrow=1)
# Save network weights 记录网络参数的中间结果
if (iterations+1) % config.snapshot_save_iterations == 0:
trainer.save(config.snapshot_prefix, iterations)
iterations += 1
if iterations >= max_iterations:
return
if __name__ == '__main__':
main(sys.argv)
最重要的函数已经标出,有以下几个,我们在另外一篇文章中解析。
trainer.dis_update
trainer.gen_update
trainer.assemble_outputs
对于不太重要的函数,其中的data_loader如下:
def get_data_loader(conf, batch_size):
dataset = []
print("dataset=%s(conf)" % conf['class_name'])
exec ("dataset=%s(conf)" % conf['class_name'])
return torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, num_workers=10)
还有一个工具函数,用来存放训练的中间值:
def prepare_snapshot_and_image_folder(snapshot_prefix, iterations, image_save_iterations, all_size=1536):
snapshot_directory = prepare_snapshot_folder(snapshot_prefix)
image_directory = prepare_image_folder(snapshot_directory)
write_html(snapshot_directory + "/index.html", iterations + 1, image_save_iterations, image_directory, all_size)
return image_directory, snapshot_directory
def prepare_snapshot_folder(snapshot_prefix):
snapshot_directory = os.path.dirname(snapshot_prefix)
if not os.path.exists(snapshot_directory):
os.makedirs(snapshot_directory)
return snapshot_directory
def prepare_image_folder(snapshot_directory):
image_directory = os.path.join(snapshot_directory, 'images')
if not os.path.exists(image_directory):
os.makedirs(image_directory)
return image_directory
def write_html(filename, iterations, image_save_iterations, image_directory, all_size=1536):
html_file = open(filename, "w")
html_file.write('''
Experiment name = UnitNet
''')
html_file.write("current
")
img_filename = '%s/gen.jpg' % (image_directory)
html_file.write("""
""" % (img_filename, img_filename, all_size))
for j in range(iterations,image_save_iterations-1,-1):
if j % image_save_iterations == 0:
img_filename = '%s/gen_%08d.jpg' % (image_directory, j)
html_file.write("
iteration [%d]
" % j)
html_file.write("""
""" % (img_filename, img_filename, all_size))
html_file.write("")
html_file.close()
记录中间值的步骤:
# Dump training stats in log file
if (iterations+1) % config.display == 0:
write_loss(iterations, max_iterations, trainer, train_writer)
if (iterations+1) % config.image_save_iterations == 0:
img_filename = '%s/gen_%08d.jpg' % (image_directory, iterations + 1)
torchvision.utils.save_image(assembled_images.data / 2 + 0.5, img_filename, nrow=1)
write_html(snapshot_directory + "/index.html", iterations + 1, config.image_save_iterations, image_directory)
elif (iterations + 1) % config.image_display_iterations == 0:
img_filename = '%s/gen.jpg' % (image_directory)
torchvision.utils.save_image(assembled_images.data / 2 + 0.5, img_filename, nrow=1)
# Save network weights
if (iterations+1) % config.snapshot_save_iterations == 0:
trainer.save(config.snapshot_prefix, iterations)
其中write_loss为:
def write_loss(iterations, max_iterations, trainer, train_writer):
print("Iteration: %08d/%08d" % (iterations + 1, max_iterations))
members = [attr for attr in dir(trainer) \
if not callable(getattr(trainer, attr)) and not attr.startswith("__") and 'loss' in attr]
for m in members:
train_writer.add_summary(summary.scalar(m, getattr(trainer, m)), iterations + 1)
members = [attr for attr in dir(trainer) \
if not callable(getattr(trainer, attr)) and not attr.startswith("__") and 'acc' in attr]
for m in members:
train_writer.add_summary(summary.scalar(m, getattr(trainer, m)), iterations + 1)