【pytorch】使用stat、profile打印网络的参数量、Flops、MAdd、内存使用的情况

目的】pytorch获取网络的参数量、MAdd、Flops
可使用库】torchstat中的stat、thop中的profile

1 stat打印

安装工具】pip install torchstat
使用例子】我们的网络只有一层,该层的数据就是整个模型的数据。
这里并没有严格按照pytorch官方提供的公式计算,个人感觉不是很好记忆;这里是使用实际的例子,来将计算方式具体化,反向的去理解公式

import torch
import torch.nn as nn
from torchstat import stat

class Net(nn.Module):
 def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 2, kernel_size=7, stride=2, padding=3, bias=False)
    def forward(self, x):
       x = self.conv1(x)

net = Net()
stat(net, (3, 500, 500))

【pytorch】使用stat、profile打印网络的参数量、Flops、MAdd、内存使用的情况_第1张图片
打印内容
stat命令打印的结果如上,我们分别分析下参数的计算方式:

  • 【params】
    网络主要的参数量
    7*7*3*2=294(W*H*C_in*C_out),因为这里偏置设为False,所有不加上偏置的参数量
  • 【memory】
    stat源码定义下查看,该参数的定义应为:节点推理时候所需的内存(具体计算公式本人暂不清楚,如果有了解的期待评论告知,多谢)
    【pytorch】使用stat、profile打印网络的参数量、Flops、MAdd、内存使用的情况_第2张图片
  • 【Flops】
    网络完成的浮点运算。这里计算以输出的Feature map为视角,其中每个元素的计算需要经历
    ((7*7*3)+(7*7*3-1))*(250*250*2) = 36625000 ~= 36.62 MFlops
    ((输出一个元素所经历的乘法次数)+(输出一个元素所经历的加法的个数))*(输出总共的元素的个数)
  • 【MAdd】
    网络完成的乘加操作的数量。一次乘加=一次乘法+一次加法,所以可以粗略的认为:Flops ~=2*MAdd
    (7*7*3)*(250*250*2) = 18375000 ~= 18.38 MMAdd
  • 【MemRead】
    网络运行时,从内存中读取的大小 = 输入的特征图大小 + 网络参数的大小
    ((500*500*3) + (7*7*3*2))*4 = 3001176.0
    这里乘以4,是因为假设这里的数是float32的,一个float32=4*byte
  • 【MemWrite】
    网络运行时,写入到内存中的大小 = 输出的特征图大小
    250*250*2*4 = 500000
  • 【MemR+W】
    MemR+W = MemRead + MemWrite。在这里等于 3001176.0+500000 = 3501176.0

能够发现,按照公式计算的 Flops/MAdd 两个变量刚好反了,按道理pytorch应该不会出现如此明显的bug,但的确自己计算是按照定义计算的,这个问题就只能先保留在这

2 profile打印

可看到打印结果与stat相应数值大小基本一致

import torch
import torch.nn as nn
# from torchstat import stat
from thop import profile

class Net(nn.Module):
   def __init__(self):
       super(Net, self).__init__()
       self.conv1 = nn.Conv2d(3, 2, kernel_size=7, stride=2, padding=3, bias=False)
   def forward(self, x):
       x = self.conv1(x)

# net = Net()
# stat(net, (3, 500, 500))

input = torch.randn(1, 3, 500, 500)
flops, params = profile(net, inputs=(input,))
print('FLOPs = ' + str(flops/1000**3) + 'G')
print('Params = ' + str(params/1000**2) + 'M')

在这里插入图片描述

你可能感兴趣的:(pytorch的框架使用记录,pytorch)