pytorch框架学习(13)——可视化工具TensorBoard

文章目录

  • 1. TensorBoard简介
  • 2. tensorboard使用
    • 2.1 SummaryWriter
    • 2.2 方法

1. TensorBoard简介

TensorBoard:TensorFlow中强大的可视化工具
支持标量、图像、文本、音频、视频和Embedding等多种数据可视化

  • 运行机制
    pytorch框架学习(13)——可视化工具TensorBoard_第1张图片

tensorboard --logdir=./runs

  • 作业
    熟悉TensorBoard的运行机制,安装TensorBoard,并绘制曲线 y = 2*x
import numpy as np
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter(comment='test_tensorboard')

for x in range(100):
    writer.add_scalar('y=2x', x * 2, x)
writer.close()

pytorch框架学习(13)——可视化工具TensorBoard_第2张图片

2. tensorboard使用

2.1 SummaryWriter

  • 功能:提供创建event file的高级接口
  • 主要属性:
    • log_dir:event file输出文件夹
    • comment:不指定log_dir时,文件夹后缀
    • filename_suffix:event file文件名后缀
log_dir = "./train_log/test_log_dir"
# writer = SummaryWriter(log_dir=log_dir, comment='_scalars', filename_suffix="12345678")
writer = SummaryWriter(comment='_scalars', filename_suffix="12345678")

for x in range(100):
    writer.add_scalar('y=pow_2_x', 2 ** x, x)

writer.close()

2.2 方法

  • 1.add_sclalar(只能记录一条曲线)

    • 功能:记录标量
    • tag:图像的标签名,图的唯一标识
    • scalar_value:要记录的标量
    • global_step:x轴
  • 2.add_scalars()

    • main_tag:该图的标签
    • tag_scalar_dict:key是变量的tag,value是变量的值
max_epoch = 100

writer = SummaryWriter(comment='test_comment', filename_suffix="test_suffix")

for x in range(max_epoch):

    writer.add_scalar('y=2x', x * 2, x)
    writer.add_scalar('y=pow_2_x', 2 ** x, x)
    writer.add_scalars('data/scalar_group', {"xsinx": x * np.sin(x),
                                             "xcosx": x * np.cos(x)}, x)
  • 3.add_histogram()
    • 功能:统计直方图与多分位数折线图
    • tag:图像的标签名,图的唯一标识
    • values:要统计的参数
    • global_step:y轴
    • bins:取直方图的bins(一般不需设置)
    writer = SummaryWriter(comment='test_comment', filename_suffix="test_suffix")

    for x in range(2):

        np.random.seed(x)

        data_union = np.arange(100)
        data_normal = np.random.normal(size=1000)

        writer.add_histogram('distribution union', data_union, x)
        writer.add_histogram('distribution normal', data_normal, x)

        plt.subplot(121).hist(data_union, label="union")
        plt.subplot(122).hist(data_normal, label="normal")
        plt.legend()
        plt.show()

    writer.close()
  • 4.add_image()

    • 功能:记录图像
    • tag:图像的标签名,图的唯一标识
    • img_tensor:图像数据,注意尺度
    • global_step:x轴
    • dataformats:数据形式,CHW,HWC,HW(C:chanel,H:high,W:width)
  • 5.torchvision.utils.make_grid

    • 功能:制作网格图像
    • tensor:图像数据,BCH*W形式
    • nrow:行数(列数自动计算)
    • padding:图像间距(像素单位)
    • normalize:是否将图像值标准化
    • range:标准化范围
    • scale_each:是否单张图维度标准化
    • pad_value:padding的像素值
  • 6.add_graph()

    • 功能:可视化模型计算图
    • model:模型,必须是nn.Module
    • input_to_model:输出给模型的数据
    • verbose:是否打印计算图结构信息
  • 7.torchsummary

    • 功能:查看模型信息,便于调试
    • model:pytorch模型
    • input_size:模型输入size
    • batch_size:batch size
    • device:“cuda” or “cpu”

你可能感兴趣的:(Pytorch,数据可视化,pytorch)