# %matplotlib inline
# 上述代码是一个注释,用于在Jupyter Notebook等环境中显示Matplotlib绘图的结果在单元格内部显示,而不是弹出新的窗口。
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l
# 导入必要的库和模块
# - torch:PyTorch库,用于构建和训练神经网络
# - torchvision:PyTorch中用于处理图像数据的库
# - torch.utils.data:PyTorch中用于处理数据加载的模块
# - torchvision.transforms:用于定义和应用数据转换的模块
# - d2l.torch:Dive into Deep Learning(《动手深度学习》)书中提供的PyTorch实用函数和工具
d2l.use_svg_display()
# 设置绘图的显示格式为SVG格式,这可以使绘图在Jupyter Notebook中以矢量图形的形式显示,更清晰和美观。
1、读取数据集
# 通过ToTensor实例将图像数据从PIL类型变换成32位浮点数格式,
# 并除以255使得所有像素的数值均在0〜1之间
trans = transforms.ToTensor()
# 创建FashionMNIST数据集的训练集实例
mnist_train = torchvision.datasets.FashionMNIST(
root="../data", # 数据集存放的根目录
train=True, # 表示加载训练集
transform=trans, # 数据变换,将图像数据转换为Tensor格式并归一化
download=True # 是否下载数据集(如果尚未下载的话)
)
# 创建FashionMNIST数据集的测试集实例
mnist_test = torchvision.datasets.FashionMNIST(
root="../data", # 数据集存放的根目录
train=False, # 表示加载测试集
transform=trans, # 数据变换,将图像数据转换为Tensor格式并归一化
download=True # 是否下载数据集(如果尚未下载的话)
)
数据集介绍
那么如何查看数据集中图片大小和通道数呢?以及训练和验证数据多少呢?
下面代码是将数字标签转换为文本标签
def get_fashion_mnist_labels(labels): #@save
"""
返回Fashion-MNIST数据集的文本标签
参数:
labels: 包含数值标签的列表或数组
返回:
包含对应文本标签的列表
"""
text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
return [text_labels[int(i)] for i in labels]
上面是啥意思呢?就是比如1表示苹果,这里以前标记的是1,现在转换为苹果
下面是可视化
def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5): #@save
"""
绘制图像列表
参数:
imgs: 包含图像的列表
num_rows: 图像展示的行数
num_cols: 图像展示的列数
titles: 可选参数,图像标题的列表
scale: 可选参数,控制图像的缩放比例
返回:
无返回值,显示绘制的图像
"""
# 计算绘图区域的尺寸
figsize = (num_cols * scale, num_rows * scale)
# 创建一个具有指定尺寸的子图
_, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)
# 将子图数组展平,以便逐个访问每个子图
axes = axes.flatten()
# 遍历图像列表并在每个子图中绘制图像
for i, (ax, img) in enumerate(zip(axes, imgs)):
if torch.is_tensor(img):
# 如果图像是PyTorch的张量,将其转换为NumPy数组并在子图上显示
ax.imshow(img.numpy())
else:
# 如果图像是PIL图像,直接在子图上显示
ax.imshow(img)
# 隐藏子图的x轴和y轴
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
if titles:
# 如果提供了标题列表,设置当前子图的标题
ax.set_title(titles[i])
# 返回绘制的子图数组
return axes
2、读取小批量
问:
(1)一个进程通常占用一个核心吗
是的
batch_size = 256
def get_dataloader_workers(): #@save
"""使用4个进程来读取数据"""
return 4
# 创建训练数据迭代器
train_iter = data.DataLoader(
mnist_train, # 使用的数据集实例
batch_size, # 每个批次的样本数量
shuffle=True, # 是否在每个epoch前打乱数据顺序
num_workers=get_dataloader_workers() # 用于加载数据的进程数量
)
(2)上面的data在哪里定义的?
看文章开头定义的
3、整合所有组件
就是合成上边所有的代码
def load_data_fashion_mnist(batch_size, resize=None): #@save
"""
下载Fashion-MNIST数据集,然后将其加载到内存中
参数:
batch_size: 批次大小,用于小批量训练
resize: 可选参数,指定图像调整的大小
返回:
包含训练数据迭代器和测试数据集的元组
"""
# 创建数据变换列表,将图像转换为Tensor格式
trans = [transforms.ToTensor()]
# 如果提供了resize参数,将图像调整大小添加到变换列表
if resize:
trans.insert(0, transforms.Resize(resize))
# 将变换列表组合成一个组合变换
trans = transforms.Compose(trans)
# 创建FashionMNIST数据集的训练集实例
mnist_train = torchvision.datasets.FashionMNIST(
root="../data", # 数据集存放的根目录
train=True, # 表示加载训练集
transform=trans, # 数据变换,包括调整大小和转换为Tensor
download=True # 是否下载数据集(如果尚未下载的话)
)
# 创建FashionMNIST数据集的测试集实例
mnist_test = torchvision.datasets.FashionMNIST(
root="../data", # 数据集存放的根目录
train=False, # 表示加载测试集
transform=trans, # 数据变换,包括调整大小和转换为Tensor
download=True # 是否下载数据集(如果尚未下载的话)
)
# 创建训练数据迭代器,并指定批次大小、是否打乱顺序和数据加载进程数量
train_data = data.DataLoader(
mnist_train, # 使用的训练数据集实例
batch_size, # 每个批次的样本数量
shuffle=True, # 是否在每个epoch前打乱数据顺序
num_workers=get_dataloader_workers() # 数据加载进程数量
)
# 创建测试数据迭代器,并指定批次大小、不打乱顺序和数据加载进程数量
test_data = data.DataLoader(
mnist_test, # 使用的测试数据集实例
batch_size, # 每个批次的样本数量
shuffle=False, # 不打乱数据顺序
num_workers=get_dataloader_workers() # 数据加载进程数量
)
# 返回训练数据迭代器和测试数据迭代器的元组
return train_data, test_data