源代码:https://github.com/Lyken17/pytorch-OpCounter
"""
使用 Flops 工具计算模型的计算参数量和计算量。
author:czjing
source: https://github.com/Lyken17/pytorch-OpCounter
# 安装thop包,使用下面的代码
# pip install thop
"""
# 导入必须的包
import torch
import torchvision
from thop import profile
if __name__ == '__main__':
# 1. 定义好的模型
model = torchvision.models.resnet50()
# 2. 模型的输入(tensor)
inputs = torch.randn(1, 3, 200, 200)
# 3. 调用thop计算
flops, params = profile(model, inputs=(inputs,))
print('flops:', flops)
print('params:', params)
输出结果示例:
"D:\Program Files (x86)\anaconda3\envs\cv\python.exe" E:/Pycharm/handTracking/models/flops_test/flops_test.py
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_bn() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[WARN] Cannot find rule for <class 'torch.nn.modules.container.Sequential'>. Treat it as zero Macs and zero Params.
[WARN] Cannot find rule for <class 'torchvision.models.resnet.Bottleneck'>. Treat it as zero Macs and zero Params.
[INFO] Register count_adap_avgpool() for <class 'torch.nn.modules.pooling.AdaptiveAvgPool2d'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[WARN] Cannot find rule for <class 'torchvision.models.resnet.ResNet'>. Treat it as zero Macs and zero Params.
flops: 3517126400.0
params: 25557032.0
Process finished with exit code 0