Pytorch之提取模型中的某一层

modules()会返回模型中所有模块的迭代器,它能够访问到最内层,比如self.layer1.conv1这个模块,还有一个与它们相对应的是name_children()属性以及named_modules(),这两个不仅会返回模块的迭代器,还会返回网络层的名字。

方法如下:

new_model = nn.Sequential(*list(model.children())[:2] 

取模型中的前两层

如果希望提取出模型中的所有卷积层,可以像下面这样操作:

for layer in model.named_modules():
    if isinstance(layer[1],nn.Conv2d):
         conv_model.add_module(layer[0],layer[1])
#使用isinstance可以判断这个模块是不是所需要的类型实例

 

你可能感兴趣的:(Pytorch学习)