torchsummary 中input size 异常的问题


本文解决问题

torchsummary针对多个输入模型的时候,其输出信息中input size等存在着错误,这里提供方案解决这个错误。


当我们使用pytorch搭建好我们自己的深度学习模型的的时候,我们总想看看具体的网络信息以及参数量大小,这时候就要请出我们的神器 torchsummary了,torchsummary的简单使用如下所示:

# pip install torchsummary
from torchsummary import summary

model = OurOwnModel()
summary(model, input_size=(3, 224, 224), device='cpu')

此时一切正常的话将会输出下面的信息:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 224, 224]           1,792
              ReLU-2         [-1, 64, 224, 224]               0
            Conv2d-3         [-1, 64, 224, 224]          36,928
              ReLU-4         [-1, 64, 224, 224]               0
         MaxPool2d-5         [-1, 64, 112, 112]               0
            Conv2d-6        [-1, 128, 112, 112]          73,856
              ReLU-7        [-1, 128, 112, 112]               0
            Conv2d-8        [-1, 128, 112, 112]         147,584
              ReLU-9        [-1, 128, 112, 112]               0
        MaxPool2d-10          [-1, 128, 56, 56]               0
           Conv2d-11          [-1, 256, 56, 56]         295,168
             ReLU-12          [-1, 256, 56, 56]               0
           Conv2d-13          [-1, 256, 56, 56]         590,080
             ReLU-14          [-1, 256, 56, 56]               0
           Conv2d-15          [-1, 256, 56, 56]         590,080
             ReLU-16          [-1, 256, 56, 56]               0
        MaxPool2d-17          [-1, 256, 28, 28]               0
           Conv2d-18          [-1, 512, 28, 28]       1,180,160
             ReLU-19          [-1, 512, 28, 28]               0
           Conv2d-20          [-1, 512, 28, 28]       2,359,808
             ReLU-21          [-1, 512, 28, 28]               0
           Conv2d-22          [-1, 512, 28, 28]       2,359,808
             ReLU-23          [-1, 512, 28, 28]               0
        MaxPool2d-24          [-1, 512, 14, 14]               0
           Conv2d-25          [-1, 512, 14, 14]       2,359,808
             ReLU-26          [-1, 512, 14, 14]               0
           Conv2d-27          [-1, 512, 14, 14]       2,359,808
             ReLU-28          [-1, 512, 14, 14]               0
           Conv2d-29          [-1, 512, 14, 14]       2,359,808
             ReLU-30          [-1, 512, 14, 14]               0
        MaxPool2d-31            [-1, 512, 7, 7]               0
           Linear-32                 [-1, 4096]     102,764,544
             ReLU-33                 [-1, 4096]               0
          Dropout-34                 [-1, 4096]               0
           Linear-35                 [-1, 4096]      16,781,312
             ReLU-36                 [-1, 4096]               0
          Dropout-37                 [-1, 4096]               0
           Linear-38                 [-1, 1000]       4,097,000
================================================================
Total params: 138,357,544
Trainable params: 138,357,544
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 218.59
Params size (MB): 527.79
Estimated Total Size (MB): 746.96
----------------------------------------------------------------

你发现一切安好,nice。但是当你像我一样开始搭建一个多输入网络的时候,这时候麻烦就来了。

from torchsummary import summary

model = OurOwnModel()
summary(model, input_size=[(3, 224, 224), (3, 224, 224), (3, 123)], device='cpu')

此时输出的信息就会有错误了。

# 上面正确的信息省略了
================================================================
Total params: 49,365,761
Trainable params: 49,365,761
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 25169045225472.00  # 输入的大小显然不对啊
Forward/backward pass size (MB): 22975.86
Params size (MB): 188.32
Estimated Total Size (MB): 25169045248636.18 # 看起来整个数据也是显然有错误的
----------------------------------------------------------------

上面的 Input Size(MB) Estimated Total Size (MB)这两项显然是有错误的。

这里提供如下的解决办法:

import torchsummary
print(torchsummary.__file__)

上面代码会输出torchsummary的安装路径,这里得到的如下:

/home/guangkun/anaconda3/envs/jet/lib/python3.7/site-packages/torchsummary/__init__.py

我们知道了torchsummary的地址之后,进入该文件夹,同级目录如下:

├── __init__.py
├── __pycache__
│   ├── __init__.cpython-37.pyc
│   └── torchsummary.cpython-37.pyc
└── torchsummary.py

修改 torchsummary.py文件(大概在100行-103行):

  total_input_size = abs(np.prod(input_size) * batch_size * 4. / (1024 ** 2.))
  total_output_size = abs(2. * total_output * 4. / (1024 ** 2.))  # x2 for gradients
  total_params_size = abs(total_params.numpy() * 4. / (1024 ** 2.))
  total_size = total_params_size + total_output_size + total_input_size

修改为:

total_input_size = abs(np.sum([np.prod(in_tuple) for in_tuple in input_size]) * batch_size * 4. / (1024 ** 2.))
total_output_size = abs(2. * total_output * 4. / (1024 ** 2.))  # x2 for gradients
total_params_size = abs(total_params.numpy() * 4. / (1024 ** 2.))
total_size = total_params_size + total_output_size + total_input_size

保存后再运行即可发现正常了,正常的输出信息如下:

================================================================
Total params: 49,365,761
Trainable params: 49,365,761
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 1.64
Forward/backward pass size (MB): 179.50
Params size (MB): 188.32
Estimated Total Size (MB): 369.45
----------------------------------------------------------------

你可能感兴趣的:(torchsummary 中input size 异常的问题)