【模型复杂度】torchsummary、torchstat和profile的使用

  模型的复杂度分析也是不同模型比较的重要指标,包括模型参数、浮点运算次数(Floating point operations,FLOPs),内存占用和运存占用等,记录一下可以评价模型复杂度的方法。

1. torchsummary

  torchsummary可计算模型的总参数和每一层的参数,但无法计算FLOPs。

以resnet18为例

import torchsummary
import torchvision.models as modelss
model = modelss.resnet18(pretrained=True)
torchsummary.summary(model, (3, 224, 224), device='cpu')

也可写为:

torchsummary.summary(model.cuda(), (3, 224, 224))

输出:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
            Conv2d-5           [-1, 64, 56, 56]          36,864
       BatchNorm2d-6           [-1, 64, 56, 56]             128
              ReLU-7           [-1, 64, 56, 56]               0
            Conv2d-8           [-1, 64, 56, 56]          36,864
       BatchNorm2d-9           [-1, 64, 56, 56]             128
             ReLU-10           [-1, 64, 56, 56]               0
       BasicBlock-11           [-1, 64, 56, 56]               0
           Conv2d-12           [-1, 64, 56, 56]          36,864
      BatchNorm2d-13           [-1, 64, 56, 56]             128
             ReLU-14           [-1, 64, 56, 56]               0
           Conv2d-15           [-1, 64, 56, 56]          36,864
      BatchNorm2d-16           [-1, 64, 56, 56]             128
             ReLU-17           [-1, 64, 56, 56]               0
       BasicBlock-18           [-1, 64, 56, 56]               0
           Conv2d-19          [-1, 128, 28, 28]          73,728
      BatchNorm2d-20          [-1, 128, 28, 28]             256
             ReLU-21          [-1, 128, 28, 28]               0
           Conv2d-22          [-1, 128, 28, 28]         147,456
      BatchNorm2d-23          [-1, 128, 28, 28]             256
           Conv2d-24          [-1, 128, 28, 28]           8,192
      BatchNorm2d-25          [-1, 128, 28, 28]             256
             ReLU-26          [-1, 128, 28, 28]               0
       BasicBlock-27          [-1, 128, 28, 28]               0
           Conv2d-28          [-1, 128, 28, 28]         147,456
      BatchNorm2d-29          [-1, 128, 28, 28]             256
             ReLU-30          [-1, 128, 28, 28]               0
           Conv2d-31          [-1, 128, 28, 28]         147,456
      BatchNorm2d-32          [-1, 128, 28, 28]             256
             ReLU-33          [-1, 128, 28, 28]               0
       BasicBlock-34          [-1, 128, 28, 28]               0
           Conv2d-35          [-1, 256, 14, 14]         294,912
      BatchNorm2d-36          [-1, 256, 14, 14]             512
             ReLU-37          [-1, 256, 14, 14]               0
           Conv2d-38          [-1, 256, 14, 14]         589,824
      BatchNorm2d-39          [-1, 256, 14, 14]             512
           Conv2d-40          [-1, 256, 14, 14]          32,768
      BatchNorm2d-41          [-1, 256, 14, 14]             512
             ReLU-42          [-1, 256, 14, 14]               0
       BasicBlock-43          [-1, 256, 14, 14]               0
           Conv2d-44          [-1, 256, 14, 14]         589,824
      BatchNorm2d-45          [-1, 256, 14, 14]             512
             ReLU-46          [-1, 256, 14, 14]               0
           Conv2d-47          [-1, 256, 14, 14]         589,824
      BatchNorm2d-48          [-1, 256, 14, 14]             512
             ReLU-49          [-1, 256, 14, 14]               0
       BasicBlock-50          [-1, 256, 14, 14]               0
           Conv2d-51            [-1, 512, 7, 7]       1,179,648
      BatchNorm2d-52            [-1, 512, 7, 7]           1,024
             ReLU-53            [-1, 512, 7, 7]               0
           Conv2d-54            [-1, 512, 7, 7]       2,359,296
      BatchNorm2d-55            [-1, 512, 7, 7]           1,024
           Conv2d-56            [-1, 512, 7, 7]         131,072
      BatchNorm2d-57            [-1, 512, 7, 7]           1,024
             ReLU-58            [-1, 512, 7, 7]               0
       BasicBlock-59            [-1, 512, 7, 7]               0
           Conv2d-60            [-1, 512, 7, 7]       2,359,296
      BatchNorm2d-61            [-1, 512, 7, 7]           1,024
             ReLU-62            [-1, 512, 7, 7]               0
           Conv2d-63            [-1, 512, 7, 7]       2,359,296
      BatchNorm2d-64            [-1, 512, 7, 7]           1,024
             ReLU-65            [-1, 512, 7, 7]               0
       BasicBlock-66            [-1, 512, 7, 7]               0
