Pytorch模型分析:计算Pytorch模型的FLOPs、模型参数量、MAdd、模型显存占用量

  由于模型分析的需要,除了对比模型在指定任务上的表现外,我们可能还需要评估模型的FLOPs、参数量、MAdd、显卡内存占用量等参数,模型的这类参数对模型的实用性有很大的意义。

torchstat

  这里首先推荐一个功能强大的工具包torchstat。该工具包可通过pip直接安装:

pip install torchstat

  安装完成后,使用torchstat进行模型分析的方法如下:

from torchstat import stat

stat(model, (3, 224, 224))

  这里的model变量即为待分析的模型,而输入的另一个参数表示输入图片的大小。分析的效果如下(以resnet18为例):

[MAdd]: AdaptiveAvgPool2d is not supported!
[Flops]: AdaptiveAvgPool2d is not supported!
[Memory]: AdaptiveAvgPool2d is not supported!
                 module name  input shape output shape      params memory(MB)             MAdd            Flops  MemRead(B)  MemWrite(B) duration[%]    MemR+W(B)
0                      conv1    3 224 224   64 112 112      9408.0       3.06    235,225,088.0    118,013,952.0    639744.0    3211264.0       4.05%    3851008.0
1                        bn1   64 112 112   64 112 112       128.0       3.06      3,211,264.0      1,605,632.0   3211776.0    3211264.0       2.15%    6423040.0
2                       relu   64 112 112   64 112 112         0.0       3.06        802,816.0        802,816.0   3211264.0    3211264.0       0.36%    6422528.0
3                    maxpool   64 112 112   64  56  56         0.0       0.77      1,605,632.0        802,816.0   3211264.0     802816.0      19.73%    4014080.0
4             layer1.0.conv1   64  56  56   64  56  56     36864.0       0.77    231,010,304.0    115,605,504.0    950272.0     802816.0       3.12%    1753088.0
5               layer1.0.bn1   64  56  56   64  56  56       128.0       0.77        802,816.0        401,408.0    803328.0     802816.0       0.59%    1606144.0
6              layer1.0.relu   64  56  56   64  56  56         0.0       0.77        200,704.0        200,704.0    802816.0     802816.0       0.07%    1605632.0
7             layer1.0.conv2   64  56  56   64  56  56     36864.0       0.77    231,010,304.0    115,605,504.0    950272.0     802816.0       2.88%    1753088.0
8               layer1.0.bn2   64  56  56   64  56  56       128.0       0.77        802,816.0        401,408.0    803328.0     802816.0       0.58%    1606144.0
9             layer1.1.conv1   64  56  56   64  56  56     36864.0       0.77    231,010,304.0    115,605,504.0    950272.0     802816.0       2.90%    1753088.0
10              layer1.1.bn1   64  56  56   64  56  56       128.0       0.77        802,816.0        401,408.0    803328.0     802816.0       0.56%    1606144.0
11             layer1.1.relu   64  56  56   64  56  56         0.0       0.77        200,704.0        200,704.0    802816.0     802816.0       0.06%    1605632.0
12            layer1.1.conv2   64  56  56   64  56  56     36864.0       0.77    231,010,304.0    115,605,504.0    950272.0     802816.0       2.86%    1753088.0
13              layer1.1.bn2   64  56  56   64  56  56       128.0       0.77        802,816.0        401,408.0    803328.0     802816.0       0.58%    1606144.0
14            layer2.0.conv1   64  56  56  128  28  28     73728.0       0.38    115,505,152.0     57,802,752.0   1097728.0     401408.0       2.30%    1499136.0
15              layer2.0.bn1  128  28  28  128  28  28       256.0       0.38        401,408.0        200,704.0    402432.0     401408.0       0.34%     803840.0
16             layer2.0.relu  128  28  28  128  28  28         0.0       0.38        100,352.0        100,352.0    401408.0     401408.0       0.06%     802816.0
17            layer2.0.conv2  128  28  28  128  28  28    147456.0       0.38    231,110,656.0    115,605,504.0    991232.0     401408.0       2.73%    1392640.0
18              layer2.0.bn2  128  28  28  128  28  28       256.0       0.38        401,408.0        200,704.0    402432.0     401408.0       0.34%     803840.0
19     layer2.0.downsample.0   64  56  56  128  28  28      8192.0       0.38     12,744,704.0      6,422,528.0    835584.0     401408.0       2.02%    1236992.0
20     layer2.0.downsample.1  128  28  28  128  28  28       256.0       0.38        401,408.0        200,704.0    402432.0     401408.0       0.33%     803840.0
21            layer2.1.conv1  128  28  28  128  28  28    147456.0       0.38    231,110,656.0    115,605,504.0    991232.0     401408.0       2.60%    1392640.0
22              layer2.1.bn1  128  28  28  128  28  28       256.0       0.38        401,408.0        200,704.0    402432.0     401408.0       0.33%     803840.0
23             layer2.1.relu  128  28  28  128  28  28         0.0       0.38        100,352.0        100,352.0    401408.0     401408.0       0.06%     802816.0
24            layer2.1.conv2  128  28  28  128  28  28    147456.0       0.38    231,110,656.0    115,605,504.0    991232.0     401408.0       2.61%    1392640.0
25              layer2.1.bn2  128  28  28  128  28  28       256.0       0.38        401,408.0        200,704.0    402432.0     401408.0       0.33%     803840.0
26            layer3.0.conv1  128  28  28  256  14  14    294912.0       0.19    115,555,328.0     57,802,752.0   1581056.0     200704.0       2.42%    1781760.0
27              layer3.0.bn1  256  14  14  256  14  14       512.0       0.19        200,704.0        100,352.0    202752.0     200704.0       0.23%     403456.0
28             layer3.0.relu  256  14  14  256  14  14         0.0       0.19         50,176.0         50,176.0    200704.0     200704.0       0.05%     401408.0
29            layer3.0.conv2  256  14  14  256  14  14    589824.0       0.19    231,160,832.0    115,605,504.0   2560000.0     200704.0       3.12%    2760704.0
30              layer3.0.bn2  256  14  14  256  14  14       512.0       0.19        200,704.0        100,352.0    202752.0     200704.0       0.24%     403456.0
31     layer3.0.downsample.0  128  28  28  256  14  14     32768.0       0.19     12,794,880.0      6,422,528.0    532480.0     200704.0       1.65%     733184.0
32     layer3.0.downsample.1  256  14  14  256  14  14       512.0       0.19        200,704.0        100,352.0    202752.0     200704.0       0.23%     403456.0
33            layer3.1.conv1  256  14  14  256  14  14    589824.0       0.19    231,160,832.0    115,605,504.0   2560000.0     200704.0       3.09%    2760704.0
34              layer3.1.bn1  256  14  14  256  14  14       512.0       0.19        200,704.0        100,352.0    202752.0     200704.0       0.25%     403456.0
35             layer3.1.relu  256  14  14  256  14  14         0.0       0.19         50,176.0         50,176.0    200704.0     200704.0       0.05%     401408.0
36            layer3.1.conv2  256  14  14  256  14  14    589824.0       0.19    231,160,832.0    115,605,504.0   2560000.0     200704.0       3.16%    2760704.0
37              layer3.1.bn2  256  14  14  256  14  14       512.0       0.19        200,704.0        100,352.0    202752.0     200704.0       0.24%     403456.0
38            layer4.0.conv1  256  14  14  512   7   7   1179648.0       0.10    115,580,416.0     57,802,752.0   4919296.0     100352.0       3.27%    5019648.0
39              layer4.0.bn1  512   7   7  512   7   7      1024.0       0.10        100,352.0         50,176.0    104448.0     100352.0       0.19%     204800.0
40             layer4.0.relu  512   7   7  512   7   7         0.0       0.10         25,088.0         25,088.0    100352.0     100352.0       0.06%     200704.0
41            layer4.0.conv2  512   7   7  512   7   7   2359296.0       0.10    231,185,920.0    115,605,504.0   9537536.0     100352.0       5.92%    9637888.0
42              layer4.0.bn2  512   7   7  512   7   7      1024.0       0.10        100,352.0         50,176.0    104448.0     100352.0       0.21%     204800.0
43     layer4.0.downsample.0  256  14  14  512   7   7    131072.0       0.10     12,819,968.0      6,422,528.0    724992.0     100352.0       1.85%     825344.0
44     layer4.0.downsample.1  512   7   7  512   7   7      1024.0       0.10        100,352.0         50,176.0    104448.0     100352.0       0.18%     204800.0
45            layer4.1.conv1  512   7   7  512   7   7   2359296.0       0.10    231,185,920.0    115,605,504.0   9537536.0     100352.0       5.59%    9637888.0
46              layer4.1.bn1  512   7   7  512   7   7      1024.0       0.10        100,352.0         50,176.0    104448.0     100352.0       0.20%     204800.0
47             layer4.1.relu  512   7   7  512   7   7         0.0       0.10         25,088.0         25,088.0    100352.0     100352.0       0.06%     200704.0
48            layer4.1.conv2  512   7   7  512   7   7   2359296.0       0.10    231,185,920.0    115,605,504.0   9537536.0     100352.0       5.62%    9637888.0
49              layer4.1.bn2  512   7   7  512   7   7      1024.0       0.10        100,352.0         50,176.0    104448.0     100352.0       0.20%     204800.0
50                   avgpool  512   7   7  512   1   1         0.0       0.00              0.0              0.0         0.0          0.0       0.48%          0.0
51                        fc          512         1000    513000.0       0.00      1,023,000.0        512,000.0   2054048.0       4000.0       6.92%    2058048.0
total                                                   11689512.0      25.65  3,638,757,912.0  1,821,399,040.0   2054048.0       4000.0     100.00%  101756992.0
=================================================================================================================================================================
Total params: 11,689,512
-----------------------------------------------------------------------------------------------------------------------------------------------------------------
Total memory: 25.65MB
Total MAdd: 3.64GMAdd
Total Flops: 1.82GFlops
Total MemR+W: 97.04MB

  从分析结果可以看出,torchstat的功能非常强大,不仅可以实现FLOPs、参数量、MAdd、显卡内存占用量等模型参数的分析,还可以看到模型每一层的分析结果,工具包不支持的layer也会列在分析结果前提醒使用者。

  虽然torchstat的功能十分强大,但是也有一些缺陷
  1. 限制模型输入仅能为图片
  2. 限制模型每一个layer的输入须为单个变量
  3. 对Pytorch-0.4.1及以下版本的支持不足(具体可参考https://blog.csdn.net/qq_40329272/article/details/106797617)

  以上这些缺陷是在实践中发现的,具体表现为程序报错。如果修改模型也无法适配torchstat,这时就要考虑另选分析工具,这里介绍一下thop工具包。

thop

  对于torchstat无法适用的模型某一个layer的输入为多个变量和Pytorch-0.4.1版本等情况,可以尝试使用thop工具包进行模型分析。thop工具包同样可以通过pip进行安装:

pip install thop

  thop工具包相对torchstat而言,功能较为简单,仅支持FLOPs和参数量的计算(或者是我没有发现,不过我看源码是只返回这俩参量)。thop工具包的使用方法如下(以resnet18为例):

from thop import profile
from thop import clever_format

input = torch.randn(1, 3, 224, 224)
flops, params = profile(model, inputs=(input, ))
print(flops, params) # 1819066368.0 11689512.0
flops, params = clever_format([flops, params], "%.3f")
print(flops, params) # 1.819G 11.690M

  这里的model变量即为待分析的模型,而clever_format函数的调用则是增加输出结果的可读性。

总结

  由于模型结构的千变万化,所以很难有一款工具包能够对所有模型都适用,难免会有缺陷或者不支持的情况。这里介绍了两款Pytorch模型分析的工具包——torchstat和thop,推荐首选torchstat进行模型分析,如果出现无法解决的程序报错,再尝试使用thop。由于不同的工具包对各种layer的支持有差异,所以最后的计算结果可能会不一致,这时需要结合工具包运行中给出的warning提示进行分析,一般只有较小的差异。

参考资料1:https://dotnet.ctolib.com/Swall0w-torchstat.html
参考资料2:https://www.jianshu.com/p/6514b8fb1ada
参考资料3:https://www.jianshu.com/p/cbada26ea29d
参考资料4:https://blog.csdn.net/qq_40329272/article/details/106797617

你可能感兴趣的:(Pytorch打怪之路,python,机器学习,深度学习,pytorch,神经网络)