【Note】pytorch-CycleGAN-and-pix2pix

1. data部分

下载数据集summer2winter_yosemite,文件夹结构如下

summer2winter_yosemite
		├─ testA	310幅256x256图像
		├─ testB	239幅
		├─ trainA	1232幅
		└─ trainB	963幅

训练模型

python train.py --dataroot datasets/summer2winter_yosemite --name run01 --model cycle_gan

查看代码 train.py

首先解析TrainOptions,同时会打印出opt

opt = TrainOptions().parse()   # get training options

解析参数时会加上一些新的opt,如 CycleGAN中的add_argument,暂时没有从代码中看懂这个新增参数的机制

创建dataset,类型为data.CustomDatasetDataLoader

dataset = create_dataset(opt)  # create a dataset given opt.dataset_mode and other options

查看 CustomDatasetDataLoader类,继承自object

创建self.datasetUnalignedDataset, AlignedDataset, SingleDataset, ColorizationDataset中的一种,它们继承自BaseDataset,由参数–dataset_mode决定

self.dataset = dataset_class(opt)

查看 UnalignedDataset类,比较重要的参数有
–max_dataset_size,如果觉得数据集中的图像太多,可以指定该参数来缩小数据集
–direction,指定两个domain的方向,默认为AtoB
–preprocess,指定图像的预处理方式,默认为resize_and_crop,可选crop | scale_width | scale_width_and_crop | none,值得注意的是,当选择none时,也会将图像resize为2的幂
–load_size, --crop_size,图像载入时首先resize的尺寸,以及图像最终的尺寸
–serial_batches,影响两处地方,一是取B_img时是按照index来取还是随机取,二是构建DataLoader时是否shuffle

图像归一化的参数mean=0.5, std=0.5

读取图像,然后转换到RGB,保证图像是3通道的

A_img = Image.open(A_path).convert('RGB')
B_img = Image.open(B_path).convert('RGB')

UnalignedDataset__getitem__返回格式如下

return {'A': A, 'B': B, 'A_paths': A_path, 'B_paths': B_path}

回到CustomDatasetDataLoader__init__,创建了self.dataset之后,封装为DataLoader

self.dataloader = torch.utils.data.DataLoader(
	self.dataset,
	batch_size=opt.batch_size,
	shuffle=not opt.serial_batches,
	num_workers=int(opt.num_threads))

2. model部分

回到 train.py

model = create_model(opt)      # create a model given opt.model and other options

执行上述语句,会打印如下信息

initialize network with normal
initialize network with normal
initialize network with normal
initialize network with normal
model [CycleGANModel] was created

create_model方法中

instance = model(opt)

其中modelCycleGANModel, Pix2PixModel, TestModel, ColorizationModel中的一种,它们均继承自BaseModel,由参数 –model 指定

查看 CycleGANModel类

__init__

# specify the training losses you want to print out. The training/test scripts will call 
self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B']

其它

parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator')

这里感觉写了个bug,因为无论命令行是否添加--no_dropoutopt.no_dropout都为True

你可能感兴趣的:(读书笔记)