AdaptiveAvgPool2d-67            [-1, 512, 1, 1]               0
           Linear-68                 [-1, 1000]         513,000
================================================================
Total params: 11,689,512
Trainable params: 11,689,512
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 62.79
Params size (MB): 44.59
Estimated Total Size (MB): 107.96
----------------------------------------------------------------

对Vision Transformer也同样适用

import torchsummary
import timm
model = timm.create_model('vit_small_patch16_224', pretrained=True)
torchsummary.summary(model.cuda(), (3, 224, 224))

输出为:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1          [-1, 768, 14, 14]         590,592
          Identity-2             [-1, 196, 768]               0
        PatchEmbed-3             [-1, 196, 768]               0
           Dropout-4             [-1, 197, 768]               0
         LayerNorm-5             [-1, 197, 768]           1,536
            Linear-6            [-1, 197, 2304]       1,769,472
           Dropout-7          [-1, 8, 197, 197]               0
            Linear-8             [-1, 197, 768]         590,592
           Dropout-9             [-1, 197, 768]               0
        Attention-10             [-1, 197, 768]               0
         Identity-11             [-1, 197, 768]               0
        LayerNorm-12             [-1, 197, 768]           1,536
           Linear-13            [-1, 197, 2304]       1,771,776
             GELU-14            [-1, 197, 2304]               0
          Dropout-15            [-1, 197, 2304]               0
           Linear-16             [-1, 197, 768]       1,770,240
          Dropout-17             [-1, 197, 768]               0
              Mlp-18             [-1, 197, 768]               0
         Identity-19             [-1, 197, 768]               0
            Block-20             [-1, 197, 768]               0
        LayerNorm-21             [-1, 197, 768]           1,536
           Linear-22            [-1, 197, 2304]       1,769,472
          Dropout-23          [-1, 8, 197, 197]               0
           Linear-24             [-1, 197, 768]         590,592
          Dropout-25             [-1, 197, 768]               0
        Attention-26             [-1, 197, 768]               0
         Identity-27             [-1, 197, 768]               0
        LayerNorm-28             [-1, 197, 768]           1,536
           Linear-29            [-1, 197, 2304]       1,771,776
             GELU-30            [-1, 197, 2304]               0
          Dropout-31            [-1, 197, 2304]               0
           Linear-32             [-1, 197, 768]       1,770,240
          Dropout-33             [-1, 197, 768]               0
              Mlp-34             [-1, 197, 768]               0
         Identity-35             [-1, 197, 768]               0
            Block-36             [-1, 197, 768]               0
        LayerNorm-37             [-1, 197, 768]           1,536
           Linear-38            [-1, 197, 2304]       1,769,472
          Dropout-39          [-1, 8, 197, 197]               0
           Linear-40             [-1, 197, 768]         590,592
          Dropout-41             [-1, 197, 768]               0
        Attention-42             [-1, 197, 768]               0
         Identity-43             [-1, 197, 768]               0
        LayerNorm-44             [-1, 197, 768]           1,536
           Linear-45            [-1, 197, 2304]       1,771,776
             GELU-46            [-1, 197, 2304]               0
          Dropout-47            [-1, 197, 2304]               0
           Linear-48             [-1, 197, 768]       1,770,240
          Dropout-49             [-1, 197, 768]               0
              Mlp-50             [-1, 197, 768]               0
         Identity-51             [-1, 197, 768]               0
            Block-52             [-1, 197, 768]               0
        LayerNorm-53             [-1, 197, 768]           1,536
           Linear-54            [-1, 197, 2304]       1,769,472
          Dropout-55          [-1, 8, 197, 197]               0
           Linear-56             [-1, 197, 768]         590,592
          Dropout-57             [-1, 197, 768]               0
        Attention-58             [-1, 197, 768]               0
         Identity-59             [-1, 197, 768]               0
        LayerNorm-60             [-1, 197, 768]           1,536
           Linear-61            [-1, 197, 2304]       1,771,776
             GELU-62            [-1, 197, 2304]               0
          Dropout-63            [-1, 197, 2304]               0
           Linear-64             [-1, 197, 768]       1,770,240
          Dropout-65             [-1, 197, 768]               0
              Mlp-66             [-1, 197, 768]               0
         Identity-67             [-1, 197, 768]               0
            Block-68             [-1, 197, 768]               0
        LayerNorm-69             [-1, 197, 768]           1,536
           Linear-70            [-1, 197, 2304]       1,769,472
          Dropout-71          [-1, 8, 197, 197]               0
           Linear-72             [-1, 197, 768]         590,592
          Dropout-73             [-1, 197, 768]               0
        Attention-74             [-1, 197, 768]               0
         Identity-75             [-1, 197, 768]               0
        LayerNorm-76             [-1, 197, 768]           1,536
           Linear-77            [-1, 197, 2304]       1,771,776
             GELU-78            [-1, 197, 2304]               0
          Dropout-79            [-1, 197, 2304]               0
           Linear-80             [-1, 197, 768]       1,770,240
          Dropout-81             [-1, 197, 768]               0
              Mlp-82             [-1, 197, 768]               0
         Identity-83             [-1, 197, 768]               0
            Block-84             [-1, 197, 768]               0
        LayerNorm-85             [-1, 197, 768]           1,536
           Linear-86            [-1, 197, 2304]       1,769,472
          Dropout-87          [-1, 8, 197, 197]               0
           Linear-88             [-1, 197, 768]         590,592
          Dropout-89             [-1, 197, 768]               0
        Attention-90             [-1, 197, 768]               0
         Identity-91             [-1, 197, 768]               0
        LayerNorm-92             [-1, 197, 768]           1,536
           Linear-93            [-1, 197, 2304]       1,771,776
             GELU-94            [-1, 197, 2304]               0
          Dropout-95            [-1, 197, 2304]               0
           Linear-96             [-1, 197, 768]       1,770,240
          Dropout-97             [-1, 197, 768]               0
              Mlp-98             [-1, 197, 768]               0
         Identity-99             [-1, 197, 768]               0
           Block-100             [-1, 197, 768]               0
       LayerNorm-101             [-1, 197, 768]           1,536
          Linear-102            [-1, 197, 2304]       1,769,472
         Dropout-103          [-1, 8, 197, 197]               0
          Linear-104             [-1, 197, 768]         590,592
         Dropout-105             [-1, 197, 768]               0
       Attention-106             [-1, 197, 768]               0
        Identity-107             [-1, 197, 768]               0
       LayerNorm-108             [-1, 197, 768]           1,536
          Linear-109            [-1, 197, 2304]       1,771,776
            GELU-110            [-1, 197, 2304]               0
         Dropout-111            [-1, 197, 2304]               0
          Linear-112             [-1, 197, 768]       1,770,240
         Dropout-113             [-1, 197, 768]               0
             Mlp-114             [-1, 197, 768]               0
        Identity-115             [-1, 197, 768]               0
           Block-116             [-1, 197, 768]               0
       LayerNorm-117             [-1, 197, 768]           1,536
          Linear-118            [-1, 197, 2304]       1,769,472
         Dropout-119          [-1, 8, 197, 197]               0
          Linear-120             [-1, 197, 768]         590,592
         Dropout-121             [-1, 197, 768]               0
       Attention-122             [-1, 197, 768]               0
        Identity-123             [-1, 197, 768]               0
       LayerNorm-124             [-1, 197, 768]           1,536
          Linear-125            [-1, 197, 2304]       1,771,776
            GELU-126            [-1, 197, 2304]               0
         Dropout-127            [-1, 197, 2304]               0
          Linear-128             [-1, 197, 768]       1,770,240
         Dropout-129             [-1, 197, 768]               0
             Mlp-130             [-1, 197, 768]               0
        Identity-131             [-1, 197, 768]               0
           Block-132             [-1, 197, 768]               0
       LayerNorm-133             [-1, 197, 768]           1,536
        Identity-134                  [-1, 768]               0
          Linear-135                 [-1, 1000]         769,000
