我们将定义一个简单的CNN模型对手写字体数据进行分类,并对定义好的CNN模型进行可视化。
首先导入相关库和数据:
#导入相关库和数据
import torch
import torch.nn as nn
import torchvision
import torchvision.utils as vutils
from torch.optim import SGD
import torch.utils.data as Data
from sklearn.datasets import load_boston
from sklearn.metrics import accuracy_score
import numpy as np
import matplotlib.pyplot as plt
下面导入手写字体数据,并定义数据加载器,以及准备测试数据集,代码如下:
#导入手写字体数据,并定义数据加载器
train_data = torchvision.datasets.MNIST(
root="./data/MNIST",
train=True,#训练数据集
##将数据转化为张量,取值范围为[0,1]
transform=torchvision.transforms.ToTensor(),
download=True
)
#将数据处理为数据加载器
train_loader = Data.DataLoader(
dataset=train_data,
batch_size=64,
shuffle=True,
num_workers=2,)
#准备测试数据集
test_data = torchvision.datasets.MNIST(
root="./data/MNIST/",
train=False,
download=False)
#为测试数据添加一个通道维度,并将取值范围缩放到0~1
test_data_x = test_data.data.type(torch.FloatTensor)/255.0
test_data_x = torch.unsqueeze(test_data_x,dim=1)
test_data_y = test_data.targets#测试集的标签
print("test_data_x.shape:",test_data_x.shape)
print("test_data_y.shape:",test_data_y.shape)
结果为:(针对测试数据,我们不处理为数据加载器,而是将整个测试集作为一个batch)
test_data_x.shape: torch.Size([10000, 1, 28, 28])
test_data_y.shape: torch.Size([10000])
结束了数据预处理,我们定义一个简单的CNN,用于展示如何可视化网络结构,代码如下:
#搭建一个卷积神经网络
class ConvNet(nn.Module):
def __init__(self):
super(ConvNet,self).__init__()
#定义第一个卷积层
self.conv1 = nn.Sequential(
nn.Conv2d(
in_channels=1,
out_channels=16,
kernel_size=3,
stride=1,
padding=1,
),
nn.ReLU(),#激活函数
nn.AvgPool2d(
kernel_size=2,#2*2
stride=2,)
)
#定义第二个卷积层
self.conv2 = nn.Sequential(
nn.Conv2d(
in_channels=16,
out_channels=32,
kernel_size=3,
stride=1,
padding=1,
),
nn.ReLU(),#激活函数
nn.MaxPool2d(
kernel_size=2,#2*2
stride=2,)
)
#定义全连接层
self.fc = nn.Sequential(
nn.Linear(
in_features=32*7*7,#输入特征
out_features=128,#输出特征数
),
nn.ReLU(),#激活函数
nn.Linear(128,64),
nn.ReLU(),
)
#定义最后的分类层
self.out = nn.Linear(64,10)
def forward(self,x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.size(0),-1)#展平多维的卷积图层
x = self.fc(x)
output = self.out(x)
return output
我们简单输出一下网络结构,代码如下:
#输出网络结构
MyConvNet = ConvNet()
print(MyConvNet)
结果为:
ConvNet(
(conv1): Sequential(
(0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU()
(2): AvgPool2d(kernel_size=2, stride=2, padding=0)
)
(conv2): Sequential(
(0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU()
(2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(fc): Sequential(
(0): Linear(in_features=1568, out_features=128, bias=True)
(1): ReLU()
(2): Linear(in_features=128, out_features=64, bias=True)
(3): ReLU()
)
(out): Linear(in_features=64, out_features=10, bias=True)
)
虽然我们通过这里也可以看到网络结构,但是文本形式终归不算直观,下面我们尝试通过torchviz库对网络结构进行图像可视化处理。
首先,我们要下载graphviz和torchviz:在anaconda prompt中进入自己的Pytorch环境,输入代码来安装graphviz和torchviz
pip install graphviz torchviz
然后,我们需要额外下载exe安装包来安装graphviz的软件,否则使用时会报错:
下载地址为:Index of /Packages/stable/windows/10/cmake/Release/x64,下载win64.exe即可,并完成安装向导,在安装过程中注意勾选添加环境变量选项即可,如下:
安装结束后记得在用户环境变量和系统环境变量中都添加上Graphviz的bin文件夹所在路径,如下:
至此,安装结束。
接下来,我们开始使用pytorchviz库可视化网络
首先导入使用的包:
from torchviz import make_dot
接下来使用make_dot对网络进行可视化处理:
#使用makedot可视化网络
x = torch.randn(1,1,28,28).requires_grad_(True)
y = MyConvNet(x)
MyConvNetvis = make_dot(y,params=dict(list(MyConvNet.named_parameters())+[('x',x)]))
最后将可视化结果保存为图片形式存储:
#保存为图片
MyConvNetvis.format = "png"
MyConvNetvis.directory = "picture/chap4"
MyConvNetvis.view()
结果为:
'picture\\chap4\\Digraph.gv.png'