由于模型分析的需要,除了对比模型在指定任务上的表现外,我们可能还需要评估模型的FLOPs、参数量、MAdd、显卡内存占用量等参数,模型的这类参数对模型的实用性有很大的意义。
这里首先推荐一个功能强大的工具包torchstat。该工具包可通过pip直接安装:
pip install torchstat
安装完成后,使用torchstat进行模型分析的方法如下:
from torchstat import stat
stat(model, (3, 224, 224))
这里的model变量即为待分析的模型,而输入的另一个参数表示输入图片的大小。分析的效果如下(以resnet18为例):
[MAdd]: AdaptiveAvgPool2d is not supported!
[Flops]: AdaptiveAvgPool2d is not supported!
[Memory]: AdaptiveAvgPool2d is not supported!
module name input shape output shape params memory(MB) MAdd Flops MemRead(B) MemWrite(B) duration[%] MemR+W(B)
0 conv1 3 224 224 64 112 112 9408.0 3.06 235,225,088.0 118,013,952.0 639744.0 3211264.0 4.05% 3851008.0
1 bn1 64 112 112 64 112 112 128.0 3.06 3,211,264.0 1,605,632.0 3211776.0 3211264.0 2.15% 6423040.0
2 relu 64 112 112 64 112 112 0.0 3.06 802,816.0 802,816.0 3211264.0 3211264.0 0.36% 6422528.0
3 maxpool 64 112 112 64 56 56 0.0 0.77 1,605,632.0 802,816.0 3211264.0 802816.0 19.73% 4014080.0
4 layer1.0.conv1 64 56 56 64 56 56 36864.0 0.77 231,010,304.0 115,605,504.0 950272.0 802816.0 3.12% 1753088.0
5 layer1.0.bn1 64 56 56 64 56 56 128.0 0.77 802,816.0 401,408.0 803328.0 802816.0 0.59% 1606144.0
6 layer1.0.relu 64 56 56 64 56 56 0.0 0.77 200,704.0 200,704.0 802816.0 802816.0 0.07% 1605632.0
7 layer1.0.conv2 64 56 56 64 56 56 36864.0 0.77 231,010,304.0 115,605,504.0 950272.0 802816.0 2.88% 1753088.0
8 layer1.0.bn2 64 56 56 64 56 56 128.0 0.77 802,816.0 401,408.0 803328.0 802816.0 0.58% 1606144.0
9 layer1.1.conv1 64 56 56 64 56 56 36864.0 0.77 231,010,304.0 115,605,504.0 950272.0 802816.0 2.90% 1753088.0
10 layer1.1.bn1 64 56 56 64 56 56 128.0 0.77 802,816.0 401,408.0 803328.0 802816.0 0.56% 1606144.0
11 layer1.1.relu 64 56 56 64 56 56 0.0 0.77 200,704.0 200,704.0 802816.0 802816.0 0.06% 1605632.0
12 layer1.1.conv2 64 56 56 64 56 56 36864.0 0.77 231,010,304.0 115,605,504.0 950272.0 802816.0 2.86% 1753088.0
13 layer1.1.bn2 64 56 56 64 56 56 128.0 0.77 802,816.0 401,408.0 803328.0 802816.0 0.58% 1606144.0
14 layer2.0.conv1 64 56 56 128 28 28 73728.0 0.38 115,505,152.0 57,802,752.0 1097728.0 401408.0 2.30% 1499136.0
15 layer2.0.bn1 128 28 28 128 28 28 256.0 0.38 401,408.0 200,704.0 402432.0 401408.0 0.34% 803840.0
16 layer2.0.relu 128 28 28 128 28 28 0.0 0.38 100,352.0 100,352.0 401408.0 401408.0 0.06% 802816.0
17 layer2.0.conv2 128 28 28 128 28 28 147456.0 0.38 231,110,656.0 115,605,504.0 991232.0 401408.0 2.73% 1392640.0
18 layer2.0.bn2 128 28 28 128 28 28 256.0 0.38 401,408.0 200,704.0 402432.0 401408.0 0.34% 803840.0
19 layer2.0.downsample.0 64 56 56 128 28 28 8192.0 0.38 12,744,704.0 6,422,528.0 835584.0 401408.0 2.02% 1236992.0
20 layer2.0.downsample.1 128 28 28 128 28 28 256.0 0.38 401,408.0 200,704.0 402432.0 401408.0 0.33% 803840.0
21 layer2.1.conv1 128 28 28 128 28 28 147456.0 0.38 231,110,656.0 115,605,504.0 991232.0 401408.0 2.60% 1392640.0
22 layer2.1.bn1 128 28 28 128 28 28 256.0 0.38 401,408.0 200,704.0 402432.0 401408.0 0.33% 803840.0
23 layer2.1.relu 128 28 28 128 28 28 0.0 0.38 100,352.0 100,352.0 401408.0 401408.0 0.06% 802816.0
24 layer2.1.conv2 128 28 28 128 28 28 147456.0 0.38 231,110,656.0 115,605,504.0 991232.0 401408.0 2.61% 1392640.0
25 layer2.1.bn2 128 28 28 128 28 28 256.0 0.38 401,408.0 200,704.0 402432.0 401408.0 0.33% 803840.0
26 layer3.0.conv1 128 28 28 256 14 14 294912.0 0.19 115,555,328.0 57,802,752.0 1581056.0 200704.0 2.42% 1781760.0
27 layer3.0.bn1 256 14 14 256 14 14 512.0 0.19 200,704.0 100,352.0 202752.0 200704.0 0.23% 403456.0
28 layer3.0.relu 256 14 14 256 14 14 0.0 0.19 50,176.0 50,176.0 200704.0 200704.0 0.05% 401408.0
29 layer3.0.conv2 256 14 14 256 14 14 589824.0 0.19 231,160,832.0 115,605,504.0 2560000.0 200704.0 3.12% 2760704.0
30 layer3.0.bn2 256 14 14 256 14 14 512.0 0.19 200,704.0 100,352.0 202752.0 200704.0 0.24% 403456.0
31 layer3.0.downsample.0 128 28 28 256 14 14 32768.0 0.19 12,794,880.0 6,422,528.0 532480.0 200704.0 1.65% 733184.0
32 layer3.0.downsample.1 256 14 14 256 14 14 512.0 0.19 200,704.0 100,352.0 202752.0 200704.0 0.23% 403456.0
33 layer3.1.conv1 256 14 14 256 14 14 589824.0 0.19 231,160,832.0 115,605,504.0 2560000.0 200704.0 3.09% 2760704.0
34 layer3.1.bn1 256 14 14 256 14 14 512.0 0.19 200,704.0 100,352.0 202752.0 200704.0 0.25% 403456.0
35 layer3.1.relu 256 14 14 256 14 14 0.0 0.19 50,176.0 50,176.0 200704.0 200704.0 0.05% 401408.0
36 layer3.1.conv2 256 14 14 256 14 14 589824.0 0.19 231,160,832.0 115,605,504.0 2560000.0 200704.0 3.16% 2760704.0
37 layer3.1.bn2 256 14 14 256 14 14 512.0 0.19 200,704.0 100,352.0 202752.0 200704.0 0.24% 403456.0
38 layer4.0.conv1 256 14 14 512 7 7 1179648.0 0.10 115,580,416.0 57,802,752.0 4919296.0 100352.0 3.27% 5019648.0
39 layer4.0.bn1 512 7 7 512 7 7 1024.0 0.10 100,352.0 50,176.0 104448.0 100352.0 0.19% 204800.0
40 layer4.0.relu 512 7 7 512 7 7 0.0 0.10 25,088.0 25,088.0 100352.0 100352.0 0.06% 200704.0
41 layer4.0.conv2 512 7 7 512 7 7 2359296.0 0.10 231,185,920.0 115,605,504.0 9537536.0 100352.0 5.92% 9637888.0
42 layer4.0.bn2 512 7 7 512 7 7 1024.0 0.10 100,352.0 50,176.0 104448.0 100352.0 0.21% 204800.0
43 layer4.0.downsample.0 256 14 14 512 7 7 131072.0 0.10 12,819,968.0 6,422,528.0 724992.0 100352.0 1.85% 825344.0
44 layer4.0.downsample.1 512 7 7 512 7 7 1024.0 0.10 100,352.0 50,176.0 104448.0 100352.0 0.18% 204800.0
45 layer4.1.conv1 512 7 7 512 7 7 2359296.0 0.10 231,185,920.0 115,605,504.0 9537536.0 100352.0 5.59% 9637888.0
46 layer4.1.bn1 512 7 7 512 7 7 1024.0 0.10 100,352.0 50,176.0 104448.0 100352.0 0.20% 204800.0
47 layer4.1.relu 512 7 7 512 7 7 0.0 0.10 25,088.0 25,088.0 100352.0 100352.0 0.06% 200704.0
48 layer4.1.conv2 512 7 7 512 7 7 2359296.0 0.10 231,185,920.0 115,605,504.0 9537536.0 100352.0 5.62% 9637888.0
49 layer4.1.bn2 512 7 7 512 7 7 1024.0 0.10 100,352.0 50,176.0 104448.0 100352.0 0.20% 204800.0
50 avgpool 512 7 7 512 1 1 0.0 0.00 0.0 0.0 0.0 0.0 0.48% 0.0
51 fc 512 1000 513000.0 0.00 1,023,000.0 512,000.0 2054048.0 4000.0 6.92% 2058048.0
total 11689512.0 25.65 3,638,757,912.0 1,821,399,040.0 2054048.0 4000.0 100.00% 101756992.0
=================================================================================================================================================================
Total params: 11,689,512
-----------------------------------------------------------------------------------------------------------------------------------------------------------------
Total memory: 25.65MB
Total MAdd: 3.64GMAdd
Total Flops: 1.82GFlops
Total MemR+W: 97.04MB
从分析结果可以看出,torchstat的功能非常强大,不仅可以实现FLOPs、参数量、MAdd、显卡内存占用量等模型参数的分析,还可以看到模型每一层的分析结果,工具包不支持的layer也会列在分析结果前提醒使用者。
虽然torchstat的功能十分强大,但是也有一些缺陷:
1. 限制模型输入仅能为图片
2. 限制模型每一个layer的输入须为单个变量
3. 对Pytorch-0.4.1及以下版本的支持不足(具体可参考https://blog.csdn.net/qq_40329272/article/details/106797617)
以上这些缺陷是在实践中发现的,具体表现为程序报错。如果修改模型也无法适配torchstat,这时就要考虑另选分析工具,这里介绍一下thop工具包。
对于torchstat无法适用的模型某一个layer的输入为多个变量和Pytorch-0.4.1版本等情况,可以尝试使用thop工具包进行模型分析。thop工具包同样可以通过pip进行安装:
pip install thop
thop工具包相对torchstat而言,功能较为简单,仅支持FLOPs和参数量的计算(或者是我没有发现,不过我看源码是只返回这俩参量)。thop工具包的使用方法如下(以resnet18为例):
from thop import profile
from thop import clever_format
input = torch.randn(1, 3, 224, 224)
flops, params = profile(model, inputs=(input, ))
print(flops, params) # 1819066368.0 11689512.0
flops, params = clever_format([flops, params], "%.3f")
print(flops, params) # 1.819G 11.690M
这里的model变量即为待分析的模型,而clever_format函数的调用则是增加输出结果的可读性。
由于模型结构的千变万化,所以很难有一款工具包能够对所有模型都适用,难免会有缺陷或者不支持的情况。这里介绍了两款Pytorch模型分析的工具包——torchstat和thop,推荐首选torchstat进行模型分析,如果出现无法解决的程序报错,再尝试使用thop。由于不同的工具包对各种layer的支持有差异,所以最后的计算结果可能会不一致,这时需要结合工具包运行中给出的warning提示进行分析,一般只有较小的差异。
参考资料1:https://dotnet.ctolib.com/Swall0w-torchstat.html
参考资料2:https://www.jianshu.com/p/6514b8fb1ada
参考资料3:https://www.jianshu.com/p/cbada26ea29d
参考资料4:https://blog.csdn.net/qq_40329272/article/details/106797617