================================================================
Total params: 48,602,344
Trainable params: 48,602,344
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 237.11
Params size (MB): 185.40
Estimated Total Size (MB): 423.09
----------------------------------------------------------------

2. torchstat

  torchstat可计算模型参数(params)、浮点运算次数(FLOPs)、内存占用(memory)和运存占用(MemR+W),主打的就是一个方方面面。

from torchstat import stat
import torchvision.models as modelss
model = modelss.resnet18(pretrained=True)
stat(model.to('cpu'), (3, 224, 224))

输出为:

                 module name  input shape output shape      params memory(MB)             MAdd            Flops  MemRead(B)  MemWrite(B) duration[%]    MemR+W(B)
0                      conv1    3 224 224   64 112 112      9408.0       3.06    235,225,088.0    118,013,952.0    639744.0    3211264.0       7.48%    3851008.0
1                        bn1   64 112 112   64 112 112       128.0       3.06      3,211,264.0      1,605,632.0   3211776.0    3211264.0       6.54%    6423040.0
2                       relu   64 112 112   64 112 112         0.0       3.06        802,816.0        802,816.0   3211264.0    3211264.0       0.53%    6422528.0
3                    maxpool   64 112 112   64  56  56         0.0       0.77      1,605,632.0        802,816.0   3211264.0     802816.0       5.68%    4014080.0
4             layer1.0.conv1   64  56  56   64  56  56     36864.0       0.77    231,010,304.0    115,605,504.0    950272.0     802816.0       5.85%    1753088.0
5               layer1.0.bn1   64  56  56   64  56  56       128.0       0.77        802,816.0        401,408.0    803328.0     802816.0       1.96%    1606144.0
6              layer1.0.relu   64  56  56   64  56  56         0.0       0.77        200,704.0        200,704.0    802816.0     802816.0       0.08%    1605632.0
7             layer1.0.conv2   64  56  56   64  56  56     36864.0       0.77    231,010,304.0    115,605,504.0    950272.0     802816.0       2.34%    1753088.0
8               layer1.0.bn2   64  56  56   64  56  56       128.0       0.77        802,816.0        401,408.0    803328.0     802816.0       2.01%    1606144.0
9             layer1.1.conv1   64  56  56   64  56  56     36864.0       0.77    231,010,304.0    115,605,504.0    950272.0     802816.0       2.46%    1753088.0
10              layer1.1.bn1   64  56  56   64  56  56       128.0       0.77        802,816.0        401,408.0    803328.0     802816.0       2.04%    1606144.0
11             layer1.1.relu   64  56  56   64  56  56         0.0       0.77        200,704.0        200,704.0    802816.0     802816.0       0.07%    1605632.0
12            layer1.1.conv2   64  56  56   64  56  56     36864.0       0.77    231,010,304.0    115,605,504.0    950272.0     802816.0       2.45%    1753088.0
13              layer1.1.bn2   64  56  56   64  56  56       128.0       0.77        802,816.0        401,408.0    803328.0     802816.0       2.09%    1606144.0
14            layer2.0.conv1   64  56  56  128  28  28     73728.0       0.38    115,505,152.0     57,802,752.0   1097728.0     401408.0       6.13%    1499136.0
15              layer2.0.bn1  128  28  28  128  28  28       256.0       0.38        401,408.0        200,704.0    402432.0     401408.0       0.29%     803840.0
16             layer2.0.relu  128  28  28  128  28  28         0.0       0.38        100,352.0        100,352.0    401408.0     401408.0       0.11%     802816.0
17            layer2.0.conv2  128  28  28  128  28  28    147456.0       0.38    231,110,656.0    115,605,504.0    991232.0     401408.0       3.17%    1392640.0
18              layer2.0.bn2  128  28  28  128  28  28       256.0       0.38        401,408.0        200,704.0    402432.0     401408.0       0.29%     803840.0
19     layer2.0.downsample.0   64  56  56  128  28  28      8192.0       0.38     12,744,704.0      6,422,528.0    835584.0     401408.0       3.46%    1236992.0
20     layer2.0.downsample.1  128  28  28  128  28  28       256.0       0.38        401,408.0        200,704.0    402432.0     401408.0       0.31%     803840.0
21            layer2.1.conv1  128  28  28  128  28  28    147456.0       0.38    231,110,656.0    115,605,504.0    991232.0     401408.0       1.70%    1392640.0
22              layer2.1.bn1  128  28  28  128  28  28       256.0       0.38        401,408.0        200,704.0    402432.0     401408.0       0.31%     803840.0
23             layer2.1.relu  128  28  28  128  28  28         0.0       0.38        100,352.0        100,352.0    401408.0     401408.0       0.09%     802816.0
24            layer2.1.conv2  128  28  28  128  28  28    147456.0       0.38    231,110,656.0    115,605,504.0    991232.0     401408.0       1.78%    1392640.0
25              layer2.1.bn2  128  28  28  128  28  28       256.0       0.38        401,408.0        200,704.0    402432.0     401408.0       0.33%     803840.0
26            layer3.0.conv1  128  28  28  256  14  14    294912.0       0.19    115,555,328.0     57,802,752.0   1581056.0     200704.0       2.81%    1781760.0
27              layer3.0.bn1  256  14  14  256  14  14       512.0       0.19        200,704.0        100,352.0    202752.0     200704.0       0.28%     403456.0
28             layer3.0.relu  256  14  14  256  14  14         0.0       0.19         50,176.0         50,176.0    200704.0     200704.0       0.08%     401408.0
29            layer3.0.conv2  256  14  14  256  14  14    589824.0       0.19    231,160,832.0    115,605,504.0   2560000.0     200704.0       3.50%    2760704.0
30              layer3.0.bn2  256  14  14  256  14  14       512.0       0.19        200,704.0        100,352.0    202752.0     200704.0       0.27%     403456.0
31     layer3.0.downsample.0  128  28  28  256  14  14     32768.0       0.19     12,794,880.0      6,422,528.0    532480.0     200704.0       2.51%     733184.0
32     layer3.0.downsample.1  256  14  14  256  14  14       512.0       0.19        200,704.0        100,352.0    202752.0     200704.0       0.27%     403456.0
33            layer3.1.conv1  256  14  14  256  14  14    589824.0       0.19    231,160,832.0    115,605,504.0   2560000.0     200704.0       5.04%    2760704.0
34              layer3.1.bn1  256  14  14  256  14  14       512.0       0.19        200,704.0        100,352.0    202752.0     200704.0       0.30%     403456.0
35             layer3.1.relu  256  14  14  256  14  14         0.0       0.19         50,176.0         50,176.0    200704.0     200704.0       0.08%     401408.0
36            layer3.1.conv2  256  14  14  256  14  14    589824.0       0.19    231,160,832.0    115,605,504.0   2560000.0     200704.0       2.06%    2760704.0
37              layer3.1.bn2  256  14  14  256  14  14       512.0       0.19        200,704.0        100,352.0    202752.0     200704.0       0.28%     403456.0
38            layer4.0.conv1  256  14  14  512   7   7   1179648.0       0.10    115,580,416.0     57,802,752.0   4919296.0     100352.0       3.63%    5019648.0
39              layer4.0.bn1  512   7   7  512   7   7      1024.0       0.10        100,352.0         50,176.0    104448.0     100352.0       0.25%     204800.0
40             layer4.0.relu  512   7   7  512   7   7         0.0       0.10         25,088.0         25,088.0    100352.0     100352.0       0.06%     200704.0
41            layer4.0.conv2  512   7   7  512   7   7   2359296.0       0.10    231,185,920.0    115,605,504.0   9537536.0     100352.0       5.48%    9637888.0
42              layer4.0.bn2  512   7   7  512   7   7      1024.0       0.10        100,352.0         50,176.0    104448.0     100352.0       0.24%     204800.0
43     layer4.0.downsample.0  256  14  14  512   7   7    131072.0       0.10     12,819,968.0      6,422,528.0    724992.0     100352.0       2.82%     825344.0
44     layer4.0.downsample.1  512   7   7  512   7   7      1024.0       0.10        100,352.0         50,176.0    104448.0     100352.0       0.26%     204800.0
45            layer4.1.conv1  512   7   7  512   7   7   2359296.0       0.10    231,185,920.0    115,605,504.0   9537536.0     100352.0       3.18%    9637888.0
46              layer4.1.bn1  512   7   7  512   7   7      1024.0       0.10        100,352.0         50,176.0    104448.0     100352.0       0.24%     204800.0
47             layer4.1.relu  512   7   7  512   7   7         0.0       0.10         25,088.0         25,088.0    100352.0     100352.0       0.06%     200704.0
48            layer4.1.conv2  512   7   7  512   7   7   2359296.0       0.10    231,185,920.0    115,605,504.0   9537536.0     100352.0       4.86%    9637888.0
49              layer4.1.bn2  512   7   7  512   7   7      1024.0       0.10        100,352.0         50,176.0    104448.0     100352.0       0.23%     204800.0
50                   avgpool  512   7   7  512   1   1         0.0       0.00              0.0              0.0         0.0          0.0       0.64%          0.0
51                        fc          512         1000    513000.0       0.00      1,023,000.0        512,000.0   2054048.0       4000.0       1.02%    2058048.0
total                                                   11689512.0      25.65  3,638,757,912.0  1,821,399,040.0   2054048.0       4000.0     100.00%  101756992.0
=================================================================================================================================================================
Total params: 11,689,512
-----------------------------------------------------------------------------------------------------------------------------------------------------------------
Total memory: 25.65MB
Total MAdd: 3.64GMAdd
Total Flops: 1.82GFlops
Total MemR+W: 97.04MB

