由于是3D数据,所以内置的transform没有用到。
这里的图像数据和标签数据都是nii数据,并且已经调整好了窗宽窗位和归一化。
如果你的数据没有上述操作,还需要另外加入代码处理。
import os
from torch.utils.data import Dataset
from torchvision.transforms import (ToTensor,RandomHorizontalFlip,
RandomVerticalFlip,RandomRotation,
Compose,)
import SimpleITK as sitk
import numpy as np
class DataSetNii(Dataset):
def __init__(self,work_dir,num_classes,transform=None,target_transform=None):
self.work_dir = work_dir
self.num_classes = num_classes
self.transform = transform
self.target_transform = target_transform
self.image_dir = os.path.join(work_dir,"JPEGImages")
self.label_dir = os.path.join(work_dir,"Segmentations")
self.filenames = os.listdir(self.image_dir)
def __len__(self):
return len(self.filenames)
def __getitem__(self, idx):
image_path = os.path.join(self.image_dir,self.filenames[idx])
label_path = os.path.join(self.label_dir,self.filenames[idx])
# print(image_path,'\n',label_path)
imageNii = sitk.ReadImage(image_path)
image = sitk.GetArrayFromImage(imageNii)
labelNii = sitk.ReadImage(label_path)
label = sitk.GetArrayFromImage(labelNii)
# image增加一个维度-> (1,32,128,128)
images = np.expand_dims(image,0).astype('float32')
# label进行one-hot(2,32,128,128)
label_shape = label.shape
labels = np.zeros(shape=[self.num_classes]+list(label_shape),dtype='float32')
for i in range(self.num_classes):
tmp = np.zeros_like(label)
tmp[label==i] = 1
labels[i,:,:,:] = tmp
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return images,labels
# if __name__ == "__main__":
# train_dir = "/home/data/hablee_data_dir/new_dir/train_nii_002"
# dataset = DataSetNii(train_dir,2)
# image,label = dataset[1]
# print(image.shape,label.shape)
# imagetm = sitk.GetImageFromArray(image.squeeze())
# label1 = sitk.GetImageFromArray(label[0,:,:,:].squeeze())
# label2 = sitk.GetImageFromArray(label[1,:,:,:].squeeze())
# sitk.WriteImage(imagetm,"image.nii.gz")
# sitk.WriteImage(label1,"label1.nii.gz")
# sitk.WriteImage(label2,"label2.nii.gz")
题外话:pytorch和tensorflow就这一点好,大家都可以输入numpy数组训练,不用转成自家的tensor,不然还得转一遍。