在神经网络结构可视化这一块,有很多比较实用的工具,但目前来说我看的比较舒服的还是微软小哥开发的Netron软件。
最近又遇到了一些问题,在可视化yolov5结构的时候,使用官方自带的export.py
导出的结构图跟设计的有些出入,使用的是torch.onnx.export
方法,将模型导出为onnx
格式,再使用Netron打开,这种方式虽然确实可以,但是在可视化yolov5结构的时候,卷积层模块莫名其妙出现了残差结构,同样在可视化简单模型的时候,BN层会莫名其妙不见了,这让我费解,所以使用torch.onnx.export
方法不是很靠谱。
目前我认为最可靠的方法就是使用torch.jit.trace
把模型保存为.pt
文件,然后在使用Netron进行打开。
直接上代码,我封装成了一个函数:
import torch
import torch.nn as nn
def to_pt(net, x, name="temp"):
"""
@net :保存的模型
@x :模型的输入
@name :保存名称
"""
path = f"./{name}.pt"
# 主要就是下面两行代码
script_model = torch.jit.trace(net, x)
script_model.save(path)
具体为什么呢,torch.jit.trace
是Troch Script模块中的函数,其中trace是跟踪执行步骤,记录模型调用推理时执行的每个步骤,主要就是整个模型保存结构的同时,还跟踪了输入,将前向传播过程给记录了下了,大概就是这么个意思(我也不是很熟悉这个模块,详情还是去度娘吧)。
举例:
import torch
import torch.nn as nn
class BN_Conv2d(nn.Module):
"""
Conv2d + BN + ReLU
"""
def __init__(self, c1, c2, k, s=1, p=0, d=1, g=1):
super(BN_Conv2d, self).__init__()
self.conv = nn.Conv2d(c1, c2, k, s, p, d, g)
self.bn = nn.BatchNorm2d(c2)
self.relu = nn.ReLU()
def forward(self, x):
return self.relu(self.bn(self.conv(x)))
net = BN_Conv2d(3, 16, 3, 2)
x = torch.randn((1, 3, 32, 32))
to_pt(net, x, "BN_Conv2d")
class ResBlock(nn.Module):
def __init__(self, ch_in, ch_out, stride=1):
super(ResBlock, self).__init__()
self.seq = nn.Sequential(
nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=stride, padding=1),
nn.BatchNorm2d(ch_out),
nn.ReLU(),
nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(ch_out)
)
self.extra = nn.Sequential()
if ch_in != ch_out:
self.extra = nn.Sequential(
nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=stride)
)
self.relu = nn.ReLU()
def forward(self, x):
out = self.seq(x)
x = self.extra(x)
out = out + x
out = self.relu(out)
return out
net = ResBlock(3, 64, 1)
x = torch.randn((1, 3, 224, 224))
to_pt(net, x, "ResBlock")
Yolov5:该文件放在yolov5文件下,图太长了就简单放一部分。
import torch
from models.yolo import Model
import os
def to_onnx_cfg(cfg, x, name="temp"):
# Create model
model = Model(cfg).to("cpu")
# x = torch.randn(1, 3, 640, 640).to("cpu")
script_model = torch.jit.trace(model, x)
if not os.path.exists(os.getcwd() + "/onnx"):
os.makedirs(os.getcwd() + "/onnx")
# torch.onnx.export(model, x, f"./{name}.onnx")
script_model.save(f"./onnx/{name}.pt")
x = torch.randn(1, 3, 640, 640).to("cpu")
to_onnx_cfg("./yolov5s.yaml", x, "yolov5s")
以上就是这些,对比图的话,之后我再放上来。
参考博客:
yolov5s.yaml中各参数作用意义及使用netron工具来可视化yolov5s的结构
pytorch JIT浅解析
TORCH.JIT理解