当时对这个yield的关键字不是很明白,找了很多blog,大致理解就是:yield就是 return 返回一个值,并且记住这个返回的位置,下次迭代就从这个位置后(下一行)开始。后来发现,哈哈哈确实说的一点不错
def data_iter(batch_size, features, labels):
num_examples = len(features)
indices = list(range(num_examples))
# 这些样本是随机读取的,没有特定的顺序
random.shuffle(indices)
for i in range(0, num_examples, batch_size):
batch_indices = torch.tensor(
indices[i: min(i + batch_size, num_examples)])
yield features[batch_indices], labels[batch_indices]
batch_size = 10
for X, y in data_iter(batch_size, features, labels):
print(X, '\n', y)
break
其实加了yield的函数,相当于变成了一个生成器函数,只有当调用的时候,才能进入迭代。
下面是用jupyter对代码进行稍加调试,以便理解
1、首先输入如下代码,发现无输出,表明函数并没有真正的执行
Input[1]:
batch_size = 10
g = data_iter(batch_size, features, labels)
Output:
None
2、接下来直接对该函数进行调用,并打印出结果,发现终于进入了循环,并且输出了feature的值,但是并没有打印i值,因为遇到yield关键字,返回了yield后面的值,并记住了改处的位置,然后直接退出函数
Input[2]:
print(next(g))
Output:
进入函数
进入循环
(tensor([[-2.2516, 0.8632],
[ 0.0199, 0.6596],
[-0.2909, 0.6650],
[ 1.0914, -1.1055],
[ 0.8471, -1.1018],
[-0.4365, -1.0246],
[ 0.8232, -0.6454],
[ 0.7633, -0.1330],
[ 0.4759, 1.3009],
[ 0.7626, 0.6188]]), tensor([[-3.2485],
[ 1.9843],
[ 1.3509],
[10.1546],
[ 9.6300],
[ 6.8132],
[ 8.0320],
[ 6.1773],
[ 0.7226],
[ 3.6142]]))
3、接着调用函数,程序会从上次记录的位置直接进入,从yield的后一句开始运行,所以先打上一轮i的值0,然后进行下一次循环,所以不会打印“进入函数”,直接打印“进入循环”,得到新的features的值,返回新值,然后记录位置,退出函数
Input[3]:
print(next(g))
Output:
0
进入循环
(tensor([[-0.1043, 0.2224],
[-1.3881, -0.8711],
[ 0.5147, -0.2740],
[-1.6334, -0.9906],
[ 1.1541, -0.7065],
[-1.2059, -1.2036],
[ 0.4678, 0.7486],
[ 1.1672, -0.3777],
[ 0.5903, 0.5108],
[-0.5092, -0.1133]]), tensor([[3.2222],
[4.3944],
[6.1510],
[4.3060],
[8.9257],
[5.8633],
[2.6000],
[7.8166],
[3.6278],
[3.5842]]))
4、再调用函数,注意,输出的i的值变成了10,并不是从0再重新开始,因为yield关键字会记录位置,所以每次都是接着上一次的程序运行。
Input[4]:
print(next(g))
Output:
10
进入循环
(tensor([[ 0.2803, 0.2294],
[ 0.0492, 0.3186],
[-0.3837, 0.2482],
[ 1.5445, 0.2595],
[-1.2430, -1.2258],
[-0.6183, -1.2285],
[-0.4411, 0.7067],
[-0.5177, -0.3614],
[ 1.5170, -0.9192],
[-2.1817, -0.9584]]), tensor([[ 3.9902],
[ 3.2159],
[ 2.5958],
[ 6.3975],
[ 5.8732],
[ 7.1226],
[ 0.9102],
[ 4.3772],
[10.3453],
[ 3.0877]]))
5、return和yield的返回值也有区别,return返回的是一个list,但yield返回的直接是函数值,所以如果这里直接将yield换成return会报错,赋值语句左右不匹配。
6、如果将break去掉,发现会打印1000组X,y。更加证实了这里的data_iter是一个生成器,当迭代到生成器中i无法再迭代时,退出for循环