下载数据集summer2winter_yosemite
,文件夹结构如下
summer2winter_yosemite
├─ testA 310幅256x256图像
├─ testB 239幅
├─ trainA 1232幅
└─ trainB 963幅
训练模型
python train.py --dataroot datasets/summer2winter_yosemite --name run01 --model cycle_gan
查看代码 train.py
首先解析TrainOptions
,同时会打印出opt
opt = TrainOptions().parse() # get training options
解析参数时会加上一些新的opt,如 CycleGAN中的add_argument,暂时没有从代码中看懂这个新增参数的机制
创建dataset
,类型为data.CustomDatasetDataLoader
dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options
查看 CustomDatasetDataLoader类,继承自object
创建self.dataset
是UnalignedDataset
, AlignedDataset
, SingleDataset
, ColorizationDataset
中的一种,它们继承自BaseDataset
,由参数–dataset_mode决定
self.dataset = dataset_class(opt)
查看 UnalignedDataset类,比较重要的参数有
–max_dataset_size,如果觉得数据集中的图像太多,可以指定该参数来缩小数据集
–direction,指定两个domain的方向,默认为AtoB
–preprocess,指定图像的预处理方式,默认为resize_and_crop
,可选crop | scale_width | scale_width_and_crop | none
,值得注意的是,当选择none
时,也会将图像resize为2的幂
–load_size, --crop_size,图像载入时首先resize的尺寸,以及图像最终的尺寸
–serial_batches,影响两处地方,一是取B_img
时是按照index
来取还是随机取,二是构建DataLoader
时是否shuffle
图像归一化的参数mean=0.5, std=0.5
读取图像,然后转换到RGB,保证图像是3通道的
A_img = Image.open(A_path).convert('RGB')
B_img = Image.open(B_path).convert('RGB')
UnalignedDataset
的__getitem__
返回格式如下
return {'A': A, 'B': B, 'A_paths': A_path, 'B_paths': B_path}
回到CustomDatasetDataLoader
的__init__
,创建了self.dataset
之后,封装为DataLoader
self.dataloader = torch.utils.data.DataLoader(
self.dataset,
batch_size=opt.batch_size,
shuffle=not opt.serial_batches,
num_workers=int(opt.num_threads))
回到 train.py
model = create_model(opt) # create a model given opt.model and other options
执行上述语句,会打印如下信息
initialize network with normal
initialize network with normal
initialize network with normal
initialize network with normal
model [CycleGANModel] was created
在 create_model
方法中
instance = model(opt)
其中model
是CycleGANModel
, Pix2PixModel
, TestModel
, ColorizationModel
中的一种,它们均继承自BaseModel
,由参数 –model 指定
查看 CycleGANModel类
在__init__
中
# specify the training losses you want to print out. The training/test scripts will call
self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B']
parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator')
这里感觉写了个bug,因为无论命令行是否添加--no_dropout
,opt.no_dropout
都为True