风格迁移1-07:Liquid Warping GAN(Impersonator)-源码无死角解析(2)-数据读取,预处理

以下链接是个人关于Liquid Warping GAN(Impersonator)-姿态迁移,所有见解,如有错误欢迎大家指出,我会第一时间纠正。有兴趣的朋友可以加微信:a944284742相互讨论技术。若是帮助到了你什么,一定要记得点赞!因为这是对我最大的鼓励。
风格迁移1-00:Liquid Warping GAN(Impersonator)-目录-史上最新无死角讲解

前言

根据上篇博客介绍的train.py,我们可以看到如下代码:

class Train(object):
    def __init__(self):
        # 对命令行参数进行解析
        self._opt = TrainOptions().parse()
        # 创建训练,测试数据迭代对象
        data_loader_train = CustomDatasetDataLoader(self._opt, is_for_train=True)
        data_loader_test = CustomDatasetDataLoader(self._opt, is_for_train=False)

        # 加载训练以及迭代数据
        self._dataset_train = data_loader_train.load_data()
        self._dataset_test = data_loader_test.load_data()

        # 获取训练以及迭代数据的长度,即每个epoch需要迭代多少次
        self._dataset_train_size = len(data_loader_train)
        self._dataset_test_size = len(data_loader_test)
        print('#train video clips = %d' % self._dataset_train_size)
        print('#test video clips = %d' % self._dataset_test_size)

现在我们就来看看数据迭代器CustomDatasetDataLoader的实现。

源码解析

首先,我们找到CustomDatasetDataLoader类的实现,其位于data/custom_dataset_data_loader.py,实现过程如下(如果不想看过程,只想知道迭代获得的是什么数据,直接翻到博客末尾即可):

class CustomDatasetDataLoader(object):
    def __init__(self, opt, is_for_train=True):
        self._opt = opt
        # 根据传入的参数,选择是否训练
        self._is_for_train = is_for_train
        # 根据传入的参数,指定加载相乘的数目
        self._num_threds = opt.n_threads_train if is_for_train else opt.n_threads_test
        # 创建对应的数据集迭代器
        self._create_dataset()

    def _create_dataset(self):
        # self._opt.dataset_mode默认为iPER,即数据的存储格式,为作者默认格式
        self._dataset = DatasetFactory.get_by_name(self._opt.dataset_mode, self._opt, self._is_for_train)
        self._dataloader = torch.utils.data.DataLoader(
            self._dataset,
            batch_size=self._opt.batch_size,
            shuffle=not self._opt.serial_batches,
            num_workers=int(self._num_threds),
            drop_last=True)

    def load_data(self):
        return self._dataloader

    def __len__(self):
        return len(self._dataset)

可以知道,其核心部分还在在于:

        # self._opt.dataset_mode默认为iPER,即数据的存储格式,为作者默认格式
        self._dataset = DatasetFactory.get_by_name(self._opt.dataset_mode, self._opt, self._is_for_train)

我们对其实现过程进行追踪:

class DatasetFactory(object):
    def __init__(self):
        pass
    # 指定该为一个静态方法
    @staticmethod
    def get_by_name(dataset_name, opt, is_for_train):
        """
        根据dataset_name参数,获得对应的数据迭代器,默认dataset_name = 'iPER'
        """
        if dataset_name == 'iPER':
            from data.imper_dataset import ImPerDataset
            dataset = ImPerDataset(opt, is_for_train)
		......
		......

继续追踪ImPerDataset的实现,可以看到其继承父类ImPerBaseDataset,那么先来看看这个类,其位于data/imper_dataset.py,实现过程如下:

