动手学深度学习李沐---yield关键字的理解

当时对这个yield的关键字不是很明白,找了很多blog,大致理解就是:yield就是 return 返回一个值,并且记住这个返回的位置,下次迭代就从这个位置后(下一行)开始。后来发现,哈哈哈确实说的一点不错

def data_iter(batch_size, features, labels):
    num_examples = len(features)
    indices = list(range(num_examples))
    # 这些样本是随机读取的,没有特定的顺序
    random.shuffle(indices)
    for i in range(0, num_examples, batch_size):
        batch_indices = torch.tensor(
            indices[i: min(i + batch_size, num_examples)])
        yield features[batch_indices], labels[batch_indices]

batch_size = 10

for X, y in data_iter(batch_size, features, labels):
    print(X, '\n', y)
    break

 其实加了yield的函数,相当于变成了一个生成器函数,只有当调用的时候,才能进入迭代。

下面是用jupyter对代码进行稍加调试,以便理解

1、首先输入如下代码,发现无输出,表明函数并没有真正的执行

Input[1]:
batch_size = 10
g = data_iter(batch_size, features, labels)

Output:
None

2、接下来直接对该函数进行调用,并打印出结果,发现终于进入了循环,并且输出了feature的值,但是并没有打印i值,因为遇到yield关键字,返回了yield后面的值,并记住了改处的位置,然后直接退出函数

Input[2]:
print(next(g))

Output:
进入函数
进入循环
(tensor([[-2.2516,  0.8632],
        [ 0.0199,  0.6596],
        [-0.2909,  0.6650],
        [ 1.0914, -1.1055],
        [ 0.8471, -1.1018],
        [-0.4365, -1.0246],
        [ 0.8232, -0.6454],
        [ 0.7633, -0.1330],
        [ 0.4759,  1.3009],
        [ 0.7626,  0.6188]]), tensor([[-3.2485],
        [ 1.9843],
        [ 1.3509],
        [10.1546],
        [ 9.6300],
        [ 6.8132],
        [ 8.0320],
        [ 6.1773],
        [ 0.7226],
        [ 3.6142]]))

3、接着调用函数,程序会从上次记录的位置直接进入,从yield的后一句开始运行,所以先打上一轮i的值0,然后进行下一次循环,所以不会打印“进入函数”,直接打印“进入循环”,得到新的features的值,返回新值,然后记录位置,退出函数

Input[3]:
print(next(g))

Output:
0
进入循环
(tensor([[-0.1043,  0.2224],
        [-1.3881, -0.8711],
        [ 0.5147, -0.2740],
        [-1.6334, -0.9906],
        [ 1.1541, -0.7065],
        [-1.2059, -1.2036],
        [ 0.4678,  0.7486],
        [ 1.1672, -0.3777],
        [ 0.5903,  0.5108],
        [-0.5092, -0.1133]]), tensor([[3.2222],
        [4.3944],
        [6.1510],
        [4.3060],
        [8.9257],
        [5.8633],
        [2.6000],
        [7.8166],
        [3.6278],
        [3.5842]]))

4、再调用函数,注意,输出的i的值变成了10,并不是从0再重新开始,因为yield关键字会记录位置,所以每次都是接着上一次的程序运行。

Input[4]:
print(next(g))

Output:
10
进入循环
(tensor([[ 0.2803,  0.2294],
        [ 0.0492,  0.3186],
        [-0.3837,  0.2482],
        [ 1.5445,  0.2595],
        [-1.2430, -1.2258],
        [-0.6183, -1.2285],
        [-0.4411,  0.7067],
        [-0.5177, -0.3614],
        [ 1.5170, -0.9192],
        [-2.1817, -0.9584]]), tensor([[ 3.9902],
        [ 3.2159],
        [ 2.5958],
        [ 6.3975],
        [ 5.8732],
        [ 7.1226],
        [ 0.9102],
        [ 4.3772],
        [10.3453],
        [ 3.0877]]))

5、return和yield的返回值也有区别,return返回的是一个list,但yield返回的直接是函数值,所以如果这里直接将yield换成return会报错,赋值语句左右不匹配。

6、如果将break去掉,发现会打印1000组X,y。更加证实了这里的data_iter是一个生成器,当迭代到生成器中i无法再迭代时,退出for循环​​​​​​​ 

动手学深度学习李沐---yield关键字的理解_第1张图片

你可能感兴趣的:(动手学深度学习,深度学习,python,人工智能)