PyTorch中的dataset pipeline (Pytorch 如何从数据集中读取数据的?)

数据集存储在我们的电脑硬盘中, Pytorch需要把这些数据从硬盘中读取并组成Pytorch能看懂的dataset形式。 然后用dataloader, 一个batch一个batch的从dataset中读取 并传入后续模型中。本文总结如何构建pytorch中的 dataset 以及如何用dataloader读取dataset.

从tensor中构建dataset

之前文章介绍了如何构建tensor, 那么有了tensor如何构建dataset呢?

在监督学习中, 会有一个Tensor储存 数据的feature, 另一个Tensor储存数据的label。 比如:

t_x = torch.rand([6, 5], dtype=torch.float32)
t_y = torch.arange(6)

t_x中储存了6个数据, 每个数据有5个feature.

t_y 储存了这6个数据的label.

接下来构建dataset, 让pytorch可以通过检索index的方法检索每个数据, 并让t_x与t_y一一配对, 就算打乱顺序t_x与t_y也能一一对应。

pytorch 中有个Dataset类, 只需构建一个Dataset的子类:

from torch.utils.data import Dataset 

class JointDataset(Dataset):
    def __init__(self, x, y):
        self.x = x 
        self.y = y 
        
    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]
    
    def __len__(self):
        return len(self.x)

然后把t_x与t_y传入 上面的custom class 中就可以构建一个pytorch看的懂的dataset了

joint_dataset = JointDataset(t_x, t_y)

for example in joint_dataset:
    print(f"x:{example[0]}, y:{example[1]}")

x:tensor([0.3030, 0.3913, 0.1098, 0.1247, 0.1747]), y:0
x:tensor([0.6247, 0.4709, 0.7010, 0.3407, 0.2678]), y:1
x:tensor([0.1844, 0.7371, 0.8012, 0.9095, 0.6837]), y:2
x:tensor([0.8457, 0.1382, 0.6116, 0.7448, 0.4173]), y:3
x:tensor([0.4306, 0.2952, 0.8508, 0.7258, 0.5765]), y:4
x:tensor([0.4122, 0.2141, 0.5772, 0.9119, 0.8334]), y:5

用DataLoader读取构建的Dataset:

from torch.utils.data import DataLoader 

data_loader = DataLoader(dataset=joint_dataset, batch_size=3, shuffle=True)

Dataloader 会一个批次(batch)一个批次的从构建的dataset中读取, 这里设这的batch_size=3. 在读取数据前, 先将数据打乱:shuffle=True. 

训练模型的时候通长需要训练N个epoch, 即: 在现有的所有数据上训练的次数。在每个epoch中, 应用data_loader:

