PyTorch 获取模型中间层方法

获取模型中间层

self.features = nn.Sequential(
            OrderedDict(
                [
                    ("conv0", nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
                    ("norm0", nn.BatchNorm2d(num_init_features)),
                    ("relu0", nn.ReLU(inplace=True)),
                    ("pool0", nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
                ]
            )
        )
self.features.add_module("denseblock%d" % (i + 1), block)
self.features.add_module("transition%d" % (i + 1), trans)
...

通过index获取

        x = torch.rand([1, 3, 320, 320])
        features = []
        for i in range(len(self.features)):
            x = self.features[i](x)
            if i == 2:
                features.append(x)

通过特征名获取

        x = torch.rand([1, 3, 320, 320])
        features = []
        for name, module in self.features._modules.items():
            x = module(x)
            if 'denseblock' in name:
                features.append(x)
                print(x.shape)

你可能感兴趣的:(PyTorch 获取模型中间层方法)