目录
1 parameters()
1.1 model.parameters():
1.2 model.named_parameters():
2 state_dict()
torch.nn.Module 模块中的可学习参数都被包含在该模型的parameters 中,可以通过model.parameters()的方法获取;
state_dict()是一 个字典,包含了模型各的参数(tensor类型),多用于保存模型;
源码:
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
r"""Returns an iterator over module parameters.
返回模块参数上的迭代器。
This is typically passed to an optimizer.
这通常被传递给优化器
Args:
recurse (bool): if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module.
如果为True, 则生成该模块 及其所有子模块的参数。否则,只生成该模块的直接成员的形参。
Yields:
Parameter: module parameter
Example::
>>> for param in model.parameters():
>>> print(type(param), param.size())
(20L,)
(20L, 1L, 5L, 5L)
"""
for name, param in self.named_parameters(recurse=recurse):
yield param
可以通过Module.parameters()获得网络参数, 迭代的返回模型所有可学习的参数 -- 是个生成器;
有些layer不包含可学习的参数,比如(relu, maxpool),因此model.parameters()不会输出这些层;
parameters()多见于优化器的初始化;
由于parameters()是生成器,因此需要利用循环或者next()来获取数据:
例子:
>>> import torch
>>> import torch.nn as nn
>>> class Net(nn.Module):
... def __init__(self):
... super().__init__()
... self.linear = nn.Linear(2,2)
... def forward(self,x):
... out = self.linear(x)
... return out
...
>>> net = Net()
>>> for para in net.parameters():
... print(para)
...
Parameter containing:
tensor([[-0.1954, -0.2290],
[ 0.5897, -0.3970]], requires_grad=True)
Parameter containing:
tensor([-0.1808, 0.2044], requires_grad=True)
>>> for para in net.named_parameters():
... print(para)
...
('linear.weight', Parameter containing:
tensor([[-0.1954, -0.2290],
[ 0.5897, -0.3970]], requires_grad=True))
('linear.bias', Parameter containing:
tensor([-0.1808, 0.2044], requires_grad=True))
是带有layer name的model.parameters(),其以tuple方式输出,其中包含两个元素,分别为layer name和 model.parameters;
layer name有后缀 .weight, .bias用于区分权重和偏置;
源码:
def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Parameter]]:
r"""Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.
返回模块参数上的迭代器,生成参数名和参数本身。
Args:
prefix (str): prefix to prepend to all parameter names.
recurse (bool): if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module.
如果为True,则生成该模块及其所有子模块的参数。否则,只生成该模块的直接成员的形参。
Yields:
(string, Parameter): Tuple containing the name and parameter
Example::
>>> for name, param in self.named_parameters():
>>> if name in ['bias']:
>>> print(param.size())
"""
gen = self._named_members(
lambda module: module._parameters.items(),
prefix=prefix, recurse=recurse)
for elem in gen:
yield elem
代码例子,看1.1部分;
model.state_dict()能够获得模型所有的参数,包括可学习的参数和不可学习的参数,返回值是一个有序字典OrderedDict.
这部分相当于在model.parameters()基础上,又额外获取了不可学习的参数部分;
例子:
key值是对网络参数的说明,这里是线性层的weight和bias;
>>> class Net(nn.Module):
... def __init__(self):
... super().__init__()
... self.linear = nn.Linear(10,8)
... self.dropout = nn.Dropout(0.5)
... self.linear1 = nn.Linear(8,2)
... def forward(self,x):
... out = self.dropout(self.linear(x))
... out = self.linear1(out)
... return out
...
>>> net = Net()
>>> net.state_dict()
OrderedDict([('linear.weight', tensor([[ 0.1415, -0.2228, -0.1262, 0.0992, -0.1600, 0.0141, -0.1841, -0.1907,
0.0295, -0.1853],
[-0.0399, -0.2487, -0.3085, 0.1602, 0.3135, 0.1379, 0.0696, 0.0362,
-0.1619, -0.0887],
[-0.1244, -0.1739, 0.1211, -0.2578, -0.0561, 0.0635, -0.1976, -0.2557,
0.1761, 0.2553],
[ 0.0912, -0.1469, -0.3012, -0.1583, -0.0028, 0.2697, 0.1947, -0.0596,
-0.2144, -0.0785],
[-0.1770, 0.0411, 0.1663, 0.1861, 0.2769, 0.0990, 0.1883, -0.1801,
0.2727, 0.1219],
[-0.1269, 0.0713, 0.2798, 0.1760, 0.0965, 0.1144, 0.2644, 0.0274,
0.0034, 0.2702],
[ 0.0628, 0.0682, -0.1842, 0.1461, 0.0678, -0.2264, -0.1249, -0.1715,
0.1115, 0.2459],
[ 0.1198, -0.2584, 0.0234, 0.2756, 0.1174, -0.1212, 0.3024, -0.2304,
-0.2950, 0.0970]])), ('linear.bias', tensor([-0.3036, -0.1933, 0.2412, 0.3137, -0.3007, 0.2386, -0.1975, 0.3127])), ('linear1.weight', tensor([[-0.1725, 0.3027, 0.1985, 0.1394, -0.1245, 0.2913, 0.0136, 0.1633],
[-0.1558, -0.0865, -0.3032, 0.1374, 0.2967, -0.2886, 0.0430, -0.1246]])), ('linear1.bias', tensor([-0.1232, -0.0690]))])
>>>
参考:PyTorch中model.state_dict(),model.modules(),model.children(),model.named_children()等含义_yaoyz105的博客-CSDN博客_model.state_dict()
model.parameters()与model.state_dict() - 知乎