pytorch: nn.Dropout 和 nn.functional.dropout 的区别

pytorch: nn.Dropout 和 nn.functional.dropout 的区别

import torch
import torch.nn as nn


class FunctionalDropout(nn.Module):
    def __init__(self, p=0.0):
        super(FunctionalDropout, self).__init__()
        self.p = p

    def forward(self, inputs):
        return nn.functional.dropout(inputs, p=self.p, training=True)


class NNDropout(nn.Module):
    def __init__(self, p=0.0):
        super(NNDropout, self).__init__()
        self.drop_layer = nn.Dropout(p=p)

    def forward(self, inputs):
        return self.drop_layer(inputs)


functional_dropout = FunctionalDropout(p=0.5)
nn_dropout = NNDropout(p=0.5)

inputs = torch.rand(10)

print("train model:")
print(f"functional_dropout: {functional_dropout(inputs)}")
print(f"nn_dropout: {nn_dropout(inputs)}")

functional_dropout.eval()
nn_dropout.eval()

print("evaluation model")
print(f"functional_dropout: {functional_dropout(inputs)}")
print(f"nn_dropout: {nn_dropout(inputs)}")

print(f"{functional_dropout}")
print(f"{nn_dropout}")

output:

train model:
functional_dropout: tensor([0.3566, 0.0000, 0.0000, 0.0000, 0.0000, 1.8748, 0.8864, 0.0000, 1.5444,
        0.0000])
nn_dropout: tensor([0.3566, 1.0434, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000])
evaluation model
functional_dropout: tensor([0.3566, 1.0434, 1.8903, 0.0000, 0.0000, 1.8748, 0.0000, 0.0000, 1.5444,
        0.0000])
nn_dropout: tensor([0.1783, 0.5217, 0.9452, 0.6607, 0.0062, 0.9374, 0.4432, 0.0659, 0.7722,
        0.7022])
        
FunctionalDropout()
NNDropout(
  (drop_layer): Dropout(p=0.5, inplace=False)
)

那么我们应该使用哪个?

两者在应用 dropout 方面是完全等价的,尽管在用法上的差异不是那么大,但有一些理由支持 nn.Dropout 而不是 nn.functional.dropout

Dropout 设计为仅在训练期间应用,因此在对模型进行prediction或evaluation时,我们希望关闭 dropout。

nn.Dropout模块可以方便地处理这个问题,并在我们的模型进入evaluation模式后立即关闭 dropout,而nn.functional.dropout 不关心evaluation/prediction模式

即使我们可以将 functional dropout 设置为 training=False 以将其关闭,它仍然不是像 nn.Dropout 这样方便的解决方案。

同时drop rate(即上述代码中的p)也存储在nn.Dropout模块中,因此我们不必将其保存在额外的变量中

最后,分配给我们模型的所有模块都在我们的模型中进行了注册,所以我们的模型类会跟踪这些模块,这就是为什么我们可以通过调用eval()来关闭dropout模块。当使用functional dropout时,我们的模型并不知道它

参考

1、Pytorch: nn.Dropout vs. F.dropout

你可能感兴趣的:(python,pytorch,pytorch,python)