基于pytorch的双模态数据载入

基于pytorch的双模态数据载入

    • 双模态数据融合
    • torch.utils.data.dataloader
    • 双模态数据载入

双模态数据融合

无论是双模态,还是多模态融合,数据载入都是其重要的一环。如将相机图像和激光雷达投影反射率图或是红外图像融合,都需要保证输入到网络的双模态图片是一一对应的,否则就失去了融合的意义。本文主要讲解基于pytorch的双模态数据载入方法,希望对需要的人有帮助。
基于pytorch的双模态数据载入_第1张图片KITTI数据集相机图像
基于pytorch的双模态数据载入_第2张图片
KITTI数据集激光雷达反射率投影图

torch.utils.data.dataloader

首先,简单介绍一下torch.utils.data.dataloader,它是pytorch中用于载入数据的重要接口,主要用于将数据集根据batch size、是否打乱顺序shuffle、采样方式sampler等来封装成tensor,作为网络的输入。
torch.utils.data.dataloader脚本github地址:https://github.com/pytorch/pytorch/blob/master/torch/utils/data/dataloader.py

CLASS  torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, 
								   batch_sampler=None, num_workers=0, collate_fn=None, 
								   pin_memory=False, drop_last=False, timeout=0, 
								   worker_init_fn=None, multiprocessing_context=None)

其中有关数据集载入顺序的属性包括shufflesamplershuffle表示在数据集载入中是否打乱顺序,默认是False不打乱顺序。将数据顺序打乱,是为了是数据更有独立性,一般将shuffle设置为Truesamplerdataloader的采样器,定义了数据集的采样规则,默认设置为None,如果定义了采样规则,那么shuffle必须设置为False

对于双模态数据载入而言,当shuffle=True时,那么data_loader1data_loader2所载入的数据是打乱的,两个模态图像将无法一一对应。当shuffle=False时,data_loader1data_loader2所载入的数据是按照顺序一一对应的,但是这样数据就失去了其独立性。

data_loader1 = data.DataLoader(dataset1, args.batch_size,                     
                              num_workers=args.num_workers,
                              shuffle=True, collate_fn=detection_collate,
                              pin_memory=True)
data_loader2 = data.DataLoader(dataset2, args.batch_size,
                              num_workers=args.num_workers,
                              shuffle=True, collate_fn=detection_collate,
                              pin_memory=True)

双模态数据载入的目的是在保证数据相匹配的前提下,打乱shuffle数据载入的顺序,所以更改属性shuffle无法满足要求。
我们再来看一下DataLoadersampler属性,其主要通过torch.utils.data.sampler.Sampler类来设置,主要包括以下两个子类,这两个子类也对应着shuffle的设置。

class torch.utils.data.sampler.SequentialSampler(data_source)
数据集数据顺序排列
参数: - data_source (Dataset) – 采样的数据集。

class torch.utils.data.sampler.RandomSampler(data_source)
数据集数据随机排列
参数: - data_source (Dataset) – 采样的数据集。

torch.utils.data.sampler.SequentialSampler是对数据集数据顺序采样,返回的始终是从0len(dataset)iter对象,对应shuffle = False; torch.utils.data.sampler.RandomSampler是对数据集数据随机采样,返回的是一个乱序的iter对象。

刚开始我是这样处理:

sampler = torch.utils.data.sampler.RandomSampler(dataset1) 
data_loader1 = torch.utils.data.DataLoader(dataset1, args.batch_size,              
                                  num_workers=args.num_workers,sampler=sampler,
                                  shuffle=False, collate_fn=detection_collate,
                                  pin_memory=True)
data_loader2 = torch.utils.data.DataLoader(dataset2, args.batch_size,
                                  num_workers=args.num_workers,sampler=sampler,
                                  shuffle=False, collate_fn=detection_collate,
                                  pin_memory=True)

分别在data_loader1data_loader2之前打印了sampler1,输出的结果是相同的,类似:

<torch.utils.data.sampler.RandomSampler object at 0x7ffacd0a87f0>

令我以为data_loader1data_loader2所载入的数据中的采样是相同的。但是结果发现,两个模态的图片根本对不上!试着打印出sampler1中的内容:print(list(sampler)),发现两个list是不一样的。

