Pytorch 模型参数量和浮点数运算数统计 方法总结

使用了四个包进行测试:
torchstat.stat
torchsummary
thop
ptflops

import torch.nn as nn

from nets.classifier import Resnet50RoIHead, VGG16RoIHead
from nets.resnet50 import resnet50
from nets.rpn import RegionProposalNetwork
from nets.vgg16 import decom_vgg16
import os
import datetime

import torch
from ptflops import get_model_complexity_info
from torchstat import stat
import torchsummary
from thop import profile
from thop import clever_format

anchors_size = [8, 16, 32]
backbone = "resnet50"
pretrained = False
net = FasterRCNN(2, anchor_scales=anchors_size, backbone=backbone, pretrained=pretrained)
################################################## stat
stat(net, (3, 608, 608))
##################################################

################################################## torchsummary
torchsummary.summary(net.cuda(), (3, 608, 608))
##################################################

################################################## thop
myinput = torch.zeros((1, 3, 608, 608)).to('cuda')
flops, params = profile(net.to('cuda'), inputs=(myinput,))
flops, params = clever_format([flops, params], "%.3f")
print(flops, params)
##################################################

################################################## ptflops
with torch.cuda.device(0):
    net = retinanet(2, phi=2)
    macs, params = get_model_complexity_info(net,
                                             (3, 608, 608),
                                             # (3, 224, 224),
                                             as_strings=True,
                                             print_per_layer_stat=True, verbose=True)
    print('{:<30}  {:<8}'.format('Computational complexity: ', macs))
    print('{:<30}  {:<8}'.format('Number of parameters: ', params))
total = sum([param.nelement() for param in net.parameters()])
print("Number of parameters: %.2fM" % (total / 1e6))
##################################################

你可能感兴趣的:(小工具,pytorch,深度学习,python)