目录
1. 介绍
2. 数据处理 dataset
2.1 预处理
2.2 加载数据
2.2.1 初始化
2.2.2 返回数据
2.2.3 样本数量
3. 测试一下
4. 完整代码
之前介绍完了Unet网络的搭建,接下来说一下要解决的任务。
本章介绍的是:数据的加载处理
下面是整个项目:
本项目参考这篇文章:UNet模型训练,深度解析! ,网络做了一些优化和更改,整个项目完成会上传到CSDN,数据可以在链接里面获取
因为data数据只有30张,并且没有test集,所以这里手工分类了一下。将对应的image和label取出来放到test里面即可,这里21张用于train,9张用于test
样本图片:
对应label:
有关内容可以参考:关于pytorch的数据处理-数据加载Dataset
因为UNet 网络,我们希望的输入是480*480的灰度图,所以预处理的时候要改变一个size
图像本身就是灰度图,所以这里不需要转换
最后要将图像转为Tensor
这里没有用数据增强:翻转、随即裁剪等等。因为这里不确定随机的翻转对image和label是否是一致的。
这里可以通过设置字典,对image,进行normalization
观察下目录结构,后面用得到
这里如果定义加载类的话,需要继承 from torch.utils.data import Dataset 里面的Dataset
初始化init 方法里面实现的是初始化相关的操作,例如指定文件的路径和预处理等等
这里root指定要处理数据的目录,这里指定的是train里面的image
想要获得image下具体图片的路径就要将root + imgs ,也就是self.imgs
getitem 是返回一个样本,那么既然这个方法返回的就是我们需要的每个样本,那么读取每个图像,甚至对图像操作都应该在getitem里面
首先,self.imgs 是个列表,里面存放的是整个训练图片的路径。根据index索引获取每个图片,
因为train和test里面的图像和标签都是相同的文件名,观察每个图片的路径,只需要将train替换成label就可以获取图像对应的标签图像了
通过上面的open获取每个对应的图片和图片的label
这里就是简单的预处理
需要注意的是,因为这里的label不是二值图片,所以需要转换一下。因为预处理的ToTensor会将像素 / 255 变成0-1之间,所以这里将大于等于0.5的设置为1,小于0.5的设置为0
最后返回image和label就行了
image:
label:
code:
import os
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
transform = transforms.Compose([
transforms.Resize((480,480)), # 缩放图像
transforms.ToTensor(), # 转为Tensor
])
# 数据处理文件
class Data_Loader(Dataset): # 加载数据
def __init__(self, root, transforms = transform): # 指定路径、预处理等等
imgs = os.listdir(root) # 获取root文件下的文件
self.imgs = [os.path.join(root,img) for img in imgs] # 获取每个文件的路径
self.transforms = transforms # 预处理
def __getitem__(self, index): # 读取图片,返回一条样本
image_path = self.imgs[index] # 根据index读取图片
label_path = image_path.replace('image', 'label') # 把路径中的image替换成label,就找到对应数据的label
image = Image.open(image_path) # 读取图片和对应的label图
label = Image.open(label_path)
if self.transforms: # 判断是否预处理
image = self.transforms(image)
label = self.transforms(label)
label[label>=0.5] = 1 # 这里转为二值图片
label[label< 0.5] = 0
return image, label
def __len__(self): # 返回样本的数量
return len(self.imgs)
# if __name__ == "__main__":
#
# dataset = Data_Loader("./data/test/image") # 加载数据
#
# for image,label in dataset:
# print(image)
# print('image size:',image.size()) # image size: torch.Size([1, 480, 480])
# print(label)
# print('label size:',label.size()) # label size: torch.Size([1, 480, 480])
# break