使用datasets.ImageFolder()划分数据集并打乱顺序(简单易懂)

一、代码

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import random

path = 
transforms=
proportion=0.1 #测试集比例
batch_size=32

data = datasets.ImageFolder(path,transforms)
n = len(data)  #数据集总数
n_test = random.sample(range(1, n), int(proportion * n))  #按比例取随机数列表

test_set = torch.utils.data.Subset(data, n_test)  #按照随机数列表取测试集
train_set = torch.utils.data.Subset(data,list(set(range(1, n)).difference(set(n_test))))  #测试集剩下作为训练集

data_train = DataLoader(train_set, batch_size=batch_size, shuffle=True)
data_test=DataLoader(test_set, batch_size=batch_size, shuffle=False)

#输出筛选的训练集labels
for batch_idex, (data, targets) in enumerate(data_test):
    print(batch_idex,targets)

二、测试结果

用了十类的图片数据集测试,结果数据集成功被打乱了!

使用datasets.ImageFolder()划分数据集并打乱顺序(简单易懂)_第1张图片

三、后记

网上其它的代码只进行划分忽略了打乱这个环节,那可能有人会问DataLoader里不是有shuffle吗,为什么不用呢?

  • 因为是先划分的数据集,如果数据集的标签是连续排列的,划分的数据集的标签会出现扎堆现象,后续再在DataLoader时打乱就没效果啦。就像下面这样,测试集将0,1标签都取走了而没有其它标签,这显然不是一个合理的数据集!

使用datasets.ImageFolder()划分数据集并打乱顺序(简单易懂)_第2张图片

list取补集代码:list(set(range(1, n)).difference(set(n_test)))

  • 取完补集最后需要转成list,不然会报错:TypeError: ‘set’ object is not subscriptable

你可能感兴趣的:(深度学习,深度学习,pytorch,人工智能)