一、pix2pixHD代码解析(train.py + test.py)
二、pix2pixHD代码解析(options设置)
三、pix2pixHD代码解析(dataset处理)
四、pix2pixHD代码解析(models搭建)
data_loader.py
##########################################################################
# 创建数据集加载主函数
##########################################################################
def CreateDataLoader(opt):
from data.custom_dataset_data_loader import CustomDatasetDataLoader
data_loader = CustomDatasetDataLoader()
print(data_loader.name()) # 返回的名字为“CustomDatasetDataLoader”
data_loader.initialize(opt) # 初始化参数
return data_loader
custom_dataset_data_loader.py
import torch.utils.data
from data.base_data_loader import BaseDataLoader
# 创建数据集
def CreateDataset(opt):
dataset = None
from data.aligned_dataset import AlignedDataset
dataset = AlignedDataset()
print("dataset [%s] was created" % (dataset.name())) # 打印数据集名字为‘AlignedDataset’
dataset.initialize(opt) # 初始化数据集参数
return dataset # 返回创建好的数据集
# 加载数据集
class CustomDatasetDataLoader(BaseDataLoader):
def name(self):
return 'CustomDatasetDataLoader'
def initialize(self, opt):
BaseDataLoader.initialize(self, opt) # 初始化参数
self.dataset = CreateDataset(opt) # 创建数据集
self.dataloader = torch.utils.data.DataLoader( # 加载创建好的数据集,并自定义相关参数
self.dataset,
batch_size=opt.batchSize,
shuffle=not opt.serial_batches,
num_workers=int(opt.nThreads))
def load_data(self):
return self.dataloader # 返回数据集
def __len__(self):
return min(len(self.dataset), self.opt.max_dataset_size) # 返回加载的数据集长度和一个epoch容许的加载最大容量
aligned_dataset.py
#############################################################################
# 数据读取的方式
#############################################################################
import os.path
from data.base_dataset import BaseDataset, get_params, get_transform, normalize
from data.image_folder import make_dataset
from PIL import Image
# 返回一个字典,里面由整理好的数据集:图片 + 类别
class AlignedDataset(BaseDataset): # init里面都是些路径的设置
def initialize(self, opt):
self.opt = opt
self.root = opt.dataroot
### input A (label maps) # 标签图的路径
dir_A = '_A' if self.opt.label_nc == 0 else '_label'
self.dir_A = os.path.join(opt.dataroot, opt.phase + dir_A) # './geometry' + 'train' + '_label'
### sort 是应用在 list 上的方法,sorted 可以对所有可迭代的对象进行排序操作。
# list 的 sort 方法返回的是对已经存在的列表进行操作;
# 而内建函数 sorted 方法返回的是一个新的 list,而不是在原来的基础上进行的操作
# (事实证明直接对string排序,与实际int值排序结果是不一样的,图片名并不是按照从小到大的顺序)
self.A_paths = sorted(make_dataset(self.dir_A)) # 返回self.dir_A下的图片路径列表
### input B (real images) # 真实图的路径
if opt.isTrain or opt.use_encoded_image:
dir_B = '_B' if self.opt.label_nc == 0 else '_img'
self.dir_B = os.path.join(opt.dataroot, opt.phase + dir_B)
self.B_paths = sorted(make_dataset(self.dir_B))
# self.B_paths = self.A_paths
### instance maps # 实例图的路径
if not opt.no_instance: # 如果no_instance为true,则不添加实例图
self.dir_inst = os.path.join(opt.dataroot, opt.phase + '_inst')
self.inst_paths = sorted(make_dataset(self.dir_inst))
# self.inst_paths = self.A_paths
### load precomputed instance-wise encoded features
if opt.load_features:
self.dir_feat = os.path.join(opt.dataroot, opt.phase + '_feat')
print('----------- loading features from %s ----------' % self.dir_feat)
self.feat_paths = sorted(make_dataset(self.dir_feat)) # 本文中没有train_feat图片
self.dataset_size = len(self.A_paths)
def __getitem__(self, index): # getitem里是具体的操作,是这个类的重点操作
### input A (label maps) # 读取标签图A
A_path = self.A_paths[index] # 获得图片路径
# A = Image.open(self.dir_A + '/' + A_path) # 先读取一张图片
A = Image.open(A_path)
params = get_params(self.opt, A.size) # 根据输入的opt和size,返回随机参数
if self.opt.label_nc == 0:
transform_A = get_transform(self.opt, params)
A_tensor = transform_A(A.convert('RGB'))
else:
transform_A = get_transform(self.opt, params, method=Image.NEAREST, normalize=False) # 图像变换
A_tensor = transform_A(A) * 255.0 # 对数据预处理,有经过to_tensor操作,再乘255
B_tensor = inst_tensor = feat_tensor = 0
### input B (real images) # 接着读入真实图像B
if self.opt.isTrain or self.opt.use_encoded_image:
B_path = self.B_paths[index]
# B = Image.open(self.dir_B + '/' + B_path).convert('RGB')
B = Image.open(B_path).convert('RGB')
transform_B = get_transform(self.opt, params)
B_tensor = transform_B(B)
### if using instance maps # 接着读入instance,后续还会处理成边缘图,和论文中描述一致。
if not self.opt.no_instance: # no_instance默认值为true
inst_path = self.inst_paths[index]
# inst = Image.open(self.dir_inst + '/' + inst_path)
inst = Image.open(inst_path)
inst_tensor = transform_A(inst) # 和semantic的处理方式一样 0-1
if self.opt.load_features: # 注意self.opt.load_features的作用是是否读取每个类别的预先计算的特征,论文中有10类,由聚类形成的。但默认是不执行的。我本人看论文对这一部分也是一知半解,以后有需求之后再研究。
feat_path = self.feat_paths[index]
feat = Image.open(feat_path).convert('RGB')
norm = normalize()
feat_tensor = norm(transform_A(feat))
input_dict = {'label': A_tensor, 'inst': inst_tensor, 'image': B_tensor,
'feat': feat_tensor, 'path': A_path}
return input_dict # 返回一个字典,记录了上述读取并经过处理的数据集。
def __len__(self):
return len(self.A_paths) // self.opt.batchSize * self.opt.batchSize
def name(self):
return 'AlignedDataset'
image_folder.py
###############################################################################
# Code from
# https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py
# Modified the original code so that it also loads images from the current
# directory as well as the subdirectories
# 获得指定目录下的图片路径 + 加载路径图片
###############################################################################
import torch.utils.data as data
from PIL import Image
import os
# 本程序支持的图片扩展名
IMG_EXTENSIONS = [
'.jpg', '.JPG', '.jpeg', '.JPEG',
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff'
]
def is_image_file(filename):
### any()函数用于判断给定的可迭代参数iterable是否全部为False,则返回False,如果有一个为True,则返回True。
# 元素除了是0、空、FALSE外都算TRUE。
# 函数等价于:
# def any(iterable):
# for element in iterable:
# if element:
# return True
# return False
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
# 制作数据集:获得数据集的图片路径列表
def make_dataset(dir): # dir为数据集文件夹路径
images = [] # 创建空列表
assert os.path.isdir(dir), '%s is not a valid directory' % dir # 确认路径存在
### os.walk() 方法是一个简单易用的文件、目录遍历器,可以帮助我们高效的处理文件、目录方面的事情
# top -- 是你所要遍历的目录的地址, 返回的是一个三元组(root,dirs,files)。
# root 所指的是当前正在遍历的这个文件夹的本身的地址,和输入的os.walk(dir)种的dir一致
# dirs 是一个 list ,内容是该文件夹中所有的 目录 的名字(不包括子目录),若无则为[]
# files 同样是 list , 内容是该文件夹中所有的 文件 的名字(不包括子目录),若无则为[]
for root, _, fnames in sorted(os.walk(dir)): # fnames为文件中读取的照片文件
for fname in fnames:
if is_image_file(fname):
path = os.path.join(root, fname) # 将文件夹路径dir 和 图片名称fname 结合起来
images.append(path) # 将图片路径存放到image列表里
# temp = fname
# images.append(temp)
return images # 返回图片路径列表
def default_loader(path):
return Image.open(path).convert('RGB')
class ImageFolder(data.Dataset):
def __init__(self, root, transform=None, return_paths=False,
loader=default_loader):
imgs = make_dataset(root) # imgs为root目录下图片路径列表
if len(imgs) == 0: # 图片数量 = 0 报错
raise(RuntimeError("Found 0 images in: " + root + "\n"
"Supported image extensions are: " +
",".join(IMG_EXTENSIONS)))
self.root = root
self.imgs = imgs
self.transform = transform
self.return_paths = return_paths
self.loader = loader
def __getitem__(self, index):
path = self.imgs[index] # 获取指定图片路径
img = self.loader(path) # 加载图片
if self.transform is not None:
img = self.transform(img) # 图片进行变换
if self.return_paths:
return img, path # 返回图片和路径
else:
return img # 仅返回图片
def __len__(self):
return len(self.imgs) # 返回指定目录下图片数量
base_dataset.py
import torch.utils.data as data
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
import random
class BaseDataset(data.Dataset):
def __init__(self):
super(BaseDataset, self).__init__()
def name(self):
return 'BaseDataset'
def initialize(self, opt):
pass
# 这个函数是根据用户指定的方式resize或者crop出合适大小的输入尺寸。
# size:输入图片的尺寸
def get_params(opt, size):
w, h = size
new_h = h
new_w = w
if opt.resize_or_crop == 'resize_and_crop':
# opt.loadSize为自己输入的尺寸,将图像缩放到这个大小
new_h = new_w = opt.loadSize # 将宽和高设置为同样大小
elif opt.resize_or_crop == 'scale_width_and_crop': # 我已在opt处设置为‘scale_width_and_crop’
new_w = opt.loadSize
new_h = opt.loadSize * h // w # 高度按照原图宽高比计算
x = random.randint(0, np.maximum(0, new_w - opt.fineSize)) # ???不明白此处的随机数什么意思
y = random.randint(0, np.maximum(0, new_h - opt.fineSize))
flip = random.random() > 0.5 # 随机数是否大于0.5,flip是bool型变量,此行代码意思为随机生成True或者False
return {'crop_pos': (x, y), 'flip': flip} # 最终的返回值,在data.aligned_dataset 45行,当作params传入了下方get_transform()函数
# 图像变换
def get_transform(opt, params, method=Image.BICUBIC, normalize=True):
transform_list = []
if 'resize' in opt.resize_or_crop: # 若opt.resize_or_crop中有'resize'
osize = [opt.loadSize, opt.loadSize]
transform_list.append(transforms.Scale(osize, method))
elif 'scale_width' in opt.resize_or_crop:
transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.loadSize, method)))
### lambda函数也叫匿名函数,即,函数没有具体的名称。先来看一个最简单例子:
# def f(x):
# return x**2
# print f(4)
#
# Python中使用lambda的话,写成这样:
# g = lambda x : x**2
# print g(4)
if 'crop' in opt.resize_or_crop:
# 使用transforms.Lambda封装其为transforms策略
transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.fineSize)))
if opt.resize_or_crop == 'none':
base = float(2 ** opt.n_downsample_global)
if opt.netG == 'local':
base *= (2 ** opt.n_local_enhancers)
transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method)))
if opt.isTrain and not opt.no_flip:
transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
transform_list += [transforms.ToTensor()]
if normalize:
transform_list += [transforms.Normalize((0.5, 0.5, 0.5), # mean和std均为0.5
(0.5, 0.5, 0.5))]
return transforms.Compose(transform_list)
def normalize():
return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
def __make_power_2(img, base, method=Image.BICUBIC):
ow, oh = img.size
h = int(round(oh / base) * base)
w = int(round(ow / base) * base)
if (h == oh) and (w == ow):
return img
return img.resize((w, h), method)
def __scale_width(img, target_width, method=Image.BICUBIC):
ow, oh = img.size
if (ow == target_width):
return img
w = target_width
h = int(target_width * oh / ow)
return img.resize((w, h), method)
# 随机平移滑动裁剪
def __crop(img, pos, size):
ow, oh = img.size
x1, y1 = pos
tw = th = size # 输入的尺寸 opt.fineSize
if (ow > tw or oh > th):
return img.crop((x1, y1, x1 + tw, y1 + th)) # 随机裁剪,因为虽然每次裁剪测大小一样,但是起始点位置不一样
return img
def __flip(img, flip):
if flip:
return img.transpose(Image.FLIP_LEFT_RIGHT)
return img