深度学习模型的数据读取(python+torch.utils.data.Dataset)

在用python编写CNN、UNet、VGG等模型时,需要先对数据集进行处理,即先读取image和label并进行预处理(比如设置图像大小等)。

在这之前已经将BraTS2018的nii格式文件保存为png格式:

https://blog.csdn.net/weixin_43330946/article/details/89576759

下面的代码是我最近写UNet用于分割BraTS2018数据集的代码。但是一般情况下,数据预处理的封装都差不多。


torch.utils.data.Dataset主要包括三个部分:

https://blog.csdn.net/weixin_43330946/article/details/89598204

1、__init__()

主要用于参数初始化,个人理解就是参数定义。根据自己网络所需要的参数进行定义。

2、__getitem__()

获取数据集的序列。数据的预处理就在这个部分编写。

3、__len()__

返回数据集的长度。一般都是固定的,没什么影响。

import numpy as np
import torch
import torch.utils.data
import os
import torchvision
from PIL import Image
from torchvision import Variable

class med(torch.utils.data.Dataset):
    def __init__(self,img_dir,anno_dir,transform=None):
        img_ilist = []    #图像的列表
        img_alist = []    #label的列表

        self.img_dir = img_dir
        self.anno_dir = anno_dir
        
        for subdir in os.listdir(img_dir):    #遍历文件夹中的所有文件及文件夹
            for file in os.listdir(os.path.join(img_dir,subdir)):    #os.path.join是添加路径。BraTS2018的数据集的子目录也是文件夹,所以需要两个for循环。
                img_ilist.append(os.path.join(img_dir,file))
                img_alist.append(os.path.join(anno_dir,file))
        self.img_ilist = img_ilist
        self.img_alist = img_alist
        self.transform = transform

    def __getitem__(self,index):
        img_ilist = self.img_ilist
        img_alist = self.img_alist
        
        img_ipath = self.img_ilist[index]    #获取img_ilist的序列
        imgi = Image.open(img_ipath).convert('L')
        #imgi = no.load(img_ipath)    #如果数据集保存成了npy格式,就用这一句
        imgi = imgi.resize((240,240))    #根据自己的需要改变图像大小
        
        img_apath = self.img_alist[index]    #获取img_alist的序列
        imga = Image.open(img_apath).convert('L')
        #imga = no.load(img_apath)    #如果数据集保存成了npy格式,就用这一句
        imga = imga.resize((240,240))    #根据自己的需要改变图像大小

        if self.transform is not None:
            imgi = self.transform(imgi)
            imga = self.transform(imga)
            imagei = imgi#.type(torch.FloatTensor)    #井号后面的内容只针对npy数据集
            imagea = imga#.type(torch.FloatTensor)    #井号后面的内容只针对npy数据集
        return imagei,imagea

    def __len__(self):
        return len(self.img_ilist)

 

你可能感兴趣的:(深度学习之代码之UNet)