Pytorch继承Dataset编写自己的Getdata

作者记录方便查询

处理训练数据集可以通过继承Dataset编写自己的Getdata进行处理,或者通过TensorDataset直接构建,TensorDataset可以避免很多代码量,但可操作性较少,因此总结一下常用的继承Dataset编写自己的Getdata(慢慢总结中。。。)

情况一:数据类别对应相应的文件夹,即每个文件夹对应一个数据的类别。

这里使用ped2进行举例

from torch.utils.data import Dataset
import cv2
import pathlib
import torch
class GetData(Dataset):
    def __init__(self, path):# path是分类文件夹的上一级

        # 适用于数据根据文件夹分类的情况
        data_root = pathlib.Path(path)
        all_image_paths = list(data_root.glob('*/*'))# 返回所有分类文件夹下面的文件名称,可以增添文件类型进行限制
        self.all_image_paths = [str(path) for path in all_image_paths]# 将文件名转换成字符串形式
        label_names = sorted(item.name for item in data_root.glob('*/') if item.is_dir())# 标签列表
        label_to_index = dict((label, index) for index, label in enumerate(label_names))# 
        self.all_image_labels = [label_to_index[path.parent.name] for path in all_image_paths]# 给每个样本打标签

    def __getitem__(self, index):
        img = cv2.imread(self.all_image_paths[index])# 这里是图片的读取方式,可以换成其他文件的形式
        label = self.all_image_labels[index]
        img = torch.tensor(img, dtype=torch.float32)
        label = torch.tensor(label)
        return img, label

    def __len__(self):
        return len(self.all_image_paths)
# 例子:(r'D:\暑假学习\自编码器异常检测\datasets\ped2\training\frames')

你可能感兴趣的:(pytorch,深度学习,python)