医学分割学习记录-AMOS22数据预处理(nnUnet)

方向针对血管介入手术,故只保留了主动脉标签.

nnUnet对数据的预处理包括cropping,resample,normalization三个步骤.

Github源码:GitHub - MIC-DKFZ/nnUNet

但是Github项目的代码比较复杂,知乎有一篇文章写的更清晰:

如何针对三维医学图像分割任务进行通用数据预处理:nnUNet中预处理流程总结及代码分析 - 知乎

主要参考了这两个完成了几个部分的处理.

file_list.py: 保留label中的主动脉,生成train_list.txt以及val_list.txt,创建训练集以及验证集的索引.

import os
import SimpleITK as sitk
import numpy as np
from os.path import join
import random
import config


def process(ct_path, seg_path):
    ct = sitk.ReadImage(ct_path)
    ct_array = sitk.GetArrayFromImage(ct)
    seg = sitk.ReadImage(seg_path)
    seg_array = sitk.GetArrayFromImage(seg)
    print("Ori shape:", ct_array.shape, seg_array.shape)
    # 将金标准中肝脏和肝肿瘤的标签融合为一个
    seg_array[seg_array > 8] = 0
    seg_array[seg_array < 8] = 0
    new_ct = sitk.GetImageFromArray(ct_array)
    new_ct.SetOrigin(ct.GetOrigin())
    new_ct.SetSpacing(ct.GetSpacing())
    new_ct.SetDirection(ct.GetDirection())

    new_seg = sitk.GetImageFromArray(seg_array)
    new_seg.SetOrigin(seg.GetOrigin())
    new_seg.SetSpacing(seg.GetSpacing())
    new_seg.SetDirection(seg.GetDirection())
    sitk.WriteImage(new_ct, ct_path.replace('data', 'fixed_data'))
    sitk.WriteImage(new_seg, seg_path.replace('data', 'fixed_data'))


class file_list:
    def __init__(self,dataset_path,fixed_path,args):
      self.dataset_path=dataset_path #data_path='./data/'
      self.fixed_path=fixed_path #fixed_path='./fixed_data/'
      self.valid_rate=args.valid_rate

    def fix_data(self):
        if not os.path.exists(self.fixed_path):
            os.makedirs(join(self.fixed_path,'ct'))
            os.makedirs(join(self.fixed_path,'label'))
        file_list=os.listdir(join(self.dataset_path,'ct'))#原数据集 file_list='./data/ct'
        Numbers=len(file_list)
        print('total number of samples is:',Numbers)
        for ct_file,i in zip(file_list,range(Numbers)):
            print("==== {} | {}/{} ====".format(ct_file, i + 1, Numbers))
            ct_path = os.path.join(self.dataset_path, 'ct', ct_file)
            seg_path = os.path.join(self.dataset_path, 'label', ct_file)
            process(ct_path, seg_path)



    def write_train_val_name_list(self):
        data_name_list = os.listdir(join(self.fixed_path, "ct"))
        data_num = len(data_name_list)
        print('the fixed dataset total numbers of samples is :', data_num)
        random.shuffle(data_name_list)

        assert self.valid_rate < 1.0
        train_name_list = data_name_list[0:int(data_num*(1-self.valid_rate))]
        val_name_list = data_name_list[int(data_num*(1-self.valid_rate)):int(data_num*((1-self.valid_rate) + self.valid_rate))]

        self.write_name_list(train_name_list, "train_path_list.txt")
        self.write_name_list(val_name_list, "val_path_list.txt")


    def write_name_list(self, name_list, file_name):
        f = open(join(self.fixed_path, file_name), 'w')
        for name in name_list:
            ct_path = os.path.join(self.fixed_path, 'ct', name)
            seg_path = os.path.join(self.fixed_path, 'label', name)
            f.write(ct_path + ' ' + seg_path + "\n")
        f.close()



if __name__ == '__main__':
    data_path='./data/'
    fixed_path='./amos_data/'
    args=config.args
    tool=file_list(data_path,fixed_path,args)
    # tool.fix_data()
    tool.write_train_val_name_list()
    # data='./data/ct/amos_0001.nii.gz'
    # label='./fixed_data/label/amos_0001.nii.gz'
    # ct='./data/ct/amos_0001.nii.gz'
    # data_array=sitk.ReadImage(data,sitk.sitkInt16)
    # spacing=np.array(data_array.GetSpacing())
    # data_array=sitk.GetArrayFromImage(data_array)
    # print(data_array.shape)
    # print(spacing)












 cropping.py:通过label产生mask,再通过mask对ct和label进行crop.本数据集很多数据cropping没有效果,有效果的也不会抛去很多size,也许是我的代码有问题.

