pytorch中的可视化:网络模型可视化以及特征图可视化

一、使用netron工具可视化pytorch模型,tensorboard太丑了不直观。

项目地址:https://github.com/lutzroeder/Netro

参考:使用netron工具可视化pytorch模型_jieleiping的博客-CSDN博客_netron pytorch

1.安装netron

pip install netron

2.直接可视化onnx/pth/pt(注意部分版本pt/pth需要转化为onnx,因为出现没线)

import netron

netron.start('net_pytorch.pth')

netron.start('net.onnx')

3.模型导出为onnx格式

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.onnx

import netron


class ForwardNet(nn.Module):
    def __init__(self):
        super(ForwardNet, self).__init__()
        self.block1 = nn.Sequential(
            nn.Conv2d(64, 64, 3, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 32, 1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 64, 3, padding=1, bias=False),
            nn.BatchNorm2d(64)
        )

        self.conv1 = nn.Conv2d(3, 64, 3, padding=1, bias=False)
        self.output = nn.Sequential(
            nn.Conv2d(64, 1, 3, padding=1, bias=True),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.conv1(x)
        identity = x
        x = F.relu(self.block1(x) + identity)
        x = self.output(x)
        return x


input = torch.rand(1, 3, 416, 416)
model = ForwardNet()
output = model(input)

onnx_path = "netForwatch.onnx"
torch.onnx.export(model, input, onnx_path)

netron.start(onnx_path)

执行上面代码后,会调用本地浏览器打开,形式和tensorboard差不多。

pytorch中的可视化:网络模型可视化以及特征图可视化_第1张图片

4.pth/pt转onnx

对pytorch模型格式(.pt/.pth)支持不友好,因此需要存为onnx

import torch
from model import Model

pytorch_net_path = 'net.pth'
onnx_net_path = 'net.onnx'

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu' )

# 权重导入模型
net = Model().to(device)
net.load_state_dict(torch.load(pytorch_net_path, map_location=device))
net.eval()

input = torch.randn(1, 1, 30, 30).to(device)   # BCHW  其中Batch必须为1,因为测试时一般为1,尺寸HW必须和训练时的尺寸一致
torch.onnx.export(net, input, onnx_net_path, verbose=False)

二、特征图可视化

参考:PyTorch模型训练特征图可视化(TensorboardX) - 知乎

你可能感兴趣的:(pytorch相关,onnx,pth转化)