提示:刚入门深度学习,记录一些犯下的小错误:
由于本周开始试图复现华为的CTR库以增加记忆,熟悉代码细节,没想到第一天看基础模块的时候就遇到了麻烦,在torch.utils类中,有如下获取损失函数的代码块:
def get_loss_fn(loss):
if isinstance(loss, str):
if loss in ["bce", "binary_crossentropy", "binary_cross_entropy"]:
loss = "binary_cross_entropy"
try:
loss_fn = getattr(torch.functional.F, loss)
except:
try:
from . import losses
loss_fn = getattr(losses, loss)
except:
raise NotImplementedError("loss={} is not supported.".format(loss))
return loss_fn
其中getattr()函数是用于返回一个对象属性值(Tip: class中的方法也是一种对象属性),因此可以看出第6行代码的作用就是返回torch.functional.F这个类中的loss函数,那么问题来了:上面代码片中的torch.functional.F是哪个类呢,或者说是哪个模块呢?之前在学习PyTorch的过程中只接触过其中的:
import torch.nn.functional as F
那么这个torch.functional.F与torch.nn.functional有何区别?
因此抱着分辨清楚的目的,查看PyTorch官方文档,我发现只有torch.nn.functional才有一系列的loss函数的实现,而输入关键词torch.functional在搜索引擎上基本找不到相关的资料,返回的搜索结果都是与前者相关的文档。于是我决定去看源码弄清楚:
可以看到这两个模块显然是不同的模块!!!而后我打开torch.functional.py文件,出现了我无语的一幕,原来在torch.functional.py的第一行就是这么写的:
问题解决了,torch.functional .F指向的就是torch.nn.functional,可能刚开始试图复现这个CTR库吧,实在搞不懂作者为什么不直接直接使用torch.nn.functional来指代?
END~