Pytorch - 使用torchsummary/torchsummaryX/torchinfo库打印模型结构、输出维度和参数信息

1 torchsummary/torchsummaryX

torchsummary Github地址:https://github.com/sksq96/pytorch-summary

torchsummaryX Github地址:https://github.com/nmhkahn/torchsummaryX

torchinfo Github地址:https://github.com/TylerYep/torchinfo

1.1 安装

安装torchsummary

pip install torchsummary

安装torchsummaryX

pip install torchsummaryX

安装torchinfo
pip

pip install torchinfo

conda

conda install -c conda-forge torchinfo

1.2 使用

1.2.1 torchsummary的使用

from torchvision import models
from torchsummary import summary

if __name__ == '__main__':
    resnet18 = models.resnet18().cuda() # 不加.cuda()会报错
    summary(resnet18, (3, 224, 224))

输出

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
            Conv2d-5           [-1, 64, 56, 56]          36,864
       BatchNorm2d-6           [-1, 64, 56, 56]             128
              ReLU-7           [-1, 64, 56, 56]               0
            Conv2d-8           [-1, 64, 56, 56]          36,864
       BatchNorm2d-9           [-1, 64, 56, 56]             128
             ReLU-10           [-1, 64, 56, 56]               0
       BasicBlock-11           [-1, 64, 56, 56]               0
           Conv2d-12           [-1, 64, 56, 56]          36,864
      BatchNorm2d-13           [-1, 64, 56, 56]             128
             ReLU-14           [-1, 64, 56, 56]               0
           Conv2d-15           [-1, 64, 56, 56]          36,864
      BatchNorm2d-16           [-1, 64, 56, 56]             128
             ReLU-17           [-1, 64, 56, 56]               0
       BasicBlock-18           [-1, 64, 56, 56]               0
           Conv2d-19          [-1, 128, 28, 28]          73,728
      BatchNorm2d-20          [-1, 128, 28, 28]             256
             ReLU-21          [-1, 128, 28, 28]               0
           Conv2d-22          [-1, 128, 28, 28]         147,456
      BatchNorm2d-23          [-1, 128, 28, 28]             256
           Conv2d-24          [-1, 128, 28, 28]           8,192
      BatchNorm2d-25          [-1, 128, 28, 28]             256
             ReLU-26          [-1, 128, 28, 28]               0
       BasicBlock-27          [-1, 128, 28, 28]               0
           Conv2d-28          [-1, 128, 28, 28]         147,456
      BatchNorm2d-29          [-1, 128, 28, 28]             256
             ReLU-30          [-1, 128, 28, 28]               0
           Conv2d-31          [-1, 128, 28, 28]         147,456
      BatchNorm2d-32          [-1, 128, 28, 28]             256
             ReLU-33          [-1, 128, 28, 28]               0
       BasicBlock-34          [-1, 128, 28, 28]               0
           Conv2d-35          [-1, 256, 14, 14]         294,912
      BatchNorm2d-36          [-1, 256, 14, 14]             512
             ReLU-37          [-1, 256, 14, 14]               0
           Conv2d-38          [-1, 256, 14, 14]         589,824
      BatchNorm2d-39          [-1, 256, 14, 14]             512
           Conv2d-40          [-1, 256, 14, 14]          32,768
      BatchNorm2d-41          [-1, 256, 14, 14]             512
             ReLU-42          [-1, 256, 14, 14]               0
       BasicBlock-43          [-1, 256, 14, 14]               0
           Conv2d-44          [-1, 256, 14, 14]         589,824
      BatchNorm2d-45          [-1, 256, 14, 14]             512
             ReLU-46          [-1, 256, 14, 14]               0
           Conv2d-47          [-1, 256, 14, 14]         589,824
      BatchNorm2d-48          [-1, 256, 14, 14]             512
             ReLU-49          [-1, 256, 14, 14]               0
       BasicBlock-50          [-1, 256, 14, 14]               0
           Conv2d-51            [-1, 512, 7, 7]       1,179,648
      BatchNorm2d-52            [-1, 512, 7, 7]           1,024
             ReLU-53            [-1, 512, 7, 7]               0
           Conv2d-54            [-1, 512, 7, 7]       2,359,296
      BatchNorm2d-55            [-1, 512, 7, 7]           1,024
           Conv2d-56            [-1, 512, 7, 7]         131,072
      BatchNorm2d-57            [-1, 512, 7, 7]           1,024
             ReLU-58            [-1, 512, 7, 7]               0
       BasicBlock-59            [-1, 512, 7, 7]               0
           Conv2d-60            [-1, 512, 7, 7]       2,359,296
      BatchNorm2d-61            [-1, 512, 7, 7]           1,024
             ReLU-62            [-1, 512, 7, 7]               0
           Conv2d-63            [-1, 512, 7, 7]       2,359,296
      BatchNorm2d-64            [-1, 512, 7, 7]           1,024
             ReLU-65            [-1, 512, 7, 7]               0
       BasicBlock-66            [-1, 512, 7, 7]               0
