PyTorch踩坑记录——torch.functional 与 torch.nn.functional的区别

问题描述:

提示:刚入门深度学习,记录一些犯下的小错误:

由于本周开始试图复现华为的CTR库以增加记忆,熟悉代码细节,没想到第一天看基础模块的时候就遇到了麻烦,在torch.utils类中,有如下获取损失函数的代码块:

def get_loss_fn(loss):
    if isinstance(loss, str):
        if loss in ["bce", "binary_crossentropy", "binary_cross_entropy"]:
            loss = "binary_cross_entropy"
    try:
        loss_fn = getattr(torch.functional.F, loss)
    except:
        try:
            from . import losses
            loss_fn = getattr(losses, loss)
        except:
            raise NotImplementedError("loss={} is not supported.".format(loss))
    return loss_fn

其中getattr()函数是用于返回一个对象属性值(Tip: class中的方法也是一种对象属性),因此可以看出第6行代码的作用就是返回torch.functional.F这个类中的loss函数,那么问题来了:上面代码片中的torch.functional.F是哪个类呢,或者说是哪个模块呢?之前在学习PyTorch的过程中只接触过其中的:

import torch.nn.functional as F

那么这个torch.functional.Ftorch.nn.functional有何区别?


解惑

因此抱着分辨清楚的目的,查看PyTorch官方文档,我发现只有torch.nn.functional才有一系列的loss函数的实现,而输入关键词torch.functional在搜索引擎上基本找不到相关的资料,返回的搜索结果都是与前者相关的文档。于是我决定去看源码弄清楚:

PyTorch踩坑记录——torch.functional 与 torch.nn.functional的区别_第1张图片
可以看到这两个模块显然是不同的模块!!!而后我打开torch.functional.py文件,出现了我无语的一幕,原来在torch.functional.py的第一行就是这么写的:

PyTorch踩坑记录——torch.functional 与 torch.nn.functional的区别_第2张图片
问题解决了,torch.functional .F指向的就是torch.nn.functional,可能刚开始试图复现这个CTR库吧,实在搞不懂作者为什么不直接直接使用torch.nn.functional来指代?

END~


你可能感兴趣的:(PyTorch,python,nlp,数据挖掘,pytorch)