PyTorch:查看模型参数

本文目录

    • 一、模型构建
    • 二、state_dict()
    • 三、named_parameters()

一、模型构建

首先随机构建一个网络模型,随后的state_dict()以及named_parameters都是在模型之后运行的

import torch
import torch.nn as nn
from torch.nn.modules.activation import ReLU
from torch.nn.modules.batchnorm import BatchNorm1d


torch.manual_seed(0)
class Model(nn.Module):

    def __init__(self, in_channel=2, out_channel=4):
        super(Model, self).__init__()


        self.l1 = nn.Sequential(
            nn.Linear(in_channel, out_channel),
            nn.BatchNorm1d(out_channel),
            nn.ReLU())

    def forward(self, x):
        y = self.l1(x)
        return y


model = Model()
input = torch.randn(2, 2)
y = model(input)
print(y)
'''
tensor([[0.9999, 0.0000, 0.0000, 1.0000],
        [0.0000, 1.0000, 0.9960, 0.0000]], grad_fn=)
'''

二、state_dict()

model.state_dict()

返回一个字典,里面包含了整个模型参数,包括buffer

Returns a dictionary containing a whole state of the module.

Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names.

print(model.state_dict().keys())
'''
odict_keys(['l1.0.weight', 'l1.0.bias', 'l1.1.weight', 'l1.1.bias', 'l1.1.running_mean', 'l1.1.running_var', 'l1.1.num_batches_tracked'])
'''


print(model.state_dict())
'''
OrderedDict([('l1.0.weight', tensor([[-0.0053,  0.3793],
        							[-0.5820, -0.5204],
        							[-0.2723,  0.1896],
        							[-0.0140,  0.5607]])), 
        	 ('l1.0.bias', tensor([-0.0628,  0.1871, -0.2137, -0.1390])), 
        	 ('l1.1.weight', tensor([1., 1., 1., 1.])), 
        	 ('l1.1.bias', tensor([0., 0., 0., 0.])), 
        	 ('l1.1.running_mean', tensor([ 0.0021,  0.0166, -0.0129, -0.0015])), 
        	 ('l1.1.running_var', tensor([0.9108, 0.9844, 0.9002, 0.9231])), 
        	 ('l1.1.num_batches_tracked', tensor(1))])
'''


for param in model.state_dict():
    print(param, "\r\t\t\t\t", model.state_dict()[param])
'''
l1.0.weight                      tensor([[-0.0053,  0.3793],
       									 [-0.5820, -0.5204],
        								 [-0.2723,  0.1896],
        								 [-0.0140,  0.5607]])
l1.0.bias                        tensor([-0.0628,  0.1871, -0.2137, -0.1390])
l1.1.weight                      tensor([1., 1., 1., 1.])
l1.1.bias                        tensor([0., 0., 0., 0.])
l1.1.running_mean                tensor([ 0.0021,  0.0166, -0.0129, -0.0015])
l1.1.running_var                 tensor([0.9108, 0.9844, 0.9002, 0.9231])
l1.1.num_batches_tracked         tensor(1)
'''

三、named_parameters()

named_parameters(prefix='', recurse=True)

这个方法和上面的state_dict()相同,都是含有模型的参数,不过这个方法返回的是一个迭代器

  • prefix:在参数名字前面加上前缀
  • recurse:如果为True,产生的参数包括当前模型以及子模型,如果为False,只包含当前模型

返回一个迭代器包括所有的模型参数,包括参数名以及参数值

(string, Parameter) – Tuple containing the name and parameter

Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

注意:这里还有一个parameters()方法,二者的唯一区别是paramters()方法只包含参数的迭代

print(model.parameters())
print(model.named_parameters())
'''


'''


for param in model.parameters():
    print(param)
'''
Parameter containing:  tensor([[-0.0053,  0.3793],
        				       [-0.5820, -0.5204],
        				       [-0.2723,  0.1896],
       						   [-0.0140,  0.5607]], requires_grad=True)
Parameter containing:  tensor([-0.0628,  0.1871, -0.2137, -0.1390], requires_grad=True)
Parameter containing:  tensor([1., 1., 1., 1.], requires_grad=True)
Parameter containing:  tensor([0., 0., 0., 0.], requires_grad=True)
'''


for name, param in model.named_parameters(prefix="xxxxxxx"):
    print(name, param)
'''
xxxxxxx.l1.0.weight
Parameter containing:  tensor([[-0.0053,  0.3793],
        					   [-0.5820, -0.5204],
        					   [-0.2723,  0.1896],
        					   [-0.0140,  0.5607]], requires_grad=True)
xxxxxxx.l1.0.bias 
Parameter containing:  tensor([-0.0628,  0.1871, -0.2137, -0.1390], requires_grad=True)
xxxxxxx.l1.1.weight 
Parameter containing:  tensor([1., 1., 1., 1.], requires_grad=True)
xxxxxxx.l1.1.bias 
Parameter containing:   tensor([0., 0., 0., 0.], requires_grad=True)
'''


print("----------------model parameters---------------------")
for name, param in model.named_parameters(prefix=""):
    if 'bias' in name:
        print(name, param.size())
'''
l1.0.bias torch.Size([4])
l1.1.bias torch.Size([4])
'''

你可能感兴趣的:(PyTorch,pytorch,python)