参考:Python3 zip() 函数 | 菜鸟教程 (runoob.com)
1.1 描述
zip() 函数用于将可迭代的对象作为参数,将对象中对应的元素打包成一个个元组,然后返回由这些元组组成的对象,这样做的好处是节约了不少的内存。
我们可以使用list()转换来输出列表。
如果各个迭代器的元素个数不一致,则返回列表长度与最短的对象相同,利用 * 号操作符,可以将元组解压为列表。
1.2 语法
zip([iterable, ...])
参数说明:
iterable – 一个或多个迭代器;
1.3 返回值
返回一个对象。
1.4 例子
>>> a = [1, 2, 3]
>>> b = [4, 5, 6]
>>> c = [7, 8, 9, 10, 11, 12]
>>> zip_object = zip(a, b)
>>> zip_object
# 可以看到,zip的返回值是一个对象
# 解包的两种方式
# 方法1
>>> list(zip_object)
[(1, 4), (2, 5), (3, 6)]
>>> list(zip(a, c))
[(1, 7), (2, 8), (3, 9)]
# 方法2
>>> print(*a)
1 2 3
>>> print(*zip_object)
# 一般用法
>>> for i, j in zip(a, b):
... print(f"i: {i}, j: {j}")
...
i: 1, j: 4
i: 2, j: 5
i: 3, j: 6
一般zip()函数用于我们想同时处理两个列表里的数据。
a = [1, 2, 3]
b = [4, 5, 6]
c = [7, 8, 9, 10, 11, 12]
for i, j in zip(a, b):
print(f"i: {i}, j: {j}")
"""
i: 1, j: 4
i: 2, j: 5
i: 3, j: 6
"""
for x,y,z in zip(a, b, c):
print(f"x: {x}, y: {y}, z: {z}")
"""
x: 1, y: 4, z: 7
x: 2, y: 5, z: 8
x: 3, y: 6, z: 9
"""
神经网络Dataloader中的应用
train_data_loader = torch.utils.data.DataLoader(train_dataset,
batch_sampler=train_batch_sampler,
pin_memory=True,
num_workers=nw,
collate_fn=train_dataset.collate_fn)
collate_fn默认是对数据(图片)通过torch.stack()进行简单的拼接。对于分类网络来说,默认方法是可以的(因为传入的就是数据的图片),但是对于目标检测来说,train_dataset返回的是一个tuple,即(image, target)。如果我们还是采用默认的合并方法,那么就会出错。所以我们需要自定义一个方法,即collate_fn=train_dataset.collate_fn,其代码如下:
@staticmethod
def collate_fn(batch):
return tuple(zip(*batch))
首先batch是用list将(image, target)进行封装,即
一个batch(batch_size=8) = [(image, target),
(image, target),
(image, target),
(image, target)]
zip(*batch)即先将list进行解包(去掉list),再用zip()进行压缩,再用tuple进行解包,这里我们用一个小例子演示一下过程。
>>> img = [[1, 2], [3, 4], [5, 6]]
>>> target = {"key1": "value1", "key2": "value2"}
>>> batch = [(img, target), (img, target), (img, target), (img, target)]
>>> batch
[([[1, 2], [3, 4], [5, 6]], {'key1': 'value1', 'key2': 'value2'}),
([[1, 2], [3, 4], [5, 6]], {'key1': 'value1', 'key2': 'value2'}),
([[1, 2], [3, 4], [5, 6]], {'key1': 'value1', 'key2': 'value2'}),
([[1, 2], [3, 4], [5, 6]], {'key1': 'value1', 'key2': 'value2'})]
>>> print(*batch) # 去掉了最外层的list
([[1, 2], [3, 4], [5, 6]], {'key1': 'value1', 'key2': 'value2'})
([[1, 2], [3, 4], [5, 6]], {'key1': 'value1', 'key2': 'value2'})
([[1, 2], [3, 4], [5, 6]], {'key1': 'value1', 'key2': 'value2'})
([[1, 2], [3, 4], [5, 6]], {'key1': 'value1', 'key2': 'value2'})
>>> zip(*batch) # 对解包数据进行zip压缩
>>> tuple(zip(*batch)) # 用tuple进行解包
(
([[1, 2], [3, 4], [5, 6]], [[1, 2], [3, 4], [5, 6]], [[1, 2], [3, 4], [5, 6]], [[1, 2], [3, 4], [5, 6]]),
({'key1': 'value1', 'key2': 'value2'}, {'key1': 'value1', 'key2': 'value2'}, {'key1': 'value1', 'key2': 'value2'}, {'key1': 'value1', 'key2': 'value2'})
)
>>> tuple(zip(*batch))[0]
([[1, 2], [3, 4], [5, 6]], [[1, 2], [3, 4], [5, 6]], [[1, 2], [3, 4], [5, 6]], [[1, 2], [3, 4], [5, 6]])
>>> tuple(zip(*batch))[1]
({'key1': 'value1', 'key2': 'value2'}, {'key1': 'value1', 'key2': 'value2'}, {'key1': 'value1', 'key2': 'value2'}, {'key1': 'value1', 'key2': 'value2'})
这样操作之后,图片数据为tuple(zip(*batch))[0],target数据为tuple(zip(*batch))[1]。
之后再迭代(enumerate该dataloader就方便多了)。