使用fvcore计算Pytorch中模型的参数数量以及FLOPs

fvcore是Facebook开源的一个轻量级的核心库,它提供了各种计算机视觉框架中常见且基本的功能。其中就包括了统计模型的参数以及FLOPs等。

fvcore is a light-weight core library that provides the most common and essential functionality shared in various computer vision frameworks

项目开源地址:
https://github.com/facebookresearch/fvcore

在python环境中安装fvcore

pip install fvcore

示例:
假设我需要计算以下resnet50的参数数量以及FLOPs参数。

import torch
from torchvision.models import resnet50
from fvcore.nn import FlopCountAnalysis, parameter_count_table

# 创建resnet50网络
model = resnet50(num_classes=1000)

# 创建输入网络的tensor
tensor = (torch.rand(1, 3, 224, 224),)

# 分析FLOPs
flops = FlopCountAnalysis(model, tensor)
print("FLOPs: ", flops.total())

# 分析parameters
print(parameter_count_table(model))

终端输出结果如下,FLOPs为4089184256,模型参数数量约为25.6M(这里的参数数量和我自己计算的有些出入,主要是在BN模块中,这里只计算了betagamma两个训练参数,没有统计moving_meanmoving_var两个参数),具体可以看下我在官方提的issue。
通过终端打印的信息我们可以发现在计算FLOPs时并没有包含BN层,池化层还有普通的add操作(我发现计算FLOPs时并没有统一的规定,在github上看的计算FLOPs项目基本每个都不同,但计算出来的结果大同小异)。

