pytorch - state_dict() , parameters() 详解

       

目录

1 parameters()

1.1 model.parameters():

1.2 model.named_parameters():

2 state_dict()


        torch.nn.Module 模块中的可学习参数都被包含在该模型的parameters 中,可以通过model.parameters()的方法获取;

        state_dict()是一 个字典,包含了模型各的参数(tensor类型),多用于保存模型;

1 parameters()

1.1 model.parameters():

        源码:

    def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
        r"""Returns an iterator over module parameters.
            返回模块参数上的迭代器。
        This is typically passed to an optimizer.
            这通常被传递给优化器
        Args:
            recurse (bool): if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module.
            如果为True, 则生成该模块 及其所有子模块的参数。否则,只生成该模块的直接成员的形参。

        Yields:
            Parameter: module parameter

        Example::

            >>> for param in model.parameters():
            >>>     print(type(param), param.size())
             (20L,)
             (20L, 1L, 5L, 5L)

        """
        for name, param in self.named_parameters(recurse=recurse):
            yield param

        可以通过Module.parameters()获得网络参数, 迭代的返回模型所有可学习的参数 --  是个生成器

        有些layer不包含可学习的参数,比如(relu, maxpool),因此model.parameters()不会输出这些层;

        parameters()多见于优化器的初始化;

由于parameters()是生成器,因此需要利用循环或者next()来获取数据:

        例子:
 

>>> import torch
>>> import torch.nn as nn

>>> class Net(nn.Module):
...     def __init__(self):
...             super().__init__()
...             self.linear = nn.Linear(2,2)
...     def forward(self,x):
...             out = self.linear(x)
...             return out
... 
>>> net = Net()
>>> for para in net.parameters():
...     print(para)
... 

Parameter containing:
tensor([[-0.1954, -0.2290],
        [ 0.5897, -0.3970]], requires_grad=True)
Parameter containing:
tensor([-0.1808,  0.2044], requires_grad=True)

>>> for para in net.named_parameters():
...     print(para)
... 
('linear.weight', Parameter containing:
tensor([[-0.1954, -0.2290],
        [ 0.5897, -0.3970]], requires_grad=True))
('linear.bias', Parameter containing:
tensor([-0.1808,  0.2044], requires_grad=True))

1.2 model.named_parameters():

        是带有layer name的model.parameters(),其以tuple方式输出,其中包含两个元素,分别为layer name和 model.parameters;

        layer name有后缀 .weight, .bias用于区分权重和偏置;

源码:

    def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Parameter]]:
        r"""Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

            返回模块参数上的迭代器,生成参数名和参数本身。

        Args:
            prefix (str): prefix to prepend to all parameter names.
            recurse (bool): if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module.
        如果为True,则生成该模块及其所有子模块的参数。否则,只生成该模块的直接成员的形参。

        Yields:
            (string, Parameter): Tuple containing the name and parameter

        Example::

            >>> for name, param in self.named_parameters():
            >>>    if name in ['bias']:
            >>>        print(param.size())

        """
        gen = self._named_members(
            lambda module: module._parameters.items(),
            prefix=prefix, recurse=recurse)
        for elem in gen:
            yield elem

代码例子,看1.1部分;

2 state_dict()

        model.state_dict()能够获得模型所有的参数,包括可学习的参数和不可学习的参数,返回值是一个有序字典OrderedDict.

这部分相当于在model.parameters()基础上,又额外获取了不可学习的参数部分;

例子:

        key值是对网络参数的说明,这里是线性层的weight和bias;

>>> class Net(nn.Module):
...     def __init__(self):
...             super().__init__()
...             self.linear = nn.Linear(10,8)
...             self.dropout = nn.Dropout(0.5)
...             self.linear1 = nn.Linear(8,2)
...     def forward(self,x):
...             out = self.dropout(self.linear(x))
...             out = self.linear1(out)
...             return out
... 
>>> net = Net()
>>> net.state_dict()
OrderedDict([('linear.weight', tensor([[ 0.1415, -0.2228, -0.1262,  0.0992, -0.1600,  0.0141, -0.1841, -0.1907,
          0.0295, -0.1853],
        [-0.0399, -0.2487, -0.3085,  0.1602,  0.3135,  0.1379,  0.0696,  0.0362,
         -0.1619, -0.0887],
        [-0.1244, -0.1739,  0.1211, -0.2578, -0.0561,  0.0635, -0.1976, -0.2557,
          0.1761,  0.2553],
        [ 0.0912, -0.1469, -0.3012, -0.1583, -0.0028,  0.2697,  0.1947, -0.0596,
         -0.2144, -0.0785],
        [-0.1770,  0.0411,  0.1663,  0.1861,  0.2769,  0.0990,  0.1883, -0.1801,
          0.2727,  0.1219],
        [-0.1269,  0.0713,  0.2798,  0.1760,  0.0965,  0.1144,  0.2644,  0.0274,
          0.0034,  0.2702],
        [ 0.0628,  0.0682, -0.1842,  0.1461,  0.0678, -0.2264, -0.1249, -0.1715,
          0.1115,  0.2459],
        [ 0.1198, -0.2584,  0.0234,  0.2756,  0.1174, -0.1212,  0.3024, -0.2304,
         -0.2950,  0.0970]])), ('linear.bias', tensor([-0.3036, -0.1933,  0.2412,  0.3137, -0.3007,  0.2386, -0.1975,  0.3127])), ('linear1.weight', tensor([[-0.1725,  0.3027,  0.1985,  0.1394, -0.1245,  0.2913,  0.0136,  0.1633],
        [-0.1558, -0.0865, -0.3032,  0.1374,  0.2967, -0.2886,  0.0430, -0.1246]])), ('linear1.bias', tensor([-0.1232, -0.0690]))])
>>> 

参考:PyTorch中model.state_dict(),model.modules(),model.children(),model.named_children()等含义_yaoyz105的博客-CSDN博客_model.state_dict()

model.parameters()与model.state_dict() - 知乎 

你可能感兴趣的:(编程,pytorch)