pytorch使用ImageFolder和random_split读取和划分数据集

实现对数据集的封装和划分,数据集格式如图所示
pytorch使用ImageFolder和random_split读取和划分数据集_第1张图片
import torch
import torch.utils.data
from torchvision import transforms,datasets

使用ImageFolder去读取,返回后的数据路径和标签对应起来

all_dataset = datasets.ImageFolder(’…/data/amazon/images’, transform=data_transform)

使用random_split实现数据集的划分,lengths是一个list,按照对应的数量返回数据个数。

train, test = torch.utils.data.random_split(dataset= all_dataset, lengths=[参数1,参数2])

接着按照正常方式使用DataLoader读取数据,返回的是DataLoader对象

train = torch.utils.data.DataLoader(train, batch_size=4, shuffle=True, num_workers=4)
test = torch.utils.data.DataLoader(test, batch_size=4, shuffle=True, num_workers=4)

你可能感兴趣的:(random_split,ImageFolder,python)