Python加载带有csv标签文件的图片数据集

import os

os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
import pandas as pd
from skimage import io
import torch
from torch.utils.data import Dataset


class MyDataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        # 读取标签文件
        self.annotations = pd.read_csv(os.path.join(root_dir, csv_file))
        # 定义文件目录
        self.root_dir = root_dir
        # 定义transform
        self.transform = transform

    # 返回数据集长度
    def __len__(self):
        return len(self.annotations)

    # 获取数据的方法,会和Dataloader连用
    def __getitem__(self, index):
        # 获取图片路径,0表示csv文件的第一列
        img_path = os.path.join(self.root_dir, self.annotations.iloc[index, 0])
        # 读取图片
        image = io.imread(img_path)
        # 获取图片对应的标签,1表示csv文件的第二列
        label = torch.tensor(int(self.annotations.iloc[index, 1]))

        # 如果使用时附加了transform参数,则对图片应用转换
        if self.transform:
            image = self.transform(image)

        # 返回图片和标签
        return image, label

下面使用这个类:

并将图片数据集转化为tensor数据类型,方面后来的神经网络学习

# 实例化,输入文件名称,路径和需要的数据转换
csv_file = r".Annotations.csv" # 添加绝对路径
root_dir = r"Images"
dataset = MyDataset(csv_file=csv_file, root_dir=root_dir, transform=transforms.ToTensor())
len_dataset = len(dataset)
print(len_dataset)
# print(type(dataset))

你可能感兴趣的:(python,开发语言)