pytorch使用ImageFolder和random_split读取和划分数据集

1. 最近重新学习torch知识,想实现对自己的数据集的封装和划分,由于自己的数据集格式如图所示

pytorch使用ImageFolder和random_split读取和划分数据集_第1张图片
层级结构:

|---data
	|---amazon
		|---images
			|---back_pack
				|---frame_0001.jpg
				|---frame_0002.jpg
				|---frame_0002.jpg
				...

2. 首先,如果数据集层级结构是这样的格式,则可以进行如下方式处理

import torch
import torch.utils.data
from torchvision import transforms,datasets


# 定义transforms的一些操作
data_transform = transforms.Compose([
		# Resize后数据的大小为224 * 224
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        # 数据标准化,采用的图片标准化参数
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
# 使用ImageFolder去读取,返回后的数据路径和标签对应起来
all_dataset = datasets.ImageFolder('../data/amazon/images', transform=data_transform)

# 使用random_split实现数据集的划分,lengths是一个list,按照对应的数量返回数据个数。
# 这儿需要注意的是,lengths的数据量总和等于all_dataset中的数据个数,这儿不是按比例划分的
train, test, valid = torch.utils.data.random_split(dataset= all_dataset, lengths=[2000, 417, 400])

# 接着按照正常方式使用DataLoader读取数据,返回的是DataLoader对象
train = torch.utils.data.DataLoader(train, batch_size=4, shuffle=True, num_workers=4)
test  = torch.utils.data.DataLoader(test,  batch_size=4, shuffle=True, num_workers=4)
valid = torch.utils.data.DataLoader(valid, batch_size=4, shuffle=True, num_workers=4)

3. 进行遍历数据

# 使用迭代器进行迭代数据进行查看,如果这儿报错:The “freeze_support()” line can be omitted if the program 
# is not going to be frozen to produce an executable
# 需要将你要运行的代码块放到main函数中运行即可
for step, (x, y) in enumerate(train):
    print(step)
    print(x.size())
    print(y.size())
    print(x)
    break

输出结果如图:
pytorch使用ImageFolder和random_split读取和划分数据集_第2张图片

4. 总结:

一开始自己是写代码实现数据的读取,划分,并封装成DataLoader,殊不知还有这么好的库函数供使用。。。
库函数处理的思路过程如图:
pytorch使用ImageFolder和random_split读取和划分数据集_第3张图片

5. 如有问题可以留言~

你可能感兴趣的:(Pytorch问题整理)