import SimpleITK as sitk
import numpy as np
import matplotlib.pyplot as plt
import shutil
# from batchgenerators.utilities.file_and_folder_operations import *
from multiprocessing import Pool
from collections import OrderedDict
from os import path
import os


def create_nonzero_mask(data):  #nonezero mask
    from scipy.ndimage import binary_fill_holes
    print(data.shape)
    assert len(data.shape) == 4 or len(data.shape) == 3, "data must have shape (C, X, Y, Z) or shape (C, X, Y)"
    nonzero_mask = np.zeros(data.shape[0:], dtype=bool)
    for c in range(data.shape[0]):
        this_mask = data[c] != 0
        nonzero_mask = nonzero_mask | this_mask
    nonzero_mask = binary_fill_holes(nonzero_mask)
    return nonzero_mask

def get_bbox_from_mask(mask, outside_value=0): #mask_boudingbox352
    mask_voxel_coords = np.where(mask != outside_value)
    minzidx = int(np.min(mask_voxel_coords[0]))
    maxzidx = int(np.max(mask_voxel_coords[0])) + 1
    minxidx = int(np.min(mask_voxel_coords[1]))
    maxxidx = int(np.max(mask_voxel_coords[1])) + 1
    minyidx = int(np.min(mask_voxel_coords[2]))
    maxyidx = int(np.max(mask_voxel_coords[2])) + 1
    return [[minzidx, maxzidx], [minxidx, maxxidx], [minyidx, maxyidx]]

def crop_to_bbox(image, bbox):
    assert len(image.shape) == 3, "only supports 3d images"
    resizer = (slice(bbox[0][0], bbox[0][1]), slice(bbox[1][0], bbox[1][1]), slice(bbox[2][0], bbox[2][1]))
    return image[resizer]

def crop_to_nonzero(data, seg, nonzero_label=-1):
    """
    :param data:
    :param seg:
    :param nonzero_label: this will be written into the segmentation map
    :return:
    """
    nonzero_mask = create_nonzero_mask(data)
    bbox = get_bbox_from_mask(nonzero_mask, 0)
    mask=crop_to_bbox(nonzero_mask,bbox)
    data = crop_to_bbox(data, bbox)
    seg=crop_to_bbox(seg,bbox)
    seg[(seg == 0) & (mask == 0)] = nonzero_label
    return data, seg, bbox

def cropping(file_ct,file_seg):
    file_c=os.listdir(file_ct)
    n=0
    for f in file_c:
        data_c=path.join(file_ct,f)
        data_s=path.join(file_seg,f)
        ct=sitk.ReadImage(data_c,sitk.sitkUInt8)
        #原始参数
        ori=ct.GetOrigin()
        spa=ct.GetSpacing()
        dir=ct.GetDirection()

        ct=sitk.GetArrayFromImage(ct)
        # ct=np.expand_dims(ct,axis=0)
        seg=sitk.ReadImage(data_s,sitk.sitkInt16)
        seg_ori=seg.GetOrigin()
        seg_spa=seg.GetSpacing()
        seg_dir=seg.GetDirection()
        seg=sitk.GetArrayFromImage(seg)
        # seg=np.expand_dims(seg,axis=0)
        fixed_data_c,fixed_data_s,bbox=crop_to_nonzero(ct,seg,nonzero_label=-1)
        fixed_ct=sitk.GetImageFromArray(fixed_data_c)
        fixed_ct.SetDirection(dir)
        fixed_ct.SetSpacing(spa)
        fixed_ct.SetOrigin(ori)
        path_new=path.join('./crop_data/ct/',f)
        sitk.WriteImage(fixed_ct,path_new)
        fixed_s=sitk.GetImageFromArray(fixed_data_s)
        fixed_s.SetDirection(seg_dir)
        fixed_s.SetSpacing(seg_spa)
        fixed_s.SetOrigin(seg_ori)
        path_seg=path_new.replace('ct','label')
        sitk.WriteImage(fixed_s,path_seg)
        print('n=',n,'/199')
        print('name:',f)
        n=n+1



