Pytorch model

Pytorch model

  • model.modules
  • model.named_modules
  • model.children
  • model.named_children
  • model.parameters
  • model.named_parameters
  • model.state_dict
  • 参考

搭建一个简单网络,继承nn.Modules

import torch 
import torch.nn as nn 

class Net(nn.Module):

    def __init__(self, num_class=10):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=6, kernel_size=3),
            nn.BatchNorm2d(6),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(in_channels=6, out_channels=9, kernel_size=3),
            nn.BatchNorm2d(9),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.classifier = nn.Sequential(
            nn.Linear(9*8*8, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(128, num_class)
        )

    def forward(self, x):
        output = self.features(x)
        output = output.view(output.size()[0], -1)
        output = self.classifier(output)
    
        return output

model = Net()

net,features和classifier两个由Sequential容器组成的nn.Module子类,features和classifier各自又包含众多的网络层,它们都属于nn.Module子类,所以从外到内共有3个层次。

model.modules

model.modules()迭代遍历模型的所有子层,所有子层即指nn.Module子类,在本文的例子中,Net(), features(), classifier(),以及nn.xxx构成的卷积,池化,ReLU, Linear, BN, Dropout等都是nn.Module子类,也就是model.modules()会迭代的遍历它们所有对象。

model_modules列表中共有15个元素,首先是整个Net,然后遍历了Net下的features子层,进一步遍历了feature下的所有层,然后又遍历了classifier子层以及其下的所有层

model.named_modules

model.named_modules()不但返回模型的所有子层,还会返回这些层的名字。返回层以及层的名字的好处是可以按名字通过迭代的方法修改特定的层。
features和classifier是Net的子层,而conv2d, ReLU, BatchNorm, Maxpool2d这些有时features的子层, Linear, Dropout, ReLU等是classifier的子层,上面的model.modules()不但会遍历模型的子层,还会遍历子层的子层,以及所有子层。

model.children

model.children()只会遍历模型的子层,这里即是features和classifier

model.named_children

model.named_children()不但迭代的遍历模型的子层,还会返回子层的名字:

model.parameters

迭代地返回模型的所有参数。Python3返回的是迭代器

model.named_parameters

迭代的返回带有名字的参数,会给每个参数加上带有 .weight或 .bias的名字以区分权重和偏置

model.state_dict

model.state_dict()直接返回模型的字典,和前面几个方法不同的是这里不需要迭代,它本身就是一个字典,可以直接通过修改state_dict来修改模型各层的参数,用于参数剪枝特别方便

参考

https://blog.csdn.net/Pl_Sun/article/details/106978171
https://blog.csdn.net/Pl_Sun/article/details/106976907
https://www.cnblogs.com/wangguchangqing/p/11058525.html

你可能感兴趣的:(pytorch,pytorch,model)