nnUNet训练并推理自己的数据集

(默认所需环境已经配置好,且为linux环境与linux命令)

一、训练

(1) 安装hiddenlayer

 pip install --upgrade git+https://github.com/nanohanno/hiddenlayer.git@bugfix/get_trace_graph#egg=hiddenlayer

(一行)

(2)安装nnUNet
1、在空间充足的位置安装nnUNet(最好超过100G),在该位置mkdir home
2、然后进入home文件夹mkdir nnUNetFrame,进入nnUNetFrame文件夹
3、在nnUNetFrame里的终端内输入

git clone https://github.com/MIC-DKFZ/nnUNet.git

4、

 cd nnunet

5、

 pip install -e .

(这个点也要输进去)

(3)整理自己的数据集
1、首先创建好相关文件夹,如下图所示,注意Task文件夹的序号要形如02,13(具体是啥随意),名称为自己数据集名称,如这里的NPZ
nnUNet训练并推理自己的数据集_第1张图片

DATASET就是自己数据集的位置,nnunet是源代码存放的位置
2、imagesTr存放训练集data,labelsTr存放训练集label,imagesTs存放测试集data,labelsTs存放测试集label,且图像名称需要进行规范,如下图所示nnUNet训练并推理自己的数据集_第2张图片
label的名称与image的一样
nnUNet训练并推理自己的数据集_第3张图片

(4)设置nnUNet的读取路径,进入创建的home文件夹,vim .bashrc,在.bashrc中输入

export nnUNet_raw_data_base="/*/home/nnUNetFrame/DATASET/nnUNet_raw"
export nnUNet_preprocessed="/*/home/nnUNetFrame/DATASET/nnUNet_preprocessed"
export RESULTS_FOLDER="/*/home/nnUNetFrame/DATASET/nnUNet_trained_models"

这里的/*就是你的home所在的父目录文件夹,按esc后,输入:wq保存退出,再输入

source .bashrc

更新文件,这样nnUNet就知道你的数据放在哪里了(类似于设置环境变量),注意文件名称别打错字了
(5)生成数据集对应的json文件
在Task文件夹下运行get_json.py文件,get_json.py文件如下图所示,注意在运行前首先要先进入home文件夹,设置环境变量如在home文件夹的终端中输入

export PYTHONPATH=$PYTHONPATH:/*/home/nnUNetFrame/nnUNet

/*与上文一样也是父目录路径,该步骤即设置nnUNet到环境变量中;
get_json.py如下所示,注意某些地方要根据自己的数据集真实情况进行修改

import os
from batchgenerators.utilities.file_and_folder_operations import save_json, subfiles
from typing import Tuple
import numpy as np


def get_identifiers_from_splitted_files(folder: str):
    uniques = np.unique([i[:-7] for i in subfiles(folder, suffix='.nii.gz', join=False)])
    return uniques


def generate_dataset_json(output_file: str, imagesTr_dir: str, imagesTs_dir: str, modalities: Tuple,
                          labels: dict, dataset_name: str, license: str = "Hebut AI", dataset_description: str = "",
                          dataset_reference="oai-zib", dataset_release='11/2021'):
    """
    :param output_file: This needs to be the full path to the dataset.json you intend to write, so
    output_file='DATASET_PATH/dataset.json' where the folder DATASET_PATH points to is the one with the
    imagesTr and labelsTr subfolders
    :param imagesTr_dir: path to the imagesTr folder of that dataset
    :param imagesTs_dir: path to the imagesTs folder of that dataset. Can be None
    :param modalities: tuple of strings with modality names. must be in the same order as the images (first entry
    corresponds to _0000.nii.gz, etc). Example: ('T1', 'T2', 'FLAIR').
    :param labels: dict with int->str (key->value) mapping the label IDs to label names. Note that 0 is always
    supposed to be background! Example: {0: 'background', 1: 'edema', 2: 'enhancing tumor'}
    :param dataset_name: The name of the dataset. Can be anything you want
    :param license:
    :param dataset_description:
    :param dataset_reference: website of the dataset, if available
    :param dataset_release:
    :return:
    """
    train_identifiers = get_identifiers_from_splitted_files(imagesTr_dir)

    if imagesTs_dir is not None:
        test_identifiers = get_identifiers_from_splitted_files(imagesTs_dir)
    else:
        test_identifiers = []

    json_dict = {}
    json_dict['name'] = "NPZ"
    json_dict['description'] = "NPZ"
    json_dict['tensorImageSize'] = "3D"
    json_dict['reference'] = dataset_reference
    json_dict['licence'] = license
    json_dict['release'] = dataset_release
    json_dict['modality'] = {"0": "CT"}
    json_dict['labels'] = {
        "0": "background",
        "1": "NPZ"
    }

    json_dict['numTraining'] = len(train_identifiers)
    json_dict['numTest'] = len(test_identifiers)
    json_dict['training'] = [
        {'image': "./imagesTr/%s.nii.gz" % i, "label": "./labelsTr/%s.nii.gz" % i} for i
        in
        train_identifiers]
    json_dict['test'] = ["./imagesTs/%s.nii.gz" % i for i in test_identifiers]

    output_file += "dataset.json"
    if not output_file.endswith("dataset.json"):
        print("WARNING: output file name is not dataset.json! This may be intentional or not. You decide. "
              "Proceeding anyways...")
    save_json(json_dict, os.path.join(output_file))


if __name__ == "__main__":
    output_file = '/root/autodl-tmp/home/nnUNetFrame/DATASET/nnUNet_raw/nnUNet_raw_data/Task01_NPZ/'
    imagesTr_dir = '/root/autodl-tmp/home/nnUNetFrame/DATASET/nnUNet_raw/nnUNet_raw_data/Task01_NPZ/imagesTr'
    imagesTs_dir = '/root/autodl-tmp/home/nnUNetFrame/DATASET/nnUNet_raw/nnUNet_raw_data/Task01_NPZ/imagesTs'
    labelsTr = '/root/autodl-tmp/home/nnUNetFrame/DATASET/nnUNet_raw/nnUNet_raw_data/Task01_NPZ/labelsTr'

    modalities = '"0": "CT"'
    labels = {
        "0": "background",
        "1": "NPC"
    }

    get_identifiers_from_splitted_files(output_file)
    generate_dataset_json(output_file,
                          imagesTr_dir,
                          imagesTs_dir,
                          labelsTr,
                          modalities,
                          labels
                          )

运行后结果如下所示
在这里插入图片描述
即获得了json文件
(6)转换数据集,nnUNet对数据集的名称特别严格,因此需要按照它的标准进行转换

nnUNet_convert_decathlon_task -i /*/home/nnUNetFrame/DATASET/nnUNet_raw/nnUNet_raw_data/Task01_NPZ

