使用Dataset、Dataloader自定义数据集

pytorch中的数据pipeline设计:生产者消费者模式,分为sampler、dataset、dataloaderlter、dataloader四个抽象层次

1、sampler:(采样器)负责生成读取index序列  采样(可以自定义控制采样顺序)

2、dataset:负责根据index读取相应数据并执行预处理(负责处理索引(index)到样本(sample)映射的一个类(class))

3、dataloaderlter:负责协调多进程执行dataset

4、dataloader:最顶层的抽象 通过index找出一条数据出来 index——>record

 

理解:

索引——>(dataset)给出相应的x和y

实现dataset的方法:

1、map-style datasets(监督学习中样本往往是确定的)

2、iterable-style datasets(迭代器,用于样本不确定)

trochvision包含了 1.常用数据集;2.常用模型框架;3.数据转换方法。其中它提供的数据集就已经是一个Dataset类 了。torchvison.datasets就是专门提供各类常用数据集的模块。

CLASS torch.utils.data.DataLoader(dataset,batch_size=1,shuffle=None,sampler=None,batch_sampler=None,num_workers=0,collate_fn=None,pin_memory=False,drop_last=False,timeout=0,worker_init_fn=None,multiprocessing_context=None,generator=None,*,prefetch_factor=2,persistent_workers=False,pin_memory_device=")

dataset:输入的数据类型,单条记录

batch_size:批训练数据量的大小,pytorch训练模型输入的数据量大小,为1时一行一行输入

num_workers:进程数

collate_fn:将一小段数据合并成数据列表,默认设置是False。如果设置成True,系统会在返回前会将张量数据(Tensors)复制到CUDA内存中


1、Dataset与自定义数据集的关系

Dataset:父类,是所有开发人员训练、测试使用的所有数据集的一个模板或是抽象

自定义的数据集:子类,具体的数据集,继承Dataset父类的所有方法和属性

class Dataset(object):
    """An abstract class representing a Dataset.
    All other datasets should subclass it. All subclasses should override
    ``__len__``, that provides the size of the dataset, and ``__getitem__``,
    supporting integer indexing in range from 0 to len(self) exclusive.
    """
 
    def __getitem__(self, index):
        raise NotImplementedError
 
    def __len__(self):
        raise NotImplementedError
 
    def __add__(self, other):
        return ConcatDataset([self, other])

2、Dataset重写原理

需要自定义三个函数:_init_  、_getitem_ 、_len_  (getitem和len是子类必须继承的)

def _init_ :初始化,把数据作为一个参数传给类

def _getitem_:根据索引获取样本对(x,y) 索引为(0,len(dataset)-1),根据数据集长度从0开始的索引序列;模型通过这个函数获取一对样本对

def _len_:表示数据集的长度,最终训练时用到的数据集的样本个数

# 自定义Dataset的基本模板
class ExampleDataset(Dataset): #自定义一个类
    def __init__(self, data): #初始化,把数据作为一个参数传递给类;
        self.data = data
     def __len__(self):
         return len(self.data)  #返回数据的长度
    
    def __getitem__(self, idx):
         x= ...
         y= ...
         return x, y
        #return self.data[idx]  #根据索引返回数据


3、自定义Dataset

完整代码:

ps:root路径可以转换成其他数据集

(1)从root开始,传入数据集的目录

我选择的是一些指纹图片,分成了三个文件夹,图片是png;每个文件夹下放一些照片

使用Dataset、Dataloader自定义数据集_第1张图片

(2)进入read_split_data函数,对数据集进行划分,将其分为(train_images_path, train_images_label, val_images_path, val_images_label)返回

数据(x)和标签(y)进行标识

 会遍历文件夹,一个文件夹为一个类别

(3)进入MyDataSet自定义数据集

import os
import torch
from torchvision import transforms
from PIL import Image
import torch
from torch.utils.data import Dataset
import os
import json
import pickle
import random
import matplotlib.pyplot as plt

def read_split_data(root: str, val_rate: float = 0.2):    #val_rate划分验证集所占所有用户的比例
    random.seed(0)  # 保证随机结果可复现
    assert os.path.exists(root), "dataset root: {} does not exist.".format(root)#判断路径是否存在

    # 遍历文件夹,一个文件夹对应一个类别
    finger_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
    # 排序,保证顺序一致
    finger_class.sort()
    # 生成类别名称以及对应的数字索引
    class_indices = dict((k, v) for v, k in enumerate(finger_class))
    json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    train_images_path = []  # 存储训练集的所有图片路径
    train_images_label = []  # 存储训练集图片对应索引信息
    val_images_path = []  # 存储验证集的所有图片路径
    val_images_label = []  # 存储验证集图片对应索引信息
    every_class_num = []  # 存储每个类别的样本总数
    supported = [".jpg", ".JPG", ".png", ".PNG"]  # 支持的文件后缀类型
    # 遍历每个文件夹下的文件
    for cla in finger_class:
        cla_path = os.path.join(root, cla)
        # 遍历获取supported支持的所有文件路径
        images = [os.path.join(root, cla, i) for i in os.listdir(cla_path)
                  if os.path.splitext(i)[-1] in supported]
        # 获取该类别对应的索引
        image_class = class_indices[cla]
        # 记录该类别的样本数量
        every_class_num.append(len(images))
        # 按比例随机采样验证样本
        val_path = random.sample(images, k=int(len(images) * val_rate))

        for img_path in images:
            if img_path in val_path:  # 如果该路径在采样的验证集样本中则存入验证集
                val_images_path.append(img_path)
                val_images_label.append(image_class)
            else:  # 否则存入训练集
                train_images_path.append(img_path)
                train_images_label.append(image_class)

    print("{} images were found in the dataset.".format(sum(every_class_num)))
    print("{} images for training.".format(len(train_images_path)))
    print("{} images for validation.".format(len(val_images_path)))

    plot_image = False
    #plot_image = True
    if plot_image:
        # 绘制每种类别个数柱状图
        plt.bar(range(len(finger_class)), every_class_num, align='center')
        # 将横坐标0,1,2,3,4替换为相应的类别名称
        plt.xticks(range(len(finger_class)), finger_class)
        # 在柱状图上添加数值标签
        for i, v in enumerate(every_class_num):
            plt.text(x=i, y=v + 5, s=str(v), ha='center')
        # 设置x坐标
        plt.xlabel('image class')
        # 设置y坐标
        plt.ylabel('number of images')
        # 设置柱状图的标题
        plt.title('finger class distribution')
        plt.show()

    return train_images_path, train_images_label, val_images_path, val_images_label


from PIL import Image
import torch
from torch.utils.data import Dataset


class MyDataSet(Dataset):
    """自定义数据集"""

    def __init__(self, images_path: list, images_class: list, transform=None):
        self.images_path = images_path
        self.images_class = images_class
        self.transform = transform

    def __len__(self):   #训练集的样本个数
        return len(self.images_path)

    def __getitem__(self, item):   #传入一个索引
        img = Image.open(self.images_path[item])   #self.images_path[item]获得对应的一张图片的路径
        # RGB为彩色图片,L为灰度图片
        #if img.mode != 'RGB':
        #    raise ValueError("image: {} isn't RGB mode.".format(self.images_path[item]))
        img = img.convert('RGB')  #修改的,把图片变为rgb
        label = self.images_class[item]

        if self.transform is not None:
            img = self.transform(img)

        return img, label

    @staticmethod
    def collate_fn(batch):
        # 官方实现的default_collate可以参考
        # https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py
        images, labels = tuple(zip(*batch))

        images = torch.stack(images, dim=0)
        labels = torch.as_tensor(labels)
        return images, labels


root = "/gemini/finger/data"  # 数据集所在根目录
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))

    train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(root) #划分训练、验证集

    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
        "val": transforms.Compose([transforms.Resize(256),
                                   transforms.CenterCrop(224),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
    #传入训练集所有图像路径的列表、所有样本的标签信息、预处理方法(transform)
    train_data_set = MyDataSet(images_path=train_images_path,
                               images_class=train_images_label,
                               transform=data_transform["train"])

    batch_size = 8
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers'.format(nw))
    train_loader = torch.utils.data.DataLoader(train_data_set,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=nw,
                                               collate_fn=train_data_set.collate_fn)

    #plot_data_loader_image(train_loader)

    for step, data in enumerate(train_loader):
        images, labels = data


if __name__ == '__main__':
    main()

 

你可能感兴趣的:(联邦学习,python,开发语言)