查看torch.utils.data.sampler的源码(https://github.com/pytorch/pytorch/blob/master/torch/utils/data/sampler.py),发现RandomSampler主要返回的是return iter(torch.randperm(n).tolist()),其中n代表数据集的数据个数n = len(self.data_source)。意思是把数据集顺序打乱,然后再送入迭代器。

由此发现,每执行一次

sampler=torch.utils.data.sampler.RandomSampler(dataset1)

都会调用torch.randperm(n),把数据集数据序号打乱!这就是为什么两次打印的list(sampler)不一样的原因。

import torch
from torch._six import int_classes as _int_classes

class Sampler(object):
    r"""Base class for all Samplers.
    Every Sampler subclass has to provide an :meth:`__iter__` method, providing a
    way to iterate over indices of dataset elements, and a :meth:`__len__` method
    that returns the length of the returned iterators.
    .. note:: The :meth:`__len__` method isn't strictly required by
              :class:`~torch.utils.data.DataLoader`, but is expected in any
              calculation involving the length of a :class:`~torch.utils.data.DataLoader`.
    """

def __init__(self, data_source, replacement=False, num_samples=None):
    self.data_source = data_source
    self.replacement = replacement
    self._num_samples = num_samples

    if not isinstance(self.replacement, bool):
        raise ValueError("replacement should be a boolean value, but got "
                         "replacement={}".format(self.replacement))

    if self._num_samples is not None and not replacement:
        raise ValueError("With replacement=False, num_samples should not be specified, "
                         "since a random permute will be performed.")

    if not isinstance(self.num_samples, int) or self.num_samples <= 0:
        raise ValueError("num_samples should be a positive integer "
                         "value, but got num_samples={}".format(self.num_samples))

@property
def num_samples(self):
    # dataset size might change at runtime
    if self._num_samples is None:
        return len(self.data_source)
    return self._num_samples

def __iter__(self):
    n = len(self.data_source)
    if self.replacement:
        return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist())
    return iter(torch.randperm(n).tolist())

def __len__(self):
    return self.num_samples

双模态数据载入

torch.utils.data.sampler所提供的采样器是不能用了,因为每次迭代都会打乱采样器中数据的顺序,导致输入到网络的两个模态的图片不匹配,对torch.utils.data.sampler.RandomSampler稍作修改:

既然每次调用RandomSampler都会执行torch.randperm(n),且这一步的目的是输出打乱的数据集序号,可以直接给RandomSampler的迭代器一个已经打乱的序号。

首先,自己定一个RandomSampler类,直接复制torch.utils.data.sampler.Sampler.RandomSampler的代码,继承torch.utils.data.sampler.Sampler基类,令其返回iter(s),其中s是打乱顺序的数据集序号list

class RandomSampler(data.sampler.Sampler):

    def __init__(self, data_source, replacement=False, num_samples=None):
        self.data_source = data_source
        self.replacement = replacement
        self._num_samples = num_samples

        if not isinstance(self.replacement, bool):
            raise ValueError("replacement should be a boolean value, but got "
                             "replacement={}".format(self.replacement))

        if self._num_samples is not None and not replacement:
            raise ValueError("With replacement=False, num_samples should not be specified, "
                             "since a random permute will be performed.")

        if not isinstance(self.num_samples, int) or self.num_samples <= 0:
            raise ValueError("num_samples should be a positive integer "
                             "value, but got num_samples={}".format(self.num_samples))

    @property
    def num_samples(self):
        # dataset size might change at runtime
        if self._num_samples is None:
            return len(self.data_source)
        return self._num_samples

    def __iter__(self):
        n = len(self.data_source)
        if self.replacement:
            return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist())
        return iter(s1)

    def __len__(self):
        return self.num_samples

data.DataLoader代码段中,首先定义一个全局变量s,然后把shuffle后的数据集序号赋值给s,接着调用刚刚定义的RandomSampler类,并把输出sampler送到data.DataLoader中:

global s
s=torch.randperm(len(dataset1)).tolist()
print(s)
sampler=RandomSampler(dataset1) 
data_loader1 = data.DataLoader(dataset1, args.batch_size,                     ## 同过Dataloader 读取数据
                              num_workers=args.num_workers,sampler=sampler,
                              shuffle=False, collate_fn=detection_collate,
                              pin_memory=True)
