数据集和transfroms结合在一起的内容
标准数据集的使用
torchvision.datasets 官方指导文档:https://pytorch.org/vision/stable/datasets.html
关于torchvision中的datasets,有多种类型标准数据集供选用,但每个数据集中参数需要说明
例子:
CIFAR10数据集
参数
cifar-10-batches-py
如果下载设置为 True,则该目录存在或将保存到该目录。transforms.RandomCrop
torchvision.datasets.CIFAR10(root: str, train: bool = True, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False)
注意: torchvision提供的数据集有很多,但每个数据集的参数不尽相同,一些数据集中的含有其他参数,具体如何使用还需要参照官方文档进行使用
import torchvision
# torchvision中CIFAR10数据集下载
train_set = torchvision.datasets.CIFAR10(root="./DataSet",train=True, download=True)
test_set = torchvision.datasets.CIFAR10(root="./DataSet",train=False, download=True)
# 输出数据集的所有classes
print(test_set.classes)
# 将test_set中第一个数据提取出来,img保存图片,target保存其对应的标签
img, target = test_set[0]
img.show()
print(test_set[0])
print(img)
print(target)
输出:
# 已经下载过的数据不会重复下载
Files already downloaded and verified
Files already downloaded and verified
#test_set 中的所有classes
['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
# test_set[0]
(<PIL.Image.Image image mode=RGB size=32x32 at 0x201A62E97B8>, 3)
# img
<PIL.Image.Image image mode=RGB size=32x32 at 0x201A62E95F8>
# target
3
简单介绍一下CIFAR10数据集:
CIFAR-10 数据集由 10 个类别的 60000 个 32x32 彩色图像组成,每个类别包含 6000 个图像。有 50000 个训练图像和 10000 个测试图像。
数据集分为五个训练批次和一个测试批次,每个批次有 10000 张图像。测试批次恰好包含来自每个类别的 1000 个随机选择的图像。训练批次包含随机顺序的剩余图像,但一些训练批次可能包含来自一个类的图像多于另一个。在它们之间,训练批次恰好包含来自每个类别的 5000 张图像。
将Transforms,torchvision的数据预处理功能同dataset进行结合,将数据集中数据进行数据类型转换,并使用torchvision的tensorboard小工具进行数据类型转换后的记录
import torchvision
# 数据集中数据为PIL.Image类型,使用ToTensor进行类型转换
from torch.utils.tensorboard import SummaryWriter
dataset_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
])
# 下载数据集,在其中参数规定数据的处理方式
train_set = torchvision.datasets.CIFAR10(root="./DataSet", train=True, transform=dataset_transform, download=True)
test_set = torchvision.datasets.CIFAR10(root="./DataSet", train=False, transform=dataset_transform, download=True)
print(type(test_set[0]))
# 可以看到数据类型为tensor,使用tensorboard进行记录
writer = SummaryWriter("Set_logs")
for i in range(10):
img, target = test_set[i]
writer.add_image(tag="Top10", img_tensor=img, global_step=i)
writer.close()
在terminal中tensorboard --logdir=Set_logs
进入tensorboard界面进行查看
与Dataset的不同,dataset是现成的已经完备的数据集,dataloader是数据加载器,dataloader所作的是从dataset中取数据,如何去、取多少,由dataloader中参数进行控制
官网中的信息:
数据加载器。结合数据集和采样器,并提供给定数据集的可迭代对象。
支持具有单进程或多进程加载、自定义加载顺序和可选的自动批处理(整理)和内存固定的地图样式和可迭代样式数据集。DataLoader
1
)。True
在每个 epoch 重新洗牌数据(默认值:)False
。Iterable
。__len__
如果指定,则shuffle
不得指定。sampler
,但一次返回一批索引。batch_size
与、shuffle
、sampler
和互斥 drop_last
。0
表示数据将在主进程中加载。(默认0
:)True
是,数据加载器将在返回之前将张量复制到 CUDA 固定内存中。如果您的数据元素是自定义类型,或者您collate_fn
返回的批次是自定义类型,请参见下面的示例。True
如果数据集大小不能被批次大小整除,则设置为丢弃最后一个不完整的批次。如果False
数据集的大小不能被批大小整除,那么最后一批将更小。(默认False
:)0
:)None
,这将在每个工作子进程上调用,并在播种之后和数据加载之前以工作人员 ID(一个 int in )作为输入。(默认:)[0, num_workers - 1]``None
None
,则 RandomSampler 将使用此 RNG 生成随机索引和多处理以生成 工作人员的base_seed。(默认None
:)2
意味着将在所有工作人员中预取总共 2 * num_workers 个样本。(默认2
:)True
是 ,数据加载器将不会在数据集被使用一次后关闭工作进程。这允许保持工作人员数据集实例处于活动状态。(默认False
:)DataLoader()中参数较多,但其中大部分就有默认值。
DataLoader()中仅有参数 dataset不具备默认值,是为了方便使用者使用非官方自制的数据集。
batch_size=4
参数的具体意义
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-th8ZhBFA-1652926054632)(D:\STUDY\神经网络和深度学习\images\batch_size的具体含义.png)]
dataset[0]返回(img,target),batch_size=4将每4个dataset同类型数据打包 及 imgs[img0,img1,img2,img3],target[target0,target1,target2,target3]
注意:这里虽然打包了imgs和targets,但是imgs中元素的排列顺序和targets中元素的排列并不一定是相互对应的。控制对应关系的是参数sampler
"""
使用dataloader加载dataset中CIFAR10数据集,并将取打包重新排列,变化过程放入tensorboard中
"""
import torchvision
# 使用dataset中CIFAR10数据集的测试集做为演示
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
test_data = torchvision.datasets.CIFAR10(root="./DataSet", train=False, transform=torchvision.transforms.ToTensor(),
download=True)
test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=False)
# 测试数据集中第一张图片样本及其归类
img, target = test_data[0]
print(img.shape) #torch.Size([3, 32, 32])
print(target) #3
# 参数batch_size=4 的具体含义就是将test_data中每4个元素同类型的数据进行打包
"""
imgs, targets = test_loader[0]
print(imgs.shape)
print(targets)
TypeError: 'DataLoader' object does not support indexing
"""
writer = SummaryWriter("dataloader_logs")
step = 0
for data in test_loader:
imgs,targets = data
#print(imgs.shape) # torch.Size([4, 3, 32, 32])
#print(targets) # tensor([1, 6, 4, 1])
writer.add_images("test_data",imgs,step)
step += 1
writer.close()
以上就是dataloader相关的简单操作,详细具体的使用还要参考源码及官方文档
ze([4, 3, 32, 32])
#print(targets) # tensor([1, 6, 4, 1])
writer.add_images(“test_data”,imgs,step)
step += 1
writer.close()
以上就是dataloader相关的简单操作,详细具体的使用还要参考源码及官方文档
在实际训练过程中,也会使用`for data in dataloader:`这样的形式来加载数据,每个data中的imgs会被输送到神经网络中。