for epoch in range(2):
    print('\n')
    print(f'epoch {epoch+1}')
    for i, batch in enumerate(data_loader, start=1):
        print(f'batch {i}:, x:, {batch[0]},
              \n          y: {batch[1]}')


epoch 1
batch 1:, x:, tensor([[0.4122, 0.2141, 0.5772, 0.9119, 0.8334],
        [0.1844, 0.7371, 0.8012, 0.9095, 0.6837],
        [0.8457, 0.1382, 0.6116, 0.7448, 0.4173]]), 
          y: tensor([5, 2, 3])
batch 2:, x:, tensor([[0.3030, 0.3913, 0.1098, 0.1247, 0.1747],
        [0.6247, 0.4709, 0.7010, 0.3407, 0.2678],
        [0.4306, 0.2952, 0.8508, 0.7258, 0.5765]]), 
          y: tensor([0, 1, 4])


epoch 2
batch 1:, x:, tensor([[0.3030, 0.3913, 0.1098, 0.1247, 0.1747],
        [0.4306, 0.2952, 0.8508, 0.7258, 0.5765],
        [0.1844, 0.7371, 0.8012, 0.9095, 0.6837]]), 
          y: tensor([0, 4, 2])
batch 2:, x:, tensor([[0.4122, 0.2141, 0.5772, 0.9119, 0.8334],
        [0.8457, 0.1382, 0.6116, 0.7448, 0.4173],
        [0.6247, 0.4709, 0.7010, 0.3407, 0.2678]]), 
          y: tensor([5, 3, 1])

从硬盘中构建Dataset

比如我们有个图片数据集需要分类,比如BMW-10 dataset. 这个数据集有 11种BMW车 存储在11个文件夹下。 如何从硬盘中读取这些图片及其相应的label 并构建一个pytorch看的懂的 dataset呢? 

首先我们先用pathlib读取数据,并可视化一些图片: 

import pathlib 

imgdir_path = pathlib.Path('bmw10_ims')

image_list = sorted([str(path) for path in imgdir_path.rglob('*.jpg')])


['bmw10_ims/1/150079887.jpg', 'bmw10_ims/1/150080038.jpg', 'bmw10_ims/1/150080476.jpg', 
...,'bmw10_ims/8/149389446.jpg', 'bmw10_ims/8/149389742.jpg', 'bmw10_ims/8/149389834.jpg']
import matplotlib.pyplot as plt 
from PIL import Image 
import numpy as np

fig = plt.figure(figsize=(10, 5)) 
for i, file in enumerate(image_list[:6]): 
    img = Image.open(file)
    print('Image shape:', np.array(img).shape)
    ax = fig.add_subplot(2, 3, i+1)
    ax.set_xticks([]); ax.set_yticks([])
    ax.imshow(img)
    ax.set_title(pathlib.Path(file).name, size=15) 
plt.tight_layout()
plt.show()


Image shape: (480, 640, 3)
Image shape: (360, 424, 3)
Image shape: (768, 1024, 3)
Image shape: (768, 1024, 3)
Image shape: (360, 480, 3)
Image shape: (183, 275, 3)

PyTorch中的dataset pipeline (Pytorch 如何从数据集中读取数据的?)_第1张图片

构建图片的label: 

#Pathlib.Path("bmw10_ims/7/149461474.jpg").parts = ('bmw10_ims', '7', '149461474.jpg')
labels = list(pathlib.Path(path).parts[-2] for path in image_list)

print(labels)


['1', '1', '1', '1',... '8', '8', '8', '8']

 

现在来构建dataset: 

class ImageDataset(Dataset):
    def __init__(self, file_list, labels):
        self.file_list = file_list 
        self.labels = labels 
        
    def __getitem__(self, index):
        file = self.file_list[index]
        label = self.labels[index]
        return file, label  
    
    def __len__(self):
        return len(self.labels)
    
image_dataset = ImageDataset(image_list, labels)
for file, label in image_dataset:
    print(file, label)


bmw10_ims/1/150079887.jpg 1
bmw10_ims/1/150080038.jpg 1
...
bmw10_ims/5/149124761.thumb.jpg 5
bmw10_ims/5/149124940.jpg 5
...
bmw10_ims/8/149389742.jpg 8
bmw10_ims/8/149389834.jpg 8

一般需要对输入的图片进行pre-processing 比如 nomoralization, resize, crop等:

import torchvision.transforms as transforms 
img_height, image_width = 128, 128 
transform = transforms.Compose([transforms.ToTensor(), 
                                transforms.Resize((img_height, image_width)),
                               ])

一般把预处理放到dataset中: 

class ImageDataset(Dataset):
    def __init__(self, file_list, labels, transform=None):
        self.file_list = file_list 
        self.labels = labels 
        self.transform = transform 
        
    def __getitem__(self, index):
        img = Image.open(self.file_list[index])
        if self.transform is not None:
            img = self.transform(img)
        label = self.labels[index]
        return img, label 
    
    def __len__(self):
        return len(self.labels)

image_dataset = ImageDataset(image_list, labels, transform)

可视化这个Dataset:

fig = plt.figure(figsize=(10, 6))
for i, example in enumerate(image_dataset):
    if i == 6:
        break
    ax = fig.add_subplot(2, 3, i+1)
    ax.set_xticks([]); ax.set_yticks([])
    print(example[0].numpy().shape)
    ax.imshow(example[0].numpy().transpose((1, 2, 0)))
    ax.set_title(f'{example[1]}', size=15)
    
plt.tight_layout()
plt.show()


(3, 128, 128)
(3, 128, 128)
(3, 128, 128)
(3, 128, 128)
(3, 128, 128)
(3, 128, 128)

PyTorch中的dataset pipeline (Pytorch 如何从数据集中读取数据的?)_第2张图片

参考自:  Machine Learning with PyTorch and Scikit-Learn Book 

你可能感兴趣的:(pytorch,深度学习,人工智能)