网络模型的参数量和FLOPs的计算 Pytorch

目录

1、torchstat 

2、thop

3、fvcore 

4、flops_counter


FLOPS和FLOPs的区别:

  • FLOPS:注意全大写,是floating point operations per second的缩写,意指每秒浮点运算次数,理解为计算速度。是一个衡量硬件性能的指标。
  • FLOPs:注意s小写,是floating point operations的缩写(s表复数),意指浮点运算数,理解为计算量。可以用来衡量算法/模型的复杂度。

在介绍torchstat包和thop包之前,先总结一下:

  • torchstat包可以统计卷积神经网络和全连接神经网络的参数和计算量。
  • thop包可以统计统计卷积神经网络、全连接神经网络以及循环神经网络的参数和计算量,程序示例等详见下文。

1、torchstat 

pip install torchstat -i https://pypi.tuna.tsinghua.edu.cn/simple

在实际操作中,我们可以调用torchstat包,帮助我们统计模型的parameters和FLOPs。如果不修改这个包里面的一些代码,那么这个包只适用于输入为3通道的图像的模型。

import torch
import torch.nn as nn
from torchstat import stat
 
 
class Simple(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, 1, padding=1, bias=False)
        self.conv2 = nn.Conv2d(16, 32, 3, 1, padding=1, bias=False)
 
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        return x
 
 
model = Simple()
stat(model, (3, 244, 244))   # 统计模型的参数量和FLOPs,(3,244,244)是输入图像的size

网络模型的参数量和FLOPs的计算 Pytorch_第1张图片

 如果把torchstat包中的一行程序进行一点点改动,那么这个包可以用来统计全连接神经网络的参数量和计算量。当然手动计算全连接神经网络的参数量和计算量也很快 =_= 。进入torchstat源代码之后,如下图所示,注释掉圈红的地方,就可以用torchstat包统计全连接神经网络的参数量和计算量了。

网络模型的参数量和FLOPs的计算 Pytorch_第2张图片

2、thop

pip install thop -i https://pypi.tuna.tsinghua.edu.cn/simple
import torch
import torch.nn as nn
from thop import profile
 
class Simple(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(10, 10)
 
    def forward(self, x):
        x = self.fc1(x)
        return x
 
net = Simple()
input = torch.randn(1, 10)  # batchsize=1, 输入向量长度为10
macs, params = profile(net, inputs=(input, ))
print(' FLOPs: ', macs*2)   # 一般来讲,FLOPs是macs的两倍
print('params: ', params)

3、fvcore 

pip install fvcore -i https://pypi.tuna.tsinghua.edu.cn/simple

用它比较好

import torch
from torchvision.models import resnet50
from fvcore.nn import FlopCountAnalysis, parameter_count_table

# 创建resnet50网络
model = resnet50(num_classes=1000)

# 创建输入网络的tensor
tensor = (torch.rand(1, 3, 224, 224),)

# 分析FLOPs
flops = FlopCountAnalysis(model, tensor)
print("FLOPs: ", flops.total())

# 分析parameters
print(parameter_count_table(model))

 终端输出结果如下,FLOPs为4089184256,模型参数数量约为25.6M(这里的参数数量和我自己计算的有些出入,主要是在BN模块中,这里只计算了beta和gamma两个训练参数,没有统计moving_mean和moving_var两个参数),具体可以看下我在官方提的issue。
通过终端打印的信息我们可以发现在计算FLOPs时并没有包含BN层,池化层还有普通的add操作(我发现计算FLOPs时并没有统一的规定,在github上看的计算FLOPs项目基本每个都不同,但计算出来的结果大同小异)。

网络模型的参数量和FLOPs的计算 Pytorch_第3张图片

注意:在使用fvcore模块计算模型的flops时,遇到了问题,记录一下解决方案。首先是在jit_analysis.py的589行出错。经过调试发现,op_counts.values()的类型是int32,但是计算要求的类型只能是int、float、np.float64和np.int64,因此需要手动进行强制转换。修改如下:

网络模型的参数量和FLOPs的计算 Pytorch_第4张图片

4、flops_counter

pip install ptflops -i https://pypi.tuna.tsinghua.edu.cn/simple

用它也很好,结果和fvcore一样

from ptflops import get_model_complexity_info

macs, params = get_model_complexity_info(model, (112, 9, 9), as_strings=True,
                                         print_per_layer_stat=True, verbose=True)
print('{:<30}  {:<8}'.format('Computational complexity: ', macs))
print('{:<30}  {:<8}'.format('Number of parameters: ', params))

网络模型的参数量和FLOPs的计算 Pytorch_第5张图片

你可能感兴趣的:(笔记,pytorch,深度学习,人工智能)