计算pytorch模型算力和参数的大小

设计神经网络时需要根据硬件考虑模型参数量和算力,故需要计算神经网络模型的params和Flops,python的三方库ptflops就可以很好的做这件事,show code:

import torch.nn as nn
import torch
from ptflops import get_model_complexity_info
class Net(nn.Module):
    def __init__(self,in_c,class_num):
        super(Net, self).__init__()
        self.conv1=nn.Sequential(
            nn.Conv2d(in_c,16,3,stride=1,padding=1),
            nn.MaxPool2d(2,2),
            nn.Conv2d(16,64,3,stride=1,padding=1),
            nn.MaxPool2d(2,2)
        )
        self.fc=nn.Sequential(
            nn.Linear(64*7*7,120),
            nn.Linear(120,84),
            nn.Linear(84,class_num)
        )
    def forward(self,x):
        out=self.conv1(x)
        out=out.view(out.size(0),-1)
        out=self.fc(out)
        return out

net = Net(1,10)
ops, params = get_model_complexity_info(net, (1, 28, 28), as_strings=True,
										print_per_layer_stat=True, verbose=True)

每层都会输出:

Warning: module Net is treated as a zero-op.
Net(
  0.397 M, 100.000% Params, 0.002 GMac, 100.000% MACs, 
  (conv1): Sequential(
    0.009 M, 2.378% Params, 0.002 GMac, 83.561% MACs, 
    (0): Conv2d(0.0 M, 0.040% Params, 0.0 GMac, 5.322% MACs, 1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): MaxPool2d(0.0 M, 0.000% Params, 0.0 GMac, 0.532% MACs, kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (2): Conv2d(0.009 M, 2.338% Params, 0.002 GMac, 77.174% MACs, 16, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): MaxPool2d(0.0 M, 0.000% Params, 0.0 GMac, 0.532% MACs, kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc): Sequential(
    0.387 M, 97.622% Params, 0.0 GMac, 16.439% MACs, 
    (0): Linear(0.376 M, 94.846% Params, 0.0 GMac, 15.972% MACs, in_features=3136, out_features=120, bias=True)
    (1): Linear(0.01 M, 2.561% Params, 0.0 GMac, 0.431% MACs, in_features=120, out_features=84, bias=True)
    (2): Linear(0.001 M, 0.214% Params, 0.0 GMac, 0.036% MACs, in_features=84, out_features=10, bias=True)
  )
)

你可能感兴趣的:(模型设计,pytorch,python,神经网络)