【Deep Learning】THOP+torchstat 计算PyTorch模型的FLOPs,问题记录与解决

文章目录

  • 1. 前言
  • 2.THOP: PyTorch-OpCounter
    • 2.1 安装
    • 2.1 使用方式
    • 2.2 问题与解决
  • 3. torchstat
    • 3.1 安装
    • 3.2 使用方式
    • 2.2 问题与解决

1. 前言

在博客中计算PyTorch模型的FLOPs 一文中,介绍了到了衡量一个深度学习模型大小的指标,尤其是FLOPs,它衡量了一个模型的复杂度。如果计算FLOPs,在下面我们介绍了两款工具。THOP: PyTorch-OpCountertorchstat

源码下载链接:
THOP: PyTorch-OpCounter:https://github.com/Lyken17/pytorch-OpCounter
torchstat:https://github.com/Swall0w/torchstat

2.THOP: PyTorch-OpCounter

2.1 安装

两种安装方式:

(1) pip install thop (now continously intergrated on Github actions)

(2) pip install --upgrade git+https://github.com/Lyken17/pytorch-OpCounter.git

注意:这里强烈推荐使用第二种安装方式!第一种方式安装的不是最新版本。

2.1 使用方式

from torchvision.models import resnet50
from thop import profile
model = resnet50()
input = torch.randn(1, 3, 224, 224)
macs, params = profile(model, inputs=(input, ))

具体使用方式请查阅:https://github.com/Lyken17/pytorch-OpCounter

2.2 问题与解决

在最开始,我使用第一种方式安装:默认安装了0.0.31.post2005241907的版本。由于我使用的pytorch为1.0.0,不支持: nn.SyncBatchNorm(多GPU执行所需的),而THOP:0.031没有对pytorch:1.0.0修复此bug,导致运行出错。

经过查阅发现:
链接:https://github.com/Lyken17/pytorch-OpCounter/pull/88/commits/958f7a4f5140a6e05ee184672b89e3c501b12140

解决方式:
新版本thop 0.0.4.post2009101201对此做出了修复,于是通过第二种方式进行了安装,重新运行,问题解决!
【Deep Learning】THOP+torchstat 计算PyTorch模型的FLOPs,问题记录与解决_第1张图片

3. torchstat

3.1 安装

同样,也是两种安装方式:

(1)pip安装

$ pip install torchstat

(2)源码安装

下载源码后,在根目录执行:

$ python3 setup.py install

3.2 使用方式

from torchstat import stat
import torchvision.models as models

model = models.resnet18()
stat(model, (3, 224, 224))

具体使用方式请查阅:https://github.com/Swall0w/torchstat

2.2 问题与解决

首先,需要说明,我构建的模型,需要在创建时赋予几个超参数,然后将模型复制到cuda上。如下:

from torchstat import stat
from models.pResNet_demo2 import pResNet as pResNet_demo2

model = pResNet_demo2(nums_ResUnit, alpha, num_classes, n_bands, spatial_size, inplanes)
use_cuda = torch.cuda.is_available()
if use_cuda:
    model = model.cuda()
stat(model, (n_bands, spatial_size, spatial_size))

但是在运行时,报错了!只要是因为我的模型是利用gpu运行,而torchstat的源码中没有考虑这一点!!以下是torchstat的部分源码,它在将输入shape喂给模型之前,没有考虑将数据赋给gpu!!
【Deep Learning】THOP+torchstat 计算PyTorch模型的FLOPs,问题记录与解决_第2张图片

解决方式:

参考链接:https://github.com/Swall0w/torchstat/issues/22

修该torchstat源码:model_hook.py

第 22行改为:

x = torch.rand(1, *self._input_size).to('cuda')

第47行改为:

itemsize = input[0].cpu().detach().numpy().itemsize

运行成功:
【Deep Learning】THOP+torchstat 计算PyTorch模型的FLOPs,问题记录与解决_第3张图片

你可能感兴趣的:(深度学习,#,Pytorch,深度学习,python)