Pytorch的模型结构可视化(tensorboard)

在pytorch中,可以导入tensorboard模块,可视化网络结构及训练流程。

下面通过“CNN训练MNIST手写数字分类”的小例子来学习一些可视化工具的用法,只需要加少量代码。

一、tensorboardX的安装

pip install tensorboard
pip install tensorflow
pip install tensorboardX

二、导入tensorboardX

import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets,transforms

#writer就相当于一个日志,保存你要做图的所有信息。第二句就是在你的项目目录下建立一个文件夹log,存放画图用的文件。刚开始的时候是空的
from tensorboardX import SummaryWriter
writer = SummaryWriter('log') #建立一个保存数据用的东西

三、搭建模型

#定义超参数
batch_size = 64
learning_rate = 1e-2
num_epoches = 20

#对数据进行预处理
data_tf = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize([0.5],[0.5])]
)


# 定义网络
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3),
            nn.BatchNorm2d(16),
            nn.ReLU(True))

        self.layer2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=3),
            nn.BatchNorm2d(32),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.layer3 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3),
            nn.BatchNorm2d(64),
            nn.ReLU(True))

        self.layer4 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.fc = nn.Sequential(
            nn.Linear(128 * 4 * 4, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 128),
            nn.ReLU(True),
            nn.Linear(128, 10)
        )

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

model = CNN()
print(model)

#下载数据集MNIST手写数字训练集
train_dataset = datasets.MNIST(
    root = './data',train=True,transform = data_tf,download = True)
test_dataset = datasets.MNIST(
    root = './data',train = False,transform = data_tf)
train_loader = DataLoader(train_dataset,batch_size=batch_size,shuffle=True)
test_loader = DataLoader(test_dataset,batch_size=batch_size,shuffle=False)

四、保存模型结构

在这里保存模型的数据流和结构:w.add_graph()

model = CNN()
dummy_input = torch.rand(20, 1, 28, 28)  # 假设输入20张1*28*28的图片
with SummaryWriter(comment='LeNet') as w:
    w.add_graph(model, (dummy_input,))

五、运行代码及可视化

1.运行代码

2.在Pycharm命令行输入

tensorboard --logdir = C:\Users\huangxin1\PycharmProjects\untitled\runs

注意 tensorboard --logdir= 路径,这里的路径改为runs文件下面生成的文件的完整路径,即:

Pytorch的模型结构可视化(tensorboard)_第1张图片

在浏览器打开命令行生成的地址,可以看到模型图结构:

Pytorch的模型结构可视化(tensorboard)_第2张图片

你可能感兴趣的:(Pytorch学习,深度学习)