首先,对于一个python数组,我们可以用for,或者next来遍历。
其次,而对于pytorch的 torch.utils.data.DataLoader, 也可以使用类似 next(iter(DataLoader)) 的方式遍历地读取dataset的数据。
再次,使用yield关键字,也可以起到“遍历”的效果。
那么问题来了。这iter,next,yield几个东西之间有什么关联,又有什么区别呢?pytorch的DataLoader又是用的什么样的方式呢?
先看一个简单例子:
fruit = ["apple", "banana", "cherry"]
# 1. 使用for循环
for f in fruit:
print(f)
# 2. 使用iter创建迭代器,通过next遍历
fr = iter(fruit)
print(next(fr))
print(next(fr))
print(next(fr))
"""
上面两种方式,都会打印以下结果:
apple
banana
cherry
"""
不要被同样的结果迷惑了。python的iter方法将数组转为一个可遍历的对象,对这个对象遍历,和对原来的数组遍历,是不一样的。不信,可以接着再执行下面的代码:
for f in fr:
print(f)
这次,没有任何输出。
其中的秘密,就在于fr是用iter()创建的对象。对这个对象遍历的时候,每个元素只能遍历到一次(可以理解为,其内部有当前读指针,每读一次,指针向前移动)。就是这么简单的过程。
再看yield,实际上干了差不多的事情:
def fruit_iter (fruits):
count = len(fruits)
for i in range(0, count):
yield fruits[i]
it = fruit_iter(fruit)
for f in it:
print(f)
"""
打印结果:
apple
banana
cherry
"""
上面fruit_iter被称为Generator Function,返回一个Generator对象。可以打印一下:
print(it)
"""
"""
另外,虽然从上面的例子看出,next()和for对于Generator Function都能起到遍历的作用,但next()还有额外的解释(链接):
Note: When you use next(), Python calls ._next_() on the function
you pass in as a parameter.
也就是说,next()函数实际上调用了传入函数的.__next()__成员函数。所以,如果传入的函数没有这个成员,则会报错。
到这个地方,相信对iter()函数、next()函数,以及yield有了一定的认识。但还有一个问题:pytorch的torch.utils.data.DataLoader 是用哪种方式实现的?用一个简单的例子验证一下:
import torch
# 生成一些测试数据
X = torch.normal(0, 1, (1000, 2)) # x: sample size = 1000, feature_dim = 2
y = torch.normal(0, 1, (1000, 1)) # y: sample size = 1000, dim = 1
# 定义一个函数,返回dataloader
def load_array(data_and_label, batch_size, is_train=True):
"""Construct a PyTorch data iterator."""
dataset = data.TensorDataset(*data_and_label)
return data.DataLoader(dataset, batch_size, shuffle=is_train)
batch_size = 10
data_iter = load_array((X, y), batch_size)
#这句会报错:next(data_iter)
next(iter(data_iter))
"""
输出前10组数据
"""
这里,为什么 next(data_iter) 报错,而 next(iter(data_iter)) 可以返回数据呢?这是因为,pytorch的DataLoader函数没有 _next_ 成员,但有 _iter_ 成员(见源文件)。所以,需要首先通过 iter() 函数返回一个 _iter_ 成员,再找这个 _iter_ 的 _next_
到这里,我们对本文开始提到的问题(iter, next, yield 和 pytorch dataloader)应该有一个比较深入的了解了。
https://www.guru99.com/python-yield-return-generator.html
https://realpython.com/introduction-to-python-generators/
https://anandology.com/python-practice-book/iterators.html
https://stackoverflow.com/questions/231767/what-does-the-yield-keyword-do
https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html#DataLoader