重写(覆写)ImageFolder,直接在一个文件夹下划分验证集与训练集

当现有的用于分类图像数据只有如下的tree路径:

./车辆分类数据集

        Lcar

                L*.jpg

        Lbus

                L*.jpg

        Ltruck

                L*.jpg

如果这是你全部的数据,并且在使用dataloader前你需要根据根目录(即:./车辆分类数据集)来创建训练用的ImageFolder以及验证用的ImageFolder(或者增加了测试用的ImageFolder)。此时,你不需要手动、或者用OS之类命令的代码来划分数据集,只需要重写(覆写)ImageFolder即可。

直接放代码。

# -*- coding: utf-8 -*-
"""
@Time : 2022/8/14 10:09
@Auth : Fanteng Meng
@File :imgfolder.py
@IDE :PyCharm
@Github : https://github.com/FT115

"""
import torch
import random
import torchvision.transforms as transforms
import torchvision.datasets as datasets

normalize = transforms.Normalize(mean=[.5, .5, .5],
                                 std=[.5, .5, .5])

train_transform = transforms.Compose([])
train_transform.transforms.append(transforms.Resize((224, 224)))
train_transform.transforms.append(transforms.ToTensor())
train_transform.transforms.append(transforms.RandomHorizontalFlip(p=0.8))
train_transform.transforms.append(normalize)

val_transform = transforms.Compose([])
val_transform.transforms.append(transforms.Resize((224, 224)))
val_transform.transforms.append(transforms.ToTensor())
val_transform.transforms.append(normalize)

class CustomImageFolder(datasets.ImageFolder):
    def __init__(self, root, transform, mode, train_ratio):
        super(CustomImageFolder, self).__init__(root, transform)
        assert mode in ['train', 'val']
        random.seed(0)
        random.shuffle(self.samples)
        if mode == 'train':
            self.samples = self.samples[:int(train_ratio*len(self))]
            self.targets = [s[1] for s in self.samples]
            self.imgs = self.samples
        elif mode == 'val':
            self.samples = self.samples[int(train_ratio*len(self)):]
            self.targets = [s[1] for s in self.samples]
            self.imgs = self.samples

    def __getitem__(self, index):
        path, target = self.samples[index]
        sample = self.loader(path)
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)
        return sample, target

    def __len__(self) -> int:
        return len(self.samples)


batch_size = 6

data_path = './实验三数据集/车辆分类数据集'

train_dataset = CustomImageFolder(data_path, train_transform, 'train', 0.7)
val_dataset = CustomImageFolder(data_path, val_transform, 'val', 0.7)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

下面进行分开解释代码中的部分功能。

1.对于所有数据要打乱标签

这是因为原始的self.samples是按照类别的所有图片顺序排列,不打乱标签会使得划分的前百分之的数据类别严重不均衡,打乱标签会变得让数据更均衡。增加种子随机,可以让每次随机打乱的结果始终一致。对应代码中的这一部分:

random.seed(0)
random.shuffle(self.samples)

2.添加了mode、train_ratio参数

用于设置时训练集,还是验证集(如果要增加测试集,可自行添加)。以百分比率为train_ratio划分训练与验证集,前train_ratio作为训练集。只需对打乱的samples进行取相应量的数据即可,self.target按照ImageFolder封装好的官方实现方式重新写即可,即每个samples的索引1的值。self.imgs与self.samples相等,这是封装内部的官方实现,这里直接写就好。对应代码在这:

        if mode == 'train':
            self.samples = self.samples[:int(train_ratio*len(self))]
            self.targets = [s[1] for s in self.samples]
            self.imgs = self.samples
        elif mode == 'val':
            self.samples = self.samples[int(train_ratio*len(self)):]
            self.targets = [s[1] for s in self.samples]
            self.imgs = self.samples

def __getitem__(self, index)与def __len__(self) 保持不变,直接复制过来。因为不变,这里也可以不写。

记录自己学到东西的同时,希望对你有所帮助~

你可能感兴趣的:(开发语言,深度学习,python)