在pytorch中,可以导入tensorboard模块,可视化网络结构及训练流程。
下面通过“CNN训练MNIST手写数字分类”的小例子来学习一些可视化工具的用法,只需要加少量代码。
一、tensorboardX的安装
pip install tensorboard
pip install tensorflow
pip install tensorboardX
二、导入tensorboardX
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets,transforms
#writer就相当于一个日志,保存你要做图的所有信息。第二句就是在你的项目目录下建立一个文件夹log,存放画图用的文件。刚开始的时候是空的
from tensorboardX import SummaryWriter
writer = SummaryWriter('log') #建立一个保存数据用的东西
三、搭建模型
#定义超参数
batch_size = 64
learning_rate = 1e-2
num_epoches = 20
#对数据进行预处理
data_tf = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize([0.5],[0.5])]
)
# 定义网络
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(1, 16, kernel_size=3),
nn.BatchNorm2d(16),
nn.ReLU(True))
self.layer2 = nn.Sequential(
nn.Conv2d(16, 32, kernel_size=3),
nn.BatchNorm2d(32),
nn.ReLU(True),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.layer3 = nn.Sequential(
nn.Conv2d(32, 64, kernel_size=3),
nn.BatchNorm2d(64),
nn.ReLU(True))
self.layer4 = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=3),
nn.BatchNorm2d(128),
nn.ReLU(True),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.fc = nn.Sequential(
nn.Linear(128 * 4 * 4, 1024),
nn.ReLU(True),
nn.Linear(1024, 128),
nn.ReLU(True),
nn.Linear(128, 10)
)
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
model = CNN()
print(model)
#下载数据集MNIST手写数字训练集
train_dataset = datasets.MNIST(
root = './data',train=True,transform = data_tf,download = True)
test_dataset = datasets.MNIST(
root = './data',train = False,transform = data_tf)
train_loader = DataLoader(train_dataset,batch_size=batch_size,shuffle=True)
test_loader = DataLoader(test_dataset,batch_size=batch_size,shuffle=False)
四、保存模型结构
在这里保存模型的数据流和结构:w.add_graph()
model = CNN()
dummy_input = torch.rand(20, 1, 28, 28) # 假设输入20张1*28*28的图片
with SummaryWriter(comment='LeNet') as w:
w.add_graph(model, (dummy_input,))
五、运行代码及可视化
1.运行代码
2.在Pycharm命令行输入
tensorboard --logdir = C:\Users\huangxin1\PycharmProjects\untitled\runs
注意 tensorboard --logdir= 路径,这里的路径改为runs文件下面生成的文件的完整路径,即:
在浏览器打开命令行生成的地址,可以看到模型图结构: