TensorBoard是Tensorflow自带的一个强大的可视化工具,也是一个Web应用程序套件。目前也可以在pytorch中使用,TensorBoard目前支持7种可视化,Scalars,Images,Audio,Graphs,Distributions,Histograms和Embeddings。其中可视化的主要功能如下。
(1)Scalars:展示训练过程中的准确率、损失值、权重/偏置的变化情况。
(2)Images:展示训练过程中记录的图像。
(3)Audio:展示训练过程中记录的音频。
(4)Graphs:展示模型的数据流图,以及训练在各个设备上消耗的内存和时间。
(5)Distributions:展示训练过程中记录的数据的分部图。
(6)Histograms:展示训练过程中记录的数据的柱状图。
(7)Embeddings:展示词向量后的投影分部。
使用TensorBoard展示数据,需要在执行Tensorflow就算图的过程中,将各种类型的数据汇总并记录到日志文件中。然后使用TensorBoard读取这些日志文件,解析数据并生产数据可视化的Web页面,让我们可以在浏览器中观察各种汇总数据。
pytorch从1.2.0开始支持tensorboard
# 如果有conda
conda install tensorboard
# 或者pip安装
pip install tensorboard
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter('logs')
x = range(100)
for i in x:
writer.add_scalar('y=x+10', i+10, i)
writer.close()python
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter("runs/logs_fina") # 存放log文件的目录
writer.add_scalar('train/loss', ave_train_loss, epoch) # 画loss,横坐标为epoch
writer.add_scalar('train/lr', ave_lr, epoch)
writer.close()
运行
tensorboard --logdir=logs
# 或者使用logs文件夹的绝对地址
tensorboard --logdir=绝对地址
# 当然也可以修改端口和ip
tensorboard --logdir=logs --host=127.0.0.1
在http://localhost:6006/网址的右侧,可以看到有一个smoothing可调,这个用来调整平滑度,默认0.6,有些图不平滑的话很难看出趋势。
其中对img_tensor的形状有要求,而默认格式是(3,H,W)即通道(channel)为3,H为高度,W为宽度,不是格式需要使用dataformats=''
,该参数填写的数据为:CHW
, HWC
, HW
,默认为 CHW
from torch.utils.tensorboard import SummaryWriter
from PIL import Image
import numpy as np
writer = SummaryWriter('logs')
image_path = 'dog.png'
img_PIL = Image.open(image_path) # 创建PIL的图片类
image_array = np.array(img_PIL) # 转成
# print(image_array.shape) # (1200,1920,3)
writer.add_image('dog', image_array, 0, dataformats='HWC')
writer.close()
dummy_input = torch.rand(512, 1, 28, 28) # 网络中输入的数据维度
with SummaryWriter(comment='LeNet') as w:
w.add_graph(net, (dummy_input,)) # net是你的网络名