Python中数据集的结合处理程序,结合OS模块进行

# import torch
# import torch
# print(torch.cuda.is_available())
from torch.utils.data import  Dataset
# **作用:** (1) 创建数据集,有__getitem__(self, index)函数来根据索引序号获取图片和标签,
# 有__len__(self)函数来获取数据集的长度.
# import cv2

from PIL import Image
import  os
# 文件路径处理

# 创建一个类,这个类继承于Dataset
class MyDataset(Dataset):
    """初始化类"""
    def __init__(self,root_dir,label_dir): # label_dir上一级的名称
        self.root_dir = root_dir
        self.label_dir = label_dir
        self.path = os.path.join(self.root_dir,self.label_dir)
        self.img_path = os.listdir(self.path)

    def __getitem__(self, idx):
        img_name = self.img_path[idx]
        img_item_path = os.path.join(self.root_dir,self.label_dir,img_name)
        img = Image.open(img_item_path)
        label = self.label_dir
        return img,label

    def __len__(self):
        return len(self.img_path)


root_dir = "2020SchoolLearning/PytorchLearning/train"
ants_label_dir = "ants_image"
bees_label_dir = "bees_image"
ant_data = MyDataset(root_dir,ants_label_dir)
bee_data = MyDataset(root_dir,bees_label_dir)

train_dataset = ant_data + bee_data

你可能感兴趣的:(机器学习)