Pytorch:数据集构建、加载、划分

在利用神经网络对自己的数据进行分析时,首先要对数据进行处理,构建用来训练和测试的数据集,并对其进行加载。具体方法如下:

本文中,我们将几类类别不同的数据分别放在不同的文件夹中,并将所有类别的文件夹放在一个大文件夹里,为构建数据集,我们首先根据文件路径和对应文件的标签生成一个.txt文件:

import os
def generate(dir, label):
    files = os.listdir(dir)
    files.sort()
    listText = open('all_data.txt', 'a')
    for file in files:
        fileType = os.path.split(file)
        if fileType[1] == '.txt':
            continue
        file1 = os.path.join(dir, file)
        print(file1)
        name = file1 + ' ' + str(int(label)) + '\n'
        listText.write(name)
    listText.close()


outer_path = 'D:/'  #存储数据的大文件夹路径

if __name__ == '__main__':
    i = 0
    folderlist = os.listdir(outer_path)  # 列举所有类别文件夹
    for folder in folderlist:
        generate(os.path.join(outer_path, folder), i)
        i += 1

下面进行数据集构建:

#%% 定义MyDataset类,读取数据集
class MyDataset(Dataset):
    def __init__(self, txt, transform=None, target_transform=None):
        super(MyDataset, self).__init__()
        fh = open(txt, 'r')
        imgs = []
        for line in fh:
            line = line.rstrip()
            words = line.split()  # 通过指定分隔符对字符串进行切片,默认为所有的空字符,包括空格、换行、制表符等
            imgs.append((words[0], int(words[1])))  #将txt文件中的数据按行读取并存入imgs列表中。
        self.img = imgs
        self.transform = transform
        self.target_transform = target_transform

    def __getitem__(self, index):
        fn, label = self.img[index]  
        img1 = Image.open(fn)
        if self.transform is not None:
            img1 = self.transform(img1)
        return img1, label

    def __len__(self):
        return len(self.img)


all_data1 = MyDataset(txt="all_data.txt", transform=transforms.ToTensor())  #txt为保存图片路径和标签的.txt文件路径。
这里,我们直接读取所有数据,下一步将利用torch.utils.data.random_split()进行数据集划分:
#%%数据集划分
train_size = int(0.8 * len(all_data1))  #获取训练集长度
test_size = len(all_data1) - train_size  #测试集长度
train_data, test_data = torch.utils.data.random_split(all_data1, [train_size, test_size])
数据集加载:
train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_data, batch_size=BATCH_SIZE, shuffle=True)

至此,完成数据集构建,划分,加载。

你可能感兴趣的:(Pytorch:数据集构建、加载、划分)