data_loader2 = data.DataLoader(dataset2, args.batch_size,
                              num_workers=args.num_workers,sampler=sampler,
                              shuffle=False, collate_fn=detection_collate,
                              pin_memory=True)

为方便展示双模态数据载入效果,这里两个数据集都采用KITTI数据集中的相机图像,编号从007000到007480。

双模态数据载入效果(batch size =3, 经过图像增强)

打印s ,和数据集中的图像编号相对应
[103, 11, 378, 128, 220, 149, 174, 344, 74, 156, 209, 189, 131, 283, 30, 8, 471, 193, 340, 46, 248, 109, 381, 439, 188, 183, 60, 72, 210, 110, 197, 425, 291, 175, 62, 69, 374, 366, 98, 242, 319, 339, 90, 380, 259, 315, 397, 117, 451, 176, 272, 373, 440, 417, 452, 221, 477, 441, 268, 73, 21, 433, 104, 196, 474, 47, 194, 86, 411, 83, 77, 480, 116, 4, 308, 338, 310, 398, 75, 150, 321, 311, 353, 412, 94, 449, 190, 450, 95, 124, 408, 360, 227, 329, 38, 448, 247, 371, 249, 217, 295, 31, 273, 438, 435, 472, 431, 33, 455, 354, 28, 5, 288, 382, 280, 80, 469, 15, 316, 388, 202, 282, 296, 53, 231, 287, 256, 185, 224, 178, 235, 85, 421, 70, 289, 234, 243, 346, 375, 236, 274, 142, 395, 219, 334, 263, 2, 476, 432, 172, 343, 377, 140, 401, 312, 349, 255, 57, 456, 459, 271, 7, 345, 317, 228, 119, 58, 170, 179, 465, 191, 368, 84, 148, 79, 357, 71, 404, 351, 99, 327, 298, 426, 3, 49, 384, 34, 240, 385, 40, 379, 413, 386, 37, 458, 78, 215, 269, 320, 159, 12, 63, 429, 152, 262, 337, 233, 212, 153, 162, 61, 399, 265, 281, 171, 163, 314, 479, 126, 467, 423, 364, 123, 341, 184, 133, 177, 463, 260, 239, 290, 405, 9, 208, 213, 336, 173, 436, 390, 462, 64, 389, 14, 26, 112, 264, 410, 331, 251, 211, 155, 129, 107, 229, 167, 418, 138, 407, 145, 409, 367, 478, 169, 113, 257, 181, 391, 29, 323, 18, 141, 285, 261, 82, 245, 115, 419, 127, 100, 25, 387, 168, 416, 294, 10, 277, 144, 422, 286, 59, 466, 330, 457, 6, 461, 199, 105, 36, 88, 414, 415, 27, 158, 428, 322, 146, 192, 266, 35, 137, 101, 293, 468, 65, 50, 161, 44, 442, 54, 230, 362, 394, 66, 111, 136, 226, 307, 244, 20, 121, 246, 305, 250, 52, 39, 218, 475, 376, 369, 92, 89, 356, 361, 96, 328, 276, 195, 68, 365, 350, 164, 87, 76, 186, 1, 324, 238, 434, 130, 454, 306, 473, 301, 430, 67, 214, 254, 200, 165, 19, 180, 355, 106, 157, 122, 342, 403, 143, 41, 147, 309, 45, 13, 16, 108, 445, 22, 198, 206, 297, 32, 43, 396, 302, 383, 279, 270, 24, 203, 464, 187, 102, 51, 205, 125, 81, 352, 93, 335, 443, 453, 332, 460, 204, 154, 392, 444, 0, 299, 139, 166, 42, 223, 91, 97, 363, 182, 241, 258, 151, 427, 420, 56, 406, 358, 207, 134, 333, 267, 252, 292, 132, 237, 48, 23, 400, 160, 402, 114, 470, 304, 232, 275, 424, 393, 347, 325, 118, 447, 222, 225, 372, 370, 201, 437, 300, 326, 120, 216, 313, 318, 303, 55, 446, 348, 284, 359, 278, 135, 253, 17]

基于pytorch的双模态数据载入_第3张图片
同样也适用于多模态数据载入。

你可能感兴趣的:(pytorch,深度学习,人工智能,python)