Pytorch使用方法一自定义 dataset 时,需要重写 __len__
和 __getitem__
__len__
提供 dataset 的大小__getitem__
提供 dataset 的索引在 Python 对象中,需要重写的双下划线开头和结尾的属性称为特殊属性,常见的有对象的名称:__name__
。
另外对象的方法也属于属性,因此以双下划线开头和结尾的方法称为特殊方法,例如上述需要重写的 __len__
与 __getitem__
便是两个特殊方法。
常见的特殊属性和特殊方法,可见我的另一篇博客:Python系列 | 常见的特殊属性与特殊方法
示例数据集来源于:LFW人脸数据集
LFW人脸数据集以人名作为文件名,文件夹下为相对应的人脸图像:
图片均以 “人名_000x” 的形式命名,以Abdullah为例:
若读者对 os 模块不太熟悉,可参考我另一篇博客:Python系列 | os模块常用命令
由于数据集的特殊性,每张图片都处于二级文件下,因此在正式定义 dataset 之前,有必要对数据集进行一定处理,将所有图像整合至一个文件夹中,具体代码如下:
import os
import shutil
def make_file(path):
if os.path.exists(path):
os.rmdir(path)
os.mkdir(path)
else:
os.mkdir(path)
def main():
root = os.path.join(os.getcwd(), 'lfw')
image_file = os.listdir(root)
image_set = list()
for file in image_file:
image_path = os.path.join(root, file)
image_list = os.listdir(image_path)
for image in image_list:
image_set.append(os.path.join(image_path, image))
new_path = os.path.join(os.getcwd(), 'lfw_dataset')
make_file(new_path)
for path in image_set:
shutil.copy(path, new_path)
print('Done !')
if __name__ == '__main__':
main()
实现结果:
将所有图像汇总后,即可使用方法一或方法二定义 dataset 。
使用 torch.utils.data.Dataset
import torchvision.transforms as transforms
from torch.utils.data import Dataset
import os
from PIL import Image
import torch
import matplotlib.pyplot as plt
class LfwDataset(Dataset): # 继承Dataset,复写__getitem__和__len__
def __init__(self, root, transform):
self.root = root
self.images = [os.path.join(self.root, path) for path in os.listdir(self.root)] # 图像路径集合
self.transform = transform # transform
def __len__(self):
return len(self.images)
def __getitem__(self, item):
image_path = self.images[item] # 图像索引,获取单张图像路径
image = Image.open(image_path)
_, image_name = os.path.split(image_path)
label, _ = image_name.split('.')
label = label[:-5]
if self.transform is not None:
image = self.transform(image)
return image, label
if __name__ == '__main__':
transform = transforms.Compose([
transforms.ToTensor(),
transforms.RandomHorizontalFlip(0.5)
])
lfw_dataset = LfwDataset(root=r'.\lfw_dataset', transform=transform)
data_loader = torch.utils.data.DataLoader(lfw_dataset, batch_size=12, shuffle=True)
for i, (inputs, target) in enumerate(data_loader):
if i == 0: # 输出一部分看看
print(inputs.shape)
print(target)
plt.figure(figsize=(12, 16))
for num in range(12): # 确认图像是否可以正确读取
plt.subplot(3, 4, num + 1)
plt.imshow(inputs[num].permute([1, 2, 0]))
plt.title(target[num], size=13)
plt.axis('off')
plt.tight_layout()
plt.show()
else:
break
打印结果:
torch.Size([12, 3, 250, 250])
('Christopher_Conyers', 'Michael_Jackson', 'Edward_Johnson', 'Heizo_Takenaka', 'Ai_Sugiyama', 'Lawrence_MacAulay', 'Geno_Auriemma', 'Bustam_A_Zedan_Aljanabi', 'Colin_Powell', 'Hugh_Grant', 'Ellen_Martin', 'Billy_Sollie')
图像输出结果:
可见,已自定义完成 dataset 。
使用 torchvision.datasets.ImageFolder
方法二较方法一要更为方便,但 torchvision.datasets.ImageFolder
要求图片文件以下图格式进行排列:
也就是说,每个类别的图像要各自为一个文件夹,这也正好符合本示例 LFW 人脸数据集的特点。
这里还有几个注意点:
dataset.classes
中torch.utils.data.DataLoader
加载 dataset 时,其类别标签返回的是相应类别的索引,而非类别标签本身import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import os
import matplotlib.pyplot as plt
import torch
def main():
root = os.path.join(os.getcwd(), 'lfw')
transform = transforms.Compose([
transforms.ToTensor(),
transforms.RandomHorizontalFlip(0.5)
])
lfw_dataset = torchvision.datasets.ImageFolder(root=root, transform=transform)
lfw_dataloader = DataLoader(lfw_dataset, batch_size=12, shuffle=True)
for i, (inputs, target_index_set) in enumerate(lfw_dataloader):
if i == 0:
print(f'inputs.shape : {inputs.shape}')
print(f'target_index_set: {target_index_set}')
plt.figure(figsize=(12, 16))
for num in range(12):
plt.subplot(3, 4, num + 1)
plt.imshow(inputs[num].permute([1, 2, 0]))
plt.title(lfw_dataset.classes[target_index_set[num]], size=13)
plt.axis('off')
plt.tight_layout()
plt.show()
else:
break
if __name__ == '__main__':
main()
打印结果:
inputs.shape : torch.Size([12, 3, 250, 250])
target_index_set: tensor([ 59, 3995, 3620, 3092, 1155, 4900, 5564, 5639, 809, 5685, 1995, 4257])
图片输出结果:
可见,方法二的自定义 dataset 要方便很多,只是对数据的存储方式有一定要求。