用pytorch里的children方法自定义网络

假设我们要自定义一个网络,这个网络由resnet-18的所有除了最后一层和倒数第二层,最后一层用我们的自定义层。这时我们可以用到children方法。

children()返回网络模型里的组成元素,且children()返回的是最外层的元素

举个例子:

m = nn.Sequential(nn.Linear(2,2), 
                  nn.ReLU(),
                 nn.Sequential(nn.Sigmoid(), nn.ReLU()))

m.children()返回的是(一个list):

[Linear(in_features=2, out_features=2), ReLU(), Sequential(
   (0): Sigmoid()
   (1): ReLU()
 )]

此时我们对children方法应该有了一些了解,那么回到我们刚才的问题:

pretrained_net = torchvision.models.resnet18(pretrained=True)
net = nn.Sequential(*list(pretrained_net.children())[:-2])

通过上面两行代码就可以获得resnet-18的除最后两层的所有层。

然后我们通过add_module方法添加自己想要的层:

num_classes = 21
net.add_module('final_conv', nn.Conv2d(512, num_classes, kernel_size=1))
net.add_module('transpose_conv', nn.ConvTranspose2d(num_classes, num_classes,
                                    kernel_size=64, padding=16, stride=32))

这样,我们就完成了定义自己的网络。

你可能感兴趣的:(吴恩达深度学习,pytorch,python,深度学习)