(1)网络结构可视化,如下图红框部分所示
(2)特征图可视化和参数量计算
需要安装 torchsummary包:
pip install torchsummary
实例如下:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x, dim=1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # PyTorch v0.4.0
model = Net().to(device)
summary(model, (1, 28, 28))
结果如下:
# 每一层类型 特征图shape 每层的参数量
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 10, 24, 24] 260
Conv2d-2 [-1, 20, 8, 8] 5,020
Dropout2d-3 [-1, 20, 8, 8] 0
Linear-4 [-1, 50] 16,050
Linear-5 [-1, 10] 510
================================================================
Total params: 21,840 # 模型整体的参数量(上面层参数量相加)
Trainable params: 21,840
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00 # 图片预处理后的大小
Forward/backward pass size (MB): 0.06 # 正向/反向传播一次的内存大小
Params size (MB): 0.08
Estimated Total Size (MB): 0.15
--------------------------------------------------------------
参考链接:
https://blog.csdn.net/MasterCayman/article/details/118693319
https://github.com/sksq96/pytorch-summary
(1)安装thop
pip install thop
(2)基本使用
from torchvision.models import resnet50
from thop import profile
model = resnet50()
input = torch.randn(1, 3, 224, 224)
flops, params = profile(model, inputs=(input, ))
参考:https://www.jianshu.com/p/6514b8fb1ada
https://zhuanlan.zhihu.com/p/337810633
着重参考:https://blog.csdn.net/junmuzi/article/details/83109660?utm_medium=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-2.control&depth_1-utm_source=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-2.control
from torchvision.models import googlenet
from thop import profile
model = googlenet()
input = torch.randn(1, 3, 224, 224)
flops, params = profile(model, inputs=(input,))
print("flops:",flops)
print('params',params)
参考:https://www.jianshu.com/p/cbada26ea29d?from=groupmessage