Pytorch 网络结构的可视化

在构建网络的过程中,需要查看网络结构,以便于优化,使用Pytorch常用的可视化工具有

1.Hidden layer

myNet = U_Net()
print(myNet)

#
## 可视化卷积神经网络,MyConvnet是定义的神经网络结构
hl_graph = hl.build_graph(myNet, torch.zeros([1, 3, 128 , 128]))
hl_graph.theme = hl.graph.THEMES["blue"].copy()
hl_graph.save("dataset/myNet.jpg", format = "png")

2.torchviz

x = torch.randn(size=(1, 3, 128, 128)).requires_grad_(True)
y = myNet(x)
MyConvNetVis = make_dot(y, params=dict(list(myNet.named_parameters()) + [('x', x)]))
MyConvNetVis.format = 'png'
MyConvNetVis.directory = './'
MyConvNetVis.view()

效果如下:

Pytorch 网络结构的可视化_第1张图片

 

3.Netron

使用netron的效果比较好,

x = torch.randn(1, 3, 128, 128)  # 随机生成一个输入
modelData = "./demo.pth"  # 定义模型数据保存的路径
# modelData = "./demo.onnx"  # 有人说应该是 onnx 文件,但我尝试 pth 是可以的
torch.onnx.export(myNet, x, modelData)  # 将 pytorch 模型以 onnx 格式导出并保存
netron.start(modelData)  # 输出网络结构

效果如下:

Pytorch 网络结构的可视化_第2张图片

 

你可能感兴趣的:(深度学习(PyTorch),pytorch,深度学习,人工智能)