获取网络模型的每一层参数量与计算量(Flops)———Pytorch

一、前言

       在现在AI各种技术需要luo地的时期,网络模型大小能够满足嵌入式平台极为重要,不仅仅需要看模型效果,也要看模型的计算量与参数,所以在评估模型的时候就要分析网络的参数量与计算量;

二、推荐pytorch工具

      1、ptflops

            安装: pip3 install --upgrade git+https://github.com/sovrasov/flops-counter.pytorch.git

           使用:具体使用参照 https://github.com/sovrasov/flops-counter.pytorch,在第三节我也贴上了我使用代码,有具体解释与注意事项,可看后面

           评价:博主就是用的这个,他可以直接看每一层的参数量与计算量,每一层的参数计算量占模型所有参数的百分比,且有每一层的卷积步长、核大小、输入输出通道数量等,并且也会输出总量;推荐

     2、thop

           安装:pip3 install thop    或者  pip3 install --upgrade git+https://github.com/Lyken17/pytorch-OpCounter.git

           使用:同样也可参照官方:https://github.com/Lyken17/pytorch-OpCounter,也可以看我使用的代码

          评价:也是很多人用的,可以自定义网络层计算,可看官方使用方法,也可以输出参数的总量与计算量,但是博主没发现如何输出每一层的计算量参数量,但是看到了借口,改源码可能需要编译,但是没时间就没弄了,以后有时间再研究;

三、使用实例

     1、ptflops使用

from torchvision.models import resnet50
import torch
import torchvision.models as models
# import torch
from ptflops import get_model_complexity_info

# model = models.resnet50()   #调用官方的模型,
checkpoints = '自己模型的path'
model = torch.load(checkpoints)
model_name = 'yolov3 cut'
flops, params = get_model_complexity_info(model, (3,320,320),as_strings=True,print_per_layer_stat=True)
print("%s |%s |%s" % (model_name,flops,params))

#注意,这里输入一定是要tuple类型,且不需要输入batch,直接输入输入通道数量与尺寸,如(3,320,320)  320为网络输入尺寸;print_per_layer_stat这个参数表示是否打印每一层的参数量与计算量;输出为网络模型的总参数量(单位M,即百万)与计算量(单位G,即十亿)

输出样例;

获取网络模型的每一层参数量与计算量(Flops)———Pytorch_第1张图片

2、thop使用

from torchvision.models import resnet50
from thop import profile

# model = resnet50()
checkpoints = '模型path'
model = torch.load(checkpoints)
model_name = 'yolov3 cut asff'
input = torch.randn(1, 3, 224, 224)
flops, params = profile(model, inputs=(input, ),verbose=True)
print("%s | %.2f | %.2f" % (model_name, params / (1000 ** 2), flops / (1000 ** 3)))#这里除以1000的平方,是为了化成M的单位,

注意:输入必须是四维的;profile函数中verbose参数,设置为true,我认为就是输出每一层的参数量与计算量的,但是没输出,下方代码为profile函数的源码,可看到31行verbose参数,待研究;

def profile(model, inputs, custom_ops=None, verbose=True):
    handler_collection = []
    if custom_ops is None:
        custom_ops = {}

    def add_hooks(m):
        if len(list(m.children())) > 0:
            return

        if hasattr(m, "total_ops") or hasattr(m, "total_params"):
            logger.warning("Either .total_ops or .total_params is already defined in %s. "
                           "Be careful, it might change your code's behavior." % str(m))

        m.register_buffer('total_ops', torch.zeros(1))
        m.register_buffer('total_params', torch.zeros(1))

        for p in m.parameters():
            m.total_params += torch.Tensor([p.numel()])

        m_type = type(m)
        fn = None
        if m_type in custom_ops:  # if defined both op maps, use custom_ops to overwrite.
            fn = custom_ops[m_type]
        elif m_type in register_hooks:
            fn = register_hooks[m_type]

        if fn is None:
            if verbose:
                logger.info("THOP has not implemented counting method for ", m)
        else:
            if verbose:
                logger.info("Register FLOP counter for module %s" % str(m))
            handler = m.register_forward_hook(fn)
            handler_collection.append(handler)

    training = model.training

    model.eval()
    model.apply(add_hooks)

    with torch.no_grad():
        model(*inputs)

    total_ops = 0
    total_params = 0
    for m in model.modules():
        if len(list(m.children())) > 0:  # skip for non-leaf module
            continue
        total_ops += m.total_ops
        total_params += m.total_params

    total_ops = total_ops.item()
    total_params = total_params.item()

    # reset model to original status
    model.train(training)
    for handler in handler_collection:
        handler.remove()

    # remove temporal buffers
    for n, m in model.named_modules():
        if len(list(m.children())) > 0:
            continue
        if "total_ops" in m._buffers:
            m._buffers.pop("total_ops")
        if "total_params" in m._buffers:
            m._buffers.pop("total_params")

    return total_ops, total_params

提高输出可读性

       加入一下代码

from thop import clever_format
macs, params = clever_format([flops, params], "%.3f")

    

你可能感兴趣的:(Pytorch)