在pytorch环境下,有两个计算FLOPs和参数量的包thop和ptflops,结果基本是一致的。
参考https://github.com/Lyken17/pytorch-OpCounter
安装方法:pip install thop
使用方法:
from torchvision.models import resnet18
from thop import profile
model = resnet18()
input = torch.randn(1, 3, 224, 224) #模型输入的形状,batch_size=1
flops, params = profile(model, inputs=(input, ))
print(flops/1e9,params/1e6) #flops单位G,para单位M
用来测试3d resnet18的FLOPs:
model =C3D_Hash_Model(48)
input = torch.randn(1, 3,10, 112, 112) #视频取10帧
flops, params = profile(model, inputs=(input, ))
print(flops/1e9,params/1e6)
参考https://github.com/sovrasov/flops-counter.pytorch
安装方法:pip install ptflops
或者 pip install git+https://github.com/sovrasov/flops-counter.pytorch.git
使用方法:
import torchvision.models as models
import torch
from ptflops import get_model_complexity_info
with torch.cuda.device(0):
net = models.resnet18()
flops, params = get_model_complexity_info(net, (3, 224, 224), as_strings=True, print_per_layer_stat=True) #不用写batch_size大小,默认batch_size=1
print('Flops: ' + flops)
print('Params: ' + params)
用来测试3d resnet18的FLOPs:
import torch
from ptflops.flops_counter import get_model_complexity_info
with torch.cuda.device(0):
net = C3D_Hash_Model(48)
flops, params = get_model_complexity_info(net, (3,10, 112, 112), as_strings=True, print_per_layer_stat=True)
print('Flops: ' + flops)
print('Params: ' + params)
如果安装ptflops出问题,可以直接到https://github.com/sovrasov/flops-counter.pytorch.git下载代码,然后直接把目录ptflops复制到项目代码中,通过from ptflops.flops_counter import get_model_complexity_info来调用函数计算FLOPs。
(待补充!)
注:
本文内容转载自https://blog.csdn.net/weixin_41519463/article/details/102468868,在此感谢。如若转载,请备注原博客链接。