pytorch八:nn.Module深入分析

如果想要深入理解nn.Module,研究其原理很有必要。首先来看看nn.Module基类的构造函数:

nn.Module??

pytorch八:nn.Module深入分析_第1张图片

  • _parameters:字典,保存用户直接设置的parameter,如self.param1=nn.Parameter(t.randn(3,3))会被检测到,在字典中加入一个key为‘param1’,value为对应的item,而self.submodel = nn.Linear(3,4)中的parameters则不会存于此。只可通过model.named_modules()查看全部参数来获得。
  • _modules:子module。通过self.submodule = nn.Linear(3,4)指定的子module会存于此。
  • _buffers:缓存。如batchnorm使用momentum机制,每次前向传播需要用到上一次前向传播的结果。
  • _training:BatchNorm与Dropout层在训练阶段和测试阶段中采取的策略不同,通过判断training的值来决定前向传播的策略。
class Net(nn.Module):
    def __init__(self):
        nn.Module.__init__(self)
        self.param1 = nn.Par

你可能感兴趣的:(pytorch)