if __name__ == '__main__':
    data='./fixed_data/ct/'
    seg='./fixed_data/label/'
    cropping(data,seg)
    # file_c=os.listdir(data)
    # n=0
    # for f in file_c:
    #     data_c=path.join(data,f)
    #     ct=sitk.ReadImage(data_c,sitk.sitkUInt8)
    #     ct=sitk.GetArrayFromImage(ct)
    #     print(ct.shape)
    # print('spacing=', np.array(data.GetSpacing()))
    # print(np.min(data))
    # # data=np.expand_dims(data,axis=0)
    # seg=sitk.ReadImage(seg,sitk.sitkInt16)
    # seg=sitk.GetArrayFromImage(seg)
    # nzmask=create_nonzero_mask(data)
    # bbox=get_bbox_from_mask(nzmask,0)
    # print(bbox)
    # data_c=crop_to_bbox(data,get_bbox_from_mask(nzmask,0))
    # data_c=get_bbox_from_mask(nzmask,0)
    # print(data_c)
    # data_c,seg_c,bbox_c=crop_to_nonzero(data,seg,nonzero_label=-1)

    # data_c=crop_to_bbox(data,get_bbox_from_mask(nzmask))
    # plt.imshow(nzmask)
    # plt.show()


 resampling and norm.py:Resampling将数据集的Spacing统一,Spacing的计算分为各向异性数据和非各项异性数据,该部分内容在知乎文章写的比较清晰.

import SimpleITK as sitk
import numpy as np
import matplotlib.pyplot as plt
import shutil
# from batchgenerators.utilities.file_and_folder_operations import *
from multiprocessing import Pool
from collections import OrderedDict
import os
from os import path
from skimage.transform import resize
def spacing(file):
    # 定义一个函数
    # 遍历当前路径下所有文件
    flag=1
    # Sp=np.zeros((200,3))
    # n=0
    # file_r = os.listdir(file)
    # for f in file_r:
    #     # 字符串拼接
    #     real_url = path.join(file, f)
    #     # 打印出来
    #     data = sitk.ReadImage(real_url)
    #     S=np.array(data.GetSpacing())
    #     print('n=',n,'S=',S)
    #     Sp[n,:]=S
    #     n=n+1
    # np.savetxt('spacing.txt',Sp)
    Sp=np.loadtxt('spacing.txt')
    Sp_m=np.percentile(Sp,50,axis=0,keepdims=True)
    if np.max(Sp_m)>3*np.min(Sp_m):
        flag=1
        print('各向异性数据集')
        Sp_target1=np.percentile(Sp[1],50,axis=0,keepdims=True)
        Sp_target2=np.percentile(Sp[3],10,axis=0,keepdims=True)
        Sp_target=np.array([Sp_target1,Sp_target1,Sp_target2]).transpose(1,0)

    else:
        flag=0
        print('非各向异性数据集')
        Sp_target=np.percentile(Sp,50,axis=0,keepdims=True)
    print('Sp_target=',Sp_target)
    return Sp_target,flag
    # 调用自定义函数

