P6:DataLoader的使用

1、准备数据集(测试集)

import torchvision

test_data = torchvision.datasets.CIFAR10('./dataset', train=False, transform=torchvision.transforms.ToTensor())

注意数据集中的图片是PIL的格式,需要格式转换。

2、使用DataLoader

from torch.utils.data import DataLoader

test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=False)

3、查看数据集中图片的尺寸及target

# 测试数据集中第一张图片及target
img, target = test_data[0]
print(img.shape)
print(target)

结果如下:

4、DataLoader的返回

P6:DataLoader的使用_第1张图片

 其做了一个打包处理。

测试如下:

for data in test_loader:
    imgs, targets = data
    print(imgs.shape)
    print(targets)

结果如下:

P6:DataLoader的使用_第2张图片

4代表4张图片(batch_size的大小)

5、drop_last的作用

如果为True,则若有剩余且数量小于batch_size,直接丢弃;

如果为False,则保留。

6、实际中需要结合epoch来使用

代码如下:

import torchvision

# 准备的测试集
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

test_data = torchvision.datasets.CIFAR10('./dataset', train=False, transform=torchvision.transforms.ToTensor())

test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=False)

# 测试数据集中第一张图片及target
img, target = test_data[0]
print(img.shape)
print(target)

writer = SummaryWriter('dataloader')
for epoch in range(2):
    step = 0
    for data in test_loader:
        imgs, targets = data
        # print(imgs.shape)
        # print(targets)
        writer.add_images('Epoch:{}'.format(epoch), imgs, step)
        step = step + 1

writer.close()

 结果如下:

P6:DataLoader的使用_第3张图片

 

 

你可能感兴趣的:(PyTorch学习笔记,python,人工智能,深度学习,pytorch)