pytorch中查看模型的参数量和计算量

做完剪枝后,需要看一个模型被压缩后的计算量和参数量,可以使用这两种方法。

1.安装包

pip install torchsummary
pip install torchstat

2.编写代码

torchsummary库:统计中只有参数量

import torch
from torchsummary import summary
from vgg import vgg
model_path = './pruned.pth.tar'
checkpoint = torch.load(model_path)
#model = vgg() 
model = vgg(cfg = checkpoint['cfg'])
model.load_state_dict(checkpoint['state_dict'])
model.to('cuda')
summary(model,(3,32,32))

torchstat库:统计参数量和计算量,较为详细

import torch
from torchstat import stat
from vgg import vgg

model_path = "./model_best.pth.tar"
checkpoint = torch.load(model_path)
#model = vgg(cfg = checkpoint['cfg'])
model = vgg()
model.load_state_dict(checkpoint['state_dict'])
stat(model,(3,32,32))

你可能感兴趣的:(模型训练,深度学习)