PyTorch 打印网络模型结构


Author :Horizon Max

编程技巧篇:各种操作小结

机器视觉篇:会变魔术 OpenCV

深度学习篇:简单入门 PyTorch

神经网络篇:经典网络模型

算法篇:再忙也别忘了 LeetCode


文章目录

    • PyTorch 打印网络模型结构
        • 使用 Print() 函数打印网络
        • Tensorflow / Keras 打印网络
        • PyTorch summary打印网络结构的方法

PyTorch 打印网络模型结构

使用 Print() 函数打印网络

我们在使用PyTorch打印模型结构时都是这样操作的:

model = simpleNet()
print(model)

打印结果:

simpleNet(
  (layer1): Sequential(
    (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): ReLU()
  )
  (layer2): Sequential(
    (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): ReLU()
  )
  (layer3): Sequential(
    (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): ReLU()
  )
  (dropout): Dropout(p=0.5, inplace=False)
  (fc): Linear(in_features=1024, out_features=10, bias=True)
  (out): Linear(in_features=10, out_features=10, bias=True)
)

可以很容易发现这样打印出来的网络结构 不清晰 ,参数看起来都很 !

如果是一个简单一点的网络可能影响不是很大,但当随着网络层数加深、结构复杂、参数量变大时,就会看得很难受 !


Tensorflow / Keras 打印网络

使用 model.summary() 函数打印出网络结构:

model = MyNet()
model.summary()

对比上面可以看到网络结构 很清晰

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_4 (Dense)              (None, 256)               25856
_________________________________________________________________
leaky_re_lu_3 (LeakyReLU)    (None, 256)               0
_________________________________________________________________
batch_normalization_1 (Batch (None, 256)               1024
_________________________________________________________________
dense_5 (Dense)              (None, 512)               131584
_________________________________________________________________
leaky_re_lu_4 (LeakyReLU)    (None, 512)               0
_________________________________________________________________
batch_normalization_2 (Batch (None, 512)               2048
_________________________________________________________________
dense_6 (Dense)              (None, 1024)              525312
_________________________________________________________________
leaky_re_lu_5 (LeakyReLU)    (None, 1024)              0
_________________________________________________________________
batch_normalization_3 (Batch (None, 1024)              4096
_________________________________________________________________
dense_7 (Dense)              (None, 784)               803600
_________________________________________________________________
reshape_1 (Reshape)          (None, 28, 28, 1)         0
=================================================================
Total params: 1,493,520
Trainable params: 1,489,936
Non-trainable params: 3,584
_________________________________________________________________

PyTorch summary打印网络结构的方法

首先需要安装一个库文件 torchinfo

pip install torchinfo
conda install -c conda-forge torchinfo

然后使用 summary 函数打印网络结构:

model = simpleNet()
batch_size = 64
summary(model, input_size=(batch_size, 3, 32, 32))

网络结构输出结果如下:

==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
simpleNet                                --                        --
├─Sequential: 1-1                        [64, 16, 16, 16]          --
│    └─Conv2d: 2-1                       [64, 16, 32, 32]          448
│    └─BatchNorm2d: 2-2                  [64, 16, 32, 32]          32
│    └─MaxPool2d: 2-3                    [64, 16, 16, 16]          --
│    └─ReLU: 2-4                         [64, 16, 16, 16]          --
├─Sequential: 1-2                        [64, 32, 8, 8]            --
│    └─Conv2d: 2-5                       [64, 32, 16, 16]          4,640
│    └─BatchNorm2d: 2-6                  [64, 32, 16, 16]          64
│    └─MaxPool2d: 2-7                    [64, 32, 8, 8]            --
│    └─ReLU: 2-8                         [64, 32, 8, 8]            --
├─Sequential: 1-3                        [64, 64, 4, 4]            --
│    └─Conv2d: 2-9                       [64, 64, 8, 8]            18,496
│    └─BatchNorm2d: 2-10                 [64, 64, 8, 8]            128
│    └─MaxPool2d: 2-11                   [64, 64, 4, 4]            --
│    └─ReLU: 2-12                        [64, 64, 4, 4]            --
├─Dropout: 1-4                           [64, 1024]                --
├─Linear: 1-5                            [64, 10]                  10,250
├─Linear: 1-6                            [64, 10]                  110
==========================================================================================
Total params: 34,168
Trainable params: 34,168
Non-trainable params: 0
Total mult-adds (M): 181.82
==========================================================================================
Input size (MB): 0.79
Forward/backward pass size (MB): 29.37
Params size (MB): 0.14
Estimated Total Size (MB): 30.29
==========================================================================================

更多详情可以参考 github 源码:torchinfo


你可能感兴趣的:(各种操作小结,PyTorch,summary,网络结构)