../d2l/torch.py中的lambda表达式

astype

astype = lambda x, *args, **kwargs: x.type(*args, **kwargs)
cmp = d2l.astype(y_hat, y.dtype) == y

解释:x接受第0个参数y_hat,args接收其它后面的参数y.dtype(这里是torch.int64),x.type是将x的元素强制转换成某个属性。综合起来这个lambda的意思是将y_hat的元素类型设置为和y.dtype一样的类型。

reduce_sum

reduce_sum = lambda x, *args, **kwargs: x.sum(*args, **kwargs)
d2l.reduce_sum(d2l.astype(cmp, y.dtype))

解释:等价于d2l.astype(cmp, y.dtype).sum();

另外,bool类型是直接可以加和的,例子如下:

a = torch.tensor([[False, True], [True, True]])
a.sum()
output: tensor(3)

你可能感兴趣的:(深度学习算法,python,numpy)