pytorch学习(5)——torchvision中的数据集使用

1 数据集Dataset下载与使用

进入pytorch官网(需要梯子):https://pytorch.org

Docs -> Torchvision

pytorch学习(5)——torchvision中的数据集使用_第1张图片

(1)点击Models and pre-trained weights可以下载已经训练好的神经网络。

(2)点击Datasets可以下载数据集。包括Image classification(图像分类)、Image detection or segmentation(图像检测和分割)、Optical Flow(光流)、Stereo Matching(立体匹配)、Image pairs(图像对)、Image captioning(图像描述、图像字幕)、Video classification(视频分类)、Video prediction(视频预测)、Base classes for custom datasets(自定义数据集的基类)和Transforms v2。

链接:https://pytorch.org/vision/stable/datasets.html

1.1 下载CIFAR10数据集

链接:https://www.cs.toronto.edu/~kriz/cifar.html

pytorch学习(5)——torchvision中的数据集使用_第2张图片

使用Python脚本下载:
root为安装路径。

import torchvision

train_set = torchvision.datasets.CIFAR10(root="G:\\dataset_CIFAR10",
                                         train=True,
                                         download=True)

test_set = torchvision.datasets.CIFAR10(root="G:\\dataset_CIFAR10",
                                         train=False,
                                         download=True)

或者使用迅雷下载。
pytorch学习(5)——torchvision中的数据集使用_第3张图片

1.2 调用数据集

python代码:

import torchvision

train_set = torchvision.datasets.CIFAR10(root="G:\\dataset_CIFAR10",
                                         train=True,
                                         download=False)

test_set = torchvision.datasets.CIFAR10(root="G:\\dataset_CIFAR10",
                                         train=False,
                                         download=False)

print("test_set[0]: " + str(test_set[0]))
print("test_set.classes: " + str(test_set.classes))

img, target = test_set[0]
print("img: " + str(img))
print("target: " + str(target))
print("test_set.classes[target]: " + str(test_set.classes[target]))
img.show()

输出结果:

test_set[0]: (, 3)
test_set.classes: [‘airplane’, ‘automobile’, ‘bird’, ‘cat’, ‘deer’, ‘dog’, ‘frog’, ‘horse’, ‘ship’, ‘truck’]
img:
target: 3
test_set.classes[target]: cat
Process finished with exit code 0
pytorch学习(5)——torchvision中的数据集使用_第4张图片

1.3 tensorboard内查看图像

python脚本代码:

import torchvision
from torch.utils.tensorboard import SummaryWriter

dataset_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])

train_set = torchvision.datasets.CIFAR10(root="G:\\dataset_CIFAR10",
                                         train=True,
                                         transform=dataset_transform,
                                         download=False)

test_set = torchvision.datasets.CIFAR10(root="G:\\dataset_CIFAR10",
                                         train=False,
                                         transform=dataset_transform,
                                         download=False)

writer = SummaryWriter("p10")
for i in range(10):
    img, target = test_set[i]
    writer.add_image("test_set", img, i)

writer.close()

脚本运行后,点击Terminal,输入:

tensorboard --logdir=“p10”

进入网页即可观察到数据集内的图像。
pytorch学习(5)——torchvision中的数据集使用_第5张图片

2 数据加载器Dataloader使用

Dataset类似于扑克牌堆,Dataloader类似于洗牌抓牌,如果shuffle为True图片需要洗牌(每次图片顺序不一样),shuffle为False图片不需要洗牌(每次图片顺序一样)。num_workers多进程,一般设置为0(windows下大于0会出现问题?)

imagenet_data = torchvision.datasets.ImageNet('path/to/imagenet_root/')
data_loader = torch.utils.data.DataLoader(imagenet_data,
                                          batch_size=4,
                                          shuffle=True,
                                          num_workers=args.nThreads)

链接: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader

pytorch学习(5)——torchvision中的数据集使用_第6张图片

python脚本代码:

import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

# 准备的测试数据集
test_data = torchvision.datasets.CIFAR10(root="G:\\Anaconda\\pycharm_pytorch\\learning_project\\dataset_CIFAR10",
                                         train=False,
                                         transform=torchvision.transforms.ToTensor(),
                                         download=False)

test_loader = DataLoader(dataset=test_data, batch_size=4, shuffle=True, num_workers=0, drop_last=False)

# 测试数据集中的第一张图片及target
img, target = test_data[0]
print(img.shape)
print(target)

writer = SummaryWriter("dataloader")
step = 1
for data in test_loader:
    imgs, targets = data
    print("imgs.shape: " + str(imgs.shape))
    print("targets: " + str(targets))
    writer.add_images("test_data", imgs, step)
    step = step + 1

writer.close()

注意:添加多张图片函数为add_images()。

脚本运行后,点击Terminal,输入:

tensorboard --logdir=“dataloader”

进入网页即可观察到数据集内的图像。
pytorch学习(5)——torchvision中的数据集使用_第7张图片

如果将batch_size修改为9,图片展示如下:
pytorch学习(5)——torchvision中的数据集使用_第8张图片

如果将shuffle设置为False,则不对数据集进行洗牌,python代码如下:

import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

# 准备的测试数据集
test_data = torchvision.datasets.CIFAR10(root="G:\\dataset_CIFAR10",
                                         train=False,
                                         transform=torchvision.transforms.ToTensor(),
                                         download=False)

test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=False, num_workers=0, drop_last=False)

writer = SummaryWriter("dataloader")

for epoch in range(2):
    step_new = 1
    for data in test_loader:
        imgs, targets = data
        writer.add_images("Epoch-{}".format(epoch), imgs, step_new)
        step_new = step_new + 1

writer.close()

图片展示如下:

pytorch学习(5)——torchvision中的数据集使用_第9张图片

你可能感兴趣的:(pytorch,学习,人工智能)