第一章 Python深度学习入门之环境软件配置
第二章 Python深度学习入门之数据处理Dataset的使用
第三章 数据可视化TensorBoard和TochVision的使用
我们通过上一节的学习知道了如何使用DataSet类读取文件夹中的图片数据,并对数据进行整理排序。接下来我们就需要将这些数据进行可视化显示,方便对数据进行分析和对比。我在学习之前对于数据的可视化会使用Python的画图库matlibplot,通过这个库我们可以绘制很多图形,比如折线图、柱状图等等。但是这个画图库太麻烦了,需要自己编写代码,今天我们要聊的TensorBoard不需要编写代码就可以显示数据,并支持多种类型的数据,比如离散数据、图片数据等等。
Tensorboard是TensorFlow提供的一组可视化工具,可以帮助开发者方便的理解、调试、优化TensorFlow 程序。为后续的学习提供了巨大的方便。
说到调试,我们一般的调试都是通过Pycharm在某行代码处打个断点,然后一步一步执行,并关注Pycharm的变量池,找到问题点。关键是有的时候我们的程序并没有错误,只是训练结果并不是很满意,想优化程序,但是在运行过程中我们不知道每一步得到的结果是什么,变量池的数据也不是很直观。这时TensorBoard就可以解决以上问题。TensorBoard可以展示我们程序的训练过程结果,更直观的关注代码的效果。
from torch.utils.tensorboard import SummaryWriter
from PIL import Image
import numpy as np
import os
我们利用上一章下载的蚂蚁和蜜蜂的数据作为数据集,使用TensorBoard进行显示。
image_path = 'H:\\learn\\hymenoptera_data\\train\\ants_image' # 此处为数据集的文件夹路径
image_name = os.listdir(image_path)
该处我们从文件夹中获取到了蚂蚁数据集的图片名称,因为我们要显示图片,所以我们要将每一张图片放入TensorBoard中。
首先我们需要通过SummaryWriter创建一个存储数据的文件夹,然后再向这个文件夹中添加需要显示的数据和图片。
writer = SummaryWriter('logs') # 在当前目录下创建logs文件夹,存储需要显示的数据
# 因为我们的数据都存储在一个列表中,因此我们需要遍历列表将每张图片给加载进去,这里我们测试10张图片
for i in range(10):
image = Image.open(os.path.join(image_path, image_name[i])) # 读取图片
image_array = np.array(image) # 将image图片的类型转换成TensorBoard需要的类型
writer.add_image('test', image_array, i, dataformats='HWC')
# 我们再测试一下显示折线图
for i in range(100):
writer.add_scalar("y = 2x",3 * i, i)
writer.close() # 最后记得关闭writer
然后运行代码,可以看到程序运行完后会终止。然后我们打开Terminal,进入到pytorch环境运行以下代码
(pytorch) PS H:> tensorboard --logdir=‘logs’ --port=‘6006’
这个 --port=‘6006’ 这条指令可以修改界面显示的端口号,默认是6006。
出现上面的输出,就说明TensorBoard启动成功,我们点击链接或者复制用浏览器打开即可显示数据内容。
接下来我们来解析上面代码用到的方法:
我们通过上面TensorBoard显示的图来理解一下这几个参数的含义。
torchvision是Pytorch的一个图形库,在后续深度学习的学习过程中相当重要,因此我们需要着重探究torchvision下的包。
我们通过查看transforms的源代码可以看到它提供了很多图形变换的类,我们可以通过Pycharm的Structure来查看他所有的方法。
Example:
>>> transforms.Compose([
>>> transforms.CenterCrop(10), #首先对图片进行中心裁剪
>>> transforms.PILToTensor(), #然后将图片转换成tensor类型的
>>> transforms.ConvertImageDtype(torch.float), #再对图片进行类型转换
>>> ])
torchvision包中常用的就是上面这几个类,在后续的学习中我们会经常用到,因此我们要熟练掌握,也可以经常去看看torchvision的官网:https://pytorch.org/vision/stable/index.html
torchvision.dataset包我们主要是用来下载数据集的,在之前的学习中,我们都是通过数据集链接在浏览器中下载的,比如第二章中的蚂蚁和蜜蜂的数据集。今天通过dataset包我们就可以下载比较常用的数据集,比如CIFAR10、MNIST等等。
class CIFAR10(VisionDataset):
"""`CIFAR10 `_ Dataset.
Args:
root (string): Root directory of dataset where directory
``cifar-10-batches-py`` exists or will be saved to if download is set to True.
train (bool, optional): If True, creates dataset from training set, otherwise
creates from test set.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""
def __init__(
self,
root: str,
train: bool = True,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
)
通过上面的源码我们可以知道下载CIFAR10数据集需要传入5个参数:
接下来我们就通过代码来测试一下通过torchvision.dataset类从网上下载CIFAR10数据集。
通过上面的截图我们可以看到数据集已经下载下来了,在同一目录下创建了一个cifar10文件夹,虽然下载速度比较慢,但是这样下载这些数据集比较方便。我们也可以通过这种方式去下载其他的数据集。
以上的TensorBoard和TorchVision是Pytorch中比较重要的库,在后续深度学习的学习过程中会经常用到这些知识点。因此我们还是需要反复的去使用这些库函数和方法,并且要多去看官方文档,官方文档才是最具权威性的学习资料。然后大家有什么问题可以在下方评论留言,大家一起学习进步!!