3d图像裁剪

由于gpu/cpu内存太小,所以需要将3d图像切分为小块,然后保存为numpy数组文件

一、裁剪图像

import matplotlib.pyplot as plt
import nibabel as nib
import os
import numpy as np
from tensorflow.keras.utils import to_categorical


def crop(img,label=None,patch_size=[128,128,16]):
    """
    patch_size[0],patch[1]能保证被x,y整除 \n
    patch_size[2]可以不用
    """
    def _crop(input_arry):
        x,y,z = input_arry.shape
        assert (x%patch_size[0]==0 and y%patch_size[1]==0),"patch_size[0]和patch[1]不能被x,y整除"
        maxiter = int(np.ceil(z/patch_size[-1]))

        crops_img = []
        w,h,d = patch_size
        for i in range(x//w):
            for j in range(y//h):
                for k in range(maxiter-1):
                    imgt = input_arry[i*w:(i+1)*w,j*h:(j+1)*h,k*d:(k+1)*d]
                    # print(imgt.shape) # 128.128,16
                    if type(label)!=type(None):
                        # 说明是训练阶段的切分
                        # 计算背景是否小于0.95
                        # 1所占的比例
                        imgtlabel = label[i*w:(i+1)*w,j*h:(j+1)*h,k*d:(k+1)*d]
                        prob = np.sum(imgtlabel)/(imgtlabel.shape[0]*imgtlabel.shape[1]*imgtlabel.shape[2])
                        if prob<0.05:# 不加入
                            continue
                        crops_img.append(imgt)
                    else:
                        # 否则是预测的切分,直接加入
                        crops_img.append(imgt)

        # 处理z轴多出来的部分,从后往前切
        for i in range(x//w):
            for j in range(y//h):
                imgt = input_arry[i*w:(i+1)*w,j*h:(j+1)*h,-d:]
                # print(f'in last crop,imgt.shape={imgt.shape}')
                if type(label)!=type(None):
                    imgtlabel = label[i*w:(i+1)*w,j*h:(j+1)*h,-d:]
                    prob = np.sum(imgtlabel)/(imgtlabel.shape[0]*imgtlabel.shape[1]*imgtlabel.shape[2])
                    if prob<0.05:# 不加入
                        continue
                    crops_img.append(imgt)
                else:
                    crops_img.append(imgt)
        
        crops_img = np.array(crops_img)
        # print(crops_img.shape) # 4*4*8+4*4=144,[144,128,128,16]
        return crops_img

    if type(label)==type(None):
        ans = _crop(img)
        print(f'crop_img.shape={ans.shape}')
        return ans
    else:
        ans1,ans2 = _crop(img),_crop(label)
        print(f'crop_img.shape={ans1.shape}, _crop(label).shape={ans2.shape}')
        return ans1,ans2



def standardize(img):
    """
    img.shape=[x,y,z,c] 
    
    return: 标准化之后的图像
    """
    standardized_image = np.zeros(img.shape)
    if img.ndim==3:
        # 只有1个通道,[x,y,z]
        for z in range(img.shape[-1]):
            img_slice = img[:,:,z]
            centered = img_slice - np.mean(img_slice)
            centered_scaled = centered / np.std(centered)
            standardized_image[:,:,z] = centered_scaled

        return standardized_image
    
    for c in range(img.shape[-1]):
        for z in range(img.shape[2]):
            img_slice = img[:,:,z,c]
            centered = img_slice - np.mean(img_slice)
            centered_scaled = centered / np.std(centered)
            standardized_image[:,:,z,c] = centered_scaled
    
    return standardized_image


if __name__ == '__main__':
    DataSetDir = "D:/files/datasets/liver/"
    # 由于保证了图像和标签的命名一致
    # 所以可以一对一对取
    filenames = os.listdir(DataSetDir+"masks")
    train_x=[]
    train_y=[]
    for filename in filenames:
        image_path = DataSetDir+"origin/"+filename
        label_path = DataSetDir+"masks/"+filename
        print(f'img:{image_path}',f'label: {label_path}')

        image = nib.load(image_path)
        img_arry = image.get_fdata() # [512,512,129]
        img_arry = standardize(img_arry)

        label = nib.load(label_path)
        label_arry = label.get_fdata() # [512,512,129]

        imgs,labels = crop(img_arry,label_arry)
        train_x.append(imgs)
        train_y.append(labels)
    
    x_train = np.concatenate(train_x,axis=0)
    y_train = np.concatenate(train_y,axis=0)
    print(f'x_train.shape={x_train.shape}, y_train.shape={y_train.shape}')
    # 增加channel维度
    x_train = np.expand_dims(x_train,axis=-1)
    # one-hot编码
    y_train = to_categorical(y_train,num_classes=2)
    print(np.unique(y_train))

    np.save('x_train.npy',x_train)
    np.save('y_train.npy',y_train)

    

二、载入图像

import numpy as np
import matplotlib.pyplot as plt


if __name__ == '__main__':
    train_x = np.load('x_train.npy')
    train_y = np.load('y_train.npy')
    print(train_x.shape,train_y.shape)

    plt.figure(figsize=(3,4))
    plt.subplot(1,2,1)
    plt.axis('off')
    plt.imshow(train_x[54,:,:,2,:],cmap='gray')
    plt.subplot(1,2,2)
    plt.axis('off')
    plt.imshow(train_y[54,:,:,2,1],cmap='gray')
    plt.show()

3d图像裁剪_第1张图片

你可能感兴趣的:(医学图像处理,计算机视觉,python,opencv)