make_dot 神经网络可视化

make_dot可以打印神经网络结构,并存储。

1. 库

from torchviz import make_dot

2.实现

net_struct = make_dot(net_out)
net_struct.render("net_struct", view=False)

由于在linux系统使用该函数,没有可视化,因此将view设置为False即可将网络结构存储为pdf格式。

3.示例

import tensorwatch as tw
import torchvision.models
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchviz import make_dot

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        # 1 input channel, 6 output channel, 5 conv
        self.conv1 = nn.Conv2d(1, 6, 5, bias=False)
        self.conv2 = nn.Conv2d(6, 10, 5, bias=False)
        self.bn1 = nn.BatchNorm2d(10, eps=1e-05, momentum=0.1)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.bn1(x)
        x = self.relu(x)
        return x

    def num_flat_features(self, x):
        size = x.size()[1:]
        num_features = 1
        for s in size:
            num_features *=s
        return num_features

net = Net()
print(net)
x = torch.zeros(1, 1, 20, 20, dtype=torch.float, requires_grad=False)
net_out = net(x)
net_struct = make_dot(net_out)  # plot graph of variable, not of a nn.Module
net_struct.render("net_struct", view=False)

打印结果如图所示,可以看到网络结构中对应的conv,bn,relu层及相关参数size


net.PNG

你可能感兴趣的:(make_dot 神经网络可视化)