三、pix2pixHD代码解析(dataset处理)

pix2pixHD代码解析

一、pix2pixHD代码解析(train.py + test.py)
二、pix2pixHD代码解析(options设置)
三、pix2pixHD代码解析(dataset处理)
四、pix2pixHD代码解析(models搭建)

三、pix2pixHD代码解析(dataset处理)

data_loader.py

##########################################################################
# 创建数据集加载主函数
##########################################################################
def CreateDataLoader(opt):
    from data.custom_dataset_data_loader import CustomDatasetDataLoader
    data_loader = CustomDatasetDataLoader()
    print(data_loader.name())                                             # 返回的名字为“CustomDatasetDataLoader”
    data_loader.initialize(opt)                                           # 初始化参数
    return data_loader

custom_dataset_data_loader.py

import torch.utils.data
from data.base_data_loader import BaseDataLoader


# 创建数据集
def CreateDataset(opt):
    dataset = None
    from data.aligned_dataset import AlignedDataset
    dataset = AlignedDataset()

    print("dataset [%s] was created" % (dataset.name()))               # 打印数据集名字为‘AlignedDataset’
    dataset.initialize(opt)                                            # 初始化数据集参数
    return dataset                                                     # 返回创建好的数据集


# 加载数据集
class CustomDatasetDataLoader(BaseDataLoader):
    def name(self):
        return 'CustomDatasetDataLoader'

    def initialize(self, opt):
        BaseDataLoader.initialize(self, opt)                           # 初始化参数
        self.dataset = CreateDataset(opt)                              # 创建数据集
        self.dataloader = torch.utils.data.DataLoader(                 # 加载创建好的数据集,并自定义相关参数
            self.dataset,
            batch_size=opt.batchSize,
            shuffle=not opt.serial_batches,
            num_workers=int(opt.nThreads))

    def load_data(self):
        return self.dataloader                                         # 返回数据集

    def __len__(self):
        return min(len(self.dataset), self.opt.max_dataset_size)       # 返回加载的数据集长度和一个epoch容许的加载最大容量

aligned_dataset.py


#############################################################################
# 数据读取的方式
#############################################################################

import os.path
from data.base_dataset import BaseDataset, get_params, get_transform, normalize
from data.image_folder import make_dataset
from PIL import Image


# 返回一个字典,里面由整理好的数据集:图片 + 类别
class AlignedDataset(BaseDataset):                                           # init里面都是些路径的设置
    def initialize(self, opt):
        self.opt = opt
        self.root = opt.dataroot    

        ### input A (label maps)                                             # 标签图的路径
        dir_A = '_A' if self.opt.label_nc == 0 else '_label'
        self.dir_A = os.path.join(opt.dataroot, opt.phase + dir_A)           # './geometry' + 'train' + '_label'
        ### sort 是应用在 list 上的方法,sorted 可以对所有可迭代的对象进行排序操作。
        # list 的 sort 方法返回的是对已经存在的列表进行操作;
        # 而内建函数 sorted 方法返回的是一个新的 list,而不是在原来的基础上进行的操作
        # (事实证明直接对string排序,与实际int值排序结果是不一样的,图片名并不是按照从小到大的顺序)
        self.A_paths = sorted(make_dataset(self.dir_A))                      # 返回self.dir_A下的图片路径列表

        ### input B (real images)                                            # 真实图的路径
        if opt.isTrain or opt.use_encoded_image:
            dir_B = '_B' if self.opt.label_nc == 0 else '_img'
            self.dir_B = os.path.join(opt.dataroot, opt.phase + dir_B)  
            self.B_paths = sorted(make_dataset(self.dir_B))
            # self.B_paths = self.A_paths

        ### instance maps                                                    # 实例图的路径
        if not opt.no_instance:                                              # 如果no_instance为true,则不添加实例图
            self.dir_inst = os.path.join(opt.dataroot, opt.phase + '_inst')
            self.inst_paths = sorted(make_dataset(self.dir_inst))
            # self.inst_paths = self.A_paths

        ### load precomputed instance-wise encoded features
        if opt.load_features:                              
            self.dir_feat = os.path.join(opt.dataroot, opt.phase + '_feat')
            print('----------- loading features from %s ----------' % self.dir_feat)
            self.feat_paths = sorted(make_dataset(self.dir_feat))            # 本文中没有train_feat图片

        self.dataset_size = len(self.A_paths) 
      
    def __getitem__(self, index):                                            # getitem里是具体的操作,是这个类的重点操作
        ### input A (label maps)                                             # 读取标签图A
        A_path = self.A_paths[index]                                         # 获得图片路径
        # A = Image.open(self.dir_A + '/' + A_path)                                               # 先读取一张图片
        A = Image.open(A_path)
        params = get_params(self.opt, A.size)                                # 根据输入的opt和size,返回随机参数
        if self.opt.label_nc == 0:
            transform_A = get_transform(self.opt, params)
            A_tensor = transform_A(A.convert('RGB'))
        else:
            transform_A = get_transform(self.opt, params, method=Image.NEAREST, normalize=False)  # 图像变换
            A_tensor = transform_A(A) * 255.0                                # 对数据预处理,有经过to_tensor操作,再乘255

        B_tensor = inst_tensor = feat_tensor = 0
        ### input B (real images)                                            # 接着读入真实图像B
        if self.opt.isTrain or self.opt.use_encoded_image:
            B_path = self.B_paths[index]
            # B = Image.open(self.dir_B + '/' + B_path).convert('RGB')
            B = Image.open(B_path).convert('RGB')
            transform_B = get_transform(self.opt, params)      
            B_tensor = transform_B(B)

        ### if using instance maps                                           # 接着读入instance,后续还会处理成边缘图,和论文中描述一致。
        if not self.opt.no_instance:                                         # no_instance默认值为true
            inst_path = self.inst_paths[index]
            # inst = Image.open(self.dir_inst + '/' + inst_path)
            inst = Image.open(inst_path)
            inst_tensor = transform_A(inst)                                  # 和semantic的处理方式一样  0-1

            if self.opt.load_features:                                       # 注意self.opt.load_features的作用是是否读取每个类别的预先计算的特征,论文中有10类,由聚类形成的。但默认是不执行的。我本人看论文对这一部分也是一知半解,以后有需求之后再研究。
                feat_path = self.feat_paths[index]            
                feat = Image.open(feat_path).convert('RGB')
                norm = normalize()
                feat_tensor = norm(transform_A(feat))

        input_dict = {'label': A_tensor, 'inst': inst_tensor, 'image': B_tensor, 
                      'feat': feat_tensor, 'path': A_path}

        return input_dict                                                    # 返回一个字典,记录了上述读取并经过处理的数据集。


    def __len__(self):
        return len(self.A_paths) // self.opt.batchSize * self.opt.batchSize

    def name(self):
        return 'AlignedDataset'

