python 常用自定义函数整理

以下函数主要用于记录和保存,方便自己查阅。
---------持续更新

1. 3D 图像处理

def numpy2sitk(arr, sitk_ori_img):
    # numpy转换为sitk
    sitk_img = sitk.GetImageFromArray(arr)
    sitk_img.SetOrigin(sitk_ori_img.GetOrigin())
    sitk_img.SetSpacing(sitk_ori_img.GetSpacing())
    sitk_img.SetDirection(sitk_ori_img.GetDirection())
    return sitk_img
import dicom2nifti
import os
import pydicom as pdic
 # dicom convert to nii file
 # method 1
def dicom2nii(dicom_path, save_root): 
	for patient in os.listdir(dicom_path):
		    new_path = dicom_path + patient +'/'
		    os.makedirs(save_root + patient +'/',exist_ok=True)
		    dicom2nifti.convert_directory(new_path,save_root + patient +'/')
		    create_name = os.listdir(save_root + patient)[0]
		    print(create_name)
		    os.rename(save_root + patient +f'/{create_name}', save_root + patient + f'/{patient}.nii.gz')
		    
# method 2
def dcm2nii(dcm_path, nii_path):
	    print(dcm_path)
	    dcm = pdic.read_file(dcm_path)
	    if 'ImagerPixelSpacing' in dcm:
	       	 	sp_2d = dcm.ImagerPixelSpacing
	    else:
	        	sp_2d = (0.278875, 0.278875)
	    arr = dcm.pixel_array
	    img = sitk.GetImageFromArray(arr)
	    img.SetSpacing((sp_2d[0], sp_2d[1], 1))
	    sitk.WriteImage(img, nii_path)
def resample_img(itk_image, out_spacing=[2.0, 2.0, 2.0], is_label=False):
    # resample images to 2mm spacing with simple itk

    original_spacing = itk_image.GetSpacing()
    original_size = itk_image.GetSize()

    out_size = [
        int(np.round(original_size[0] * (original_spacing[0] / out_spacing[0]))),
        int(np.round(original_size[1] * (original_spacing[1] / out_spacing[1]))),
        int(np.round(original_size[2] * (original_spacing[2] / out_spacing[2])))]

    resample = sitk.ResampleImageFilter()
    resample.SetOutputSpacing(out_spacing)
    resample.SetSize(out_size)
    resample.SetOutputDirection(itk_image.GetDirection())
    resample.SetOutputOrigin(itk_image.GetOrigin())
    resample.SetTransform(sitk.Transform())
    resample.SetDefaultPixelValue(itk_image.GetPixelIDValue())

    if is_label:
        resample.SetInterpolator(sitk.sitkNearestNeighbor)
    else:
        resample.SetInterpolator(sitk.sitkBSpline)

    return resample.Execute(itk_image)
def roi_extract(img_path, lab_path, patch_size):
    """从目标区域中心处扩展成patch_size大小的区域

    Args:
        img_path (str): 图像路径
        lab_path (str): 标签路径
        patch_size (int): 扩展的目标尺寸

    Returns:
        sitk image: 输出sitk数据
    """    
    img = sitk.ReadImage(img_path)
    lab = sitk.ReadImage(lab_path)
    img_arr = sitk.GetArrayFromImage(img)
    lab_arr = sitk.GetArrayFromImage(lab)
    expand_slice = 0
    nonzero_list = []
    for ind in [(1,2),(0,2),(0,1)]:
        nonzero_ind = np.any(lab_arr, axis=ind)
        start_slice, end_slice = np.where(nonzero_ind)[0][[0, -1]] # [  0, 605]
        nonzero_list.append([start_slice, end_slice])

    for ind, slice in enumerate(nonzero_list):
        slice_diff = slice[1]-slice[0]
        expand_slice = (patch_size-slice_diff)//2
        if slice[0] < expand_slice:
            slice[0] = 0
            slice[1] += 2*expand_slice
        elif slice[1] + expand_slice >= lab_arr.shape[ind]:
            slice[1] = lab_arr.shape[ind] - 1
            slice[0] -= 2*expand_slice
        else:
            slice[0] -= expand_slice
            slice[1] += expand_slice
        # 补全
        if slice[1]-slice[0] > patch_size:
            slice[0] += 1
        if slice[1]-slice[0] < patch_size:
            slice[1] += 1
    new_img_arr = img_arr[nonzero_list[0][0]:nonzero_list[0][1], nonzero_list[1][0]:nonzero_list[1][1], nonzero_list[2][0]:nonzero_list[2][1]]
    new_lab_arr = lab_arr[nonzero_list[0][0]:nonzero_list[0][1], nonzero_list[1][0]:nonzero_list[1][1], nonzero_list[2][0]:nonzero_list[2][1]] 
    
    sitk_img = numpy2sitk(new_img_arr, img)
    sitk_lab = numpy2sitk(new_lab_arr, lab)

    return sitk_img, sitk_lab
def slice_slide_crop(img_path, lab_path, patch_num, patch_size):
    """朝3D数据的切片方向切patch_num+1个patch_size尺寸的Patch

    Args:
        img_path (str): 图像路径
        lab_path (str): 标签路径
        patch_num (int): patch数量
        patch_size (int): patch尺寸

    Returns:
        sitk: 输出sitk数据并保存
    """    
    img = sitk.ReadImage(img_path)
    lab = sitk.ReadImage(lab_path)
    if patch_num*patch_size > img.GetSize()[2]:
        stride = (img.GetSize()[2]-patch_size)//patch_num
        start_slice, end_slice = 0, patch_size
        for i in range(patch_num+1):
            crop_slice_ind = [start_slice + i*stride, end_slice + i*stride]
            new_img = img[:,:,crop_slice_ind[0]:crop_slice_ind[1]]
            new_lab = lab[:,:,crop_slice_ind[0]:crop_slice_ind[1]]
            sitk.WriteImage(new_img,f'0001_0001/crop_img/0001_0001_s{i}.nii.gz')
            sitk.WriteImage(new_lab,f'0001_0001/crop_lab/0001_0001_s{i}.nii.gz')
        return True
    else:
        print('patch_size/patch_num is too small')
        return False

2. Pytorch相关处理

from collections import OrderedDict
def load_distributed_model(model_path):
	# 加载并行训练的模型参数
	device = torch.device("cuda")
	model = DGCNN().to(device)  #自己的模型
	state_dict = torch.load(model_path)    #存放模型的位置
	new_state_dict = OrderedDict()
	for k, v in state_dict.items():
    	name = k[7:] # remove `module.`
    	new_state_dict[name] = v
    # load params
	model.load_state_dict(new_state_dict)
	return model

3. 错误情况

# sitk.ReadImage 报错 正交坐标系,重新写入sfrom/qfrom信息
for patient in os.listdir(root):
    new_path = root+patient
    img = nib.load(new_path + f'/{patient}.nii.gz')
    qform = img.get_qform()
    img.set_qform(qform)
    sfrom = img.get_sform()
    img.set_sform(sfrom)
    nib.save(img, new_path + f'/{patient}.nii.gz')
#  修改 spacing mm单位为cm单位
def spacing_adjust(nii_data):

    ori_spacing = nii_data.GetSpacing()
    ori_origin = nii_data.GetOrigin()
    nii_data.SetSpacing([i/10. for i in ori_spacing])
    nii_data.SetOrigin([i/10. for i in ori_origin])
    return nii_data

你可能感兴趣的:(精简记录,python,python)