pytorch中几种Dataset读取数据重写实现

通过重写Dataset类,对自己制作的数据集进行读取后传给DataLoader。主要用来完成从哪里读取数据和标签的功能。主要是__getitem__(返回数据集和标签)和__len__(返回数据的长度)这两个方法。

import numpy as np
import torch
import os
from PIL import Image
from torch.utils.data import Dataset

class MyDataset_1(Dataset):
    """
    通过包含数据路径和标签的txt文件读取
    txt_path:txt文本路径, 该文本包含了图像的路径信息, 以及标签信息
    transform: 数据处理,对图像进行随机裁剪, 以及转换成tensor
    """
    def __init__(self, txt_path, transform=None, target_transform=None):
        super(MyDataset_1, self).__init__()
        fh = open(txt_path)
        imgs = []
        # 一行一行读取txt文件
        for line in fh:
            line = line.rstrip()  # 这一行就是图像的路径以及标签

            words = line.split()
            imgs.append((words[0], int(words[1])))
            self.imgs = imgs
            self.transform = transform
            self.target_transform = target_transform

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, index):
        fn, label = self.imgs[index]  # 通过index索引返回一个图像路径fn与标签label
        img = Image.open(fn)
        if self.transform:
            img = self.transform(img)
        return img, label

class MyDataset_2(Dataset):
    """
    通过标签文件读取,csv文件前面是numpy数据,最后一列是label
    """
    def __init__(self, csv_file):
        super(MyDataset_2, self).__init__()

        #  xy是一个容器, 通过读取一个包含数据和标签信息的文件
        xy = np.loadtxt(csv_file, delimiter=',', dtype=np.float32)

        self.x_data = torch.from_numpy(xy[:, 0:-1])
        self.y_data = torch.from_numpy(xy[:, -1])

        self.len = len(xy)  # 给后面的__len__()使用

    def __len__(self):
        return self.len

    def __getitem__(self, index):
        return self.x_data[index], self.y_data[index]

class MyDataset_3(Dataset):
    """
    每个文件夹是一个类,每个文件夹中都是该类的图片(这种方法就等同于torchvision.datasets.ImageFolder)
    """
    def __init__(self, dirname, transform=None):
        super(MyDataset_3, self).__init__()
        self.classes = os.listdir(dirname)  # 有多少个目录就等于多少个类别,这边获得类别名
        self.images = []
        self.transform = transform
        for i, classes in enumerate(self.classes):
            classes_path = os.path.join(dirname, classes)  # 类别目录
            for image_name in os.listdir(classes_path):  # 便利该类别中的图片
                self.images.append((os.path.join(classes_path, image_name), i))  # 获得图片路径和类别名索引

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        image_name, classes = self.images[index]
        image = Image.open(image_name)
        if self.transform:
            image = self.transform(image)
        return image, classes



你可能感兴趣的:(pytorch,pytorch,dataset)