pytorch实现自定义医学3d图像数据集dataset

由于是3D数据,所以内置的transform没有用到。
这里的图像数据和标签数据都是nii数据,并且已经调整好了窗宽窗位和归一化。
如果你的数据没有上述操作,还需要另外加入代码处理。

import os
from torch.utils.data import Dataset
from torchvision.transforms import (ToTensor,RandomHorizontalFlip,
                            RandomVerticalFlip,RandomRotation,
                            Compose,)
import SimpleITK as sitk
import numpy as np


class DataSetNii(Dataset):
    def __init__(self,work_dir,num_classes,transform=None,target_transform=None):
        self.work_dir = work_dir
        self.num_classes = num_classes
        self.transform = transform
        self.target_transform = target_transform

        self.image_dir = os.path.join(work_dir,"JPEGImages")
        self.label_dir = os.path.join(work_dir,"Segmentations")

        self.filenames = os.listdir(self.image_dir)
        

    
    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        image_path = os.path.join(self.image_dir,self.filenames[idx])
        label_path = os.path.join(self.label_dir,self.filenames[idx])
        # print(image_path,'\n',label_path)

        imageNii = sitk.ReadImage(image_path)
        image = sitk.GetArrayFromImage(imageNii)
        labelNii = sitk.ReadImage(label_path)
        label = sitk.GetArrayFromImage(labelNii)

        # image增加一个维度-> (1,32,128,128)
        images = np.expand_dims(image,0).astype('float32')
        
        # label进行one-hot(2,32,128,128)
        label_shape = label.shape
        labels = np.zeros(shape=[self.num_classes]+list(label_shape),dtype='float32')
        for i in range(self.num_classes):
            tmp = np.zeros_like(label)
            tmp[label==i] = 1
            labels[i,:,:,:] = tmp

        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        
        return images,labels


# if __name__ == "__main__":
#     train_dir = "/home/data/hablee_data_dir/new_dir/train_nii_002"

#     dataset = DataSetNii(train_dir,2)

#     image,label = dataset[1]
#     print(image.shape,label.shape)

#     imagetm = sitk.GetImageFromArray(image.squeeze())
#     label1 = sitk.GetImageFromArray(label[0,:,:,:].squeeze())
#     label2 = sitk.GetImageFromArray(label[1,:,:,:].squeeze())
#     sitk.WriteImage(imagetm,"image.nii.gz")
#     sitk.WriteImage(label1,"label1.nii.gz")
#     sitk.WriteImage(label2,"label2.nii.gz")

题外话:pytorch和tensorflow就这一点好,大家都可以输入numpy数组训练,不用转成自家的tensor,不然还得转一遍。

你可能感兴趣的:(pytorch,医学图像处理,pytorch,3d,python)