pytorch 笔记:torchsummary、计算模型参数量

1 torchsummary

作用:打印神经网络的结构

以pytorch笔记:搭建简易CNN_UQI-LIUWJ的博客-CSDN博客 中搭建的CNN为例

import torch
from torchsummary import summary

class CNN(nn.Module):
    def __init__(self):
        super(CNN,self).__init__()
 
        self.conv1=nn.Sequential(
            nn.Conv2d(
                in_channels=1,
#输入shape (1,28,28)
                out_channels=16,
#输出shape(16,28,28),16也是卷积核的数量
                kernel_size=5,
                stride=1,
                padding=2),
#如果想要conv2d出来的图片长宽没有变化,那么当stride=1的时候,padding=(kernel_size-1)/2
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
 #在2*2空间里面下采样,输出shape(16,14,14)
        )
           
        self.conv2=nn.Sequential(
            nn.Conv2d(
                in_channels=16,
#输入shape (16,14,14)
                out_channels=32,
#输出shape(32,14,14)
                kernel_size=5,
                stride=1,
                padding=2),
#输出shape(32,7,7),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )
 
        self.fc=nn.Linear(32*7*7,10)
#输出一个十维的东西,表示我每个数字可能性的权重
        
    def forward(self,x):
            x=self.conv1(x)
            x=self.conv2(x)
            x=x.view(x.shape[0],-1)
            x=self.fc(x)
            return x
    
cnn=CNN()
summary(cnn,(1,28,28))

输出的结果是这样的:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 16, 28, 28]             416
              ReLU-2           [-1, 16, 28, 28]               0
         MaxPool2d-3           [-1, 16, 14, 14]               0
            Conv2d-4           [-1, 32, 14, 14]          12,832
              ReLU-5           [-1, 32, 14, 14]               0
         MaxPool2d-6             [-1, 32, 7, 7]               0
            Linear-7                   [-1, 10]          15,690
================================================================
Total params: 28,938
Trainable params: 28,938
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.32
Params size (MB): 0.11
Estimated Total Size (MB): 0.44
----------------------------------------------------------------

2  手动计算参数量

还是使用上面说的cnn,我们结合torch的numel()方法实现之

num=0
for i in cnn.parameters():
    if i.requires_grad==True:
        num+=i.numel()
num
#28938
sum(i.numel() for i in cnn.parameters() if i.requires_grad==True)
#28938

不难发现和上面的结果是一样的

pytorch 笔记:torchsummary、计算模型参数量_第1张图片

 

你可能感兴趣的:(pytorch学习,pytorch,深度学习,神经网络)