Pytorch(六)(模型参数的遍历) —— model.parameters() & model.named_parameters() & model.state_dict()

神经网络的模型参数

model.parameters(), model.named_parameters(), model.state_dict() 这三个方法都可以查看神经网络的参数信息,用于更新参数,或者用于模型的保存。作用都类似,写法略有出入

就以Pytorch之经典神经网络(一) —— 全连接网络(MNIST) 来举例 Pytorch之经典神经网络CNN(一) —— 全连接网络 / MLP (MNIST) (trainset和Dataloader & batch training & learning_rate)_hxxjxw的博客-CSDN博客   

model.named_parameters()

net.named_parameters()中param是len为2的tuple
param[0]是name,fc1.weight、fc1.bias等
param[1]是fc1.weight、fc1.bias等对应的值

一直是0,1,2,......, 这种序号

for _,param in enumerate(net.named_parameters()):
    print(param[0])
    print(param[1])
    print('----------------')

Pytorch(六)(模型参数的遍历) —— model.parameters() & model.named_parameters() & model.state_dict()_第1张图片

model.parameters()

net.parameters()中param就是fc1.weight、fc1.bias等对应的值,没带名字

for _,param in enumerate(net.parameters()):
    print(param)
    print('----------------')

Pytorch(六)(模型参数的遍历) —— model.parameters() & model.named_parameters() & model.state_dict()_第2张图片

model.state_dict()

net.state_dict() 中的param就只是str字符串 fc1.weight, fc1.bias等等

但它们可以作为参数来输出对应的值

for _,param in enumerate(net.state_dict()):
    print(param)
    print(net.state_dict()[param])
    print('----------------')

Pytorch(六)(模型参数的遍历) —— model.parameters() & model.named_parameters() & model.state_dict()_第3张图片

神经网络的各个层

当神经网络是这么定义的时候,即没有用nn.Sequential()

Pytorch(六)(模型参数的遍历) —— model.parameters() & model.named_parameters() & model.state_dict()_第4张图片

此时 print(net)

net = Net()
print(net)

Pytorch(六)(模型参数的遍历) —— model.parameters() & model.named_parameters() & model.state_dict()_第5张图片

输出单个的网络层

net = Net()
print(net.fc1)
print(net.fc2)
print(net.fc3)

输出各个网络层的weight,bias参数

net = Net()
print(net.fc1.weight)
print(net.fc1.bias)
print(net.fc2.weight)
print(net.fc2.bias)
print(net.fc3.weight)
print(net.fc3.bias)

Pytorch(六)(模型参数的遍历) —— model.parameters() & model.named_parameters() & model.state_dict()_第6张图片

当使用nn.Sequential定义的时候

import torch
import torchvision
from torchvision import transforms
from matplotlib import pyplot as plt
from torch import nn
from torch.nn import functional as F

from utils import plot_image,plot_curve,one_hot

# class Net(nn.Module):
#     def __init__(self):
#         super(Net, self).__init__()
#
#         #三层全连接层
#         #wx+b
#         self.fc1 = nn.Linear(28*28, 256)
#         self.fc2 = nn.Linear(256,64)
#         self.fc3 = nn.Linear(64,10)
#
#     def forward(self, x):
#         x = F.rule(self.fc1(x)) #F.relu和torch.relu,用哪个都行
#         x = F.relu(self.fc2(x))
#         x = F.relu(self.fc(3))
#
#         return x


class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()

        self.fc = nn.Sequential(
            nn.Linear(28 * 28, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, 10)
        )

        def forward(self, x):
            # x: [b, 1, 28, 28]
            # h1 = relu(xw1+b1)
            x = self.fc(x)

            return x

batch_size = 512
#一次处理的图片的数量
#gpu一次可以处理并行多张图片

transform = transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.1307,), (0.3081,))
])


trainset = torchvision.datasets.MNIST(
    root='dataset/',
    train=True,  #如果为True,从 training.pt 创建数据,否则从 test.pt 创建数据。
    download=True, #如果为true,则从 Internet 下载数据集并将其放在根目录中。 如果已下载数据集,则不会再次下载。
    transform=transform
)
#train=True表示是训练数据,train=False是测试数据

train_loader = torch.utils.data.DataLoader(
    dataset=trainset,
    batch_size=batch_size,
    shuffle=True  #在加载的时候将图片随机打散
)

testset = torchvision.datasets.MNIST(
    root='dataset/',
    train=False,
    download=True,
    transform=transform
)

train_loader = torch.utils.data.DataLoader(
    dataset=testset,
    batch_size=batch_size,
    shuffle=True
)

net = Net()
print(net.fc)
print(net.fc[0])
print(net.fc[1])
print(net.fc[2])
print(net.fc[3])
print(net.fc[4])
print(net.fc[0].weight)
print(net.fc[0].bias)


Pytorch(六)(模型参数的遍历) —— model.parameters() & model.named_parameters() & model.state_dict()_第7张图片

你可能感兴趣的:(神经网络,深度学习)