image_folder.py

###############################################################################
# Code from
# https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py
# Modified the original code so that it also loads images from the current
# directory as well as the subdirectories
# 获得指定目录下的图片路径 + 加载路径图片
###############################################################################
import torch.utils.data as data
from PIL import Image
import os

# 本程序支持的图片扩展名
IMG_EXTENSIONS = [
    '.jpg', '.JPG', '.jpeg', '.JPEG',
    '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff'
]


def is_image_file(filename):
    ### any()函数用于判断给定的可迭代参数iterable是否全部为False,则返回False,如果有一个为True,则返回True。
    # 元素除了是0、空、FALSE外都算TRUE。
    # 函数等价于:
    # def any(iterable):
    #     for element in iterable:
    #         if element:
    #             return True
    #     return False
    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)


# 制作数据集:获得数据集的图片路径列表
def make_dataset(dir):                                                 # dir为数据集文件夹路径
    images = []                                                        # 创建空列表
    assert os.path.isdir(dir), '%s is not a valid directory' % dir     # 确认路径存在

    ### os.walk() 方法是一个简单易用的文件、目录遍历器,可以帮助我们高效的处理文件、目录方面的事情
    # top -- 是你所要遍历的目录的地址, 返回的是一个三元组(root,dirs,files)。
    # root 所指的是当前正在遍历的这个文件夹的本身的地址,和输入的os.walk(dir)种的dir一致
    # dirs 是一个 list ,内容是该文件夹中所有的 目录 的名字(不包括子目录),若无则为[]
    # files 同样是 list , 内容是该文件夹中所有的 文件 的名字(不包括子目录),若无则为[]
    for root, _, fnames in sorted(os.walk(dir)):                       # fnames为文件中读取的照片文件
        for fname in fnames:
            if is_image_file(fname):
                path = os.path.join(root, fname)                       # 将文件夹路径dir 和 图片名称fname 结合起来
                images.append(path)                                    # 将图片路径存放到image列表里
                # temp = fname
                # images.append(temp)
    return images                                                      # 返回图片路径列表


def default_loader(path):
    return Image.open(path).convert('RGB')


