【Pytorch】netron可视化——靠谱的使用方法

在神经网络结构可视化这一块,有很多比较实用的工具,但目前来说我看的比较舒服的还是微软小哥开发的Netron软件。

最近又遇到了一些问题,在可视化yolov5结构的时候,使用官方自带的export.py导出的结构图跟设计的有些出入,使用的是torch.onnx.export方法,将模型导出为onnx格式,再使用Netron打开,这种方式虽然确实可以,但是在可视化yolov5结构的时候,卷积层模块莫名其妙出现了残差结构,同样在可视化简单模型的时候,BN层会莫名其妙不见了,这让我费解,所以使用torch.onnx.export方法不是很靠谱。


目前我认为最可靠的方法就是使用torch.jit.trace把模型保存为.pt文件,然后在使用Netron进行打开。

直接上代码,我封装成了一个函数:

import torch
import torch.nn as nn

def to_pt(net, x, name="temp"):
    """
    @net	:保存的模型
    @x		:模型的输入
    @name	:保存名称
    """
    path = f"./{name}.pt"
    # 主要就是下面两行代码
    script_model = torch.jit.trace(net, x)
    script_model.save(path)

具体为什么呢,torch.jit.trace是Troch Script模块中的函数,其中trace是跟踪执行步骤,记录模型调用推理时执行的每个步骤,主要就是整个模型保存结构的同时,还跟踪了输入,将前向传播过程给记录了下了,大概就是这么个意思(我也不是很熟悉这个模块,详情还是去度娘吧)。

举例:

import torch
import torch.nn as nn

class BN_Conv2d(nn.Module):
    """
    Conv2d + BN + ReLU
    """

    def __init__(self, c1, c2, k, s=1, p=0, d=1, g=1):
        super(BN_Conv2d, self).__init__()
        self.conv = nn.Conv2d(c1, c2, k, s, p, d, g)
        self.bn = nn.BatchNorm2d(c2)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.relu(self.bn(self.conv(x)))
    
net = BN_Conv2d(3, 16, 3, 2)
x = torch.randn((1, 3, 32, 32))
to_pt(net, x, "BN_Conv2d")

【Pytorch】netron可视化——靠谱的使用方法_第1张图片

class ResBlock(nn.Module):
    def __init__(self, ch_in, ch_out, stride=1):
        super(ResBlock, self).__init__()
        self.seq = nn.Sequential(
            nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=stride, padding=1),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(),
            nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(ch_out)
        )
        self.extra = nn.Sequential()
        if ch_in != ch_out:
            self.extra = nn.Sequential(
                nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=stride)
            )
        self.relu = nn.ReLU()

    def forward(self, x):
        out = self.seq(x)
        x = self.extra(x)
        out = out + x
        out = self.relu(out)
        return out

net = ResBlock(3, 64, 1)
x = torch.randn((1, 3, 224, 224))
to_pt(net, x, "ResBlock")

【Pytorch】netron可视化——靠谱的使用方法_第2张图片

Yolov5:该文件放在yolov5文件下,图太长了就简单放一部分。

import torch
from models.yolo import Model
import os


def to_onnx_cfg(cfg, x, name="temp"):
    # Create model
    model = Model(cfg).to("cpu")
    # x = torch.randn(1, 3, 640, 640).to("cpu")
    script_model = torch.jit.trace(model, x)
    if not os.path.exists(os.getcwd() + "/onnx"):
        os.makedirs(os.getcwd() + "/onnx")
    # torch.onnx.export(model, x, f"./{name}.onnx")
    script_model.save(f"./onnx/{name}.pt")


x = torch.randn(1, 3, 640, 640).to("cpu")
to_onnx_cfg("./yolov5s.yaml", x, "yolov5s")

【Pytorch】netron可视化——靠谱的使用方法_第3张图片

以上就是这些,对比图的话,之后我再放上来。


参考博客:

yolov5s.yaml中各参数作用意义及使用netron工具来可视化yolov5s的结构

pytorch JIT浅解析

TORCH.JIT理解

你可能感兴趣的:(Pytorch学习笔记,深度学习,Python,pytorch,深度学习,神经网络)