UNet笔记

import os

task_name = 'Task114_heart_MNMs'
raw_data_base_path="/home/ilkay/Documents/ruru/nnUNet/ocmr_task114/nnUNet_raw"
preprocessed_path="/home/ilkay/Documents/ruru/nnUNet/ocmr_task114/nnUNet_preprocessed"
result_folder_path="/home/ilkay/Documents/ruru/nnUNet/ocmr_task114/nnUNet_trained_models"

os.environ['nnUNet_raw_data_base'] = raw_data_base_path
os.environ['nnUNet_preprocessed'] = preprocessed_path
os.environ['RESULTS_FOLDER'] = result_folder_path
! nnUNet_download_pretrained_model Task114_heart_MNMs
! nnUNet_plan_and_preprocess -t 114 
# ! nnUNet_plan_and_preprocess -t 114 -pl3d None
! nnUNet_train 3d_fullres nnUNetTrainerV2 114 1
! nnUNet_find_best_configuration -tr nnUNetTrainerV2 -m 3d_fullres -t 114
'''
       -i: 输入(你的待推理测试集);
       -o: 输出(测试集的推理结果);
       -t: 你的任务对应的数字ID;
       -m: 对应的训练时使用的网络架构;
       -f: 数字4代表使用五折交叉验证训练出的模型
'''
! nnUNet_predict -i /home/ilkay/Documents/ruru/nnUNet/ocmr_task114/nnUNet_raw/nnUNet_raw_data/Task114_heart_MNMs/imagesTs -o /home/ilkay/Documents/ruru/nnUNet/ocmr_task114/nnUNet_raw/nnUNet_raw_data/Task114_heart_MNMs/inferTs -t 114 -m 3d_fullres -f 4
test_labels_path = "/home/ilkay/Documents/ruru/nnUNet/ocmr_task114/nnUNet_raw/nnUNet_raw_data/Task114_heart_MNMs/labelsTs"
path_3d_fullres = "/home/ilkay/Documents/ruru/nnUNet/ocmr_task114/nnUNet_raw/nnUNet_raw_data/Task114_heart_MNMs/inferTs"
! nnUNet_evaluate_folder -ref $test_labels_path -pred $path_3d_fullres -l 1 2 3
import json
import numpy as np
def get_results(pred_path):
    with open(os.path.join(pred_path, "summary.json")) as f:
        summary = json.load(f)
    scores = []
    for i,val in summary['results']['mean'].items():
        print('idx: ', i, 'Dice:', val['Dice'])
        scores.append(val['Dice'])
    print('mean:', 'Dice:', np.mean(scores))
get_results(path_3d_fullres)
import cv2
from PIL import Image
import imageio
import nibabel as nib
import numpy as np


def label_color(label):
    
    colors = [
    [31  , 0   , 255] ,
    [0   , 159 , 255] ,
    [255 , 19  , 0],
    [0   , 255 , 178]]

    return colors[label] 
    

def draw_mask(image, box, mask, label=None, color=None, binarize_threshold=0.5):
    """ Draws a mask in a given box.

    Args
        image              : Three dimensional image to draw on.
        box                : Vector of at least 4 values (x1, y1, x2, y2) representing a box in the image.
        mask               : A 2D float mask which will be reshaped to the size of the box, binarized and drawn over the image.
        color              : Color to draw the mask with. If the box has 5 values, the last value is assumed to be the label and used to construct a default color.
        binarize_threshold : Threshold used for binarizing the mask.
    """
    if label is not None:
        color = label_color(label)
    if color is None:
        color = (0, 255, 0)

    # resize to fit the box
    mask = mask.astype(np.float32)
    mask = cv2.resize(mask, (box[2] - box[0], box[3] - box[1]))

    # binarize the mask
    mask = (mask > binarize_threshold).astype(np.uint8)

    # draw the mask in the image
    mask_image = np.zeros((image.shape[0], image.shape[1]), np.uint8)
    mask_image[box[1]:box[3], box[0]:box[2]] = mask
    mask = mask_image

    # compute a nice border around the mask
    border = mask - cv2.erode(mask, np.ones((5, 5), np.uint8), iterations=1)

    # apply color to the mask and border
    mask = (np.stack([mask] * 3, axis=2) * color).astype(np.uint8)
    border = (np.stack([border] * 3, axis=2) * (255, 255, 255)).astype(np.uint8)

    # draw the mask
    indices = np.where(mask != [0, 0, 0])
    image[indices[0], indices[1], :] = 0.5 * image[indices[0], indices[1], :] + 0.5 * mask[indices[0], indices[1], :]

    # draw the border
    #indices = np.where(border != [0, 0, 0])
    #image[indices[0], indices[1], :] = 0.2 * image[indices[0], indices[1], :] + 0.8 * border[indices[0], indices[1], :]

    
def normalize_minmax(img):
    mi, ma = img.min(), img.max()
    return (img - mi) / (ma - mi)

    return img


def draw_gif(data_path, gt_path, preds_path, gif_path,):
    os.makedirs(os.path.dirname(gif_path), exist_ok=True)
    data = nib.load(data_path).get_data()
    gt = nib.load(gt_path).get_data()
    preds = nib.load(preds_path).get_data()


    windowed = normalize_minmax(data) * 255
    images = []
    for i in range(windowed.shape[-1]):
        img = Image.fromarray(windowed[...,i].astype("uint8")).convert("RGB")
        img = np.array(img)
        img_orig = img.copy()
        img_pred = img.copy()

        # draw gt
        mask = gt[...,i]
        for j in range(1,4):
            class_mask = (mask == j).astype("int")
            if np.sum(class_mask) > 0:
                draw_mask(img,[0,0,windowed.shape[1],windowed.shape[0]], class_mask, label=j)

        # draw pred
        pred = preds[...,i]
        for j in range(1,4):
            class_mask = (pred == j).astype("int")
            if np.sum(class_mask) > 0:
                draw_mask(img_pred,[0,0,windowed.shape[1],windowed.shape[0]], class_mask, label=j)

        img_stacked = np.hstack([img, img_orig, img_pred]).astype("uint8")
        
        img = Image.fromarray(img_stacked).resize((512*3, 512))

        jpeg_path = os.path.join(gif_path.replace(".gif", ""), str(i).zfill(4) + ".jpeg")
        os.makedirs(os.path.dirname(jpeg_path), exist_ok=True)
        img.save(jpeg_path)
        images.append(img)
    imageio.mimsave(gif_path, images, duration=0.3)
test_images_path = "/home/ilkay/Documents/ruru/nnUNet/ocmr_task114/nnUNet_raw/nnUNet_raw_data/Task114_heart_MNMs/imagesTs"
instances = [instance.replace("_0000.nii.gz", "") for instance in os.listdir(test_images_path) if instance.endswith('.nii.gz')]

for instance in instances[0:100]:
    data_path = f"/{test_images_path}/{instance}_0000.nii.gz"
    gt_path = f"{test_labels_path}/{instance}.nii.gz"
    preds_path = f"{path_3d_fullres}/{instance}.nii.gz"
    gif_path = f"/home/ilkay/Documents/ruru/nnUNet/ocmr_task114/nnUNet_raw/nnUNet_raw_data/Task114_heart_MNMs/draw_result/{instance}.gif"

    draw_gif(data_path, gt_path, preds_path, gif_path)




UNet笔记_第1张图片
UNet笔记_第2张图片

你可能感兴趣的:(MRI,pytorch,深度学习)