pytorch nn.Module类及其参数详解 state_dict和parameters

pytorch nn.Module类

详解Pytorch中的网络构造

pytorch中文文档

pytorch教程之nn.Module类详解——state_dict和parameters两个方法的差异性比较

import torch
import torch.nn.functional as F
from torch.optim import SGD

class MyNet(torch.nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()  # 第一句话,调用父类的构造函数
        self.conv1 = torch.nn.Conv2d(3, 32, 3, 1, 1)
        self.relu1=torch.nn.ReLU()
        self.max_pooling1=torch.nn.MaxPool2d(2,1)

        self.mlp = torch.nn.Sequential( 
            torch.nn.Conv2d(3, 32, 3, 2, 1),
            torch.nn.Sigmoid(),
            torch.nn.MaxPool2d(3,1),)
        # self.conv2 = torch.nn.Conv2d(3, 32, 3, 1, 1)
        # self.relu2=torch.nn.ReLU()
        # self.max_pooling2=torch.nn.MaxPool2d(2,1)

        self.dense1 = torch.nn.Linear(32 * 3 * 3, 128)
        self.dense2 = torch.nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.max_pooling1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.max_pooling2(x)
        x = self.dense1(x)
        x = self.dense2(x)
        return x

model = MyNet() # 构造模型
print(model.parameters())

print('\n')
for name, para in model.named_parameters():
    print(name)
    print(para.shape)
    print("---------------------------")

print('\n')
for name, para in model.named_parameters():
    if 'bias' in name:
        print(name)
        print(para.type)
        print("---------------------------")

print('\n')
no_decay = ["bias", "LayerNorm.weight"]
i = 0
for n, p in model.named_parameters():
    i += 1
    print("{}".format(i))
    print(any(nd in n for nd in no_decay))
params_decay = [n for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)]
params_nodecay = [n for n, p in model.named_parameters() if any(nd in n for nd in no_decay)]
print(params_decay)
print(params_nodecay)

输出:

<generator object Module.parameters at 0x000002018E623048>


conv1.weight
torch.Size([32, 3, 3, 3])
---------------------------
conv1.bias
torch.Size([32])
---------------------------
mlp.0.weight
torch.Size([32, 3, 3, 3])
---------------------------
mlp.0.bias
torch.Size([32])
---------------------------
dense1.weight
torch.Size([128, 288])
---------------------------
dense1.bias
torch.Size([128])
---------------------------
dense2.weight
torch.Size([10, 128])
---------------------------
dense2.bias
torch.Size([10])
---------------------------


conv1.bias
<built-in method type of Parameter object at 0x0000020193FE17C8>
---------------------------
mlp.0.bias
<built-in method type of Parameter object at 0x0000020193FE1868>
---------------------------
dense1.bias
<built-in method type of Parameter object at 0x0000020193FE1908>
---------------------------
dense2.bias
<built-in method type of Parameter object at 0x0000020193FE19A8>
---------------------------


1
False
2
True
3
False
4
True
5
False
6
True
7
False
8
True
['conv1.weight', 'mlp.0.weight', 'dense1.weight', 'dense2.weight']
['conv1.bias', 'mlp.0.bias', 'dense1.bias', 'dense2.bias']```

你可能感兴趣的:(Pytorch)