最近在回顾廖雪峰老师的python教程,在“高阶函数”这一节有一段利用filter函数和生成器求素数的代码,这里为了方便理解和调试进行了简化:
def func(n):
return lambda x: x % n > 0
def primes():
it = (i for i in np.arange(3, 20, 2))
while True:
n = next(it)
yield n
it = filter(func(n), it)
print(list(primes()))
运行结果:[3, 5, 7, 11, 13, 17, 19]
代码看上去很简单:func()的功能是返回一个匿名函数,作用是判断输入值是否大于n;prime()是一个返回素数的生成器,实现过程是先定义一个不含偶数的生成器,然后每返回一个素数,就利用filter排除生成器中可以被这个素数整除的元素。看到这里的时候我在想,为什么要单独定义一个func()呢,反正返回的也是一个lambda,干脆把func()里的内容放进filter,写python当然是追求代码越简洁越好咯。然而,噩梦开始了:
def primes():
it = (i for i in np.arange(3, 20, 2))
while True:
n = next(it)
yield n
it = filter(lambda x: x % n > 0, it)
print(list(primes()))
运行结果:[3, 5, 7, 9, 11, 13, 15, 17, 19]
心凉了一会,既然踩到了坑,就把坑慢慢填起来吧。先来回忆一下filter()的用法:接收一个函数和一个序列,传入的函数依次作用于每个元素,然后根据返回的布尔值决定是否保留该元素。试一下利用filter实现对偶数的筛选:
L = list(filter(lambda s: s % 2 == 0, (i for i in range(5))))
print(L)
运行结果:[0, 2, 4]
结果意料之中,filter()的逻辑很简洁,匿名函数和生成器作为filter()函数参数也没问题。但是回过头来看之前的结果,在yield 3之后filter并没有起到排除可以被3整除元素的作用,9和15依然出现在了最终结果中,所以猜测lambda这里出了问题。换一个简单的lambda再试一下:
def primes():
it = (i for i in np.arange(3, 20, 2))
while True:
n = next(it)
yield n
it = filter(lambda x: x < 15, it)
print(list(primes()))
运行结果:[3, 5, 7, 9, 11, 13]
结果正确。为什么会出现这种问题,我在找了很多资料后发现,其实官方的FAQ中介绍了这样一个python的坑,我们看一下官方的例子:
squares = []
for x in range(5):
squares.append(lambda: x**2)
print([i() for i in squares])
运行结果:[16, 16, 16, 16, 16]
感觉是不是跟前面的例子很类似,但是结果同样的匪夷所思?明明是把0到4的平方传进去的,但是为什么出来的却全是4的平方?官方的解释如下:
This happens because x is not local to the lambdas, but is defined in the outer scope, and it is accessed when the lambda is called — not when it is defined. At the end of the loop, the value of x is 4, so all the functions now 16.
官方也补充到,我们还可以进一步通过继续更改x的值并查看lambda的结果如何变化来验证这一点。
x = 8
squares[2]()
运行结果:64
为了避免这种情况,我们需要把外部数值作为参数传给lambda来代替全局变量(原始代码中func(n)函数正是这个作用),代码如下:
squares = []
for x in range(5):
squares.append(lambda n=x: n**2)
print([i() for i in squares])
运行结果:[0, 1, 4, 9, 16]
同理,之前的lambda也可以改为:
def primes():
it = (i for i in np.arange(3, 20, 2))
while True:
n = next(it)
yield n
it = filter(lambda x, n=n: x % n > 0, it)
print(list(primes()))
运行结果:[3, 5, 7, 11, 13, 17, 19]