def resample_image(itk_image, out_spacing,seg,flag):
    assert flag==1, "非各项异性"
    out_spacing=out_spacing.squeeze(0)
    original_spacing = itk_image.GetSpacing()
    original_size = itk_image.GetSize()
    out_spacing1=np.array([out_spacing[0],out_spacing[1],original_spacing[2]])
    out_size1 = [
        int(np.round(original_size[0] * (original_spacing[0] / out_spacing1[0]))),
        int(np.round(original_size[1] * (original_spacing[1] / out_spacing1[1]))),
        int(np.round(original_size[2] * (original_spacing[2] / out_spacing1[2])))
    ]
    # 上述也可以直接用下面这句简写
    #out_size = [int(round(osz*ospc/nspc)) for osz,ospc,nspc in zip(original_size, original_spacing, out_spacing)]

    resample = sitk.ResampleImageFilter()
    resample.SetOutputSpacing(out_spacing1)
    resample.SetSize(out_size1)
    resample.SetOutputDirection(itk_image.GetDirection())
    resample.SetOutputOrigin(itk_image.GetOrigin())
    resample.SetTransform(sitk.Transform())
    resample.SetDefaultPixelValue(itk_image.GetPixelIDValue())
    resample.SetInterpolator(sitk.sitkBSplineResamplerOrder3)
    new_img=resample.Execute(itk_image)

    original_spacing2 = new_img.GetSpacing()
    original_size2 = new_img.GetSize()
    out_spacing2=np.array([out_spacing[0],out_spacing[1],out_spacing[2]])

    out_size2 = [
        int(np.round(original_size2[0] * (original_spacing2[0] / out_spacing2[0]))),
        int(np.round(original_size2[1] * (original_spacing2[1] / out_spacing2[1]))),
        int(np.round(original_size2[2] * (original_spacing2[2] / out_spacing2[2])))
    ]


    resample2 = sitk.ResampleImageFilter()
    resample2.SetOutputSpacing(out_spacing2)
    resample2.SetSize(out_size2)
    resample2.SetOutputDirection(itk_image.GetDirection())
    resample2.SetOutputOrigin(itk_image.GetOrigin())
    resample2.SetTransform(sitk.Transform())
    resample2.SetDefaultPixelValue(itk_image.GetPixelIDValue())
    resample2.SetInterpolator(sitk.sitkNearestNeighbor)
    new_ct=resample2.Execute(new_img)

    original_spacing_seg = seg.GetSpacing()
    original_size_seg = seg.GetSize()

    out_spacing_seg = out_spacing

    out_size_seg = [
        int(np.round(original_size_seg[0] * (original_spacing_seg[0] / out_spacing_seg[0]))),
        int(np.round(original_size_seg[1] * (original_spacing_seg[1] / out_spacing_seg[1]))),
        int(np.round(original_size_seg[2] * (original_spacing_seg[2] / out_spacing_seg[2])))
    ]
    resample3 = sitk.ResampleImageFilter()
    resample3.SetOutputSpacing(out_spacing_seg)
    resample3.SetSize(out_size_seg)
    resample3.SetOutputDirection(seg.GetDirection())
    resample3.SetOutputOrigin(seg.GetOrigin())
    resample3.SetTransform(sitk.Transform())
    resample3.SetDefaultPixelValue(0)
    resample3.SetInterpolator(sitk.sitkNearestNeighbor)
    new_seg=resample3.Execute(seg)

    return new_ct,new_seg



def resampling(path_c):
    Spacing_target,flag=spacing(path_c)

    n=0
    file_r = os.listdir(path_c)
    for f in file_r:
        # 字符串拼接
        real_url = path.join(path_c, f)
        seg_url=real_url.replace('ct','label')
        # 打印出来
        data = sitk.ReadImage(real_url, sitk.sitkUInt8)
        seg = sitk.ReadImage(seg_url,sitk.sitkInt16)
        new_data,new_seg=resample_image(data,Spacing_target,seg,flag)
        new_path=path.join('./new_data/ct/',f)
        sitk.WriteImage(new_data,new_path)
        new_path_seg=new_path.replace('ct','label')
        sitk.WriteImage(new_seg,new_path_seg)
        print('n=',n,'name=',f)
        n=n+1

def vox(data,seg):
    seg_i=sitk.ReadImage(seg,sitk.sitkInt16)
    seg_array=sitk.GetArrayFromImage(seg_i)
    mask=seg_array>0
    # print(mask.shape)
    ct=sitk.ReadImage(data,sitk.sitkUInt8)
    ct_array=sitk.GetArrayFromImage(ct)
    voxels=ct_array[mask][::10]
    voxels=voxels.flatten()
    print(voxels.shape)
    return voxels

def vox_all(file):
    vox_all=[]
    n=1
    list=os.listdir(file)
    for f in list:
        file_ct=path.join(file,f)
        file_seg=path.join(file,f).replace('ct','label')
        vox_vs=vox(file_ct,file_seg)
        vox_all=np.append(vox_all,vox_vs)
        print('n=',n)
        n=n+1
    mean=np.mean(vox_all)
    std=np.std(vox_all)
    np.savetxt('voxels_all.txt',vox_all)
    print('mean_vox=',mean,'std=',std)
    percentile_99_5 = np.percentile(vox_all, 99.5)
    percentile_00_5 = np.percentile(vox_all, 00.5)
    n=1
    for f in list:
        file_ct=path.join(file,f)
        file_seg=path.join(file,f).replace('ct','label')
        ct=sitk.ReadImage(file_ct,sitk.sitkUInt8)
        seg=sitk.ReadImage(file_seg,sitk.sitkInt16)
        ct_array=sitk.GetArrayFromImage(ct)
        seg_array=sitk.GetArrayFromImage(seg)
        ct_array=np.clip(ct_array,percentile_00_5,percentile_99_5)
        ct_array=(ct_array-mean)/std
        ct_array[seg_array < 0] = 0
        ct_new=sitk.GetImageFromArray(ct_array)
        ct_new.SetDirection(ct.GetDirection())
        ct_new.SetOrigin(ct.GetOrigin())
        ct_new.SetSpacing(ct.GetSpacing())
        path_ct=file_ct.replace('new_data','new_data2')
        sitk.WriteImage(ct_new ,path_ct)
        print('n=',n,'name=',f)
        n=n+1

