Pytorch 继承Dataset加载自己的数据集

1、应用场景

在使用 Pytorch 做分类任务的时候,一般会用自带的torchvision.datasets.ImageFolder()函数,但是这个对数据存储方式有要求,不一定适合自己,如果考虑加载自己的数据,就要考虑重写Dataset类了。

ImageFolder 对数据存储方式要求:
        root/dog/xxx.png
        root/dog/xxy.png
        root/dog/xxz.png
        ... ...
        root/cat/123.png
        root/cat/nsdf3.png
        root/cat/asd932_.png

2、定制自己的数据加载方式

一般我们不想来回移动数据,知道图片的路径即可,告诉模型在哪里自己去拿,是比较好的方式。所以我们只要继承Dataset类,重新实现一下即可。

大致方法可分为三步:
  1. 把图片的路径和label整理到文本中(什么文本都可以,方式也不限,但要方便自己解析)。
  2. 将数据信息,解析,并存到list中。
  3. 重新实现,__getitem__() 函数,读取每条数据和标签,并返回。
train.txt --(第1列是数据路径,第2列标签)
        root/dog/xxx.png	0
        root/dog/xxy.png	0
        root/dog/xxz.png	0
        root/cat/123.png	1
        root/cat/nsdf3.png	1
        root/cat/asd932_.png	1
具体代码实现
#!/usr/bin/python
# -*- coding: UTF-8 -*-

from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import torch


__all__ = ['MyDataset']


class MyDataset(Dataset):

    def __init__(self, dataPath, transform=None, target_transform=None):
        imgsPath = open(dataPath, 'r')
        imgs = []
        for line in imgsPath:
            line = line.rstrip()
            words = line.split()
            imgs.append((words[0], int(words[1])))
        self.imgs = imgs
        self.transform = transform
        self.target_transform = target_transform

    def __getitem__(self, index):
        fn, label = self.imgs[index]
        img = Image.open(fn).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            label = self.transform(label)
        return img, label

    def __len__(self):
        return len(self.imgs)


if __name__ == '__main__':

    transform_train = transforms.Compose([transforms.Resize(256),  # 重置图像分辨率
                                          transforms.RandomResizedCrop(224),  # 随机裁剪
                                          transforms.RandomHorizontalFlip(),  # 以概率p水平翻转
                                          transforms.RandomVerticalFlip(),  # 以概率p垂直翻转
                                          transforms.ToTensor(),])
    trainset = MyDataset(dataPath='train.txt', transform=transform_train)  # 训练集
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=8, shuffle=True, num_workers=2, pin_memory=True)
    for step, (tx, ty) in enumerate(trainloader, 0):
        print('---test---', tx, ty)

声明: 总结学习,有问题或不当之处,可以批评指正哦,谢谢。

优秀的参考链接

[1]:https://github.com/tensor-yu/PyTorch_Tutorial
[2]:https://blog.csdn.net/u011995719/article/details/85102770

你可能感兴趣的:(Python,PyTorch)