pytorch 语义分割-医学图像-脑肿瘤数据集的载入模块

由于最近目标是完成基于深度学习的脑肿瘤语义分割实验,所以需要用到自定义的数据载入,本文参考了一下博客:https://blog.csdn.net/tuiqdymy/article/details/84779716?utm_source=app,一开始是做的眼底图像分割,由于使用的是DRIVE数据集,所以数据量很少,之前也是按照上面这篇博客标注了关于图片id的txt文件,但是这次是应用在kaggle脑肿瘤数据集上,kaggle脑肿瘤数据集百度云下载连接:链接:https://pan.baidu.com/s/12RTIv-RqEZwYCm27Im2Djw  提取码:tave  数据量挺大,再完全按照上面博客的方法实现对数据的载入显然不现实,所以就自己稍加修改,记录一下自己的学习过程。

首先我们可以看一下数据存储的结构叭:

pytorch 语义分割-医学图像-脑肿瘤数据集的载入模块_第1张图片

 pytorch 语义分割-医学图像-脑肿瘤数据集的载入模块_第2张图片

上面的每一个文件夹代表着每一个病人的脑肿瘤数据,每一个文件夹下的图片数据也大约有20几张数据(包含着原始图片和金标准图片),每个文件夹中的图片数据不一定相等。


下面就来看看数据加载模块的代码叭。

import torch
import os, glob
import random, csv
import matplotlib.pylab as plt
import torchvision
from PIL import Image
from  torch.utils.data import Dataset
import numpy as np


class  driveDateset(Dataset):
    def __init__(self, root, ignore_label=255):
        super(driveDateset,self).__init__()


        self.root = root

        self.files = []

        for file in os.listdir(self.root):
            fil = os.path.join(self.root, file)
            for data1 in os.listdir(fil):
                data1_split = data1.split('.')
                data11_split = data1_split[0].split('_')
                for data2 in os.listdir(fil):
                    data2_split = data2.split('.')
                    data22_split = data2_split[0].split('_')
                    if (data11_split[-1]==data22_split[-2]) & (data22_split[-1]=='mask'):
                        img_file = os.path.join(fil,data1)
                        label_file = os.path.join(fil,data2)
                        self.files.append({
                            "img":img_file,
                            "label":label_file,
                            "name":data1_split[0]
                        })

    #返回数据集大小
    def __len__(self):
        return len(self.files)

    #实现数据的下标索引
    def  __getitem__(self, index):
        dataflies = self.files[index]

        '''load the data '''
        name = dataflies["name"]
        image = Image.open(dataflies["img"]).convert('RGB')
        label = Image.open(dataflies["label"]).convert('L')
        size_origin = image.size # w*h

        I = np.asarray(image, np.float32)

        I = I.transpose((2, 0, 1))  # transpose the  H*W*C to C*H*W
        L = np.asarray(np.array(label), np.int64)
        # print(I.shape,L.shape)
        return I.copy(), L.copy(), np.array(size_origin), name

然后我们可以debug一下,看看代码中的各个参数的具体赋值情况,便于理解再次修改应用到自己所需数据集中。

pytorch 语义分割-医学图像-脑肿瘤数据集的载入模块_第3张图片

我们来看看测试的运行效果图:

pytorch 语义分割-医学图像-脑肿瘤数据集的载入模块_第4张图片

上面是原始图片,下面四张图片是对应的金标准图片。

pytorch 语义分割-医学图像-脑肿瘤数据集的载入模块_第5张图片

pytorch 语义分割-医学图像-脑肿瘤数据集的载入模块_第6张图片

pytorch 语义分割-医学图像-脑肿瘤数据集的载入模块_第7张图片

pytorch 语义分割-医学图像-脑肿瘤数据集的载入模块_第8张图片

到这里就测试完毕。下面是测试部分代码:

if __name__ == '__main__':
    DATA_DIRECTORY = "F:\\experiment_code\\U-net_brain\\kaggle_3m\\train"
    Batch_size = 4
    MEAN = (104.008, 116.669, 122.675)
    dst = driveDateset(DATA_DIRECTORY)
    # just for test,  so the mean is (0,0,0) to show the original images.
    # But when we are training a model, the mean should have another value
    trainloader = torch.utils.data.DataLoader(dst, batch_size=Batch_size)
    plt.ion()
    for i, data in enumerate(trainloader):
        imgs, labels, _, _ = data
        if i % 1 == 0:
            img = torchvision.utils.make_grid(imgs).numpy()
            img = img.astype(np.uint8)  # change the dtype from float32 to uint8,
                    # because the plt.imshow() need the uint8
            img = np.transpose(img, (1, 2, 0))  # transpose the C*H*W to H*W*C
            plt.imshow(img)
            plt.show()
            plt.pause(0.5)

                #label = torchvision.utils.make_grid(labels).numpy()
            labels = labels.numpy().astype(np.uint8)  # change the dtype from float32 to uint8,
            #                                       # because the plt.imshow() need the uint8
            for i in range(labels.shape[0]):
                plt.imshow(labels[i], cmap='gray')
                plt.show()
                plt.pause(0.5)

新手上车深度学习,写的不太完美,但是希望这篇博客对大家有用鸭。

你可能感兴趣的:(深度学习,深度学习)