项目地址:https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
train.py:是训练的脚本,支持不同的数据集和不同的模型的训练,模型有 --model: e.g., pix2pix, cyclegan, colorization ,数据集有 --dataset_mode: e.g., aligned, unaligned, single, colorization。
//训练一个CycleGAN的模型代码
python train.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan
//训练一个pix2pix的模型代码
python train.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --direction BtoA
test.py:是训练的脚本,加载保存在'--checkpoints_dir'下的模型,将结果保存到'--results_dir'。'--model test'选项用来从一侧生成CycleGAN的结果。'--dataset_mode single':仅仅从一个集合中加载图片。'--model cycle_gan':双向加载和生成结果。结果保存在./results/。用Use '--results_dir
//测试一个CycleAGAN模型的代码(both sides)
python test.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan
//测试一个CycleAGAN模型的代码(one side only)
python test.py --dataroot datasets/horse2zebra/testA --name horse2zebra_pretrained --model test --no_dropout
//测试pix2pix模型的代码
python test.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --direction BtoA
data 目录:中包含所有与数据加载和数据预处理有关的模块,为了添加一个自定义叫dummy的数据集,你需要:1.添加一个dummy_dataset.py的文件并且定义一个子类DummyDataset(继承自BaseDataset),2.实现下面四个函数(a).__init__ (初始化类别, 你需要首先调用 BaseDataset.__init__(self, opt)), (b).__len__ (返回数据集的大小), (c).__getitem__ (得到数据), (d).modify_commandline_options (添加数据集特指选项和设置默认选项)。完成以上部分,你便可以通过指定标志--dataset_mode dummy去用数据集类别。
__init__.py:实现这个包和训练、测试脚本之间的接口,train.py和test.py用from data import create_dataset和dataset = create_dataset(opt)两条语句去创建数据集(已知opt)。
base_dataset.py:实现数据集的抽象基类(ABC)。包含一般的转换函数(如:get_transform,__scale_width),这些在之后的子类中将会用到。
image_folder.py: 实现一个图片文件夹类。我们修改官方那个的Pytorch图像文件夹类以便这个类能加载来自当前路径下和子路径下的图片。
template_dataset.py:提供具有详细文档的数据集模板。如果您计划实现自己的数据集,请检查此文件.
aligned_dataset.py包含一个可以加载图像对的dataset类。它假设有一个单独的图像目录/path/to/data/train,其中包含了{A,B}形式的图像对。测试过程中需要准备一个目录/path/到/data/test作为测试数据。
unaligned_dataset.py包含一个可以加载未对齐/未配对数据集的dataset类。假设从A域path/to/data/trainA和从B域path/to/data/trainB两个目录分别存放训练图像。
然后,您可以使用dataset标志--dataroot /path/to/data来训练模型。同样,在测试期间需要准备两个目录/path/to/data/testA和/path/to/data/testB。
single_dataset.py包含一个数据集类,它可以加载由路径dataroot /path/to/data指定的一组单个图像。
它可以用于生成周期的结果仅为一侧与模型选项-模型测试。
colorization_dataset.py实现了一个数据集类,它可以加载RGB格式的自然图像集,并将RGB格式转换为实验室颜色空间中的(L, ab)对。
它是基于pix2pixel的着色模型(——模型着色)所需要的。
models目录:包含和目标函数、优化器、网络结构相关的模块,为了添加一个自定义的model类(dummy),你需要添加一个dummy.py文件,并且定义一个子类DummyModel(继承自BaseModel类),你需要实现以下四个函数:__init__(初始这个类;你需要先调用BaseModel.__init__(self, opt));set_input(从数据集中解压数据并使用预处理);forward(生成中间结果);optimize_parameters(计算损失,梯度,更新网络权重);(可选)modify_commandline_options(特定模型选择,设置默认选项);完成这四个函数后,你可以用--model dummy来使用模型。可以以我们的模板类为例,以下我们详细的解释每一个文件:
__init__.py: 实现这个包和训练和测试脚本之间的接口。train.py和tset.py调用from models import create_model和model = create_model(opt)这两条语句去创建一个给定选项opt的模型,你也需要调用model.setup(opt)去适当的初始化这个模型。
base_model.py:给模型实现一个抽象基类,它包含一些通用的helper函数(如:setup, test, update_learning_rate, save_networks, load_networks),在之后的子类中将被用到。
template_model.py 提供一个具有详细说明文档的模型模板。如果你想要实现你自己的模型,那么希望你核查一下该文件。
pix2pix_model.py: 实现pix2pix的模型,给定成对的数据,学习一个输入到输出的一个映射,该模型训需要--dataset_mode aligned这样的数据集,默认情况下,他使用--netG unet256的U-Nets生成器,--netD basic判别器(PatchGAN),和一个--gan_mode vanilla的GAN loss(标准交叉熵)
colorization_model.py: 为图像的彩色化实现一个Pix2PixModel模型的子类(黑白图像到彩色图像),该模型的训练要求 -dataset_model colorization数据集,它训练一个pix2pix的模型,在实验的颜色空间中从L频道映射到ab频道,默认地,这个colorization数据集将自动的设定--input_nc 1 和 --output_nc 2。
cycle_gan_model.py:实现CycleGAN模型,无需成对的数据便可学习图像到图像的翻译。该模型训练需要 --dataset_mode unaligned数据集,默认地,它使用--netG resnet_9blocks的ResNet生成器,一个--netD basic判别器(通过pix2pix引入的PatchGAN),一个最小平方GANs的目标(--gan_mode lsgan)。
networks.py:模型实现网络架构(生成器和判别器),还有归一化层,初始化方法,优化调度程序(如学习率策略),GAN的目标函数(vanilla,lsgan,wgangp)。
test_model.py:实现一个可以被用来生成单向CycleGAN结果的模型,这个模型自动设定--dataset_mode single,只从一个集合加载数据。看测试说明以了解更多细节。
options目录 包含我们可选择的模块,训练选项,测试选项,和基本的选项(训练和测试用的)。TrainOptions和TestOptions都是BaseOptions的子类,它们将重用在BaseOptions中定义的选项
__init__.py:是必需的,以使Python将目录options视为包含包,
base_options.py包括训练和测试中使用的选项。它还实现了一些辅助函数,如解析、打印和保存选项。它还在数据集类和模型类中收集在modify_commandline_options函数中定义的附加选项。
train_options.py包含了仅在训练期间使用的选项。
test_options.py包含仅在测试期间使用的选项。
util目录 包含各种有用的助手函数集合。
__init__.py:是必需的,以使Python将目录util视为包含包,
get_data.py:提供了一个Python脚本,用于下载CycleGAN和pix2pix数据集。另外,您还可以使用诸如download_pix2pix_model.sh和download_cyclegan_model.sh这样的bash脚本。
html.py实现了一个模块,它将图像保存到单个HTML文件中。它包括add_header(向HTML文件添加一个文本头)、add_images(向HTML文件添加一行图像)、save(将HTML保存到磁盘)等函数。它基于Python库dominate,一个用于使用DOM API创建和操作HTML文档的Python库。
image_pool.py:实现了一个图像缓冲区,用来存储以前生成的图像。这个缓冲区使我们能够使用生成图像的历史来更新标识符,而不是使用最新生成器生成的图像。本文对原始思想进行了讨论。缓冲区的大小由标志--pool_size控制。
visualizer.py包括几个可以显示/保存图像和打印/保存日志信息的函数。它使用Python库visdom进行显示,使用Python库dominate(包装在HTML中)创建带有图像的HTML文件。
util.py由简单的辅助函数组成,如tensor2im(将张量数组转换为numpy图像数组)、diagnose_network(计算并打印梯度的平均绝对值的平均值)和mkdirs(创建多个目录)。