方向针对血管介入手术,故只保留了主动脉标签.
nnUnet对数据的预处理包括cropping,resample,normalization三个步骤.
Github源码:GitHub - MIC-DKFZ/nnUNet
但是Github项目的代码比较复杂,知乎有一篇文章写的更清晰:
如何针对三维医学图像分割任务进行通用数据预处理:nnUNet中预处理流程总结及代码分析 - 知乎
主要参考了这两个完成了几个部分的处理.
file_list.py: 保留label中的主动脉,生成train_list.txt以及val_list.txt,创建训练集以及验证集的索引.
import os
import SimpleITK as sitk
import numpy as np
from os.path import join
import random
import config
def process(ct_path, seg_path):
ct = sitk.ReadImage(ct_path)
ct_array = sitk.GetArrayFromImage(ct)
seg = sitk.ReadImage(seg_path)
seg_array = sitk.GetArrayFromImage(seg)
print("Ori shape:", ct_array.shape, seg_array.shape)
# 将金标准中肝脏和肝肿瘤的标签融合为一个
seg_array[seg_array > 8] = 0
seg_array[seg_array < 8] = 0
new_ct = sitk.GetImageFromArray(ct_array)
new_ct.SetOrigin(ct.GetOrigin())
new_ct.SetSpacing(ct.GetSpacing())
new_ct.SetDirection(ct.GetDirection())
new_seg = sitk.GetImageFromArray(seg_array)
new_seg.SetOrigin(seg.GetOrigin())
new_seg.SetSpacing(seg.GetSpacing())
new_seg.SetDirection(seg.GetDirection())
sitk.WriteImage(new_ct, ct_path.replace('data', 'fixed_data'))
sitk.WriteImage(new_seg, seg_path.replace('data', 'fixed_data'))
class file_list:
def __init__(self,dataset_path,fixed_path,args):
self.dataset_path=dataset_path #data_path='./data/'
self.fixed_path=fixed_path #fixed_path='./fixed_data/'
self.valid_rate=args.valid_rate
def fix_data(self):
if not os.path.exists(self.fixed_path):
os.makedirs(join(self.fixed_path,'ct'))
os.makedirs(join(self.fixed_path,'label'))
file_list=os.listdir(join(self.dataset_path,'ct'))#原数据集 file_list='./data/ct'
Numbers=len(file_list)
print('total number of samples is:',Numbers)
for ct_file,i in zip(file_list,range(Numbers)):
print("==== {} | {}/{} ====".format(ct_file, i + 1, Numbers))
ct_path = os.path.join(self.dataset_path, 'ct', ct_file)
seg_path = os.path.join(self.dataset_path, 'label', ct_file)
process(ct_path, seg_path)
def write_train_val_name_list(self):
data_name_list = os.listdir(join(self.fixed_path, "ct"))
data_num = len(data_name_list)
print('the fixed dataset total numbers of samples is :', data_num)
random.shuffle(data_name_list)
assert self.valid_rate < 1.0
train_name_list = data_name_list[0:int(data_num*(1-self.valid_rate))]
val_name_list = data_name_list[int(data_num*(1-self.valid_rate)):int(data_num*((1-self.valid_rate) + self.valid_rate))]
self.write_name_list(train_name_list, "train_path_list.txt")
self.write_name_list(val_name_list, "val_path_list.txt")
def write_name_list(self, name_list, file_name):
f = open(join(self.fixed_path, file_name), 'w')
for name in name_list:
ct_path = os.path.join(self.fixed_path, 'ct', name)
seg_path = os.path.join(self.fixed_path, 'label', name)
f.write(ct_path + ' ' + seg_path + "\n")
f.close()
if __name__ == '__main__':
data_path='./data/'
fixed_path='./amos_data/'
args=config.args
tool=file_list(data_path,fixed_path,args)
# tool.fix_data()
tool.write_train_val_name_list()
# data='./data/ct/amos_0001.nii.gz'
# label='./fixed_data/label/amos_0001.nii.gz'
# ct='./data/ct/amos_0001.nii.gz'
# data_array=sitk.ReadImage(data,sitk.sitkInt16)
# spacing=np.array(data_array.GetSpacing())
# data_array=sitk.GetArrayFromImage(data_array)
# print(data_array.shape)
# print(spacing)
cropping.py:通过label产生mask,再通过mask对ct和label进行crop.本数据集很多数据cropping没有效果,有效果的也不会抛去很多size,也许是我的代码有问题.
import SimpleITK as sitk
import numpy as np
import matplotlib.pyplot as plt
import shutil
# from batchgenerators.utilities.file_and_folder_operations import *
from multiprocessing import Pool
from collections import OrderedDict
from os import path
import os
def create_nonzero_mask(data): #nonezero mask
from scipy.ndimage import binary_fill_holes
print(data.shape)
assert len(data.shape) == 4 or len(data.shape) == 3, "data must have shape (C, X, Y, Z) or shape (C, X, Y)"
nonzero_mask = np.zeros(data.shape[0:], dtype=bool)
for c in range(data.shape[0]):
this_mask = data[c] != 0
nonzero_mask = nonzero_mask | this_mask
nonzero_mask = binary_fill_holes(nonzero_mask)
return nonzero_mask
def get_bbox_from_mask(mask, outside_value=0): #mask_boudingbox352
mask_voxel_coords = np.where(mask != outside_value)
minzidx = int(np.min(mask_voxel_coords[0]))
maxzidx = int(np.max(mask_voxel_coords[0])) + 1
minxidx = int(np.min(mask_voxel_coords[1]))
maxxidx = int(np.max(mask_voxel_coords[1])) + 1
minyidx = int(np.min(mask_voxel_coords[2]))
maxyidx = int(np.max(mask_voxel_coords[2])) + 1
return [[minzidx, maxzidx], [minxidx, maxxidx], [minyidx, maxyidx]]
def crop_to_bbox(image, bbox):
assert len(image.shape) == 3, "only supports 3d images"
resizer = (slice(bbox[0][0], bbox[0][1]), slice(bbox[1][0], bbox[1][1]), slice(bbox[2][0], bbox[2][1]))
return image[resizer]
def crop_to_nonzero(data, seg, nonzero_label=-1):
"""
:param data:
:param seg:
:param nonzero_label: this will be written into the segmentation map
:return:
"""
nonzero_mask = create_nonzero_mask(data)
bbox = get_bbox_from_mask(nonzero_mask, 0)
mask=crop_to_bbox(nonzero_mask,bbox)
data = crop_to_bbox(data, bbox)
seg=crop_to_bbox(seg,bbox)
seg[(seg == 0) & (mask == 0)] = nonzero_label
return data, seg, bbox
def cropping(file_ct,file_seg):
file_c=os.listdir(file_ct)
n=0
for f in file_c:
data_c=path.join(file_ct,f)
data_s=path.join(file_seg,f)
ct=sitk.ReadImage(data_c,sitk.sitkUInt8)
#原始参数
ori=ct.GetOrigin()
spa=ct.GetSpacing()
dir=ct.GetDirection()
ct=sitk.GetArrayFromImage(ct)
# ct=np.expand_dims(ct,axis=0)
seg=sitk.ReadImage(data_s,sitk.sitkInt16)
seg_ori=seg.GetOrigin()
seg_spa=seg.GetSpacing()
seg_dir=seg.GetDirection()
seg=sitk.GetArrayFromImage(seg)
# seg=np.expand_dims(seg,axis=0)
fixed_data_c,fixed_data_s,bbox=crop_to_nonzero(ct,seg,nonzero_label=-1)
fixed_ct=sitk.GetImageFromArray(fixed_data_c)
fixed_ct.SetDirection(dir)
fixed_ct.SetSpacing(spa)
fixed_ct.SetOrigin(ori)
path_new=path.join('./crop_data/ct/',f)
sitk.WriteImage(fixed_ct,path_new)
fixed_s=sitk.GetImageFromArray(fixed_data_s)
fixed_s.SetDirection(seg_dir)
fixed_s.SetSpacing(seg_spa)
fixed_s.SetOrigin(seg_ori)
path_seg=path_new.replace('ct','label')
sitk.WriteImage(fixed_s,path_seg)
print('n=',n,'/199')
print('name:',f)
n=n+1
if __name__ == '__main__':
data='./fixed_data/ct/'
seg='./fixed_data/label/'
cropping(data,seg)
# file_c=os.listdir(data)
# n=0
# for f in file_c:
# data_c=path.join(data,f)
# ct=sitk.ReadImage(data_c,sitk.sitkUInt8)
# ct=sitk.GetArrayFromImage(ct)
# print(ct.shape)
# print('spacing=', np.array(data.GetSpacing()))
# print(np.min(data))
# # data=np.expand_dims(data,axis=0)
# seg=sitk.ReadImage(seg,sitk.sitkInt16)
# seg=sitk.GetArrayFromImage(seg)
# nzmask=create_nonzero_mask(data)
# bbox=get_bbox_from_mask(nzmask,0)
# print(bbox)
# data_c=crop_to_bbox(data,get_bbox_from_mask(nzmask,0))
# data_c=get_bbox_from_mask(nzmask,0)
# print(data_c)
# data_c,seg_c,bbox_c=crop_to_nonzero(data,seg,nonzero_label=-1)
# data_c=crop_to_bbox(data,get_bbox_from_mask(nzmask))
# plt.imshow(nzmask)
# plt.show()
resampling and norm.py:Resampling将数据集的Spacing统一,Spacing的计算分为各向异性数据和非各项异性数据,该部分内容在知乎文章写的比较清晰.
import SimpleITK as sitk
import numpy as np
import matplotlib.pyplot as plt
import shutil
# from batchgenerators.utilities.file_and_folder_operations import *
from multiprocessing import Pool
from collections import OrderedDict
import os
from os import path
from skimage.transform import resize
def spacing(file):
# 定义一个函数
# 遍历当前路径下所有文件
flag=1
# Sp=np.zeros((200,3))
# n=0
# file_r = os.listdir(file)
# for f in file_r:
# # 字符串拼接
# real_url = path.join(file, f)
# # 打印出来
# data = sitk.ReadImage(real_url)
# S=np.array(data.GetSpacing())
# print('n=',n,'S=',S)
# Sp[n,:]=S
# n=n+1
# np.savetxt('spacing.txt',Sp)
Sp=np.loadtxt('spacing.txt')
Sp_m=np.percentile(Sp,50,axis=0,keepdims=True)
if np.max(Sp_m)>3*np.min(Sp_m):
flag=1
print('各向异性数据集')
Sp_target1=np.percentile(Sp[1],50,axis=0,keepdims=True)
Sp_target2=np.percentile(Sp[3],10,axis=0,keepdims=True)
Sp_target=np.array([Sp_target1,Sp_target1,Sp_target2]).transpose(1,0)
else:
flag=0
print('非各向异性数据集')
Sp_target=np.percentile(Sp,50,axis=0,keepdims=True)
print('Sp_target=',Sp_target)
return Sp_target,flag
# 调用自定义函数
def resample_image(itk_image, out_spacing,seg,flag):
assert flag==1, "非各项异性"
out_spacing=out_spacing.squeeze(0)
original_spacing = itk_image.GetSpacing()
original_size = itk_image.GetSize()
out_spacing1=np.array([out_spacing[0],out_spacing[1],original_spacing[2]])
out_size1 = [
int(np.round(original_size[0] * (original_spacing[0] / out_spacing1[0]))),
int(np.round(original_size[1] * (original_spacing[1] / out_spacing1[1]))),
int(np.round(original_size[2] * (original_spacing[2] / out_spacing1[2])))
]
# 上述也可以直接用下面这句简写
#out_size = [int(round(osz*ospc/nspc)) for osz,ospc,nspc in zip(original_size, original_spacing, out_spacing)]
resample = sitk.ResampleImageFilter()
resample.SetOutputSpacing(out_spacing1)
resample.SetSize(out_size1)
resample.SetOutputDirection(itk_image.GetDirection())
resample.SetOutputOrigin(itk_image.GetOrigin())
resample.SetTransform(sitk.Transform())
resample.SetDefaultPixelValue(itk_image.GetPixelIDValue())
resample.SetInterpolator(sitk.sitkBSplineResamplerOrder3)
new_img=resample.Execute(itk_image)
original_spacing2 = new_img.GetSpacing()
original_size2 = new_img.GetSize()
out_spacing2=np.array([out_spacing[0],out_spacing[1],out_spacing[2]])
out_size2 = [
int(np.round(original_size2[0] * (original_spacing2[0] / out_spacing2[0]))),
int(np.round(original_size2[1] * (original_spacing2[1] / out_spacing2[1]))),
int(np.round(original_size2[2] * (original_spacing2[2] / out_spacing2[2])))
]
resample2 = sitk.ResampleImageFilter()
resample2.SetOutputSpacing(out_spacing2)
resample2.SetSize(out_size2)
resample2.SetOutputDirection(itk_image.GetDirection())
resample2.SetOutputOrigin(itk_image.GetOrigin())
resample2.SetTransform(sitk.Transform())
resample2.SetDefaultPixelValue(itk_image.GetPixelIDValue())
resample2.SetInterpolator(sitk.sitkNearestNeighbor)
new_ct=resample2.Execute(new_img)
original_spacing_seg = seg.GetSpacing()
original_size_seg = seg.GetSize()
out_spacing_seg = out_spacing
out_size_seg = [
int(np.round(original_size_seg[0] * (original_spacing_seg[0] / out_spacing_seg[0]))),
int(np.round(original_size_seg[1] * (original_spacing_seg[1] / out_spacing_seg[1]))),
int(np.round(original_size_seg[2] * (original_spacing_seg[2] / out_spacing_seg[2])))
]
resample3 = sitk.ResampleImageFilter()
resample3.SetOutputSpacing(out_spacing_seg)
resample3.SetSize(out_size_seg)
resample3.SetOutputDirection(seg.GetDirection())
resample3.SetOutputOrigin(seg.GetOrigin())
resample3.SetTransform(sitk.Transform())
resample3.SetDefaultPixelValue(0)
resample3.SetInterpolator(sitk.sitkNearestNeighbor)
new_seg=resample3.Execute(seg)
return new_ct,new_seg
def resampling(path_c):
Spacing_target,flag=spacing(path_c)
n=0
file_r = os.listdir(path_c)
for f in file_r:
# 字符串拼接
real_url = path.join(path_c, f)
seg_url=real_url.replace('ct','label')
# 打印出来
data = sitk.ReadImage(real_url, sitk.sitkUInt8)
seg = sitk.ReadImage(seg_url,sitk.sitkInt16)
new_data,new_seg=resample_image(data,Spacing_target,seg,flag)
new_path=path.join('./new_data/ct/',f)
sitk.WriteImage(new_data,new_path)
new_path_seg=new_path.replace('ct','label')
sitk.WriteImage(new_seg,new_path_seg)
print('n=',n,'name=',f)
n=n+1
def vox(data,seg):
seg_i=sitk.ReadImage(seg,sitk.sitkInt16)
seg_array=sitk.GetArrayFromImage(seg_i)
mask=seg_array>0
# print(mask.shape)
ct=sitk.ReadImage(data,sitk.sitkUInt8)
ct_array=sitk.GetArrayFromImage(ct)
voxels=ct_array[mask][::10]
voxels=voxels.flatten()
print(voxels.shape)
return voxels
def vox_all(file):
vox_all=[]
n=1
list=os.listdir(file)
for f in list:
file_ct=path.join(file,f)
file_seg=path.join(file,f).replace('ct','label')
vox_vs=vox(file_ct,file_seg)
vox_all=np.append(vox_all,vox_vs)
print('n=',n)
n=n+1
mean=np.mean(vox_all)
std=np.std(vox_all)
np.savetxt('voxels_all.txt',vox_all)
print('mean_vox=',mean,'std=',std)
percentile_99_5 = np.percentile(vox_all, 99.5)
percentile_00_5 = np.percentile(vox_all, 00.5)
n=1
for f in list:
file_ct=path.join(file,f)
file_seg=path.join(file,f).replace('ct','label')
ct=sitk.ReadImage(file_ct,sitk.sitkUInt8)
seg=sitk.ReadImage(file_seg,sitk.sitkInt16)
ct_array=sitk.GetArrayFromImage(ct)
seg_array=sitk.GetArrayFromImage(seg)
ct_array=np.clip(ct_array,percentile_00_5,percentile_99_5)
ct_array=(ct_array-mean)/std
ct_array[seg_array < 0] = 0
ct_new=sitk.GetImageFromArray(ct_array)
ct_new.SetDirection(ct.GetDirection())
ct_new.SetOrigin(ct.GetOrigin())
ct_new.SetSpacing(ct.GetSpacing())
path_ct=file_ct.replace('new_data','new_data2')
sitk.WriteImage(ct_new ,path_ct)
print('n=',n,'name=',f)
n=n+1
def resample_seg(file):
Sp=np.loadtxt('spacing.txt')
Sp_target1 = np.percentile(Sp[1], 50, axis=0, keepdims=True)
Sp_target2 = np.percentile(Sp[3], 10, axis=0, keepdims=True)
Sp_target = np.array([Sp_target1, Sp_target1, Sp_target2]).transpose(1, 0)
print('TARGET=',Sp_target)
n=0
file_r = os.listdir(file)
for f in file_r:
# 字符串拼接
real_url = path.join(file, f)
# 打印出来
seg = sitk.ReadImage(real_url,sitk.sitkInt16)
# ori=seg.GetOrigin()
# dir=seg.GetDirection()
original_spacing_seg = seg.GetSpacing()
original_size_seg = seg.GetSize()
# seg = sitk.GetArrayFromImage(seg)
out_spacing_seg = Sp_target.squeeze(0)
out_size_seg = [
int(np.round(original_size_seg[0] * (original_spacing_seg[0] / out_spacing_seg[0]))),
int(np.round(original_size_seg[1] * (original_spacing_seg[1] / out_spacing_seg[1]))),
int(np.round(original_size_seg[2] * (original_spacing_seg[2] / out_spacing_seg[2])))
]
# new_seg=resize(seg,out_size_seg,order=0)
resample3 = sitk.ResampleImageFilter()
# new_seg=sitk.GetImageFromArray(new_seg)
# new_seg.SetSpacing(out_spacing_seg)
# new_seg.SetOrigin(ori)
# new_seg.SetDirection(dir)
resample3.SetSize(out_size_seg)
resample3.SetOutputDirection(seg.GetDirection())
resample3.SetOutputOrigin(seg.GetOrigin())
resample3.SetTransform(sitk.Transform())
resample3.SetDefaultPixelValue(seg.GetPixelIDValue())
resample3.SetInterpolator(sitk.sitkNearestNeighbor)
new_seg = resample3.Execute(seg)
path_new=real_url.replace('data','new_data')
sitk.WriteImage(new_seg,path_new)
print('n=',n,'name=',f)
n=n+1
if __name__ == '__main__':
file='./new_data/ct/'
vox_all(file)
# print(file.replace('ct','label'))
# Sp=resampling(file)
# x1=[1,1,1,1,2,2,2,2]
# x2=[5,2,3,1]
# print(np.hstack((x1+x2)).shape)
# print(np.append(x1,x2).shape)
# resample_seg(file)
# Sp=np.loadtxt('spacing.txt')
# Sp_m=np.percentile(Sp,50,axis=0,keepdims=True)
# print(Sp_m)
# S1 = [[0.61458331,0.61458331,5.],
# [0.58333331, 0.58333331, 5.],
# [0.48333331, 0.48333331, 3.],
# [0.28333331, 0.28333331, 1.]]
# S1=np.array(S1)
# Sp_target1 = np.percentile(S1[1], 50, axis=0, keepdims=True)
# Sp_target2 = np.percentile(S1[3], 10, axis=0, keepdims=True)
# Sp_target3 = np.percentile(S1,50,axis=0,keepdims=True)
# print(Sp_target3)
# S=np.array([Sp_target1,Sp_target1,Sp_target2]).transpose(1,0)
# print(S.shape)
# print(S)
# print(S1.shape)
# print(np.median(S1[1,:]))
# S2 = [0.58333331,0.58333331,5.]
# Sp=[]
# Sp=np.append(Sp,S1,axis=0)
# print(Sp[11])
# print(Sp.shape)
写的代码比较稚嫩,而且为了便于debug,一步一步是分开写的,每一步都保存了数据,通过3Dslicer可以很直观的看出来有没有问题.
之后该写Dataloader和Net了,先在Unet上试一试.之前跑通的3DUnet是整个数据集训练的,AMOS数据集处理前后的Size都是不统一的,需要写基于Patch的网络训练.
AMOS数据集这样预处理之后Size比较大,训练速度估计会慢很多,可能也是个代解决的问题.网络结构也需要创新.怎么能获得又快又好的效果呢?