Pytorch 构建自己的数据集 输入与标签皆为图片

 

构建数据集大概步骤为:用各种方法(cv2,PIL,skimage等等)读取成对图片→通过某方法,返回类似(输入,标签)形式→转为tensor形式→传入Dataloader形成映射。

 

现在我需要一个成对图片的数据集,即输入与标签都是图片,且输入与标签在命名上完全相同。Pytorch中则是使用TORCH.UTILS.DATA下的Dataloader方法来构建数据集,即:

myDataset = torch.utils.data.DataLoader(dataset)

此处dataset支持的数据集形式之一便是映射(Map-style datasets)形式,这也正是我的数据集所需要的样子。而为了构建这样形式的数据集,需要对torch.utils.data中的Dataset类进行继承,并覆写__getitem__()方法,该方法的作用是对于给定的键值(key)返回对应的数据。当然,也可以根据需要覆写__len__()方法,用于得到数据集的长度(数量)。

 

 

那么接下来的事情就很明确了,构造一个继承torch.utils.data.Data的类,其返回值的形式为(input,label):

import torch
import torchvision
from torch.utils.data import Dataset
import os
from PIL import Image
from matplotlib import pyplot as plt

#对读取的图片采取的处理方法,详情自行搜索transforms的用法
transforms_imag=torchvision.transforms.Compose([torchvision.transforms.Resize([64,64]),
                                                torchvision.transforms.ToTensor()])
#输入与标签图片所在的目录
input_root='./pic/input/'
label_root='./pic/label/'

class MyDataset(Dataset):#继承了Dataset子类
    def __init__(self,input_root,label_root,transform=None):
        #分别读取输入/标签图片的路径信息
        self.input_root=input_root
        self.input_files=os.listdir(input_root)#列出指定路径下的所有文件

        self.label_root=label_root
        self.label_files=os.listdir(label_root)

        self.transforms=transform
    def __len__(self):
        #获取数据集大小
        return len(self.input_files)
    def __getitem__(self, index):
        #根据索引(id)读取对应的图片
        input_img_path=os.path.join(self.input_root,self.input_files[index])
        input_img=Image.open(input_img_path)
        #视频教程使用skimage来读取的图片,但我在之后使用transforms处理图片时会报错
        #所以在此我修改为使用PIL形式读取的图片

        label_img_path=os.path.join(self.label_root,self.label_files[index])
        label_img=Image.open(label_img_path)

        if self.transforms:
            #transforms方法如果有就先处理,然后再返回最后结果
            input_img=self.transforms(input_img)
            label_img=self.transforms(label_img)
        
        return (input_img,label_img)#返回成对的数据

        ###把以下代码放在return前,反注释后运行
        # # test only for PIL#
        # input_img.show()
        # label_img.show()
        # # test only for PIL#



接下来便是运行了:

dataset_train=MyDataset(input_root, label_root, transform=transforms_imag)
trainloader=torch.utils.data.DataLoader(dataset_train)

至此,一个自己的数据集便构建完成了,根据我在transforms里的处理,该数据集的图片皆被转换为了pytorch可以处理的tensor格式,并把图片尺寸修改为64x64大小。

 

可以使用以下方法读取数据集中的数据:

for b_index, (data,label) in enumerate(trainloader):
    x = data
    y = label

以上就是构建一个自己的数据集的过程了,这个方法也比较通用,应该会有不错的泛用性。

 

关于ImageFolder

有这么个读取图片的方法,torchvision.datasets.ImageFolder,因为看起来很美好,所以一直在纠结怎么用它来一步到位,但喂入Dataloader后根本不是我想要的结果。该方法大概的意思是会根据你的文件夹来为图片自动添加标签。


最后倒推一下,整体思路是这样的:

1.Pytorch中使用卷积等运算时,比如torch.nn.Conv2d,它要求的输入格式为(N,C,H,W),其中N代表batch size。

2.输入格式是一个四维的量,而平时读取图片的方法能够获得的只有C,H,W这三个量,所以得想办法在原有的(C,H,W)上再多加一维。

3.此时Pytorch的torch.utils.data.DataLoader就提供了该方法,透过对其中参数batch_size的设置(默认为1),这样图片的格式就可以转换为带有N这一维了。

4.为了使用Dataloader(dataset),要将输入的dataset改成符合标准的格式。接着便是根据对dataset的注解来自行写相对应的类方法。

5.同时也别忘记pytorch计算时是使用的tensor(张量)格式的数据,所以在读取图片后要记得转换格式。上面代码中transforms的方法就包含了这一步。

 


参考:

1.如何在Pytorch中自建图片数据集(油管)

2.文档对torch.utils.data.Dataloader的解释

3.文档对torch.utils.data.Dataset的解释

4.文档对torchvision.datasets.ImageFolder的解释

5.文档对torch.nn.Conv2d的解释

 

你可能感兴趣的:(python,pytorch)