pytorch环境根据CSV文件加载数据集(Kaggle-test)

文章目录

  • 前言
  • 一、数据集预处理
  • 二、读取和加载步骤
    • 1.继承Dataset
    • 2.构造自己的Dataset
    • 3.load数据集
  • 总结


前言

最近在做李沐发布的Kaggle树叶分类竞赛,在处理和加载Classify-Leaves数据集的时候遇到了一些问题,真是巧妇难为无米之炊啊,现在记录下来,希望可以帮助到更多的初学者!


一、数据集预处理

数据集由一个images文件和三个csv文件组成
pytorch环境根据CSV文件加载数据集(Kaggle-test)_第1张图片
打开train.csv文件可以看到图片类别的标签是字符串,不满足深度学习训练的要求:

pytorch环境根据CSV文件加载数据集(Kaggle-test)_第2张图片
通过如下代码可将标签映射成数字并添加在label列后:

train_file = pd.read_csv('../classify-leaves/train.csv') # 读取的文件是DataFarame格式

data = train_file['label'].values # 将文件中label这一列拿出来 把格式转换为numpy格式

data_list = list(set(data)) # 去除重复项,转换为list形式

dict = {} 
for i in range(len(data_list)):
    dict[data_list[i]] = i # 构造映射字典

train_file['label_number'] = train_file['label'].map(dict) # 利用映射将字符串类别转换为数字

train_file.to_csv('../classify-leaves/train_.csv', index=False) # 将修改的文件保存到本地 不保存序号

修改后的train.csv文件:pytorch环境根据CSV文件加载数据集(Kaggle-test)_第3张图片 一般官方数据集是通过torchvison.datasets读取,然后用DataLoader函数加载,例如MNIST等,但是Classify-Leaves数据集不可以直接通过torchvison.datasets读取,因此我们需要定义自己的Dataset。

二、读取和加载步骤

1.继承Dataset

代码如下(示例):

class Dataset(object):
#这个函数就是根据索引,迭代的读取路径和标签
def __init__(self, csv_file, root_dir, transform=None):
    raise NotImplementedError

def __getitem__(self, index):
    raise NotImplementedError

#返回数据的长度
def __len__(self):
    raise NotImplementedError

2.构造自己的Dataset

代码如下:

class MyDataset(Dataset):

    def __init__(self, csv_file, root_dir, transform=None):
        """
            csv_file: 标签文件的路径.
            root_dir: 所有图片的路径.
            transform: 一系列transform操作
        """
        self.data_frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.data_frame) # 返回数据集长度

    def __getitem__(self, idx):
        img_path = os.path.join(self.root_dir, 
        						self.data_frame.iloc[idx, 0]) #获取图片所在路径
        img = Image.open(img_path).convert('RGB') # 防止有些图片是RGBA格式
        
        label_number = self.data_frame.iloc[idx, 2] # 获取图片的类别标签
        
        if self.transform:
            img = self.transform(img)

        return img, label_number # 返回图片和标签

定义好Dataset后,我们就可以进行下一步了

3.load数据集

#读取数据集
train_dataset = MyDataset(csv_file='../classify-leaves/train_.csv',
                          root_dir='../classify-leaves',
                          transform=torchvision.transforms.ToTensor())
train_iter = DataLoader(train_dataset, batch_size=128) # 加载数据集
for X, y in train_iter: # 迭代batchsize中的数据
    print(X.shape) # torch.Size([128, 3, 224, 224])
    print(y.shape) # torch.Size([128])
    break

总结

以上就是今天要讲的内容,本文仅仅简单介绍了针对特定的数据集如何构造合适的Dataset,以方便通过DataLoader函数来加载。

你可能感兴趣的:(python,pytorch)