pytorch: 计算网络模型的计算量(FLOPs)和参数量(Params)

计算量:
FLOPS,浮点运算次数,指运行一次网络模型需要进行浮点运算的次数。
参数量:
Params,是指网络模型中需要训练的参数总数。

第一步:安装模块(thop)

pip install thop

第二步:计算

import torch
from thop import profile

net = Model()  # 定义好的网络模型
input = torch.randn(1, 3, 112, 112)
flops, params = profile(net, (inputs,))
print('flops: ', flops, 'params: ', params)

注意:

  • 输入input的第一维度是批量(batch size),批量的大小不回影响参数量, 计算量是batch_size=1的倍数
  • profile(net, (inputs,))的 (inputs,)中必须加上逗号,否者会报错

你可能感兴趣的:(pytorch,pytorch,深度学习)