转换后会出现Task001_NPZ文件夹,里面的内容如下图所示
nnUNet训练并推理自己的数据集_第4张图片

(7)对数据集进行预处理,nnUNet精髓之一

nnUNet_plan_and_preprocess -t 01

01即自己Task文件夹编号
(8)训练

nnUNet_train 3d_fullres nnUNetTrainerV2 01 4

01是编号,4是代表5折交叉验证中的第4折

获取的模型如下图所示,位于nnUNet_trained_models内,训练开始后一定要先看看是否有这些,有的话就说明步骤正确,可以放心等它跑出结果了,否则说明前面的步骤有一步出bug了,需要看看哪里出了问题
nnUNet训练并推理自己的数据集_第5张图片
跑的速度比较慢,我这个数据集在RTX3090跑一个epoch要70多秒,总共要跑1000个epoch,因此需要耐心等待。

二、推理:

(1)创建好存放数据集的文件夹
在这里插入图片描述
infersTs为空文件夹,是存放nnUNet的推理结果的地方,labelsTs存放测试集的label文件
(2)
在终端输入

nnUNet_predict -i /*/home/nnUNetFrame/DATASET/nnUNet_raw/nnUNet_raw_data/Task001_NPZ/imagesTs/ -o /*/home/nnUNetFrame/DATASET/nnUNet_raw/nnUNet_raw_data/Task001_NPZ/inferTs -t 1 -m 3d_fullres -f 4

根据自己的数据集进行修改,该命令的意思是让nnUNet进行推理,同时告诉nnUNet你的测试集在哪里,即-i后面部分,以及告诉nnUNet得到的对应的推理结果应当放到哪里,即-o后面部分,-t后面部分是任务序号,-m后面部分是推理采用的网络架构,-f后面部分是几折交叉验证
(3)计算Dice
由于nnUNet中并没有计算测试集dice的部分,所以需要自己编程实现
2D及3D dice计算:

import torch
import torch.nn as nn
from glob import glob
import SimpleITK as sitk

def dice(predict, soft_y):
    """
    get dice scores for each class in predict and soft_y
    """
    tensor_dim = len(predict.size())
    num_class = list(predict.size())[1]
    if (tensor_dim == 5):
        soft_y = soft_y.permute(0, 2, 3, 4, 1)
        predict = predict.permute(0, 2, 3, 4, 1)
    elif (tensor_dim == 4):
        soft_y = soft_y.permute(0, 2, 3, 1)
        predict = predict.permute(0, 2, 3, 1)
    else:
        raise ValueError("{0:}D tensor not supported".format(tensor_dim))

    soft_y = torch.reshape(soft_y, (-1, num_class))
    predict = torch.reshape(predict, (-1, num_class))

    y_vol = torch.sum(soft_y, dim=0)
    p_vol = torch.sum(predict, dim=0)
    intersect = torch.sum(soft_y * predict, dim=0)
    dice_score = (2.0 * intersect + 1e-5) / (y_vol + p_vol + 1e-5)
    return dice_score


if __name__ == "__main__":
    infer_path = "/root/autodl-tmp/home/nnUNetFrame/DATASET/nnUNet_raw/nnUNet_raw_data/Task001_NPZ/infersTs/*"  # 推理结果地址
    label_path = "/root/autodl-tmp/home/nnUNetFrame/DATASET/nnUNet_raw/nnUNet_raw_data/Task001_NPZ/labelsTs/*"  # 测试集label地址
    infer = sorted(glob(infer_path))
    label = sorted(glob(label_path))
    score_avg = 0
    for i in range(len(label)):
        inf, lab = infer[i], label[i]
        inf, lab = sitk.ReadImage(inf, sitk.sitkFloat32), sitk.ReadImage(lab, sitk.sitkFloat32)
        inf, lab = sitk.GetArrayFromImage(inf), sitk.GetArrayFromImage(lab)
        inf, lab = torch.from_numpy(inf), torch.from_numpy(lab)
        inf, lab = inf.unsqueeze(0).unsqueeze(0), lab.unsqueeze(0).unsqueeze(0)
        score = dice(inf, lab)
        print(i, infer[i])
        print(score)
        score_avg += score
    score_avg /= len(label)
    print("avg dice is ", score_avg)

参考:
(其他具体步骤可以看一下博文)

https://blog.csdn.net/weixin_42061636/article/details/107623757
https://blog.csdn.net/u014264373/article/details/117417790
https://blog.csdn.net/weixin_42061636/article/details/107719274

你可能感兴趣的:(pytorch,python)