数据集存储在我们的电脑硬盘中, Pytorch需要把这些数据从硬盘中读取并组成Pytorch能看懂的dataset形式。 然后用dataloader, 一个batch一个batch的从dataset中读取 并传入后续模型中。本文总结如何构建pytorch中的 dataset 以及如何用dataloader读取dataset.
之前文章介绍了如何构建tensor, 那么有了tensor如何构建dataset呢?
在监督学习中, 会有一个Tensor储存 数据的feature, 另一个Tensor储存数据的label。 比如:
t_x = torch.rand([6, 5], dtype=torch.float32)
t_y = torch.arange(6)
t_x中储存了6个数据, 每个数据有5个feature.
t_y 储存了这6个数据的label.
接下来构建dataset, 让pytorch可以通过检索index的方法检索每个数据, 并让t_x与t_y一一配对, 就算打乱顺序t_x与t_y也能一一对应。
pytorch 中有个Dataset类, 只需构建一个Dataset的子类:
from torch.utils.data import Dataset
class JointDataset(Dataset):
def __init__(self, x, y):
self.x = x
self.y = y
def __getitem__(self, idx):
return self.x[idx], self.y[idx]
def __len__(self):
return len(self.x)
然后把t_x与t_y传入 上面的custom class 中就可以构建一个pytorch看的懂的dataset了
joint_dataset = JointDataset(t_x, t_y)
for example in joint_dataset:
print(f"x:{example[0]}, y:{example[1]}")
x:tensor([0.3030, 0.3913, 0.1098, 0.1247, 0.1747]), y:0
x:tensor([0.6247, 0.4709, 0.7010, 0.3407, 0.2678]), y:1
x:tensor([0.1844, 0.7371, 0.8012, 0.9095, 0.6837]), y:2
x:tensor([0.8457, 0.1382, 0.6116, 0.7448, 0.4173]), y:3
x:tensor([0.4306, 0.2952, 0.8508, 0.7258, 0.5765]), y:4
x:tensor([0.4122, 0.2141, 0.5772, 0.9119, 0.8334]), y:5
from torch.utils.data import DataLoader
data_loader = DataLoader(dataset=joint_dataset, batch_size=3, shuffle=True)
Dataloader 会一个批次(batch)一个批次的从构建的dataset中读取, 这里设这的batch_size=3. 在读取数据前, 先将数据打乱:shuffle=True.
训练模型的时候通长需要训练N个epoch, 即: 在现有的所有数据上训练的次数。在每个epoch中, 应用data_loader:
for epoch in range(2):
print('\n')
print(f'epoch {epoch+1}')
for i, batch in enumerate(data_loader, start=1):
print(f'batch {i}:, x:, {batch[0]},
\n y: {batch[1]}')
epoch 1
batch 1:, x:, tensor([[0.4122, 0.2141, 0.5772, 0.9119, 0.8334],
[0.1844, 0.7371, 0.8012, 0.9095, 0.6837],
[0.8457, 0.1382, 0.6116, 0.7448, 0.4173]]),
y: tensor([5, 2, 3])
batch 2:, x:, tensor([[0.3030, 0.3913, 0.1098, 0.1247, 0.1747],
[0.6247, 0.4709, 0.7010, 0.3407, 0.2678],
[0.4306, 0.2952, 0.8508, 0.7258, 0.5765]]),
y: tensor([0, 1, 4])
epoch 2
batch 1:, x:, tensor([[0.3030, 0.3913, 0.1098, 0.1247, 0.1747],
[0.4306, 0.2952, 0.8508, 0.7258, 0.5765],
[0.1844, 0.7371, 0.8012, 0.9095, 0.6837]]),
y: tensor([0, 4, 2])
batch 2:, x:, tensor([[0.4122, 0.2141, 0.5772, 0.9119, 0.8334],
[0.8457, 0.1382, 0.6116, 0.7448, 0.4173],
[0.6247, 0.4709, 0.7010, 0.3407, 0.2678]]),
y: tensor([5, 3, 1])
比如我们有个图片数据集需要分类,比如BMW-10 dataset. 这个数据集有 11种BMW车 存储在11个文件夹下。 如何从硬盘中读取这些图片及其相应的label 并构建一个pytorch看的懂的 dataset呢?
首先我们先用pathlib读取数据,并可视化一些图片:
import pathlib
imgdir_path = pathlib.Path('bmw10_ims')
image_list = sorted([str(path) for path in imgdir_path.rglob('*.jpg')])
['bmw10_ims/1/150079887.jpg', 'bmw10_ims/1/150080038.jpg', 'bmw10_ims/1/150080476.jpg',
...,'bmw10_ims/8/149389446.jpg', 'bmw10_ims/8/149389742.jpg', 'bmw10_ims/8/149389834.jpg']
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
fig = plt.figure(figsize=(10, 5))
for i, file in enumerate(image_list[:6]):
img = Image.open(file)
print('Image shape:', np.array(img).shape)
ax = fig.add_subplot(2, 3, i+1)
ax.set_xticks([]); ax.set_yticks([])
ax.imshow(img)
ax.set_title(pathlib.Path(file).name, size=15)
plt.tight_layout()
plt.show()
Image shape: (480, 640, 3)
Image shape: (360, 424, 3)
Image shape: (768, 1024, 3)
Image shape: (768, 1024, 3)
Image shape: (360, 480, 3)
Image shape: (183, 275, 3)
构建图片的label:
#Pathlib.Path("bmw10_ims/7/149461474.jpg").parts = ('bmw10_ims', '7', '149461474.jpg')
labels = list(pathlib.Path(path).parts[-2] for path in image_list)
print(labels)
['1', '1', '1', '1',... '8', '8', '8', '8']
现在来构建dataset:
class ImageDataset(Dataset):
def __init__(self, file_list, labels):
self.file_list = file_list
self.labels = labels
def __getitem__(self, index):
file = self.file_list[index]
label = self.labels[index]
return file, label
def __len__(self):
return len(self.labels)
image_dataset = ImageDataset(image_list, labels)
for file, label in image_dataset:
print(file, label)
bmw10_ims/1/150079887.jpg 1
bmw10_ims/1/150080038.jpg 1
...
bmw10_ims/5/149124761.thumb.jpg 5
bmw10_ims/5/149124940.jpg 5
...
bmw10_ims/8/149389742.jpg 8
bmw10_ims/8/149389834.jpg 8
一般需要对输入的图片进行pre-processing 比如 nomoralization, resize, crop等:
import torchvision.transforms as transforms
img_height, image_width = 128, 128
transform = transforms.Compose([transforms.ToTensor(),
transforms.Resize((img_height, image_width)),
])
一般把预处理放到dataset中:
class ImageDataset(Dataset):
def __init__(self, file_list, labels, transform=None):
self.file_list = file_list
self.labels = labels
self.transform = transform
def __getitem__(self, index):
img = Image.open(self.file_list[index])
if self.transform is not None:
img = self.transform(img)
label = self.labels[index]
return img, label
def __len__(self):
return len(self.labels)
image_dataset = ImageDataset(image_list, labels, transform)
可视化这个Dataset:
fig = plt.figure(figsize=(10, 6))
for i, example in enumerate(image_dataset):
if i == 6:
break
ax = fig.add_subplot(2, 3, i+1)
ax.set_xticks([]); ax.set_yticks([])
print(example[0].numpy().shape)
ax.imshow(example[0].numpy().transpose((1, 2, 0)))
ax.set_title(f'{example[1]}', size=15)
plt.tight_layout()
plt.show()
(3, 128, 128)
(3, 128, 128)
(3, 128, 128)
(3, 128, 128)
(3, 128, 128)
(3, 128, 128)
参考自: Machine Learning with PyTorch and Scikit-Learn Book