pytorch 网络结构可视化之netron

pytorch 网络结构可视化之netron


目录

  • pytorch 网络结构可视化之netron
  • 一、netron
  • 二、使用步骤
    • 1.安装可视化工具netron
    • 2.导出可视化模型文件
      • ①导出onnx格式模型文件
      • ②torch.jit.trace转换模型文件
    • 3.netron载入模型
  • 三、总结


一、netron

netron是一个深度学习模型可视化库,支持以下格式的模型存储文件:

ONNX (.onnx, .pb)
Keras (.h5, .keras)
CoreML (.mlmodel)
TensorFlow Lite (.tflite)
TensorFlow

但netron并不支持pytorch通过torch.save方法导出的模型文件,(及可视化过程中无法捕获模型的执行操作与结构)。
因此在pytorch保存模型的时候,可以用torch.onnx模块将其导出为onnx格式的模型文件,或用torch.jit.trace模块追踪模型在输入数据后的执行路径调用的操作。

整体的流程分为两步,第一步,基于pytorch两种方法导出模型文件。第二步,netron载入模型文件,进行可视化。

二、使用步骤

1.安装可视化工具netron

pip install netron

2.导出可视化模型文件

①导出onnx格式模型文件

import torchvision
import torch
data = torch.rand(1, 3, 224, 224)
model=torchvision.models.resnet50()
output = model(data)
# 导出为onnx格式
onnx_path = "onnx_model.onnx"
torch.onnx.export(model, data, onnx_path)

②torch.jit.trace转换模型文件

torch.jit.trace在跟踪遇到的计算步骤时通过函数或模块运行示例输入,并输出执行Tracing操作的基于图形的函数。Tracing非常适用于不涉及数据相关控制流的简单模块和功能,例如标准卷积神经网络。但是,如果Tracing具有依赖于数据的if语句和循环的函数,则仅记录由示例输入执行的执行路径调用的操作,即尽量避免转换代码中有if条件控制的模型。

import torchvision
import torch
data = torch.rand(1, 3, 224, 224)
model=torchvision.models.resnet50()
output = model(data)
trace_model = torch.jit.trace(model, data)
trace_model.save("mtrace.pt")

如果模型设计多个输入,需要将传入torch.onnx.export和torch.jit.trace中的data参数改为多输入张量元组,即data=(input1,input2 )

3.netron载入模型

如果能成功转换模型,在python代码调用netron库来载入模型进行可视化。

import netron
netron.start("mtrace.pth")

netron还做了一个在线demo网站,可以直接上传模型文件查看可视化结果,与代码调用netron库来载入模型一样。网址https://netron.app/
pytorch 网络结构可视化之netron_第1张图片
整体效果比美观
pytorch 网络结构可视化之netron_第2张图片

三、总结

在实际过程中,由于网络模型中复杂的结构以及调用,导出为onnx格式的模型时会出现各式各样的问题
在这里插入图片描述
但torch.jit.trace相对好用一些,能使我们快速便捷地了解复杂模型。
pytorch 网络结构可视化之netron_第3张图片

你可能感兴趣的:(pytorch,深度学习)