pytorch_数据读取_dataset

一、Dataset是什么?

Dataset类似建立一个数组,建立数据集和数据标签之间的联系(就像数组下标和元素之间的联系)。

二、什么时候用Dataset?

1、引入datasets内部封装的数据集。

例如:CIFAR10是一个关于图片的数据,下面代码就是它的引入

data = datasets.CIFAR10("./data/", transform=transform, train=True, download=True)

2、引入自己的数据集,这里使用了使用ImageFolder这个API。

FaceDataset = datasets.ImageFolder('./data', transform=img_transform)

3、ImageFolder介绍

(1)ImageFolder的基本概念

ImageFolder对文件夹类型的数据集进行引入,这里文件夹内部存储的数据集要求是同一类型的图片。

(2)参数介绍

ImageFolder(root, transform=None, target_transform=None, loader=default_loader)
参数:
root:指定保持图片的文件夹路径
transform:对PIL Image进行的转换操作
target_transform:对label的转换

loader:给定路径后如何读取图片,默认读取为RGB格式的PIL Image对象

三、如何定义一个Dataset

1、我们定义的数据集都要继承自torch.utils.data.Dataset ,所以下面的代码步骤必不可少。

from torch.utils.data import Dataset
class MyData(Dataset):

2、我们这里使用的数据集一般是MAP数据集

(1)一个Map式的数据集必须要重写getitem(self, index),len(self) 两个内建方法。

(2)用来表示从索引到样本的映射(Map)。就是将你的数据集打包成一个数组,可以精准地定位长度,并随时可以返回特定位置的数据。

举个例子,当使用dataset[idx]命令时,可以在你的硬盘中读取你的数据集中第idx张图片以及其标签
(如果有的话);len(dataset)则会返回这个数据集的容量。

(3)定义的代码

class CustomDataset(data.Dataset):#需要继承data.Dataset
    def __init__(self):#而且这里的self标签可以建立self.xxx达到xxx全局化的作用
        # TODO
        # 1. Initialize file path or list of file names.
        pass
    def __getitem__(self, index):
        # TODO
        # 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
        # 2. Preprocess the data (e.g. torchvision.Transform).
        # 3. Return a data pair (e.g. image and label).
        #这里需要注意的是,第一步:read one data,是一个data
        pass
    def __len__(self):
        # You should change 0 to the total size of your dataset.
        return 0

(4)例子,这里我们的图片文件储存在“./data/faces/”文件夹下,图片的名字并不是从1开始,而是从final_train_tag_dict.txt这个文件保存的字典中读取,label信息也是用这个文件中读取。大家可以照着上面的注释阅读这段代码

from torch.utils import data
import numpy as np
from PIL import Image


class face_dataset(data.Dataset):
    def __init__(self):
        self.file_path = './data/faces/'
        f=open("final_train_tag_dict.txt","r")
        self.label_dict=eval(f.read())
        f.close()

    def __getitem__(self,index):
        label = list(self.label_dict.values())[index-1]
        img_id = list(self.label_dict.keys())[index-1]
        img_path = self.file_path+str(img_id)+".jpg"
        img = np.array(Image.open(img_path))
        return img,label

    def __len__(self):
        return len(self.label_dict)

(5)一个完整的读取文件的例子


from torch.utils.data import Dataset
from PIL import Image
import os
class MyData(Dataset):

    def __init__(self,root_dir,label_dir):
        self.root_dir=root_dir
        self.label_dir=label_dir
        self.path=os.path.join(self.root_dir,self.label_dir)
        self.img_path =os.listdir(self.path)


    def __getitem__(self, idx):
        img_name = self.img_path[idx]
        img_item_path = os.path.join(self.root_dir,self.label_dir,img_name)#URL的拼接
        img=Image.open(img_item_path)
        label=self.label_dir
        return img,label
    def __len__(self):
        return len(self.img_path)


root_dir="练手数据集/val"
ant_label_dir="ants"
bees_label_dir="bees"
ants_dataset=MyData(root_dir,ant_label_dir)
bees_dataset=MyData(root_dir,bees_label_dir)
train_dataset = ants_dataset + bees_dataset #将两个数据集合并。
img,label=train_dataset[123]
img.show() #可以展示图片


这里的self在class类里面的作用就是def init(self,root_dir,label_dir)将初始化的root_dir和label_dir在class类内部进行公有化。
例如:接下来在__getitem__函数内部就直接进行使用了,img_item_path = os.path.join(self.root_dir,self.label_dir,img_name)这里就体现了它的公有化。

你可能感兴趣的:(pytorch,pytorch,python)