【PyTorch自定义Dataloader步骤解析】

PyTorch自定义Dataloader步骤解析

  • 摘要
  • 1 所用数据集介绍
  • 2 自定义Dataloader
    • 2.1 读取txt文件路径和标签
    • 2.2 路径和label分别放入list中
    • 2.3 补充完整路径
    • 2.4 组合上面三步,写成一个class结构
    • 2.5 实例化dataloader
    • 2.6 验证一下所得dataloader

摘要

当我们用Pytorch训练深度学习模型时,通常需要将大量的数据集加载到模型中进行训练。而在Pytorch中,我们可以使用DataLoader来对数据进行批处理,以及对数据集进行随机排序、并行加载等操作。下图为dataloader作用示意图。
【PyTorch自定义Dataloader步骤解析】_第1张图片但是,在某些情况下,我们需要使用自定义的数据集,并需要根据自己的需求自定义DataLoader,这时候了解如何自定义dataloader就变得尤为重要。本文将介绍如何在Pytorch中自定义dataloader,以及如何处理自定义数据集。

1 所用数据集介绍

本文采用102花卉数据集(数据集同博客:https://blog.csdn.net/fly_ddaa/article/details/130071408),共包括两个文件夹(train_filelist、val_filelist)和两个.txt文件。
【PyTorch自定义Dataloader步骤解析】_第2张图片
本文不同之处在于花卉图片并未按照label值创建子文件夹从而方便读取,而放在大文件中。同时label值是存储在.txt文件中,因此我们需要单独进行处理才能生成Dataloader,如下图:
【PyTorch自定义Dataloader步骤解析】_第3张图片【PyTorch自定义Dataloader步骤解析】_第4张图片

2 自定义Dataloader

2.1 读取txt文件路径和标签

首先,我们定义一个加载标注文件的函数。它打开给定路径下的标注文件并将其读取为一个列表,其中每个元素是一个图像文件名及其相应的标签值。函数会将这些信息存储在一个Python字典中,以便在训练时可以方便地使用。值得注意的是,标注文件中的标签值必须以整数类型的numpy数组形式给出,因为这是PyTorch所要求的标签类型。

def load_annoation(ann_file):
    data_infos = {}
    with open(ann_file) as f:
        samples = [x.strip().split(' ') for x in f.readlines()]
        for filname,label in samples:
            data_infos[filname] = np.array(label,dtype=np.int64)
    return data_infos

学习时,我们可将返回值输出打印来更好理解该函数所进行的操作,代码如下:

data_infos = load_annoation('./flower_data/train.txt')
for key, value in data_infos.items():
    print(key, type(value))

结果如下:
【PyTorch自定义Dataloader步骤解析】_第5张图片

2.2 路径和label分别放入list中

非必须,但建议参照pytorch官方文档进行操作,代码如下:

image_name = list(data_infos.keys())
image_label = list(data_infos.values())

2.3 补充完整路径

DataLoader 并不会直接储存路径。在使用 DataLoader 加载数据时,需要提供一个数据集对象,该数据集对象中包含了每个数据样本的路径信息。代码如下:

basic_dir = './flower_data/'
train_dir = basic_dir + 'train_filelist'
val_dir = basic_dir + 'val_filelist'

合并成完成路径:

image_path = [os.path.join(train_dir,img) for img in image_name]

2.4 组合上面三步,写成一个class结构

这是一个自定义的PyTorch数据集类,其中root_dir是数据集根目录的路径,ann_file是包含图像文件名及其对应标签的文本文件的路径。其中from torch.utils.data import DataLoader,Dataset必须要写,FlowerDateset (Dataset)中FlowerDateset可以修改,Dataset不能修改。

from torch.utils.data import DataLoader,Dataset
class FlowerDateset(Dataset):
    def __init__(self,root_dir,ann_file,transform = None):
        self.root_dir = root_dir
        self.ann_file = ann_file
        self.transform = transform
        self.data_infos = self.load_annoation()
        self.image_path = [os.path.join(root_dir,img) for img in self.data_infos.keys()]
        self.label = [label for label in list(data_infos.values())]
    
    def load_annoation(self):
        data_infos = {}
        with open(self.ann_file) as f:
            samples = [x.strip().split(' ') for x in f.readlines()]
            for filname,label in samples:
                data_infos[filname] = np.array(label,dtype=np.int64)
        return data_infos
    
    def __getitem__(self,idx):
        image = Image.open(self.image_path[idx])
        label = self.label[idx]
        if self.transform:
            image = self.transform(image)
        label = torch.from_numpy(np.array(label))
        return image,label
    
    def __len__(self):
        return len(self.image_path)

在类初始化的时候,通过load_annotation方法读取文本文件中的图像文件名和对应标签信息,并存储在data_infos字典中,字典的键是图像文件名,值是对应的标签。

在__getitem__方法中,通过读取图像文件的路径,打开图像,并获取它的标签。如果数据集需要应用变换,则对图像应用变换。最后,将图像和标签以元组的形式返回。

在__len__方法中,返回数据集的图像数量,即image_path列表的长度。

2.5 实例化dataloader

train_dataset = FlowerDateset(root_dir=train_dir,ann_file='./flower_data/train.txt',transform=data_transforms['train'])
valid_dataset = FlowerDateset(root_dir=val_dir,ann_file='./flower_data/val.txt',transform=data_transforms['valid'])

train_dataloader = DataLoader(train_dataset,batch_size=64,shuffle=True)
val_dataloader = DataLoader(valid_dataset,batch_size=64,shuffle=True)

2.6 验证一下所得dataloader

第一种写法,需要用numpy转换为np.array

image,label = iter(train_dataloader).next()
sample = image[0].numpy().squeeze()
sample = sample.transpose(1,2,0)
sample *= np.array((0.229, 0.224, 0.225))
sample += np.array((0.485, 0.456, 0.406))
plt.imshow(sample)

第二种写法,无需用numpy转换为np.array

image,label = iter(val_dataloader).next()
sample = image[1].squeeze()
sample = sample.permute(1,2,0)
sample *= torch.tensor((0.229, 0.224, 0.225))
sample += torch.tensor((0.485, 0.456, 0.406))
plt.imshow(sample)

结果如下:
【PyTorch自定义Dataloader步骤解析】_第6张图片
关注博主,私聊可得完整代码和数据集!!

你可能感兴趣的:(深度学习,深度学习,人工智能,python)