有关python的iter,next,yield,和pytorch的dataloader

问题

首先,对于一个python数组,我们可以用for,或者next来遍历。

其次,而对于pytorch的 torch.utils.data.DataLoader, 也可以使用类似 next(iter(DataLoader)) 的方式遍历地读取dataset的数据。

再次,使用yield关键字,也可以起到“遍历”的效果。

那么问题来了。这iter,next,yield几个东西之间有什么关联,又有什么区别呢?pytorch的DataLoader又是用的什么样的方式呢?

分析

先看一个简单例子:

fruit = ["apple", "banana", "cherry"]
# 1. 使用for循环
for f in fruit:
    print(f)
# 2. 使用iter创建迭代器,通过next遍历
fr = iter(fruit)
print(next(fr))
print(next(fr))
print(next(fr))

"""
上面两种方式,都会打印以下结果:
apple
banana
cherry
"""

不要被同样的结果迷惑了。python的iter方法将数组转为一个可遍历的对象,对这个对象遍历,和对原来的数组遍历,是不一样的。不信,可以接着再执行下面的代码:

for f in fr:
    print(f)

这次,没有任何输出。

其中的秘密,就在于fr是用iter()创建的对象。对这个对象遍历的时候,每个元素只能遍历到一次(可以理解为,其内部有当前读指针,每读一次,指针向前移动)。就是这么简单的过程。

再看yield,实际上干了差不多的事情:

def fruit_iter (fruits):
    count = len(fruits)
    for i in range(0, count):
        yield fruits[i]

it = fruit_iter(fruit)
for f in it:
  print(f)
"""
打印结果:
apple
banana
cherry
"""

上面fruit_iter被称为Generator Function,返回一个Generator对象。可以打印一下:

print(it)
"""

"""

另外,虽然从上面的例子看出,next()和for对于Generator Function都能起到遍历的作用,但next()还有额外的解释(链接):

Note: When you use next(), Python calls ._next_() on the function
you pass in as a parameter.

也就是说,next()函数实际上调用了传入函数的.__next()__成员函数。所以,如果传入的函数没有这个成员,则会报错。

到这个地方,相信对iter()函数、next()函数,以及yield有了一定的认识。但还有一个问题:pytorch的torch.utils.data.DataLoader 是用哪种方式实现的?用一个简单的例子验证一下:

import torch
# 生成一些测试数据
X = torch.normal(0, 1, (1000, 2)) # x: sample size = 1000, feature_dim = 2
y = torch.normal(0, 1, (1000, 1)) # y: sample size = 1000, dim = 1

# 定义一个函数,返回dataloader
def load_array(data_and_label, batch_size, is_train=True):
    """Construct a PyTorch data iterator."""
    dataset = data.TensorDataset(*data_and_label)
    return data.DataLoader(dataset, batch_size, shuffle=is_train)

batch_size = 10
data_iter = load_array((X, y), batch_size)

#这句会报错:next(data_iter)
next(iter(data_iter))
"""
输出前10组数据
"""

这里,为什么 next(data_iter) 报错,而 next(iter(data_iter)) 可以返回数据呢?这是因为,pytorch的DataLoader函数没有 _next_ 成员,但有 _iter_ 成员(见源文件)。所以,需要首先通过 iter() 函数返回一个 _iter_ 成员,再找这个 _iter_ 的 _next_

到这里,我们对本文开始提到的问题(iter, next, yield 和 pytorch dataloader)应该有一个比较深入的了解了。

参考

https://www.guru99.com/python-yield-return-generator.html
https://realpython.com/introduction-to-python-generators/
https://anandology.com/python-practice-book/iterators.html
https://stackoverflow.com/questions/231767/what-does-the-yield-keyword-do
https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html#DataLoader

你可能感兴趣的:(深度学习,python,pytorch,pytorch,python,深度学习)