目录
1. 数据加载
2. Dataset
__init__
__getitem__
__len__
测试一下
完整代码
3. Dataset - ImageFolder
最近在使用 Unet 做图像分割,设计到 处理数据有关的工作,查了点资料,做一些简单的总结
在pytorch 中,数据的加载可以通过自定义的数据集对象实现,这里是Dataset 类,实现自定义的数据集需要继承Dataset,并且实现两个方法
其实,之前一直都有用过Dataset类,但是都是直接调库的,所以导致现在对Dataset有点熟悉又有点陌生的感觉
之前下载CIFAR10 数据集的时候,用的都是:
- 这里的torchvision 提供数据集
- torchvision 里面的dataset 就包含了各种的数据集
接下来,通过猫和狗的图像介绍Dataset ,介绍如何处理数据
首先先创建一个文件夹,里面随便上网上下载几张猫和狗的图片,放在同一个文件夹下
这里的猫狗文件名被改了,后面数字是随机输的,目的是通过 ' . ' 前面的dog和cat生成label
然后提前导入下面的库文件
接下来定义初始化方法
init 里面是初始化方法,例如传入图片的路径,或者要不要选择预处理等等
这里并不实际加载图片,只是指定路径,真正的读取图片在getitem方法里面
os.listdir : 会将data下面所以的文件读取,放在imgs里面,打印结果是上面的注释
然后self.imgs 会将imgs里面的路径和root路径 拼接在一块,输出结果如下:
['./data/cat.15454.jpg', './data/cat.445.jpg', './data/cat.46456.jpg', './data/cat.656165.jpg', './data/dog.123.jpg', './data/dog.15564.jpg', './data/dog.4545.jpg', './data/dog.456465.jpg']
imgs 里面是具体文件的路径,root里面是文件夹的路径
上面说过,getitem 是返回一个样本,所以说这里是将结果返回的。那么返回之前做的处理数据的操作,也在__getitem__里面。
这里的img_path 通过self.imgs[index] 会将self.imgs里面的内容一个个读取出来
而self.imgs 里面是下图,每个数据的路径
所以self.imgs[index]会遍历self.imgs 里面的路径,返回给img_path
然后,就可以根据每个路径的id去做label了。将img_path 路径按照 '/ '分割,-1代表取最后一个字符串,如果里面有dog就为1,cat就为0,例如:下面的例子打印的就是Yes
最后,看看有没有预处理transforms ,然后返回data和label就行了
返回样本的个数 = 图片路径的个数
最后的结果如下:
所以通过上面的代码,就可以实现一个自定义自己数据集的办法,并且可以获取
import torch
import torchvision.datasets
from torch.utils.data import Dataset # 继承Dataset类
import os
from PIL import Image
import numpy as np
from torchvision import transforms
# 预处理
data_transform = transforms.Compose([
transforms.Resize((224,224)), # 缩放图像
transforms.ToTensor(), # 转为Tenso
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)) # 标准化
])
class DogCat(Dataset): # 数据处理
def __init__(self,root,transforms = None): # 初始化,指定路径,是否预处理等等
#['cat.15454.jpg', 'cat.445.jpg', 'cat.46456.jpg', 'cat.656165.jpg', 'dog.123.jpg', 'dog.15564.jpg', 'dog.4545.jpg', 'dog.456465.jpg']
imgs = os.listdir(root)
self.imgs = [os.path.join(root,img) for img in imgs] # 取出root下所有的文件
self.transforms = data_transform # 图像预处理
def __getitem__(self, index): # 读取图片
img_path = self.imgs[index]
label = 1 if 'dog' in img_path.split('/')[-1] else 0 # dog -> 1,cat -> 0
data = Image.open(img_path)
if self.transforms: # 图像预处理
data = self.transforms(data)
return data,label
def __len__(self):
return len(self.imgs)
dataset = DogCat('./data/',transforms=True)
for img,label in dataset:
print('img:',img.size(),'label:',label)
'''
img: torch.Size([3, 224, 224]) label: 0
img: torch.Size([3, 224, 224]) label: 0
img: torch.Size([3, 224, 224]) label: 0
img: torch.Size([3, 224, 224]) label: 0
img: torch.Size([3, 224, 224]) label: 1
img: torch.Size([3, 224, 224]) label: 1
img: torch.Size([3, 224, 224]) label: 1
img: torch.Size([3, 224, 224]) label: 1
'''
ImageFolder 可以更好的将上述的猫狗打好标签
ImageFolder 假设所有的文件按文件夹保存,每个文件夹下存储同一个类别的图片,文件名为类名
例如:将上述的图片放在不同的文件夹下
文件名的大小写要一致,如首字母大写,都要大写
这样ImageFolder 读取的label就是按照文件名顺序排序成为字典的,也就是{类名:序号}。就是类名+对应的label
可以通过 .class_to_idx 查看
打印结果为:
['Cat', 'Dog']
{'Cat': 0, 'Dog': 1}
Dataset ImageFolder
Number of datapoints: 8
Root location: ./DogCat/
[('./DogCat/Cat\\cat.15454.jpg', 0), ('./DogCat/Cat\\cat.445.jpg', 0), ('./DogCat/Cat\\cat.46456.jpg', 0), ('./DogCat/Cat\\cat.656165.jpg', 0), ('./DogCat/Dog\\dog.123.jpg', 1), ('./DogCat/Dog\\dog.15564.jpg', 1), ('./DogCat/Dog\\dog.4545.jpg', 1), ('./DogCat/Dog\\dog.456465.jpg', 1)]
这个就是为什么 pytorch 搭建AlexNet 对花进行分类 这里面对花分类,文件夹的顺序就是这个类别的顺序
最后就是: