查看模型各层参数(Pytorch)

利用Pytorch搭建CNN网络

这个实验用到的数据集是Mnist数据集,图片维度是1×28×28

import torch.nn as nn

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        # 卷积层
        self.conv1 = nn.Sequential(
            nn.Conv2d(   # 图片的维度:(1, 28, 28)
                in_channels= 1,   # 图片的高度
                out_channels= 16, # 输出的高度:filter的个数
                kernel_size =5,   # filter的像素点是5×5
                stride = 1,       # 每次扫描跳的范围
                padding = 2       # 补全边缘像素点
            ),  # 图片的维度:(16,28,28)
            nn.ReLU(),  # 图片的维度:(16,28,28)
            nn.MaxPool2d(kernel_size=2,),  # 图片的维度:(16,14,14)
        )
        # 卷积层
        self.conv2 = nn.Sequential(  # 图片的维度:(16,14,14)
            nn.Conv2d(16,32,5,1,2),  # 图片的维度:(32,14,14)
            nn.ReLU(),
            nn.MaxPool2d(2)  # 图片的维度:(32,7,7)
        )
        # 线性层
        self.out = nn.Linear(32*7*7, 10)

    # 展平操作
    def forward(self, x):
        print(x.size())  # 查看模型的输入,tensorboardX input_to_model
        x = self.conv1(x)
        print(x.size()) 
        x = self.conv2(x)  # 图片的维度:(batch,32,7,7)
        print(x.size()) 
        # 展平操作, -1表示自适应
        x = x.view(x.size(0), -1)  # 图片的维度:(batch,32*7*7)
        print(x.size()) 
        output = self.out(x)
        return output

cnn = CNN()
print(cnn)

模型的结构:

CNN(
  (conv1): Sequential(
    (0): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (conv2): Sequential(
    (0): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (out): Linear(in_features=1568, out_features=10, bias=True)
)

 模型参数访问:

  • parameters()
  • named_parameters()
for name, parameter in cnn.named_parameters():
    print(name, ':', parameter.size())

 输出:

conv1.0.weight : torch.Size([16, 1, 5, 5])
conv1.0.bias : torch.Size([16])
conv2.0.weight : torch.Size([32, 16, 5, 5])
conv2.0.bias : torch.Size([32])
out.weight : torch.Size([10, 1568])
out.bias : torch.Size([10])

 

你可能感兴趣的:(pytorch,python,deep,learning)