Skipped operation aten::batch_norm 53 time(s)
Skipped operation aten::max_pool2d 1 time(s)
Skipped operation aten::add_ 16 time(s)
Skipped operation aten::adaptive_avg_pool2d 1 time(s)
FLOPs:  4089184256
| name                   | #elements or shape   |
|:-----------------------|:---------------------|
| model                  | 25.6M                |
|  conv1                 |  9.4K                |
|   conv1.weight         |   (64, 3, 7, 7)      |
|  bn1                   |  0.1K                |
|   bn1.weight           |   (64,)              |
|   bn1.bias             |   (64,)              |
|  layer1                |  0.2M                |
|   layer1.0             |   75.0K              |
|    layer1.0.conv1      |    4.1K              |
|    layer1.0.bn1        |    0.1K              |
|    layer1.0.conv2      |    36.9K             |
|    layer1.0.bn2        |    0.1K              |
|    layer1.0.conv3      |    16.4K             |
|    layer1.0.bn3        |    0.5K              |
|    layer1.0.downsample |    16.9K             |
|   layer1.1             |   70.4K              |
|    layer1.1.conv1      |    16.4K             |
|    layer1.1.bn1        |    0.1K              |
|    layer1.1.conv2      |    36.9K             |
|    layer1.1.bn2        |    0.1K              |
|    layer1.1.conv3      |    16.4K             |
|    layer1.1.bn3        |    0.5K              |
|   layer1.2             |   70.4K              |
|    layer1.2.conv1      |    16.4K             |
|    layer1.2.bn1        |    0.1K              |
|    layer1.2.conv2      |    36.9K             |
|    layer1.2.bn2        |    0.1K              |
|    layer1.2.conv3      |    16.4K             |
|    layer1.2.bn3        |    0.5K              |
|  layer2                |  1.2M                |
|   layer2.0             |   0.4M               |
|    layer2.0.conv1      |    32.8K             |
|    layer2.0.bn1        |    0.3K              |
|    layer2.0.conv2      |    0.1M              |
|    layer2.0.bn2        |    0.3K              |
|    layer2.0.conv3      |    65.5K             |
|    layer2.0.bn3        |    1.0K              |
|    layer2.0.downsample |    0.1M              |
|   layer2.1             |   0.3M               |
|    layer2.1.conv1      |    65.5K             |
|    layer2.1.bn1        |    0.3K              |
|    layer2.1.conv2      |    0.1M              |
|    layer2.1.bn2        |    0.3K              |
|    layer2.1.conv3      |    65.5K             |
|    layer2.1.bn3        |    1.0K              |
|   layer2.2             |   0.3M               |
|    layer2.2.conv1      |    65.5K             |
|    layer2.2.bn1        |    0.3K              |
|    layer2.2.conv2      |    0.1M              |
|    layer2.2.bn2        |    0.3K              |
|    layer2.2.conv3      |    65.5K             |
|    layer2.2.bn3        |    1.0K              |
|   layer2.3             |   0.3M               |
|    layer2.3.conv1      |    65.5K             |
|    layer2.3.bn1        |    0.3K              |
|    layer2.3.conv2      |    0.1M              |
|    layer2.3.bn2        |    0.3K              |
|    layer2.3.conv3      |    65.5K             |
|    layer2.3.bn3        |    1.0K              |
|  layer3                |  7.1M                |
|   layer3.0             |   1.5M               |
|    layer3.0.conv1      |    0.1M              |
|    layer3.0.bn1        |    0.5K              |
|    layer3.0.conv2      |    0.6M              |
|    layer3.0.bn2        |    0.5K              |
|    layer3.0.conv3      |    0.3M              |
|    layer3.0.bn3        |    2.0K              |
|    layer3.0.downsample |    0.5M              |
|   layer3.1             |   1.1M               |
|    layer3.1.conv1      |    0.3M              |
|    layer3.1.bn1        |    0.5K              |
|    layer3.1.conv2      |    0.6M              |
|    layer3.1.bn2        |    0.5K              |
|    layer3.1.conv3      |    0.3M              |
|    layer3.1.bn3        |    2.0K              |
|   layer3.2             |   1.1M               |
|    layer3.2.conv1      |    0.3M              |
|    layer3.2.bn1        |    0.5K              |
|    layer3.2.conv2      |    0.6M              |
|    layer3.2.bn2        |    0.5K              |
|    layer3.2.conv3      |    0.3M              |
|    layer3.2.bn3        |    2.0K              |
|   layer3.3             |   1.1M               |
|    layer3.3.conv1      |    0.3M              |
|    layer3.3.bn1        |    0.5K              |
|    layer3.3.conv2      |    0.6M              |
|    layer3.3.bn2        |    0.5K              |
|    layer3.3.conv3      |    0.3M              |
|    layer3.3.bn3        |    2.0K              |
|   layer3.4             |   1.1M               |
|    layer3.4.conv1      |    0.3M              |
|    layer3.4.bn1        |    0.5K              |
|    layer3.4.conv2      |    0.6M              |
|    layer3.4.bn2        |    0.5K              |
|    layer3.4.conv3      |    0.3M              |
|    layer3.4.bn3        |    2.0K              |
|   layer3.5             |   1.1M               |
|    layer3.5.conv1      |    0.3M              |
|    layer3.5.bn1        |    0.5K              |
|    layer3.5.conv2      |    0.6M              |
|    layer3.5.bn2        |    0.5K              |
|    layer3.5.conv3      |    0.3M              |
|    layer3.5.bn3        |    2.0K              |
|  layer4                |  15.0M               |
|   layer4.0             |   6.0M               |
|    layer4.0.conv1      |    0.5M              |
|    layer4.0.bn1        |    1.0K              |
|    layer4.0.conv2      |    2.4M              |
|    layer4.0.bn2        |    1.0K              |
|    layer4.0.conv3      |    1.0M              |
|    layer4.0.bn3        |    4.1K              |
|    layer4.0.downsample |    2.1M              |
|   layer4.1             |   4.5M               |
|    layer4.1.conv1      |    1.0M              |
|    layer4.1.bn1        |    1.0K              |
|    layer4.1.conv2      |    2.4M              |
|    layer4.1.bn2        |    1.0K              |
|    layer4.1.conv3      |    1.0M              |
|    layer4.1.bn3        |    4.1K              |
|   layer4.2             |   4.5M               |
|    layer4.2.conv1      |    1.0M              |
|    layer4.2.bn1        |    1.0K              |
|    layer4.2.conv2      |    2.4M              |
|    layer4.2.bn2        |    1.0K              |
|    layer4.2.conv3      |    1.0M              |
|    layer4.2.bn3        |    4.1K              |
|  fc                    |  2.0M                |
|   fc.weight            |   (1000, 2048)       |
|   fc.bias              |   (1000,)            |

Process finished with exit code 0

更多使用方法,可以去原项目中查看使用文档。

你可能感兴趣的:(pytorch,深度学习,pytorch,深度学习,机器学习)