Pytorch:如何定义自己创建的数据集

本文为个人知识学习的记录,未来可以复习回顾。

在Pytorch中定义数据集主要涉及到两个主要的类:

  • Datasets

  • DataLoader

1 Datasets

1.1 什么是Datasets?

Datasets是我们用的数据集的库,pytorch自带多种数据集,如Cifar10、MNIST等

1.2 为什么要定义Datasets?

Pytorch中有工具函数torch.utils.Data.DataLoader,通过这个函数我们在准备加载数据集使用mini-batch的时候可以使用多线程并行处理,这样可以加快我们准备数据集的速度。Datasets就是构建这个工具函数的实例参数之一。

1.3 如何定义Datasets?

Dataset类是Pytorch中图像数据集中最为重要的一个类,也是Pytorch中所有数据集加载类中应该继承的父类。其中父类中的两个私有成员函数必须被重载,否则将会触发错误提示:

def __getitem__(self,index):

def __len__(self):
  • __len__:返回数据集的大小
  • __getitem__:编写支持数据集索引的函数

注:

重点是 getitem函数,getitem接收一个index,然后返回图片数据和标签,这个index通常指的是一个list的index,这个list的每个元素就包含了图片数据的路径和标签信息。

如何制作list?

通常的方法是将图片的路径和标签信息存储在一个txt中,然后从该txt中读取。

1.3.1 读取数据的基本流程

  • 制作存储了图片的路径和标签信息的txt
  • 将这些信息转化为list,该list每一个元素对应一个样本
  • 通过getitem函数,读取数据和标签,并返回数据和标签

1.3.2 Datasets的整体框架

from torchvision.utils.data import Dataset

class MyDataset(Dataset):#需要继承Dataset
    def __init__(self):
        # TODO
        # 1. 初始化文件路径或文件名列表。
        #也就是在这个模块里,我们所做的工作就是初始化该类的一些基本参数。
        pass
    def __getitem__(self, index):
        # TODO

        #1.从文件中读取一个数据(例如,使用numpy.fromfile,PIL.Image.open)。
        #2.预处理数据(例如torchvision.Transform)。
        #3.返回数据对(例如图像和标签)。
        #这里需要注意的是,第一步:read one data,是一个data
        pass
    def __len__(self):
        # 应该将0更改为数据集的总大小。

1.4 栗子

此栗子为图像分割实验中创建自己的数据集:

import os
import cv2
import numpy as np
from torch.utils.data import Dataset


# 训练集
class TrainDataset(Dataset):  # 继承Dataset
    def __init__(self, data_path, transform=None):  # 制作一个list,将图片路径以及标签信息存储在一个txt中
        # 读取数据
        self.images = os.listdir(data_path + '/images')
        self.labels = os.listdir(data_path + '/masks')

        # 查看图片与标签数量是否一致
        assert len(self.images) == len(self.labels), 'Number does not match'

        self.transform = transform  # 定义做何种变换

        # 下面这才是重点,将数据与标签连接在一起,为了之后可以索引
        # 构建list列表
        self.images_and_labels = []  # 创建一个空列表
        for i in range(len(self.images)):  # 往空列表里装东西,为了之后索引
            self.images_and_labels.append(
                (data_path + '/images/' + self.images[i], data_path + '/masks/' + self.labels[i])
            )

    def __getitem__(self, index):  # 读取数据和标签,并返回数据和标签
        # 读取数据
        image_path, label_path = self.images_and_labels[index]
        # 图像处理
        image = cv2.imread(image_path)  # 读取图像,(H,W,C)
        image = cv2.resize(image, (224, 224))  # 将图像尺寸变为 224*224

        # 对标签进行处理,从而可以与结果对比,得到损失值
        label = cv2.imread(label_path, 0)  # 读取标签,且为灰度图
        label = cv2.resize(label, (224, 224))  # 将标签尺寸变为 224*224
        # 由于是二值分类,所以有以下操作.背景为0, 目标为1
        label = label / 255  # 调整数值范围为[0,1.0],, 因为神经网络中也做同样处理
        label = label.astype('uint8')  # 因为转换为了整型,只需要整数部分,所以数值不为1的,全置0。

        # one-hot编码
        label = np.eye(2)[label]  # 此处矩阵由二维变成了三维度
        label = np.array(list(map(lambda x: abs(x-1), label))).astype('float32')  # 0变为1,1变为0
        label = label.transpose(2, 0, 1)  # (H,W,C) => (C,H,W)

        if self.transform is not None:
            image = self.transform(image)
        return image, label  # 返回索引

    def __len__(self):  # 必须写,返回数据集的长度
        return len(self.images)


# 测试集
class TestDataset(Dataset):
    def __init__(self, data_path, transform=None):
        self.images = os.listdir(data_path + '/images')
        self.transform = transform
        self.imgs = []
        for i in range(len(self.images)):
            # self.imgs.append(data_path + '/images/' + self.images[i])
            self.imgs.append(os.path.join(data_path, 'images/', self.images[i]))

    def __getitem__(self, item):
        img_path = self.imgs[item]
        img = cv2.imread(img_path)
        img = cv2.resize(img, (224, 224))

        if self.transform is not None:
            img = self.transform(img)
        return img

    def __len__(self):
        return len(self.images)


if __name__ == '__main__':
    img = cv2.imread('../data/train/masks/150.jpg', 0)
    img = cv2.resize(img, (16, 16))
    img2 = img / 255
    cv2.imshow('pic1', img2)
    cv2.waitKey()
    print(img2)

    img3 = img2.astype('uint8')
    cv2.imshow('pic2', img3)
    cv2.waitKey()
    print(img3)

    # 下面开始矩阵就变成了3维
    hot1 = np.eye(2)[img3]  # 对标签矩阵的每个元素都做了编码,(0,1)背景元素,(1,0)目标元素
    print(hot1)
    print(hot1.ndim)
    print(hot1.shape)  # (16,16,2) C=16,H=16,W=16

    hot2 = np.array(list(map(lambda x: abs(x - 1), hot1))) # 变换一下位置。(1,0)背景元素,(0,1)目标元素
    print(hot2)
    print(hot2.ndim)
    print(hot2.shape)  # (16,16,2) C=16,H=16,W=16

    hot3 = hot2.transpose(2, 0, 1)
    print(hot3)  # (C=2,H=16,W=16)

2 DataLoader

Dataset类是读入数据集数据并且对读入的数据进行了索引。但是光有这个功能是不够用的,在实际的加载数据集的过程中,我们的数据量往往都很大,对此我们还需要一下几个功能:

  • 可以分批次读取:batch-size
  • 可以对数据进行随机读取,可以对数据进行洗牌操作(shuffling),打乱数据集内数据分布的顺序
  • 可以并行加载数据(利用多核处理器加快载入数据的效率)

此时就需要Dataloader类,常用操作有:batch_size(每个batch的大小), shuffle(是否进行shuffle操作),num_workers(加载数据的时候使用几个子进程)。Dataloader这个类并不需要我们自己设计代码,我们只需要利用DataLoader类读取我们设计好的Dataset子类即可:

from torchvision.utils.data import DataLoader

train_loader = DataLoader(dataset=train_data, batch_size=6, shuffle=True ,num_workers=4)
test_loader = DataLoader(dataset=test_data, batch_size=6, shuffle=False,num_workers=4)

参考:https://blog.csdn.net/sinat_42239797/article/details/90641659

 

 

 

 

 

 

 

 

 

 

你可能感兴趣的:(修仙之路:pytorch篇)