torchvision中数据集的使用

Torchvision.dataset

import torchvision
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms

# ctrl+p 查看参数
# torchvision.datasets当中提供了很多数据集,具体参数或其他信息可以官网查看
# 也可以直接查看文件中数据集的相关代码,里面会有数据集的下载链接

dataset_transform = transforms.Compose([
    transforms.ToTensor()
])

train_set = torchvision.datasets.CIFAR10(root='./data', train=True, transform=dataset_transform, download=False)
test_set = torchvision.datasets.CIFAR10(root='./data', train=False, transform=dataset_transform, download=False)

# print(test_set[0])
# print(test_set.classes)
# # 数据集返回的是图片和对应标签序号,标签名在.classes中
# img, target = test_set[0]
# print(img)
# print(target)
# print(test_set.classes[target])

writer = SummaryWriter('p4')
for i in range(10):
    img, target = test_set[i]
    writer.add_image('test_set', img, i)

writer.close()

你可能感兴趣的:(pytorch,pytorch,python)