AdaptiveAvgPool2d-67            [-1, 512, 1, 1]               0
           Linear-68                 [-1, 1000]         513,000
================================================================
Total params: 11,689,512
Trainable params: 11,689,512
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 62.79
Params size (MB): 44.59
Estimated Total Size (MB): 107.96
----------------------------------------------------------------

1.2.2 torchsummaryX的使用

import torch
import torch.nn as nn
from torchsummaryX import summary

class Net(nn.Module):
    def __init__(self,
                 vocab_size=20, embed_dim=300,
                 hidden_dim=512, num_layers=2):
        super().__init__()

        self.hidden_dim = hidden_dim
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.encoder = nn.LSTM(embed_dim, hidden_dim,
                               num_layers=num_layers)
        self.decoder = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x):
        embed = self.embedding(x)
        out, hidden = self.encoder(embed)
        out = self.decoder(out)
        out = out.view(-1, out.size(2))
        return out, hidden

if __name__ == '__main__':
    inputs = torch.zeros((100, 1), dtype=torch.long)  # [length, batch_size]
    summary(Net(), inputs)

输出

===========================================================
            Kernel Shape   Output Shape   Params  Mult-Adds
Layer                                                      
0_embedding    [300, 20]  [100, 1, 300]     6000       6000
1_encoder              -  [100, 1, 512]  3768320    3760128
2_decoder      [512, 20]   [100, 1, 20]    10260      10240
-----------------------------------------------------------
                       Totals
Total params          3784580
Trainable params      3784580
Non-trainable params        0
Mult-Adds             3776368
===========================================================

1.2.3 torchinfo的使用

from torchvision import models
from torchinfo import summary

if __name__ == '__main__':
    resnet18 = models.resnet18().cuda() # 不加.cuda()会报错
    summary(resnet18, input_size=(1, 3, 244, 244))

输出

