torch之nn.moduleList 和Sequential由来、用法和实例

nn.moduleList 和Sequential由来、用法和实例

注意:本博文详细记述了nn.moduleList 和 nn.Sequential的区别以及使用,解释比较清晰。


  • Sequential和ModuleList在使用上的区别
  1. Sequential继承自Module,在其内部封装有forward函数,直接调用即可;
  2. ModuleList虽然同样是继承Module,但是其内部并没有封装forward函数,只是使用append或者extend操作对list进行扩充,本质上是一个list
  3. 使用:ModuleList作为网络结构的一部分,需要网络结构独立构建forward函数,而forward函数中可以引用ModuleList成员;
  4. 使用:Sequential 建立nn.Sequential()对象,必须小心确保一个块的输出大小与下一个块的输入大小匹配,基本上,它的行为就像一个nn.Module;

  • 问题:对于nn.Sequential使用时,直接调用解析:

  • 实例:

class ASPP(nn.Module):
    def __init__(self, in_channels, out_channels, paddings, dilations):
        # todo depthwise separable conv
        super(ASPP, self).__init__()
        self.conv11 = nn.Sequential(nn.Conv2d(in_channels, out_channels, 1, bias=False, ),
                                    nn.BatchNorm2d(out_channels))
        self.conv33 = nn.Sequential(nn.Conv2d(in_channels, out_channels, 3,
                                    padding=paddings, dilation=dilations, bias=False, ),
                                      nn.BatchNorm2d(out_channels))
        self.conv_p = nn.Sequential(nn.Conv2d(in_channels, out_channels, 1, bias=False, ),
                                    nn.BatchNorm2d(out_channels))

        self.concate_conv = nn.Sequential(nn.Conv2d(out_channels * 3, out_channels, 1, bias=False),
                                          nn.BatchNorm2d(out_channels))

        # self.upsample = nn.Upsample(mode='bilinear', align_corners=True)

    def forward(self, x):
        conv11 = self.conv11(x)		#上述self.conv11实现中并没有代用input参数,为何在此处可以直接调用,并且传入参数x,见下解释
        conv33 = self.conv33(x)

        # image pool and upsample
        image_pool = nn.AvgPool2d(kernel_size=x.size()[2:])
        image_pool = image_pool(x)
        upsample = nn.Upsample(size=x.size()[2:], mode='bilinear', align_corners=True)
        upsample = upsample(image_pool)
        upsample = self.conv_p(upsample)


        # concate
        concate = torch.cat([conv11, conv33, upsample], dim=1)

        return self.concate_conv(concate)
  • 解释:
  1. nn.Sequential是继承自Module的类,同样nn.Conv2d等函数也是继承自Module的类,而Module中实现由__calller__函数来调用子类中的forward函数;

  2. nn.Sequential内部封装有forward函数,forward函数带有input参数;

#source:https://pytorch.org/docs/stable/_modules/torch/nn/modules/container.html#Sequential
class Sequential(Module):
    def forward(self, input):
        for module in self._modules.values():
            input = module(input)
        return input

#source:https://pytorch.org/docs/stable/_modules/torch/nn/modules/conv.html#Conv2d
class Conv2d(_ConvNd):
    def forward(self, input):
        if self.padding_mode == 'circular':
            expanded_padding = ((self.padding[1] + 1) // 2, self.padding[1] // 2,
                                (self.padding[0] + 1) // 2, self.padding[0] // 2)
            return F.conv2d(F.pad(input, expanded_padding, mode='circular'),
                            self.weight, self.bias, self.stride,
                            _pair(0), self.dilation, self.groups)
        return F.conv2d(input, self.weight, self.bias, self.stride,
                        self.padding, self.dilation, self.groups)

#source:https://pytorch.org/docs/stable/_modules/torch/nn/modules/activation.html#ReLU
class ReLU(Module):
    @weak_script_method
    def forward(self, input):
        return F.relu(input, inplace=self.inplace)

你可能感兴趣的:(torch)