目的:大致看懂cycleGAN代码结构
参考:https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/overview.md
目录
一、概览
train.py
test.py
data文件夹
models文件夹
options文件夹
util文件夹
二、train.py
三、test.py
四、data文件夹
4.1 __init__.py
4.2 base_dataset.py
4.3 image_folder.py
4.4 template_dataset.py
4.5 aligned_dataset.py
4.6 unaligned_dataset.py
4.7 single、clolorization dataset
五、models
六、options and util
用于模型训练
--model
: e.g., pix2pix
, cyclegan
, colorization
) and
different datasets (with option --dataset_mode
: e.g., aligned
, unaligned
, single
, colorization
用于模型测试
包含关于所有数据加载数据处理的程序。
模型相关的objective functions, optimizations, and network architectures.
训练,测试以及相关模型的选项
相关帮助函数的杂项汇总
train获取数据(data文件夹),创建模型(model),循环epoch然后在epoch内更新参数,存储网络
读取数据,创建模型,然后将数据送入模型进行test
包含着模型加载与处理的程序。
用于给train和test过程生成数据集
from data import create_dataset
dataset = create_dataset(opt) to create a dataset given the option opt.
https://docs.python.org/3/library/abc.html
get_transform
, __scale_width
), which can be later used in subclasses.用于运用abstract base class abc。
pytorch默认只从文件夹中读文件,作者可以从文件夹和子文件夹中读文件。
创建一个数据集的模板,以及详细的描述
用于加载样本对(主要用于pix2pix,对于我们的cycleGAN并无太大作用)
用于unpaired 数据集,用于cycleGAN,训练时trainA 和trainB 中应该放入domainA和domainB中的东西,test时也是这样。
/path/to/data/train
, which contains image pairs in the form of {A,B}. See here on how to prepare aligned datasets. During test time, you need to prepare a directory /path/to/data/test
as test data./path/to/data/trainA
and from domain B /path/to/data/trainB
respectively. Then you can train the model with the dataset flag --dataroot /path/to/data
. Similarly, you need to prepare two directories /path/to/data/testA
and /path/to/data/testB
during test time.--dataroot /path/to/data
. It can be used for generating CycleGAN results only for one side with the model option -model test
.--model colorization
).modules related to objective functions, optimizations, and network architectures.
更改models的顺序:
To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
You need to implement the following five functions:
-- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
--
--
--
--
In the function <__init__>, you need to define four lists:
-- self.loss_names (str list): specify the training losses that you want to plot and save.
-- self.model_names (str list): define networks used in our training.
-- self.visual_names (str list): specify the images that you want to display and save.
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.