注:本文为学习 DataWhale 开源教程《深入浅出 Pytorch》学习笔记,仅记录笔者认为价值较大、使用较少的知识点。
Pytorch 自身仅提供了简单的可视化功能,即打印模型,会展现模型的基本结构,例如运行下列代码:
import torchvision.models as models
model = models.resnet18()
print(model)
以上代码加载了 torchvision 中定义好的 resnet18 模型,并直接打印出来,结果如图:
第三方库 torchinfo 提供了更完备的可视化功能,我们可通过 conda 命令安装该第三方库:
conda install -c conda-forge torchinfo
该库可直接通过 summary 函数使用即可,需要的两个参数为模型及输入维度:
summary(resnet18, (1, 3, 224, 224))
# 1 为batch_size
打印结果如图:
如图,可以便捷地看到输入在每一层的维度以及该层的参数数量。
在训练过程中,除了简单打印之外,我们还可以通过第三方库 tensorBoard 来实现训练过程的可视化。
常规情况下,可通过 conda 安装:
conda install tensorBoardX
安装完成之后,首先配置训练过程记录的存放位置:
from tensorboardX import SummaryWriter
writer = SummaryWriter('./cache')
# 此处配置在当前工作目录的子目录 cache 下面
接着通过以下命令可以运行 tensorBoard:
tensorboard --logdir=/path/to/logs/
由于笔者在服务器上使用 Pytorch,因此需要在服务器上配置安装 tensorBoard,在服务器上的安装略有不同,见下文。
在服务器上安装 tensorBoard,基本流程大致类似,但会遇到一些其他问题,笔者在此将整体流程和遇到的问题列于下,供读者参考。
首先同样需要通过 conda 或 pip 安装该库,并在 Python 环境中设置保存路径。
接着运行下列命令运行 tensorBoard:
tensorboard --logdir=/home/user/notebook/cache --port=8081
同常规情况不同的是,在此需要指定运行端口号,之后在本地通过 ip:端口号即可访问 tensorBoard 界面,同时,此处需指明数据文件的存储位置。例如笔者使用服务器 ip 为 1.1.1.1234,则通过 1.1.1.1234:8081 即可访问该界面。
安装过程中遇到的几个问题:
该问题解决方案:重装 tensorBoard 1.15.0 版本,可使用 pip 命令:
pip install tensorboardX==1.15.0
该问题解决方案:找到报错的文件路径(即 xx/collections/_init_.py),打开该文件,在文件中加入:
from collections.abc import Mapping
首先定义一个模型:
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(in_channels=3,out_channels=32,kernel_size = 3)
self.pool = nn.MaxPool2d(kernel_size = 2,stride = 2)
self.conv2 = nn.Conv2d(in_channels=32,out_channels=64,kernel_size = 5)
self.adaptive_pool = nn.AdaptiveMaxPool2d((1,1))
self.flatten = nn.Flatten()
self.linear1 = nn.Linear(64,32)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(32,1)
self.sigmoid = nn.Sigmoid()
def forward(self,x):
x = self.conv1(x)
x = self.pool(x)
x = self.conv2(x)
x = self.pool(x)
x = self.adaptive_pool(x)
x = self.flatten(x)
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
y = self.sigmoid(x)
return y
model = Net()
print(model)
接着使用 add_graph 函数启用 tensorBoard:
writer.add_graph(model, input_to_model = torch.rand(1, 3, 224, 224))
writer.close()
然后在浏览器打开对应网页即可: