pytorch学习笔记——4.1Pytorch中网络结构的可视化

摘要:

深度学习网络通常具有比较深的层次结构,因此需要可视化工具将建立的深度学习网络结构层次化的展示出来。本文中我们首先定义一个简单的CNN网络对MNIST数据进行分类,并通过PytorchViz库进行网络的可视化处理。


一、准备网络和数据

        我们将定义一个简单的CNN模型对手写字体数据进行分类,并对定义好的CNN模型进行可视化。

        首先导入相关库和数据:

#导入相关库和数据
import torch
import torch.nn as nn
import torchvision
import torchvision.utils as vutils
from torch.optim import SGD
import torch.utils.data as Data
from sklearn.datasets import load_boston
from sklearn.metrics import accuracy_score
import numpy as np
import matplotlib.pyplot as plt

        下面导入手写字体数据,并定义数据加载器,以及准备测试数据集,代码如下:

#导入手写字体数据,并定义数据加载器
train_data = torchvision.datasets.MNIST(
        root="./data/MNIST",
        train=True,#训练数据集
        ##将数据转化为张量,取值范围为[0,1]
        transform=torchvision.transforms.ToTensor(),
        download=True
        )
#将数据处理为数据加载器
train_loader = Data.DataLoader(
                                dataset=train_data,
                                batch_size=64,
                                shuffle=True,
                                num_workers=2,)
#准备测试数据集
test_data = torchvision.datasets.MNIST(
root="./data/MNIST/",
train=False,
download=False)

#为测试数据添加一个通道维度,并将取值范围缩放到0~1
test_data_x = test_data.data.type(torch.FloatTensor)/255.0
test_data_x = torch.unsqueeze(test_data_x,dim=1)
test_data_y = test_data.targets#测试集的标签
print("test_data_x.shape:",test_data_x.shape)
print("test_data_y.shape:",test_data_y.shape)

结果为:(针对测试数据,我们不处理为数据加载器,而是将整个测试集作为一个batch)

test_data_x.shape: torch.Size([10000, 1, 28, 28])
test_data_y.shape: torch.Size([10000])

        结束了数据预处理,我们定义一个简单的CNN,用于展示如何可视化网络结构,代码如下:

#搭建一个卷积神经网络
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet,self).__init__()
        #定义第一个卷积层
        self.conv1 = nn.Sequential(
            nn.Conv2d(
                in_channels=1,
                out_channels=16,
                kernel_size=3,
                stride=1,
                padding=1,
                ),
            nn.ReLU(),#激活函数
            nn.AvgPool2d(
                kernel_size=2,#2*2
                stride=2,)
        )
        #定义第二个卷积层
        self.conv2 = nn.Sequential(
            nn.Conv2d(
                in_channels=16,
                out_channels=32,
                kernel_size=3,
                stride=1,
                padding=1,
                ),
            nn.ReLU(),#激活函数
            nn.MaxPool2d(
                kernel_size=2,#2*2
                stride=2,)
        )
        #定义全连接层
        self.fc = nn.Sequential(
            nn.Linear(
                in_features=32*7*7,#输入特征
                out_features=128,#输出特征数
                ),
            nn.ReLU(),#激活函数
            nn.Linear(128,64),
            nn.ReLU(),
        ) 
        #定义最后的分类层
        self.out = nn.Linear(64,10)
    def forward(self,x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0),-1)#展平多维的卷积图层
        x = self.fc(x)
        output = self.out(x)
        return output

  我们简单输出一下网络结构,代码如下:

#输出网络结构
MyConvNet = ConvNet()
print(MyConvNet)

结果为:

ConvNet(
  (conv1): Sequential(
    (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): AvgPool2d(kernel_size=2, stride=2, padding=0)
  )
  (conv2): Sequential(
    (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc): Sequential(
    (0): Linear(in_features=1568, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=64, bias=True)
    (3): ReLU()
  )
  (out): Linear(in_features=64, out_features=10, bias=True)
)

二、PytorchViz库可视化网络

        虽然我们通过这里也可以看到网络结构,但是文本形式终归不算直观,下面我们尝试通过torchviz库对网络结构进行图像可视化处理。

        首先,我们要下载graphviz和torchviz:在anaconda prompt中进入自己的Pytorch环境,输入代码来安装graphviz和torchviz

pip install graphviz torchviz

         然后,我们需要额外下载exe安装包来安装graphviz的软件,否则使用时会报错:

下载地址为:Index of /Packages/stable/windows/10/cmake/Release/x64,下载win64.exe即可,并完成安装向导,在安装过程中注意勾选添加环境变量选项即可,如下:

pytorch学习笔记——4.1Pytorch中网络结构的可视化_第1张图片

 安装结束后记得在用户环境变量和系统环境变量中都添加上Graphviz的bin文件夹所在路径,如下:

pytorch学习笔记——4.1Pytorch中网络结构的可视化_第2张图片

 至此,安装结束。

        接下来,我们开始使用pytorchviz库可视化网络

        首先导入使用的包:

from torchviz import make_dot

        接下来使用make_dot对网络进行可视化处理:

#使用makedot可视化网络
x = torch.randn(1,1,28,28).requires_grad_(True)
y = MyConvNet(x)
MyConvNetvis = make_dot(y,params=dict(list(MyConvNet.named_parameters())+[('x',x)]))

        最后将可视化结果保存为图片形式存储:

#保存为图片
MyConvNetvis.format = "png"
MyConvNetvis.directory = "picture/chap4"
MyConvNetvis.view()

结果为:

'picture\\chap4\\Digraph.gv.png'

pytorch学习笔记——4.1Pytorch中网络结构的可视化_第3张图片

 

你可能感兴趣的:(pytorch入门,pytorch,学习,深度学习)