nnunet(十七) nnUNet_convert_decathlon_task

 移除路径中左后的反斜杠

def remove_trailing_slash(filename: str):
    while filename.endswith('/'):#以反斜杠结尾
        filename = filename[:-1]#去掉反斜杠
    return filename

 返回对应的子文件夹

def subdirs(folder, join=True, prefix=None, suffix=None, sort=True):
    if join:
        l = os.path.join
    else:
        l = lambda x, y: y
    res = [l(folder, i) for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i))#返回folder子文件夹
            and (prefix is None or i.startswith(prefix))
            and (suffix is None or i.endswith(suffix))]
    if sort:
        res.sort()
    return res

 拆分4D数据,单个模态分开存储

def split_4d_nifti(filename, output_folder):
    img_itk = sitk.ReadImage(filename)#读取4d数据
    dim = img_itk.GetDimension()#看看有多少个dimension
    file_base = filename.split("/")[-1]#获取文件名
    if dim == 3:#如果是3D数据
        shutil.copy(filename, join(output_folder, file_base[:-7] + "_0000.nii.gz"))#直接将数据重命名拷过去,因为只有一个模态
        return
    elif dim != 4:#既不是3维也不是4维则不支持该模式
        raise RuntimeError("Unexpected dimensionality: %d of file %s, cannot split" % (dim, filename))
    else:
        img_npy = sitk.GetArrayFromImage(img_itk)#numpy array
        spacing = img_itk.GetSpacing()#spacing
        origin = img_itk.GetOrigin()#origin
        direction = np.array(img_itk.GetDirection()).reshape(4,4)#direction
        # now modify these to remove the fourth dimension
        spacing = tuple(list(spacing[:-1]))#三维的spacing
        origin = tuple(list(origin[:-1]))#三维的origin
        direction = tuple(direction[:-1, :-1].reshape(-1))#三维的direction
        for i, t in enumerate(range(img_npy.shape[0])):#不同的模态
            img = img_npy[t]#当前模态
            img_itk_new = sitk.GetImageFromArray(img)#data
            img_itk_new.SetSpacing(spacing)#spacing
            img_itk_new.SetOrigin(origin)#origin
            img_itk_new.SetDirection(direction)#direction
            sitk.WriteImage(img_itk_new, join(output_folder, file_base[:-7] + "_%04.0d.nii.gz" % i))#保存到输出目录,文件名+四位表示模态的整数

 4D数据转换

def split_4d(input_folder, num_processes=default_num_threads, overwrite_task_output_id=None):
    assert isdir(join(input_folder, "imagesTr")) and isdir(join(input_folder, "labelsTr")) and \
           isfile(join(input_folder, "dataset.json")), \
        "The input folder must be a valid Task folder from the Medical Segmentation Decathlon with at least the " \
        "imagesTr and labelsTr subfolders and the dataset.json file"

    while input_folder.endswith("/"):#取出task路径末尾无关的反斜杠
        input_folder = input_folder[:-1]

    full_task_name = input_folder.split("/")[-1]#获取完整的task name

    #确定这是一个task
    assert full_task_name.startswith("Task"), "The input folder must point to a folder that starts with TaskXX_"

    first_underscore = full_task_name.find("_")#找到间隔符索引
    #MSD的ID都是两位数
    assert first_underscore == 6, "Input folder start with TaskXX with XX being a 3-digit id: 00, 01, 02 etc"

    input_task_id = int(full_task_name[4:6])#获取TaskID
    if overwrite_task_output_id is None:
        overwrite_task_output_id = input_task_id#输出文件夹的ID

    task_name = full_task_name[7:]#获取task name

    output_folder = join(nnUNet_raw_data, "Task%03.0d_" % overwrite_task_output_id + task_name)#组合成三位整数保存保存数据目录

    if isdir(output_folder):#如果本来就有数据
        shutil.rmtree(output_folder)#删除原有数据

    files = []
    output_dirs = []

    maybe_mkdir_p(output_folder)#创建文件夹
    for subdir in ["imagesTr", "imagesTs"]:
        curr_out_dir = join(output_folder, subdir)
        if not isdir(curr_out_dir):
            os.mkdir(curr_out_dir)#创建子文件夹
        curr_dir = join(input_folder, subdir)
        nii_files = [join(curr_dir, i) for i in os.listdir(curr_dir) if i.endswith(".nii.gz")]#获取所有nifti文件并组合成完整路径
        nii_files.sort()#排序
        for n in nii_files:
            files.append(n)#输入文件
            output_dirs.append(curr_out_dir)#输出文件夹

    shutil.copytree(join(input_folder, "labelsTr"), join(output_folder, "labelsTr"))#把整个label文件拷贝过去

    p = Pool(num_processes)
    p.starmap(split_4d_nifti, zip(files, output_dirs))#开启num_processes个线程
    p.close()
    p.join()
    shutil.copy(join(input_folder, "dataset.json"), output_folder)#直接拷贝dataset.json文件

 

 