==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
ResNet                                   [1, 1000]                 --
├─Conv2d: 1-1                            [1, 64, 122, 122]         9,408
├─BatchNorm2d: 1-2                       [1, 64, 122, 122]         128
├─ReLU: 1-3                              [1, 64, 122, 122]         --
├─MaxPool2d: 1-4                         [1, 64, 61, 61]           --
├─Sequential: 1-5                        [1, 64, 61, 61]           --
│    └─BasicBlock: 2-1                   [1, 64, 61, 61]           --
│    │    └─Conv2d: 3-1                  [1, 64, 61, 61]           36,864
│    │    └─BatchNorm2d: 3-2             [1, 64, 61, 61]           128
│    │    └─ReLU: 3-3                    [1, 64, 61, 61]           --
│    │    └─Conv2d: 3-4                  [1, 64, 61, 61]           36,864
│    │    └─BatchNorm2d: 3-5             [1, 64, 61, 61]           128
│    │    └─ReLU: 3-6                    [1, 64, 61, 61]           --
│    └─BasicBlock: 2-2                   [1, 64, 61, 61]           --
│    │    └─Conv2d: 3-7                  [1, 64, 61, 61]           36,864
│    │    └─BatchNorm2d: 3-8             [1, 64, 61, 61]           128
│    │    └─ReLU: 3-9                    [1, 64, 61, 61]           --
│    │    └─Conv2d: 3-10                 [1, 64, 61, 61]           36,864
│    │    └─BatchNorm2d: 3-11            [1, 64, 61, 61]           128
│    │    └─ReLU: 3-12                   [1, 64, 61, 61]           --
├─Sequential: 1-6                        [1, 128, 31, 31]          --
│    └─BasicBlock: 2-3                   [1, 128, 31, 31]          --
│    │    └─Conv2d: 3-13                 [1, 128, 31, 31]          73,728
│    │    └─BatchNorm2d: 3-14            [1, 128, 31, 31]          256
│    │    └─ReLU: 3-15                   [1, 128, 31, 31]          --
│    │    └─Conv2d: 3-16                 [1, 128, 31, 31]          147,456
│    │    └─BatchNorm2d: 3-17            [1, 128, 31, 31]          256
│    │    └─Sequential: 3-18             [1, 128, 31, 31]          8,448
│    │    └─ReLU: 3-19                   [1, 128, 31, 31]          --
│    └─BasicBlock: 2-4                   [1, 128, 31, 31]          --
│    │    └─Conv2d: 3-20                 [1, 128, 31, 31]          147,456
│    │    └─BatchNorm2d: 3-21            [1, 128, 31, 31]          256
│    │    └─ReLU: 3-22                   [1, 128, 31, 31]          --
│    │    └─Conv2d: 3-23                 [1, 128, 31, 31]          147,456
│    │    └─BatchNorm2d: 3-24            [1, 128, 31, 31]          256
│    │    └─ReLU: 3-25                   [1, 128, 31, 31]          --
├─Sequential: 1-7                        [1, 256, 16, 16]          --
│    └─BasicBlock: 2-5                   [1, 256, 16, 16]          --
│    │    └─Conv2d: 3-26                 [1, 256, 16, 16]          294,912
│    │    └─BatchNorm2d: 3-27            [1, 256, 16, 16]          512
│    │    └─ReLU: 3-28                   [1, 256, 16, 16]          --
│    │    └─Conv2d: 3-29                 [1, 256, 16, 16]          589,824
│    │    └─BatchNorm2d: 3-30            [1, 256, 16, 16]          512
│    │    └─Sequential: 3-31             [1, 256, 16, 16]          33,280
│    │    └─ReLU: 3-32                   [1, 256, 16, 16]          --
│    └─BasicBlock: 2-6                   [1, 256, 16, 16]          --
│    │    └─Conv2d: 3-33                 [1, 256, 16, 16]          589,824
│    │    └─BatchNorm2d: 3-34            [1, 256, 16, 16]          512
│    │    └─ReLU: 3-35                   [1, 256, 16, 16]          --
│    │    └─Conv2d: 3-36                 [1, 256, 16, 16]          589,824
│    │    └─BatchNorm2d: 3-37            [1, 256, 16, 16]          512
│    │    └─ReLU: 3-38                   [1, 256, 16, 16]          --
├─Sequential: 1-8                        [1, 512, 8, 8]            --
│    └─BasicBlock: 2-7                   [1, 512, 8, 8]            --
│    │    └─Conv2d: 3-39                 [1, 512, 8, 8]            1,179,648
│    │    └─BatchNorm2d: 3-40            [1, 512, 8, 8]            1,024
│    │    └─ReLU: 3-41                   [1, 512, 8, 8]            --
│    │    └─Conv2d: 3-42                 [1, 512, 8, 8]            2,359,296
│    │    └─BatchNorm2d: 3-43            [1, 512, 8, 8]            1,024
│    │    └─Sequential: 3-44             [1, 512, 8, 8]            132,096
│    │    └─ReLU: 3-45                   [1, 512, 8, 8]            --
│    └─BasicBlock: 2-8                   [1, 512, 8, 8]            --
│    │    └─Conv2d: 3-46                 [1, 512, 8, 8]            2,359,296
│    │    └─BatchNorm2d: 3-47            [1, 512, 8, 8]            1,024
│    │    └─ReLU: 3-48                   [1, 512, 8, 8]            --
│    │    └─Conv2d: 3-49                 [1, 512, 8, 8]            2,359,296
│    │    └─BatchNorm2d: 3-50            [1, 512, 8, 8]            1,024
│    │    └─ReLU: 3-51                   [1, 512, 8, 8]            --
├─AdaptiveAvgPool2d: 1-9                 [1, 512, 1, 1]            --
├─Linear: 1-10                           [1, 1000]                 513,000
==========================================================================================
Total params: 11,689,512
Trainable params: 11,689,512
Non-trainable params: 0
Total mult-adds (G): 2.27
==========================================================================================
Input size (MB): 0.71
Forward/backward pass size (MB): 48.20
Params size (MB): 46.76
Estimated Total Size (MB): 95.67
==========================================================================================

你可能感兴趣的:(Pytorch,pytorch,模型结构,torchsummary,torchinfo,torchsummaryX)