Pytorch学习笔记(1)--加载数据

Pytorch加载数据

pytorch加载数值主要分为两个部分:

  1. 复写Dataset类建立一个my_dataset类用于将数据集从硬盘中逐条读取
  2. 利用Dataloader模块读取数据

构建Dataset:

构建dataset类主要需要几个步骤:
1.初始化def init,一般传入数据文件地址和transform
2 grtitem方法, 后续被dataloader调用的方法,主要传入index,抽出每一个样本,而且对样本的特征和标签做区分。同时还会处理传入的数据。
3 __len__用于return数据的数量,会被调用。
4自定义的一个方法 get_info,用于将数据从硬盘中读取出来,并且将数据读取出来生成一个列表,以图片为例【(图片, 对应标签)】

import os
from torch.utils.data import Dataset
from  PIL import Image
label_c_d = {'cats': 0, 'dogs': 1}

class cat_dog_dataset(Dataset):
    def __init__(self, data_dir, transform=None):
        #data_dir中存储所有文件路径和标签
        self.data_dir = self.get_img_info(data_dir)
        self.label_name = {'cats': 0, 'dogs': 1}
        self.transform = transform
    def __getitem__(self, index):
        '''
        __getitem__是Dataset内部定义的方法用于读取单个
        :param index:  输入索引
        :return: 返回一个样本
        '''
        path_img, label = self.data_dir[index]
        img = Image.open(path_img).convert('RGB')

        if self.transform is not None:
            img = self.transform(img)
        return img, label

    def __len__(self):
        return len(self.data_dir)

    @staticmethod
    def get_img_info(data_dir):
        '''

        :param data_dir:  数据文件路径:本例中到training_set/training_set
        :return: 一个包含所有数据(image_path, int(label)原组的列表
        '''
        data_info =list()

        # os.walk:返回一个含三个值的序列,1这个序列遍历了data_dir的文件路径, 2data_dir里全部子文件的路径, 和3data_dir里全部文件的文件名

        for root, dirs, _ in os.walk(data_dir):
            for dir in dirs:
                image_names = os.listdir(os.path.join(root, dir))
                image_names = list(filter(lambda x: x.endswith('jpg'), image_names))

                for i in range(len(image_names)):
                    image_name = image_names[i]
                    image_path = os.path.join(root, dir, image_name)
                    label = label_c_d[dir]
                    data_info.append((image_path, int(label)))
        return data_info

Dataloader模块

Dataloader模快利用sample(采样器)对dataset返回的数据进行采样和分batch。

# 构建MyDataset实例
train_data = cat_dog_dataset(data_dir=train_dir, transform=train_transform)
valid_data = cat_dog_dataset(data_dir=valid_dir, transform=valid_transform)

# 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)

你可能感兴趣的:(Pytroch,pytorch)