其中MAdd为网络乘和加的理论量,数值上FLOPs为MAdd的一半

尴尬的是,torchstat似乎没法计算Vision Transformer类模型,会报错,应该是特征扁平化引起的,更尴尬的是我也不知道怎么改

3. profile

  profile可计算模型参数(params)和浮点运算次数(FLOPs)。

import torch
from thop import profile
import torchvision.models as modelss
model = modelss.resnet18(pretrained=True)
dummy_input = torch.randn(1, 3, 224, 224)
flops, params = profile(model, (dummy_input,))
print('the flops is {}G, the params is {}M'.format(round(flops / (10 ** 9), 2), round(params / (10 ** 6), 2)))

输出为:

the flops is 1.82G, the params is 11.69M

可见torchsummary、torchstat和profile计算的模型参数一致,torchstat和profile计算的FLOPs也一致

profile可用于Vision Transformer类模型,主打的就是一个雨露均沾

import torch
from thop import profile
import timm
model = timm.create_model('vit_small_patch16_224', pretrained=True)
dummy_input = torch.randn(1, 3, 224, 224)
flops, params = profile(model, (dummy_input,))
print('the flops is {}G, the params is {}M'.format(round(flops / (10 ** 9), 2), round(params / (10 ** 6), 2)))

输出为:

the flops is 9.42G, the params is 48.6M

你可能感兴趣的:(深度学习,人工智能,模型复杂度分析)