关于pytorch的数据处理-数据加载Dataset

目录

1. 数据加载

2. Dataset 

__init__

__getitem__

__len__

测试一下

完整代码

3. Dataset - ImageFolder


1. 数据加载

最近在使用 Unet 做图像分割,设计到 处理数据有关的工作,查了点资料,做一些简单的总结

在pytorch 中,数据的加载可以通过自定义的数据集对象实现,这里是Dataset 类,实现自定义的数据集需要继承Dataset,并且实现两个方法

  • __getitem__: 返回一个样本
  • __len__: 返回样本的数量

其实,之前一直都有用过Dataset类,但是都是直接调库的,所以导致现在对Dataset有点熟悉又有点陌生的感觉

之前下载CIFAR10 数据集的时候,用的都是:

  •  这里的torchvision 提供数据集
  •  torchvision 里面的dataset 就包含了各种的数据集

2. Dataset 

接下来,通过猫和狗的图像介绍Dataset ,介绍如何处理数据

首先先创建一个文件夹,里面随便上网上下载几张猫和狗的图片,放在同一个文件夹下

这里的猫狗文件名被改了,后面数字是随机输的,目的是通过 ' . ' 前面的dog和cat生成label

关于pytorch的数据处理-数据加载Dataset_第1张图片


然后提前导入下面的库文件

关于pytorch的数据处理-数据加载Dataset_第2张图片

 


__init__

接下来定义初始化方法

关于pytorch的数据处理-数据加载Dataset_第3张图片

 init 里面是初始化方法,例如传入图片的路径,或者要不要选择预处理等等

这里并不实际加载图片,只是指定路径,真正的读取图片在getitem方法里面

os.listdir : 会将data下面所以的文件读取,放在imgs里面,打印结果是上面的注释

然后self.imgs 会将imgs里面的路径和root路径 拼接在一块,输出结果如下:

['./data/cat.15454.jpg', './data/cat.445.jpg', './data/cat.46456.jpg', './data/cat.656165.jpg', './data/dog.123.jpg', './data/dog.15564.jpg', './data/dog.4545.jpg', './data/dog.456465.jpg']

imgs 里面是具体文件的路径,root里面是文件夹的路径

__getitem__

上面说过,getitem 是返回一个样本,所以说这里是将结果返回的。那么返回之前做的处理数据的操作,也在__getitem__里面。

关于pytorch的数据处理-数据加载Dataset_第4张图片

 

这里的img_path 通过self.imgs[index] 会将self.imgs里面的内容一个个读取出来

而self.imgs 里面是下图,每个数据的路径

 所以self.imgs[index]会遍历self.imgs 里面的路径,返回给img_path

打印结果: 

关于pytorch的数据处理-数据加载Dataset_第5张图片

 然后,就可以根据每个路径的id去做label了。将img_path 路径按照 '/ '分割,-1代表取最后一个字符串,如果里面有dog就为1,cat就为0,例如:下面的例子打印的就是Yes

关于pytorch的数据处理-数据加载Dataset_第6张图片

 

最后,看看有没有预处理transforms ,然后返回data和label就行了

__len__

返回样本的个数 = 图片路径的个数 

测试一下

最后的结果如下:

关于pytorch的数据处理-数据加载Dataset_第7张图片

 

所以通过上面的代码,就可以实现一个自定义自己数据集的办法,并且可以获取

完整代码

import torch
import torchvision.datasets
from torch.utils.data import Dataset        # 继承Dataset类
import os
from PIL import Image
import numpy as np
from torchvision import transforms


# 预处理
data_transform = transforms.Compose([
    transforms.Resize((224,224)),           # 缩放图像
    transforms.ToTensor(),                  # 转为Tenso
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))       # 标准化
])


class DogCat(Dataset):      # 数据处理
    def __init__(self,root,transforms = None):                  # 初始化,指定路径,是否预处理等等

        #['cat.15454.jpg', 'cat.445.jpg', 'cat.46456.jpg', 'cat.656165.jpg', 'dog.123.jpg', 'dog.15564.jpg', 'dog.4545.jpg', 'dog.456465.jpg']
        imgs = os.listdir(root)

        self.imgs = [os.path.join(root,img) for img in imgs]    # 取出root下所有的文件
        self.transforms = data_transform                        # 图像预处理

    def __getitem__(self, index):       # 读取图片
        img_path = self.imgs[index]
        label = 1 if 'dog' in img_path.split('/')[-1] else 0        #  dog -> 1,cat -> 0

        data = Image.open(img_path)

        if self.transforms:     # 图像预处理
            data = self.transforms(data)

        return data,label

    def __len__(self):
        return len(self.imgs)


dataset = DogCat('./data/',transforms=True)

for img,label in dataset:
    print('img:',img.size(),'label:',label)
'''
img: torch.Size([3, 224, 224]) label: 0
img: torch.Size([3, 224, 224]) label: 0
img: torch.Size([3, 224, 224]) label: 0
img: torch.Size([3, 224, 224]) label: 0
img: torch.Size([3, 224, 224]) label: 1
img: torch.Size([3, 224, 224]) label: 1
img: torch.Size([3, 224, 224]) label: 1
img: torch.Size([3, 224, 224]) label: 1
'''

3. Dataset - ImageFolder

ImageFolder 可以更好的将上述的猫狗打好标签

ImageFolder 假设所有的文件按文件夹保存,每个文件夹下存储同一个类别的图片,文件名为类名

例如:将上述的图片放在不同的文件夹下

文件名的大小写要一致,如首字母大写,都要大写

关于pytorch的数据处理-数据加载Dataset_第8张图片

 

 这样ImageFolder 读取的label就是按照文件名顺序排序成为字典的,也就是{类名:序号}。就是类名+对应的label

可以通过 .class_to_idx 查看

关于pytorch的数据处理-数据加载Dataset_第9张图片

 

打印结果为:

['Cat', 'Dog']



{'Cat': 0, 'Dog': 1}



Dataset ImageFolder
    Number of datapoints: 8
    Root location: ./DogCat/



[('./DogCat/Cat\\cat.15454.jpg', 0), ('./DogCat/Cat\\cat.445.jpg', 0), ('./DogCat/Cat\\cat.46456.jpg', 0), ('./DogCat/Cat\\cat.656165.jpg', 0), ('./DogCat/Dog\\dog.123.jpg', 1), ('./DogCat/Dog\\dog.15564.jpg', 1), ('./DogCat/Dog\\dog.4545.jpg', 1), ('./DogCat/Dog\\dog.456465.jpg', 1)]
 

这个就是为什么 pytorch 搭建AlexNet 对花进行分类 这里面对花分类,文件夹的顺序就是这个类别的顺序

 

最后就是:

关于pytorch的数据处理-数据加载Dataset_第10张图片

 

你可能感兴趣的:(关于PyTorch,使用的,smart,power,pytorch,深度学习,人工智能)