【Pytorch】Load your own dataset

【Pytorch】Load your own dataset_第1张图片

pytorch 在载入数据时用torchvision.datasets.ImageFolder 配合 torch.utils.data.DataLoader 很方便,但是只能遍历图片和图片的标签,无法灵活的获取图片的其他信息,比如图片的名字,本文介绍如何定义自己的 ImageFolder,在使用 Dataloader 时实现获取图片名字的功能!


文章目录

  • 1 ImageFolder and DataLoader
  • 2 OwnFolder and DataLoader
  • 3 transforms


1 ImageFolder and DataLoader

以分类为例,用 pytorch 的 torchvision.datasets.ImageFolder 配合 torch.utils.data.DataLoader 即可对数据按类别进行读取、预处理、分成 batch

import torchvision
import torch

train_dataset = torchvision.datasets.ImageFolder(
    train_data_pth,
    transforms.Compose([
        transforms.Resize(input_size,interpolation=2), # resize
        transforms.ToTensor(), # ToTensor
        normalize,])) # Normalization

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size, # set batchsize
    shuffle=False,
    num_workers=n_worker,
    pin_memory=True)

参考

  • pytorch训练自己图像分类数据集
  • 使用pytorch测试单张图片(test single image with pytorch)

ImageFoldertrain_data_pth 是存放数据集的文件夹,文件结构应该如下

train_data_pth
	class1
		xxx.jpg
		...
	class2
		xxx.jpg
		...
	...
	classn
		xxx.jpg
		...

Dataloader 的参数介绍如下

  • dataset:加载的数据集(Dataset对象)
  • batch_size:batch size
  • shuffle:是否将数据打乱
  • sampler: 样本抽样,后续会详细介绍
  • num_workers:使用多进程加载的进程数,0 代表不使用多进程
  • collate_fn: 如何将多个样本数据拼接成一个 batch,一般使用默认的拼接方式即可’
  • pin_memory:是否将数据保存在pin memory 区,pin memory 中的数据转到 GPU 会快一些
  • drop_last:dataset中的数据个数可能不是 batch_size 的整数倍,drop_last 为 True 会将多出来不足一个batch的数据丢弃

参考 pytorch之DataLoader()函数

官网中 Dataloader 的介绍如下(https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader)

【Pytorch】Load your own dataset_第2张图片
【Pytorch】Load your own dataset_第3张图片

在训练和测试时,可用如下循环来对数据进行操作

for batch_images, batch_labels in train_loader:
	pass

把数据集的文件夹建立好,直接调用 ImageFolderDataLoader 来进行数据的载入分批读取确实很方便,但是如果我们想知道哪些图片分类错误了, train_loader loader 中仅有 image(图片) 和 label 属性,没有 image name(图片名称) 属性,有些力不从心!

因此,我们可以自己写 ImageFolder 来实现读取 image、label、image name 的功能,当然熟悉这个流程后,以后可以进行更个性化的操作!

2 OwnFolder and DataLoader

自己写数据读取和预处理,来替代 torchvision.datasets.ImageFolder 的功能,具体实现如下 class Own_Dataset 所示

class Own_Dataset(Dataset):
    def __init__(self, image_label_list, transform=None):
        super().__init__()
        self.samples_list = image_label_list  # xxx.jpg class1 
        self.transform = transform  # pre-processing of data

    def __getitem__(self, index):
        img_name = self.samples_list[index][0] # absolute path of image name
        with open(img_name,"rb") as f:
            img = Image.open(f).convert("RGB") # load image
        label = self.samples_list[index][1] # image label
        if img is None:
            print(img_name)
        if self.transform is not None:
            img = self.transform(img)
        return img, label, img_name

    def __len__(self):
        return len(self.samples_list)

其中 image_label_list 为列表,存放着图片的绝对路径以及标签信息,格式如下

[(/train_data_pth/calss1/1.jpg,class1),
(/train_data_pth/calss1/2.jpg,class1),
...,
(/train_data_pth/calssn/m.jpg,classn)]

想实现更多功能,在 def __getitem__(self, index): 中定义即可,

__getitem__:实例[idx] 时触发

参考:【python】类(11)

配合 DataLoader 使用

train_loader = torch.utils.data.DataLoader(
    Own_Dataset(image_label_list=val_list,
               transform=transforms.Compose([
               	   transforms.Resize(input_size,interpolation=2), # resize
                   transforms.ToTensor(),
                   normalize,])),
    batch_size=test_batch_size,
    shuffle=False,
    num_workers=n_worker,
    pin_memory=True)

训练测试时,就可以访问图片,类别以及图片名信息了,如下所示

for batch_images, batch_labels,batch_names in train_loader:
	pass

3 transforms

下面介绍部分 torchvision.transforms 方法

更多的 torchvision.transforms 方法可以参考官网介绍

https://pytorch.org/docs/stable/torchvision/transforms.html

train_dataset = datasets.ImageFolder(
    train_data_pth,
    transforms.Compose([
        transforms.Resize(scale_size,interpolation=2),
        transforms.RandomRotation(5),
        transforms.ColorJitter(brightness=0.1,contrast=0.1,
                               saturation=0.1,hue=0.1),
        transforms.FiveCrop(input_size),
        transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),
        transforms.Lambda(lambda crops: torch.stack([transforms.Normalize(
            mean = [0.5,0.5,0.5],
            std = [0.5,0.5,0.5])(crop) for crop in crops]))
    ]))

input_size 和 scale_size 写成元组的形式,eg,(224,224) 和 (256,256)

Normalize 时注意 mean 和 std 一定要除以 255,值介于 0~1 之间

FiveCrop 或者 TenCrop 时,测试代码也需要进行相应的调整,如下

原来

out = net(batch_images)

现在

bs, ncrops, c, h, w, = batch_images.size()
result = net(batch_images.view(-1,c,h,w))
out = result.view(bs,ncrops,-1).mean(1)

你可能感兴趣的:(【Pytorch】Load your own dataset)