无论是双模态,还是多模态融合,数据载入都是其重要的一环。如将相机图像和激光雷达投影反射率图或是红外图像融合,都需要保证输入到网络的双模态图片是一一对应的,否则就失去了融合的意义。本文主要讲解基于pytorch的双模态数据载入方法,希望对需要的人有帮助。
KITTI数据集相机图像
KITTI数据集激光雷达反射率投影图
首先,简单介绍一下torch.utils.data.dataloader
,它是pytorch
中用于载入数据的重要接口,主要用于将数据集根据batch size
、是否打乱顺序shuffle
、采样方式sampler
等来封装成tensor
,作为网络的输入。
torch.utils.data.dataloader
脚本github
地址:https://github.com/pytorch/pytorch/blob/master/torch/utils/data/dataloader.py
CLASS torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, multiprocessing_context=None)
其中有关数据集载入顺序的属性包括shuffle
和sampler
,shuffle
表示在数据集载入中是否打乱顺序,默认是False
不打乱顺序。将数据顺序打乱,是为了是数据更有独立性,一般将shuffle
设置为True
;sampler
是dataloader
的采样器,定义了数据集的采样规则,默认设置为None
,如果定义了采样规则,那么shuffle
必须设置为False
。
对于双模态数据载入而言,当shuffle=True
时,那么data_loader1
和data_loader2
所载入的数据是打乱的,两个模态图像将无法一一对应。当shuffle=False
时,data_loader1
和data_loader2
所载入的数据是按照顺序一一对应的,但是这样数据就失去了其独立性。
data_loader1 = data.DataLoader(dataset1, args.batch_size,
num_workers=args.num_workers,
shuffle=True, collate_fn=detection_collate,
pin_memory=True)
data_loader2 = data.DataLoader(dataset2, args.batch_size,
num_workers=args.num_workers,
shuffle=True, collate_fn=detection_collate,
pin_memory=True)
双模态数据载入的目的是在保证数据相匹配的前提下,打乱shuffle
数据载入的顺序,所以更改属性shuffle
无法满足要求。
我们再来看一下DataLoader
的sampler
属性,其主要通过torch.utils.data.sampler.Sampler
类来设置,主要包括以下两个子类,这两个子类也对应着shuffle
的设置。
class torch.utils.data.sampler.SequentialSampler(data_source)
数据集数据顺序排列
参数: - data_source (Dataset) – 采样的数据集。
class torch.utils.data.sampler.RandomSampler(data_source)
数据集数据随机排列
参数: - data_source (Dataset) – 采样的数据集。
torch.utils.data.sampler.SequentialSampler
是对数据集数据顺序采样,返回的始终是从0
到len(dataset)
的iter
对象,对应shuffle = False
; torch.utils.data.sampler.RandomSampler
是对数据集数据随机采样,返回的是一个乱序的iter
对象。
刚开始我是这样处理:
sampler = torch.utils.data.sampler.RandomSampler(dataset1)
data_loader1 = torch.utils.data.DataLoader(dataset1, args.batch_size,
num_workers=args.num_workers,sampler=sampler,
shuffle=False, collate_fn=detection_collate,
pin_memory=True)
data_loader2 = torch.utils.data.DataLoader(dataset2, args.batch_size,
num_workers=args.num_workers,sampler=sampler,
shuffle=False, collate_fn=detection_collate,
pin_memory=True)
分别在data_loader1
和data_loader2
之前打印了sampler1
,输出的结果是相同的,类似:
<torch.utils.data.sampler.RandomSampler object at 0x7ffacd0a87f0>
令我以为data_loader1
和data_loader2
所载入的数据中的采样是相同的。但是结果发现,两个模态的图片根本对不上!试着打印出sampler1
中的内容:print(list(sampler))
,发现两个list
是不一样的。
查看torch.utils.data.sampler
的源码(https://github.com/pytorch/pytorch/blob/master/torch/utils/data/sampler.py),发现RandomSampler
主要返回的是return iter(torch.randperm(n).tolist())
,其中n
代表数据集的数据个数n = len(self.data_source)
。意思是把数据集顺序打乱,然后再送入迭代器。
由此发现,每执行一次
sampler=torch.utils.data.sampler.RandomSampler(dataset1)
都会调用torch.randperm(n)
,把数据集数据序号打乱!这就是为什么两次打印的list(sampler)
不一样的原因。
import torch
from torch._six import int_classes as _int_classes
class Sampler(object):
r"""Base class for all Samplers.
Every Sampler subclass has to provide an :meth:`__iter__` method, providing a
way to iterate over indices of dataset elements, and a :meth:`__len__` method
that returns the length of the returned iterators.
.. note:: The :meth:`__len__` method isn't strictly required by
:class:`~torch.utils.data.DataLoader`, but is expected in any
calculation involving the length of a :class:`~torch.utils.data.DataLoader`.
"""
def __init__(self, data_source, replacement=False, num_samples=None):
self.data_source = data_source
self.replacement = replacement
self._num_samples = num_samples
if not isinstance(self.replacement, bool):
raise ValueError("replacement should be a boolean value, but got "
"replacement={}".format(self.replacement))
if self._num_samples is not None and not replacement:
raise ValueError("With replacement=False, num_samples should not be specified, "
"since a random permute will be performed.")
if not isinstance(self.num_samples, int) or self.num_samples <= 0:
raise ValueError("num_samples should be a positive integer "
"value, but got num_samples={}".format(self.num_samples))
@property
def num_samples(self):
# dataset size might change at runtime
if self._num_samples is None:
return len(self.data_source)
return self._num_samples
def __iter__(self):
n = len(self.data_source)
if self.replacement:
return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist())
return iter(torch.randperm(n).tolist())
def __len__(self):
return self.num_samples
torch.utils.data.sampler
所提供的采样器是不能用了,因为每次迭代都会打乱采样器中数据的顺序,导致输入到网络的两个模态的图片不匹配,对torch.utils.data.sampler.RandomSampler
稍作修改:
既然每次调用RandomSampler
都会执行torch.randperm(n)
,且这一步的目的是输出打乱的数据集序号,可以直接给RandomSampler
的迭代器一个已经打乱的序号。
首先,自己定一个RandomSampler
类,直接复制torch.utils.data.sampler.Sampler.RandomSampler
的代码,继承torch.utils.data.sampler.Sampler
基类,令其返回iter(s)
,其中s
是打乱顺序的数据集序号list
class RandomSampler(data.sampler.Sampler):
def __init__(self, data_source, replacement=False, num_samples=None):
self.data_source = data_source
self.replacement = replacement
self._num_samples = num_samples
if not isinstance(self.replacement, bool):
raise ValueError("replacement should be a boolean value, but got "
"replacement={}".format(self.replacement))
if self._num_samples is not None and not replacement:
raise ValueError("With replacement=False, num_samples should not be specified, "
"since a random permute will be performed.")
if not isinstance(self.num_samples, int) or self.num_samples <= 0:
raise ValueError("num_samples should be a positive integer "
"value, but got num_samples={}".format(self.num_samples))
@property
def num_samples(self):
# dataset size might change at runtime
if self._num_samples is None:
return len(self.data_source)
return self._num_samples
def __iter__(self):
n = len(self.data_source)
if self.replacement:
return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist())
return iter(s1)
def __len__(self):
return self.num_samples
在data.DataLoader
代码段中,首先定义一个全局变量s
,然后把shuffle
后的数据集序号赋值给s
,接着调用刚刚定义的RandomSampler
类,并把输出sampler
送到data.DataLoader
中:
global s
s=torch.randperm(len(dataset1)).tolist()
print(s)
sampler=RandomSampler(dataset1)
data_loader1 = data.DataLoader(dataset1, args.batch_size, ## 同过Dataloader 读取数据
num_workers=args.num_workers,sampler=sampler,
shuffle=False, collate_fn=detection_collate,
pin_memory=True)
data_loader2 = data.DataLoader(dataset2, args.batch_size,
num_workers=args.num_workers,sampler=sampler,
shuffle=False, collate_fn=detection_collate,
pin_memory=True)
为方便展示双模态数据载入效果,这里两个数据集都采用KITTI数据集中的相机图像,编号从007000到007480。
双模态数据载入效果(batch size =3, 经过图像增强)
打印s
,和数据集中的图像编号相对应
[103, 11, 378, 128, 220, 149, 174, 344, 74
, 156, 209, 189, 131, 283, 30, 8, 471, 193, 340, 46, 248, 109, 381, 439, 188, 183, 60, 72, 210, 110, 197, 425, 291, 175, 62, 69, 374, 366, 98, 242, 319, 339, 90, 380, 259, 315, 397, 117, 451, 176, 272, 373, 440, 417, 452, 221, 477, 441, 268, 73, 21, 433, 104, 196, 474, 47, 194, 86, 411, 83, 77, 480, 116, 4, 308, 338, 310, 398, 75, 150, 321, 311, 353, 412, 94, 449, 190, 450, 95, 124, 408, 360, 227, 329, 38, 448, 247, 371, 249, 217, 295, 31, 273, 438, 435, 472, 431, 33, 455, 354, 28, 5, 288, 382, 280, 80, 469, 15, 316, 388, 202, 282, 296, 53, 231, 287, 256, 185, 224, 178, 235, 85, 421, 70, 289, 234, 243, 346, 375, 236, 274, 142, 395, 219, 334, 263, 2, 476, 432, 172, 343, 377, 140, 401, 312, 349, 255, 57, 456, 459, 271, 7, 345, 317, 228, 119, 58, 170, 179, 465, 191, 368, 84, 148, 79, 357, 71, 404, 351, 99, 327, 298, 426, 3, 49, 384, 34, 240, 385, 40, 379, 413, 386, 37, 458, 78, 215, 269, 320, 159, 12, 63, 429, 152, 262, 337, 233, 212, 153, 162, 61, 399, 265, 281, 171, 163, 314, 479, 126, 467, 423, 364, 123, 341, 184, 133, 177, 463, 260, 239, 290, 405, 9, 208, 213, 336, 173, 436, 390, 462, 64, 389, 14, 26, 112, 264, 410, 331, 251, 211, 155, 129, 107, 229, 167, 418, 138, 407, 145, 409, 367, 478, 169, 113, 257, 181, 391, 29, 323, 18, 141, 285, 261, 82, 245, 115, 419, 127, 100, 25, 387, 168, 416, 294, 10, 277, 144, 422, 286, 59, 466, 330, 457, 6, 461, 199, 105, 36, 88, 414, 415, 27, 158, 428, 322, 146, 192, 266, 35, 137, 101, 293, 468, 65, 50, 161, 44, 442, 54, 230, 362, 394, 66, 111, 136, 226, 307, 244, 20, 121, 246, 305, 250, 52, 39, 218, 475, 376, 369, 92, 89, 356, 361, 96, 328, 276, 195, 68, 365, 350, 164, 87, 76, 186, 1, 324, 238, 434, 130, 454, 306, 473, 301, 430, 67, 214, 254, 200, 165, 19, 180, 355, 106, 157, 122, 342, 403, 143, 41, 147, 309, 45, 13, 16, 108, 445, 22, 198, 206, 297, 32, 43, 396, 302, 383, 279, 270, 24, 203, 464, 187, 102, 51, 205, 125, 81, 352, 93, 335, 443, 453, 332, 460, 204, 154, 392, 444, 0, 299, 139, 166, 42, 223, 91, 97, 363, 182, 241, 258, 151, 427, 420, 56, 406, 358, 207, 134, 333, 267, 252, 292, 132, 237, 48, 23, 400, 160, 402, 114, 470, 304, 232, 275, 424, 393, 347, 325, 118, 447, 222, 225, 372, 370, 201, 437, 300, 326, 120, 216, 313, 318, 303, 55, 446, 348, 284, 359, 278, 135, 253, 17]