PyTorch查看网络模型的参数量params和FLOPs等

在PyTorch中,可以使用torchstat这个库来查看网络模型的一些信息,包括总的参数量params、MAdd、显卡内存占用量和FLOPs等。

使用前需要先安装torchstat包,如下:

pip install torchstat

示例代码如下:

from torchstat import stat
from torchvision.models import resnet50, resnet101, resnet152, resnext101_32x8d

model = resnet50()
stat(model, (3, 224, 224))

如果只是想看模型的总参数量,可以通过如下方式:

total = sum([param.nelement() for param in model.parameters()])
print("Number of parameters: %.2fM" % (total/1e6))

stat打印完整信息如下:

PyTorch查看网络模型的参数量params和FLOPs等_第1张图片

PyTorch查看网络模型的参数量params和FLOPs等_第2张图片

你可能感兴趣的:(指南,pytorch,神经网络,FLOPs,params)