class ImPerBaseDataset(DatasetBase):
    """
    该数据迭代器,每次都是获得一对图像
    """
    def __init__(self, opt, is_for_train):
        super(ImPerBaseDataset, self).__init__(opt, is_for_train)
        self._name = 'ImPerBaseDataset'

        self._intervals = opt.intervals

        # read dataset,从txt文件读取训练集或者测试集数据的路径,
        # 以及smpls(主要存储视频中人物的形状:如身高,胖瘦等等信息)中的数据
        self._read_dataset_paths()

    def __getitem__(self, index):
        # assert (index < self._dataset_size)

        # start_time = time.time()
        # get sample data,根据index获得一个短视频的所有信息,包含了:
        # 'images' = 该短视频的所有帧图片路径,假设共n张图片
        # 'cams'  = 每帧图片对应的摄像头参数,论文中的K,形状为[n,3]
        # 'thetas' = 论文中的 θ 参数,形状为[n,72]--人体的72个关节
        # 'betas' =  论文中的 β 参数,形状为[n, 10]--高矮瘦胖的描述信息
        # 'length' = 图片帧的总数目,即为前面假设的n
        v_info = self._vids_info[index % self._num_videos]

        # 对v_info进行解析,并且从该id对应的短视频中,随机读取两张图片的像素存储到images
        # smpls存储的是这两张图像对应的cams,thetas,betas 参数,并且连接了起来,即形状为[2,85]
        # 注意,在默认的训练中,该_load_pairs函数,已经被重写,略有区别
        images, smpls = self._load_pairs(v_info)

        # pack data
        sample = {
            'images': images,
            'smpls': smpls
        }

        # 改变图片形状,转化为张量,默认大小为[2,3,256,256]
        sample = self._transform(sample)
        # print(time.time() - start_time)

        return sample

    def __len__(self):
        return self._dataset_size

    def _read_dataset_paths(self):
        # 获得视频图像的目录,以及smpls信息的目录的路径
        self._root = self._opt.data_dir
        self._vids_dir = os.path.join(self._root, self._opt.images_folder)
        self._smpls_dir = os.path.join(self._root, self._opt.smpls_folder)

        # read video list
        self._num_videos = 0
        self._dataset_size = 0
        # 获得目录下所有文件的路径
        use_ids_filename = self._opt.train_ids_file if self._is_for_train else self._opt.test_ids_file
        use_ids_filepath = os.path.join(self._root, use_ids_filename)

        # 读取视频对应id的信息,训练时,use_ids_filepath为train.txt中的内容
        self._vids_info = self._read_vids_info(use_ids_filepath)

    def _read_vids_info(self, file_path):
        vids_info = []
        with open(file_path, 'r') as reader:

            lines = []
            for line in reader:
                line = line.rstrip()
                lines.append(line)

            total = len(lines)
            for i, line in enumerate(lines):
                # 获得一个子目录下所有图像的路径
                images_path = glob.glob(os.path.join(self._vids_dir, line, '*'))
                # 进行排序
                images_path.sort()
                # 加载每个子目录对应的pose_shape.pkl文件
                smpl_data = load_pickle_file(os.path.join(self._smpls_dir, line, 'pose_shape.pkl'))
                # 获得子目录对应的摄像头参数
                cams = smpl_data['cams']
                # kps_data = load_pickle_file(os.path.join(self._smpls_dir, line, 'kps.pkl'))
                # kps = (kps_data['kps'] + 1) / 2.0 * 1024

                # 判断摄像头参数的数目,和该目录下的图像数目,是否相同
                assert len(images_path) == len(cams), '{} != {}'.format(len(images_path), len(cams))

                # 该处存储的是一个子目录下所有图片的信息
                info = {
                    'images': images_path,
                    'cams': cams,
                    'thetas': smpl_data['pose'],
                    'betas': smpl_data['shape'],
                    'length': len(images_path)
                }
                # vids_info存储所有子目录的信息
                vids_info.append(info)

                # 默认self._intervals = 10,获取图片里的时候并不是每一帧都获取,
                # 而是间隔 self._intervals 帧后再获取
                self._dataset_size += info['length'] // self._intervals
                # self._dataset_size += info['length']

                # 记录视频的数目
                self._num_videos += 1
                print('loading video = {}, {} / {}'.format(line, i, total))


                if self._opt.debug:
                    if i > 1:
                        break

        return vids_info

    @property
    def video_info(self):
        return self._vids_info

    def _load_pairs(self, vid_info):
        length = vid_info['length']
        pair_ids = np.random.choice(length, size=2, replace=False)

        smpls = np.concatenate((vid_info['cams'][pair_ids],
                                vid_info['thetas'][pair_ids],
                                vid_info['betas'][pair_ids]), axis=1)

        images = []
        images_paths = vid_info['images']
        for t in pair_ids:
            image_path = images_paths[t]
            image = cv_utils.read_cv2_img(image_path)

            images.append(image)

        return images, smpls

    def _create_transform(self):
        transform_list = [
            ImageTransformer(output_size=self._opt.image_size),
            ToTensor()]
        self._transform = transforms.Compose(transform_list)

知道了父类的实现,我们再回过头来看看class ImPerDataset(ImPerBaseDataset)的实现过程就很简单了,如下:

class ImPerDataset(ImPerBaseDataset):

    def __init__(self, opt, is_for_train):
        super(ImPerDataset, self).__init__(opt, is_for_train)
        self._name = 'ImPerDataset'


    def _load_pairs(self, vid_info):
        """
        该函数是对父类函数的重写,
        """
        length = vid_info['length']
        # 选择两个配对的图片帧
        start = np.random.randint(0, 15)
        end = np.random.randint(0, length)
        pair_ids = np.array([start, end], dtype=np.int32)

        # 把对应的信息连接起来,想知道细节,可以看父类中该函数的实现
        smpls = np.concatenate((vid_info['cams'][pair_ids],
                                vid_info['thetas'][pair_ids],
                                vid_info['betas'][pair_ids]), axis=1)

        # 读取配对的图像像素
        images = []
        images_paths = vid_info['images']
        for t in pair_ids:
            image_path = images_paths[t]
            image = cv_utils.read_cv2_img(image_path)

            images.append(image)

        return images, smpls

总结

其实总的来说,就是随机选取一个短视频(每个短视频都是对一个人进行录像)的两帧照片像素,并且获每张对应的cams,thetas,betas参数:

        'cams'  = 每帧图片对应的摄像头参数,论文中的K,形状为[2,3]
        'thetas' = 论文中的 θ 参数,形状为[272]--人体的72个关节
        'betas' =  论文中的 β 参数,形状为[2, 10]--高矮瘦胖的描述信息

并且这3类参数,或链接层一起,变成一个形状为 [2,85] 的参数smpls。

你可能感兴趣的:(风格迁移)