PyTorch的F.dropout为什么要加self.training?

诸神缄默不语-个人CSDN博文目录

写作业的时候发现老师强调 F.dropout() 必须要传入 self.training,感到疑惑,所以上网搜寻了一下解释,最终明白了情况。

dropout方法出自Dropout: A Simple Way to Prevent Neural Networks from Overfitting,证明该方法有效的文献:Improving neural networks by preventing co-adaptation of feature detectors。
dropout方法是将输入Tensor的元素按伯努利分布随机置0,具体原理此处不赘,以后待补。总之就是训练的时候要用dropout,验证/测试的时候要关dropout。在PyTorch中的实现,是在训练阶段时输出直接乘以 1 1 − p \frac{1}{1-p} 1p1,测试阶段就直接当恒等函数1来用2

以下介绍Module的training属性,F(torch.nn.functional).dropout 和 nn(torch.nn).Dropout 中相应操作的实现方式,以及Module的training属性受train()eval()方法影响而改变的机制。

文章目录

  • 1. Module的training属性
  • 2. torch.nn.functional.dropout的入参training
  • 3. torch.nn.Dropout不需要手动开关
  • 4. Module的train()和eval()方法改变self.training
  • 5. 除正文中已列文档外的参考资料

1. Module的training属性

见torch.nn.Module官方文档
是Module的属性,布尔值,返回Module是否处于训练状态。也就是说在训练时training就是True。
默认为True,也就是Module初始化时默认为训练状态。

2. torch.nn.functional.dropout的入参training

torch.nn.functional.dropout官方文档

torch.nn.functional.dropout(input, p=0.5, training=True, inplace=False)
入参training默认为True,置True时应用Dropout,置False时不用。
因此在调用F.dropout()时,直接将self.training传入函数,就可以在训练时应用dropout,评估时关闭dropout。

示例代码:

x=F.dropout(x,p,self.training)

3. torch.nn.Dropout不需要手动开关

torch.nn.Dropout官方文档

torch.nn.Dropout(p=0.5, inplace=False)

其源代码为(Dropout源码):

class Dropout(_DropoutNd):
    def forward(self, input: Tensor) -> Tensor:
        return F.dropout(input, self.p, self.training, self.inplace)

就这个类相当于将 F.dropout() 进行了包装,内置传入了self.training,就不用像在 F.dropout() 里需要手动传参,也能实现在训练时应用dropout,评估时关闭dropout。

示例代码:

m = nn.Dropout(p=0.2)
input = torch.randn(20, 16)
output = m(input)

4. Module的train()和eval()方法改变self.training

torch.nn.Module.train官方文档
train(mode=True)
如果入参为True,则将Module设置为training mode,training随之变为True;反之则设置为evaluation mode,training为False。

torch.nn.Module.eval官方文档
eval()
将Module设置为evaluation mode,相当于 self.train(False)

5. 除正文中已列文档外的参考资料

  1. PyTorch 有哪些坑/bug? - 雷杰的回答 - 知乎 那时候F.dropout的training默认置False,更容易错了……
  2. F.dropout源代码
  3. (深度学习)Pytorch之dropout训练_junbaba_的博客-CSDN博客
  4. torch.nn.Module中的training属性详情,与Module.train()和Module.eval()的关系_chaiiiiiiiiiiiiiiiii的博客-CSDN博客

  1. 恒等函数 identity function identity function维基百科 ↩︎

  2. 测试阶段的源代码比较直接,可以看出来(来源:pytorch/symbolic_opset9.py at master · pytorch/pytorch):
    在这里插入图片描述 ↩︎

你可能感兴趣的:(人工智能学习笔记,神经网络,PyTorch,dropout,深度学习,python)