最近用到 nnUNet 这个框架做了一个影像分割的项目。正好学习一下这个框架的源码。
我的电脑是ubuntu,已经安装好了nnUNet框架,并且按照作者提供的这个案例的步骤,对相应步骤的原码进行解读。
使用到数据也是作者在文档中提供的前列腺(prostate)数据集,下载地址:https://drive.google.com/drive/folders/1HqEgzS8BV2c7xYNrZdEAnrHk7osJJ–2
只是记录我读源码的过程,所以文档格式很乱,也可能有错误。
nnUNet 作者使用到的数据集是 NIFITI 格式(.nii.gz)。该数据集比较特征,是一个 4D 形式,具体是一个样本数据集包含了通过APC和T2两种实验技术获取的影像图像,姑且认为1个 nii.gz 文件包含了两张影像图片。
然而,nnUNet 这个框架只支持2D和3D数据格式。因此,为了使数据符合框架,故需要执行下面这个维度转换步骤:
# 用于将4D的数据转换为3D,直白的说是将包含两张3D图片的文件 拆分为 2个 3D图片文件。
nnUNet_convert_decathlon_task -i /xxx/Task04_Hippocampus
我们丛书聚集中找一个样本的文件,通过SimpleITK读取,并输出维度:
import SimpleITK as sitk
img_itk = sitk.ReadImage("./prostate_00.nii.gz")
dim = img_itk.GetDimension()
print("数据维度:", dim)
# 输出结果是 4
img_npy = sitk.GetArrayFromImage(img_itk)
print("文件的图片为:",img_npy.shape)
# 输出结果是 (2, 15, 320, 320)
另外,我们打开下载的数据文件中,打开一个叫 “dataset.json” 的文件(这个文件很重要,使用nnUNet框架的必要输入文件)。重点看下面 modality 这个信息,表示我们的数据图像包括两个来源,即T2和ADC。其他信息这里暂时不展开说。
{
"name": "PROSTATE",
"description": "Prostate transitional zone and peripheral zone segmentation",
"reference": "Radboud University, Nijmegen Medical Centre",
"licence":"CC-BY-SA 4.0",
"relase":"1.0 04/05/2018",
"tensorImageSize": "4D",
"quantitative": [0,1],
"modality": {
"0": "T2",
"1": "ADC"
},
"labels": {
"0": "background",
"1": "PZ",
"2": "TZ"
},
"numTraining": 32,
"numTest": 16,
"training":[{"image":"./imagesTr/prostate_16.nii.gz","label":"./labelsTr/prostate_16.nii.gz"},{"image":"./imagesTr/prostate_04.nii.gz","label":"./labelsTr/prostate_04.nii.gz"},
.......
}
(1) 作者是通过 nnUNet_convert_decathlon_task
这个命令进行维度转换的。我们通过 linux 命令 which
找到这个命令,然后打开对应文件,可以发现这个命令最终对应的code文件是 nnunet/experiment_planning/nnUNet_convert_decathlon_task.py
。
(2)通过上述文件,可以一步步定位到下面这个函数,也就是该功能的核心函数:
def split_4d_nifti(filename, output_folder, add_zeros=False):
img_itk = sitk.ReadImage(filename) # 读取nii.gz数据集
dim = img_itk.GetDimension()
file_base = filename.split("/")[-1]
if dim == 3:
shutil.copy(filename, join(output_folder, file_base[:-7] + "_0000.nii.gz"))
return
elif dim != 4:
raise RuntimeError("Unexpected dimensionality: %d of file %s, cannot split" % (dim, filename))
else:
img_npy = sitk.GetArrayFromImage(img_itk)
spacing = img_itk.GetSpacing() # 获取spacing
origin = img_itk.GetOrigin() # 获取origin
direction = np.array(img_itk.GetDirection()).reshape(4,4) # 获取direction
# now modify these to remove the fourth dimension
spacing = tuple(list(spacing[:-1])) # 这几行都是删除多余维度的meta信息
origin = tuple(list(origin[:-1]))
direction = tuple(direction[:-1, :-1].reshape(-1))
for i, t in enumerate(range(img_npy.shape[0])): # 将4D拆分为3D
img = img_npy[t] # 3D图片矩阵
img_itk_new = sitk.GetImageFromArray(img)
img_itk_new.SetSpacing(spacing) # 这几行是添加meta信息
img_itk_new.SetOrigin(origin)
img_itk_new.SetDirection(direction)
sitk.WriteImage(img_itk_new, join(output_folder, file_base[:-7] + "_%04.0d.nii.gz" % i)) # 写出文件