小土堆pytorch学习笔记003 | 下载数据集dataset 及报错处理

目录

1、下载数据集

2、展示数据集里面的内容

3、DataLoader 的使用

例子:

结果展示:


1、下载数据集

# 数据集

import torchvision

train_set = torchvision.datasets.CIFAR10(root="./test10_dataset", train=True, download=True)
test_set = torchvision.datasets.CIFAR10(root="./test10_dataset", train=False, download=True)

如果上述代码在下载的时候,报错,那么需要添加两行代码。

# 数据集

import torchvision

import ssl
ssl._create_default_https_context = ssl._create_unverified_context

train_set = torchvision.datasets.CIFAR10(root="./test10_dataset", train=True, download=True)
test_set = torchvision.datasets.CIFAR10(root="./test10_dataset", train=False, download=True)

运行结果:

2、展示数据集里面的内容

# 数据集
import ssl
import torchvision
from torch.utils.tensorboard import SummaryWriter

ssl._create_default_https_context = ssl._create_unverified_context

dataset_transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])

train_set = torchvision.datasets.CIFAR10(root="./test10_dataset", train=True, transform=dataset_transforms, download=True)
test_set = torchvision.datasets.CIFAR10(root="./test10_dataset", train=False, transform=dataset_transforms, download=True)


print(test_set[1])
writer = SummaryWriter("test10_logs")
for i in range(10):
    img,target = test_set[i]
    writer.add_image("test_set",img,i)

writer.close()

# print(test_set[0])
# img, target = test_set[0]
# img.show()

结果展示:

小土堆pytorch学习笔记003 | 下载数据集dataset 及报错处理_第1张图片

3、DataLoader 的使用

https://pytorch.org/docs/stable/data.htmlicon-default.png?t=N7T8http://xn--dataloader-po3sm345a小土堆pytorch学习笔记003 | 下载数据集dataset 及报错处理_第2张图片

例子:

import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

# 准备的测试集
test_data = torchvision.datasets.CIFAR10(root="./test10_dataset", train=False, transform=torchvision.transforms.ToTensor())

# DataLoader()里面的参数:  shuffle:洗牌
test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=False)

# 长测试数据集中第一张图片及target
img, target = test_data[0]
print(img.shape)
print(target)

writer = SummaryWriter('logs_dataloader')
step = 0
for data in test_loader:
    imgs, targets = data
    writer.add_images("test_data", imgs, step)
    step = step + 1

writer.close()

结果展示:

小土堆pytorch学习笔记003 | 下载数据集dataset 及报错处理_第3张图片

你可能感兴趣的:(深度学习,人工智能,深度学习,机器学习,pytorch,python)