BraTS2021 + nnU-Net配置

nnUnet:https://github.com/MIC-DKFZ/nnUNet

BraTS2021: BRaTS 2021 Task 1 Dataset | Kaggle

apex安装

git clone https://github.com/NVIDIA/apex
cd apex
pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./

hiddenlayer安装

pip install --upgrade git+https://github.com/FabianIsensee/hiddenlayer.git@more_plotted_details#egg=hiddenlayer

将下载的BraTS2021数据集整理成nnU-Net所要求的格式。

除去命名规范和文件夹整理规范外,还有两个学生要注意的点。一个是标签必须连续标号(原数据集标签为0,1,2,4,需修改为0,1,2,3)。还有一个是坐标的原点对齐,需要将训练集的坐标原点设置为(0,0,0),和标签对齐。

import glob
import SimpleITK as sitk
import numpy as np
import matplotlib.pyplot as plt
import os
import shutil

path = 'C:\\Users\\Dylon\\Downloads\\archive\\BraTS2021_Training_Data'
train_path = 'C:\\Users\\Dylon\\Downloads\\Task501_BrainTumour\\imagesTr'
label_path = 'C:\\Users\\Dylon\\Downloads\\Task501_BrainTumour\\labelsTr'
for root, dirs, files in os.walk(path):
    for name in files:
        if name.find('t1.nii.gz') != -1:
            name2 = name.replace('t1', '0000')
            shutil.move(os.path.join(root, name), os.path.join(train_path, name2))
        elif name.find('t2.nii.gz') != -1:
            name2 = name.replace('t2', '0001')
            shutil.move(os.path.join(root, name), os.path.join(train_path, name2))
        elif name.find('flair.nii.gz') != -1:
            name2 = name.replace('flair', '0002')
            shutil.move(os.path.join(root, name), os.path.join(train_path, name2))
        elif name.find('t1ce.nii.gz') != -1:
            name2 = name.replace('t1ce', '0003')
            shutil.move(os.path.join(root, name), os.path.join(train_path, name2))
        elif name.find('seg.nii.gz') != -1:
            name2 = name.replace('_seg', '')
            shutil.move(os.path.join(root, name), os.path.join(label_path, name2))

def read_img(img_path):
    return sitk.GetArrayFromImage(sitk.ReadImage(img_path))

for root, dirs, files in os.walk(label_path):
    for name in files:
        p = os.path.join(root, name)
        img = read_img(p)
        img[img == 4] = 3
        sitk.WriteImage(sitk.GetImageFromArray(img), p)

for root, dirs, files in os.walk(train_path):
    for name in files:
        p = os.path.join(root, name)
        img = sitk.ReadImage(p)
        img.SetOrigin((0.0, 0.0, 0.0))
        sitk.WriteImage(img, p)

自动生成dataset.json文件

from nnunet.dataset_conversion.utils import generate_dataset_json


p1 = 'nnUNet_raw_data/Task501_BraTS/dataset.json'
p2 = 'nnUNet_raw_data/Task501_BraTS/imagesTr'
p3 = 'nnUNet_raw_data/Task501_BraTS/imagesTs'

generate_dataset_json(p1, p2, p3, ('t1', 't2', 'flair', 't1ce'),
                          labels={0: 'background', 1: 'NCR', 2: 'ED', 3: 'ET'}, dataset_name='BraTS21', license='hands off!')

预处理

nnUNet_plan_and_preprocess -t 501 --verify_dataset_integrity

运行

nnUNet_train 3d_fullres nnUNetTrainerV2 Task501_BraTS 0 --npz

你可能感兴趣的:(医学图像分割,计算机视觉)