#    Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
#    Licensed under the Apache License, Version 2.0 (the "License");
#    you may not use this file except in compliance with the License.
#    You may obtain a copy of the License at
#
#        http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS,
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#    See the License for the specific language governing permissions and
#    limitations under the License.
from batchgenerators.utilities.file_and_folder_operations import *
from nnunet.configuration import default_num_threads
from nnunet.experiment_planning.utils import split_4d
from nnunet.utilities.file_endings import remove_trailing_slash


def crawl_and_remove_hidden_from_decathlon(folder):
    folder = remove_trailing_slash(folder)#移除路径最后的反斜杠
    #MSD目录下的taskxx_xx
    assert folder.split('/')[-1].startswith("Task"), "This does not seem to be a decathlon folder. Please give me a " \
                                                     "folder that starts with TaskXX and has the subfolders imagesTr, " \
                                                     "labelsTr and imagesTs"
    subf = subfolders(folder, join=False)#获取所有子文件夹
    #断言是否有子文件夹imageTr、imageTs、labelTr
    assert 'imagesTr' in subf, "This does not seem to be a decathlon folder. Please give me a " \
                                                     "folder that starts with TaskXX and has the subfolders imagesTr, " \
                                                     "labelsTr and imagesTs"
    assert 'imagesTs' in subf, "This does not seem to be a decathlon folder. Please give me a " \
                                                     "folder that starts with TaskXX and has the subfolders imagesTr, " \
                                                     "labelsTr and imagesTs"
    assert 'labelsTr' in subf, "This does not seem to be a decathlon folder. Please give me a " \
                                                     "folder that starts with TaskXX and has the subfolders imagesTr, " \
                                                     "labelsTr and imagesTs"
    #去掉文件夹中无用的文件
    _ = [os.remove(i) for i in subfiles(folder, prefix=".")]
    _ = [os.remove(i) for i in subfiles(join(folder, 'imagesTr'), prefix=".")]
    _ = [os.remove(i) for i in subfiles(join(folder, 'labelsTr'), prefix=".")]
    _ = [os.remove(i) for i in subfiles(join(folder, 'imagesTs'), prefix=".")]


def main():
    import argparse
    parser = argparse.ArgumentParser(description="The MSD provides data as 4D Niftis with the modality being the first"
                                                 " dimension. We think this may be cumbersome for some users and "
                                                 "therefore expect 3D niftixs instead, with one file per modality. "
                                                 "This utility will convert 4D MSD data into the format nnU-Net "
                                                 "expects")
    parser.add_argument("-i", help="Input folder. Must point to a TaskXX_TASKNAME folder as downloaded from the MSD "
                                   "website", required=True)
    parser.add_argument("-p", required=False, default=default_num_threads, type=int,
                        help="Use this to specify how many processes are used to run the script. "
                             "Default is %d" % default_num_threads)
    parser.add_argument("-output_task_id", required=False, default=None, type=int,
                        help="If specified, this will overwrite the task id in the output folder. If unspecified, the "
                             "task id of the input folder will be used.")
    args = parser.parse_args()

    crawl_and_remove_hidden_from_decathlon(args.i)#检测是否为MSD数据

    split_4d(args.i, args.p, args.output_task_id)#将MSD的4D数据拆分


if __name__ == "__main__":
    main()

你可能感兴趣的:(segmentation,nnunet)