def resample_seg(file):
    Sp=np.loadtxt('spacing.txt')
    Sp_target1 = np.percentile(Sp[1], 50, axis=0, keepdims=True)
    Sp_target2 = np.percentile(Sp[3], 10, axis=0, keepdims=True)
    Sp_target = np.array([Sp_target1, Sp_target1, Sp_target2]).transpose(1, 0)
    print('TARGET=',Sp_target)
    n=0
    file_r = os.listdir(file)
    for f in file_r:
        # 字符串拼接
        real_url = path.join(file, f)
        # 打印出来
        seg = sitk.ReadImage(real_url,sitk.sitkInt16)
        # ori=seg.GetOrigin()
        # dir=seg.GetDirection()
        original_spacing_seg = seg.GetSpacing()
        original_size_seg = seg.GetSize()
        # seg = sitk.GetArrayFromImage(seg)
        out_spacing_seg = Sp_target.squeeze(0)

        out_size_seg = [
            int(np.round(original_size_seg[0] * (original_spacing_seg[0] / out_spacing_seg[0]))),
            int(np.round(original_size_seg[1] * (original_spacing_seg[1] / out_spacing_seg[1]))),
            int(np.round(original_size_seg[2] * (original_spacing_seg[2] / out_spacing_seg[2])))
        ]
        # new_seg=resize(seg,out_size_seg,order=0)
        resample3 = sitk.ResampleImageFilter()
        # new_seg=sitk.GetImageFromArray(new_seg)
        # new_seg.SetSpacing(out_spacing_seg)
        # new_seg.SetOrigin(ori)
        # new_seg.SetDirection(dir)
        resample3.SetSize(out_size_seg)
        resample3.SetOutputDirection(seg.GetDirection())
        resample3.SetOutputOrigin(seg.GetOrigin())
        resample3.SetTransform(sitk.Transform())
        resample3.SetDefaultPixelValue(seg.GetPixelIDValue())
        resample3.SetInterpolator(sitk.sitkNearestNeighbor)
        new_seg = resample3.Execute(seg)
        path_new=real_url.replace('data','new_data')
        sitk.WriteImage(new_seg,path_new)
        print('n=',n,'name=',f)
        n=n+1


if __name__ == '__main__':
    file='./new_data/ct/'
    vox_all(file)
    # print(file.replace('ct','label'))
    # Sp=resampling(file)
    # x1=[1,1,1,1,2,2,2,2]
    # x2=[5,2,3,1]
    # print(np.hstack((x1+x2)).shape)
    # print(np.append(x1,x2).shape)

    # resample_seg(file)

    # Sp=np.loadtxt('spacing.txt')
    # Sp_m=np.percentile(Sp,50,axis=0,keepdims=True)
    # print(Sp_m)


    # S1 = [[0.61458331,0.61458331,5.],
    #      [0.58333331, 0.58333331, 5.],
    #       [0.48333331, 0.48333331, 3.],
    #       [0.28333331, 0.28333331, 1.]]
    # S1=np.array(S1)
    # Sp_target1 = np.percentile(S1[1], 50, axis=0, keepdims=True)
    # Sp_target2 = np.percentile(S1[3], 10, axis=0, keepdims=True)
    # Sp_target3 = np.percentile(S1,50,axis=0,keepdims=True)
    # print(Sp_target3)
    # S=np.array([Sp_target1,Sp_target1,Sp_target2]).transpose(1,0)
    # print(S.shape)
    # print(S)


    # print(S1.shape)
    # print(np.median(S1[1,:]))
    # S2 = [0.58333331,0.58333331,5.]
    # Sp=[]
    # Sp=np.append(Sp,S1,axis=0)
    # print(Sp[11])
    # print(Sp.shape)





 写的代码比较稚嫩,而且为了便于debug,一步一步是分开写的,每一步都保存了数据,通过3Dslicer可以很直观的看出来有没有问题.

之后该写Dataloader和Net了,先在Unet上试一试.之前跑通的3DUnet是整个数据集训练的,AMOS数据集处理前后的Size都是不统一的,需要写基于Patch的网络训练.

AMOS数据集这样预处理之后Size比较大,训练速度估计会慢很多,可能也是个代解决的问题.网络结构也需要创新.怎么能获得又快又好的效果呢?

你可能感兴趣的:(医学图像分割学习记录,学习)