pytorch构建自己的数据集

ImageFolder可以使用的情况是,比如猫狗识别中,train一个文件夹下已经将猫和狗分为两个不同的文件夹了,那我们可以直接使用ImageFolder来包装成数据集。
但有时候遇到的数据集是train一个文件夹下包含了所有的猫狗图片,那我们就无法使用ImageFolder函数,此时我们可以自己构造一个dataset类。
核心是getitem函数,这个函数的主要功能是根据样本的索引,返回索引对应的一张图片的图像数据X与对应的标签Y,也就是返回一个对应的训练样本。
具体而言,getitem的实现思路比较简单,将索引idx转换为图片的路径,然后用PIL的Image包来读取图片数据,然后将数据用torchvision的transforms转换成tensor并且进行Resize来统一大小(给出的图片尺寸不一致)与归一化,这样一来就可以得到图像数据了。因为训练集中图片的文件名上面带有猫狗的标签,所以标签可以通过对图片文件名split后得到然后转成0,1编码。

class MyDataset(Dataset):
    def __init__(self,data_path:str,train=True,transform=None):
        self.data_path = data_path
        self.train_flag = train
        if transform is None:([    ……
            ])
        else:
            self.transform = transform
        self.path_dir = os.listdir(data_path)
    def __getitem__(self, idx:int):
        img_path = self.path_dir[idx]
        if self.train_flag is True:
            if img_path.split('.')[0] =='dog':
                label = 1
            else:
                label = 0
        else:
            label = int(img_path.split('.')[0])
        label = torch.as_tensor(label,dtype=torch.int64)
        img_path = os.path.join(self.data_path,img_path)
        image = Image.open(img_path)
        image = self.transform(image)
        return image,label
    def __len__(self)->int:
        return len(self.path_dir)
        ```

你可能感兴趣的:(pytorch构建自己的数据集)