class ImageFolder(data.Dataset):

    def __init__(self, root, transform=None, return_paths=False,
                 loader=default_loader):
        imgs = make_dataset(root)                                      # imgs为root目录下图片路径列表
        if len(imgs) == 0:                                             # 图片数量 = 0 报错
            raise(RuntimeError("Found 0 images in: " + root + "\n"
                               "Supported image extensions are: " +
                               ",".join(IMG_EXTENSIONS)))

        self.root = root
        self.imgs = imgs
        self.transform = transform
        self.return_paths = return_paths
        self.loader = loader

    def __getitem__(self, index):
        path = self.imgs[index]                                        # 获取指定图片路径
        img = self.loader(path)                                        # 加载图片
        if self.transform is not None:
            img = self.transform(img)                                  # 图片进行变换
        if self.return_paths:
            return img, path                                           # 返回图片和路径
        else:
            return img                                                 # 仅返回图片

    def __len__(self):
        return len(self.imgs)                                          # 返回指定目录下图片数量

base_dataset.py

import torch.utils.data as data
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
import random

class BaseDataset(data.Dataset):
    def __init__(self):
        super(BaseDataset, self).__init__()

    def name(self):
        return 'BaseDataset'

    def initialize(self, opt):
        pass


# 这个函数是根据用户指定的方式resize或者crop出合适大小的输入尺寸。
# size:输入图片的尺寸
def get_params(opt, size):
    w, h = size
    new_h = h
    new_w = w
    if opt.resize_or_crop == 'resize_and_crop':
        # opt.loadSize为自己输入的尺寸,将图像缩放到这个大小
        new_h = new_w = opt.loadSize                                     # 将宽和高设置为同样大小
    elif opt.resize_or_crop == 'scale_width_and_crop':                   # 我已在opt处设置为‘scale_width_and_crop’
        new_w = opt.loadSize
        new_h = opt.loadSize * h // w                                    # 高度按照原图宽高比计算

    x = random.randint(0, np.maximum(0, new_w - opt.fineSize))           # ???不明白此处的随机数什么意思
    y = random.randint(0, np.maximum(0, new_h - opt.fineSize))
    
    flip = random.random() > 0.5                                         # 随机数是否大于0.5,flip是bool型变量,此行代码意思为随机生成True或者False
    return {'crop_pos': (x, y), 'flip': flip}                            # 最终的返回值,在data.aligned_dataset 45行,当作params传入了下方get_transform()函数


# 图像变换
def get_transform(opt, params, method=Image.BICUBIC, normalize=True):
    transform_list = []
    if 'resize' in opt.resize_or_crop:                                   # 若opt.resize_or_crop中有'resize'
        osize = [opt.loadSize, opt.loadSize]
        transform_list.append(transforms.Scale(osize, method))   
    elif 'scale_width' in opt.resize_or_crop:
        transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.loadSize, method)))

    ### lambda函数也叫匿名函数,即,函数没有具体的名称。先来看一个最简单例子:
    # def f(x):
    #   return x**2
    # print f(4)
    #
    # Python中使用lambda的话,写成这样:
    # g = lambda x : x**2
    # print g(4)

    if 'crop' in opt.resize_or_crop:
        # 使用transforms.Lambda封装其为transforms策略
        transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.fineSize)))

    if opt.resize_or_crop == 'none':
        base = float(2 ** opt.n_downsample_global)
        if opt.netG == 'local':
            base *= (2 ** opt.n_local_enhancers)
        transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method)))

    if opt.isTrain and not opt.no_flip:
        transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))

    transform_list += [transforms.ToTensor()]

    if normalize:
        transform_list += [transforms.Normalize((0.5, 0.5, 0.5),         # mean和std均为0.5
                                                (0.5, 0.5, 0.5))]
    return transforms.Compose(transform_list)

def normalize():    
    return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))

def __make_power_2(img, base, method=Image.BICUBIC):
    ow, oh = img.size        
    h = int(round(oh / base) * base)
    w = int(round(ow / base) * base)
    if (h == oh) and (w == ow):
        return img
    return img.resize((w, h), method)

def __scale_width(img, target_width, method=Image.BICUBIC):
    ow, oh = img.size
    if (ow == target_width):
        return img    
    w = target_width
    h = int(target_width * oh / ow)    
    return img.resize((w, h), method)

# 随机平移滑动裁剪
def __crop(img, pos, size):
    ow, oh = img.size
    x1, y1 = pos
    tw = th = size                                                       # 输入的尺寸 opt.fineSize
    if (ow > tw or oh > th):        
        return img.crop((x1, y1, x1 + tw, y1 + th))                      # 随机裁剪,因为虽然每次裁剪测大小一样,但是起始点位置不一样
    return img

def __flip(img, flip):
    if flip:
        return img.transpose(Image.FLIP_LEFT_RIGHT)
    return img

你可能感兴趣的:(PyTorch,GAN)