nnUNet代码学习——数据预处理部分(二)

在将数据黑边crop并保存为npz、pkl文件之后,nnUNet接下来就是以数据集为单位,分析数据集基本信息。而这些都集成在**nnunet\experiment_planning\DatasetAnalyzer.py**文件中,下面对该模块文件进行分析,该模块的作用有:

  1. 统计数据集中所有数据的大小和spacing
  2. 计算不同模态间数据集强度分布
  3. 计算crop区域所占比例

再完成上述统计后,使用pkl将数据集信息保存

DataAnalyzer数据分析模块


class DatasetAnalyzer(object):
    def __init__(self, folder_with_cropped_data, overwrite=True, num_processes=default_num_threads):
        """
        :param folder_with_cropped_data:
        :param overwrite: If True then precomputed values will not be used and instead recomputed from the data.
        False will allow loading of precomputed values. This may be dangerous though if some of the code of this class
        was changed, therefore the default is True.
        """
        self.num_processes = num_processes
        self.overwrite = overwrite
        self.folder_with_cropped_data = folder_with_cropped_data
        self.sizes = self.spacings = None
        self.patient_identifiers = get_patient_identifiers_from_cropped_files(self.folder_with_cropped_data)  # 这里得到crop后所有的id号
        assert isfile(join(self.folder_with_cropped_data, "dataset.json")), \
            "dataset.json needs to be in folder_with_cropped_data"
        self.props_per_case_file = join(self.folder_with_cropped_data, "props_per_case.pkl")
        self.intensityproperties_file = join(self.folder_with_cropped_data, "intensityproperties.pkl")

    # 读取每个case的properties
    def load_properties_of_cropped(self, case_identifier):
        with open(join(self.folder_with_cropped_data, "%s.pkl" % case_identifier), 'rb') as f:
            properties = pickle.load(f)
        return properties

    # 这里还需要debug重新看一下
    @staticmethod
    def _check_if_all_in_one_region(seg, regions):
        res = OrderedDict()
        for r in regions:
            new_seg = np.zeros(seg.shape)
            for c in r:
                new_seg[seg == c] = 1
            labelmap, numlabels = label(new_seg, return_num=True)
            if numlabels != 1:
                res[tuple(r)] = False
            else:
                res[tuple(r)] = True
        return res

    @staticmethod
    def _collect_class_and_region_sizes(seg, all_classes, vol_per_voxel):
        volume_per_class = OrderedDict()
        region_volume_per_class = OrderedDict()
        for c in all_classes:
            region_volume_per_class[c] = []
            volume_per_class[c] = np.sum(seg == c) * vol_per_voxel
            labelmap, numregions = label(seg == c, return_num=True)
            for l in range(1, numregions + 1):
                region_volume_per_class[c].append(np.sum(labelmap == l) * vol_per_voxel)
        return volume_per_class, region_volume_per_class

    def _get_unique_labels(self, patient_identifier):
        seg = np.load(join(self.folder_with_cropped_data, patient_identifier) + ".npz")['data'][-1]
        unique_classes = np.unique(seg)
        return unique_classes

    def _load_seg_analyze_classes(self, patient_identifier, all_classes):
        """
        1) what class is in this training case?
        2) what is the size distribution for each class?
        3) what is the region size of each class?
        4) check if all in one region
        :return:
        """
        seg = np.load(join(self.folder_with_cropped_data, patient_identifier) + ".npz")['data'][-1]
        pkl = load_pickle(join(self.folder_with_cropped_data, patient_identifier) + ".pkl")
        vol_per_voxel = np.prod(pkl['itk_spacing'])

        # ad 1)
        unique_classes = np.unique(seg)

        # 4) check if all in one region,这里感觉有点重复了
        regions = list()
        regions.append(list(all_classes))
        for c in all_classes:
            regions.append((c, ))

        all_in_one_region = self._check_if_all_in_one_region(seg, regions)  # 判断每个类别是否只有一个区域,还是存在多个区域

        # 2 & 3) region sizes
        volume_per_class, region_sizes = self._collect_class_and_region_sizes(seg, all_classes, vol_per_voxel)  # 每个类别区域在实际空间中所占大小

        return unique_classes, all_in_one_region, volume_per_class, region_sizes

    def get_classes(self):
        datasetjson = load_json(join(self.folder_with_cropped_data, "dataset.json"))
        return datasetjson['labels']

    # 对seg进行分析
    def analyse_segmentations(self):
        class_dct = self.get_classes()

        if self.overwrite or not isfile(self.props_per_case_file):
            p = Pool(self.num_processes)
            res = p.map(self._get_unique_labels, self.patient_identifiers)  # 得到每个id中unique_label
            p.close()
            p.join()

            # 注意这里使用OrderedDict保存数据,比单纯的使用list更好
            props_per_patient = OrderedDict()
            # 统计所有分割结果的label信息,每个数据有多少个类别
            for p, unique_classes in zip(self.patient_identifiers, res):
                props = dict()
                props['has_classes'] = unique_classes
                props_per_patient[p] = props

            save_pickle(props_per_patient, self.props_per_case_file)
        else:
            props_per_patient = load_pickle(self.props_per_case_file)
        return class_dct, props_per_patient

    # 根据之前保存的properties,得到crop后每个case的size以及spacing
    def get_sizes_and_spacings_after_cropping(self):
        sizes = []
        spacings = []
        # for c in case_identifiers:
        for c in self.patient_identifiers:
            properties = self.load_properties_of_cropped(c)
            sizes.append(properties["size_after_cropping"])
            spacings.append(properties["original_spacing"])

        return sizes, spacings

    def get_modalities(self):
        datasetjson = load_json(join(self.folder_with_cropped_data, "dataset.json"))
        modalities = datasetjson["modality"]
        modalities = {int(k): modalities[k] for k in modalities.keys()}
        return modalities

    def get_size_reduction_by_cropping(self):
        size_reduction = OrderedDict()
        for p in self.patient_identifiers:
            props = self.load_properties_of_cropped(p)
            shape_before_crop = props["original_size_of_raw_data"]
            shape_after_crop = props['size_after_cropping']
            size_red = np.prod(shape_after_crop) / np.prod(shape_before_crop)
            size_reduction[p] = size_red   # 得到crop区域占全部数据的比例
        return size_reduction
		
		
    def _get_voxels_in_foreground(self, patient_identifier, modality_id):
        all_data = np.load(join(self.folder_with_cropped_data, patient_identifier) + ".npz")['data']
        modality = all_data[modality_id]
        mask = all_data[-1] > 0    # 这个只统计了seg中大于零的区域 返回true Fasle
				# 只获取true中的数据,每10个间隔,统计一下数据,减少后面的计算量,返回list
        voxels = list(modality[mask][::10]) # no need to take every voxel 
        return voxels
		
		# 统计数据的强度信息包括中位数、均值、方差......
    @staticmethod
    def _compute_stats(voxels):
        if len(voxels) == 0:
            return np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan
        median = np.median(voxels)
        mean = np.mean(voxels)
        sd = np.std(voxels)
        mn = np.min(voxels)
        mx = np.max(voxels)
        percentile_99_5 = np.percentile(voxels, 99.5)
        percentile_00_5 = np.percentile(voxels, 00.5)
        return median, mean, sd, mn, mx, percentile_99_5, percentile_00_5

    # 收集不同模态间数据强度信息
    def collect_intensity_properties(self, num_modalities):
        if self.overwrite or not isfile(self.intensityproperties_file):
            # 注意这里的多进程使用,这里可以在多进程间,得到处理结果中间值,并进行相应操作,可以在以后的代码中用到
            p = Pool(self.num_processes)

            results = OrderedDict()
            for mod_id in range(num_modalities):
                results[mod_id] = OrderedDict()
                # _get_voxels_in_foreground(self, patient_identifier, modality_id) id号以及模态id
                v = p.starmap(self._get_voxels_in_foreground, zip(self.patient_identifiers,
                                                              [mod_id] * len(self.patient_identifiers)))
		
                w = []
                for iv in v:
                    w += iv
								# 这里的用法没有看的太明白
                median, mean, sd, mn, mx, percentile_99_5, percentile_00_5 = self._compute_stats(w)  # w是不同模态的list
                local_props = p.map(self._compute_stats, v)  # 返回强度数据

                props_per_case = OrderedDict()

                # 然后这里再统计不同模态数据的强度统计情况
                for i, pat in enumerate(self.patient_identifiers):
                    props_per_case[pat] = OrderedDict()
                    props_per_case[pat]['median'] = local_props[i][0]
                    props_per_case[pat]['mean'] = local_props[i][1]
                    props_per_case[pat]['sd'] = local_props[i][2]
                    props_per_case[pat]['mn'] = local_props[i][3]
                    props_per_case[pat]['mx'] = local_props[i][4]
                    props_per_case[pat]['percentile_99_5'] = local_props[i][5]
                    props_per_case[pat]['percentile_00_5'] = local_props[i][6]

                results[mod_id]['local_props'] = props_per_case
                results[mod_id]['median'] = median
                results[mod_id]['mean'] = mean
                results[mod_id]['sd'] = sd
                results[mod_id]['mn'] = mn
                results[mod_id]['mx'] = mx
                results[mod_id]['percentile_99_5'] = percentile_99_5
                results[mod_id]['percentile_00_5'] = percentile_00_5

            p.close()
            p.join()
						# 使用pkl将强度信息进行保存
            save_pickle(results, self.intensityproperties_file)
        else:
            results = load_pickle(self.intensityproperties_file)
        return results

    # 对数据进行区分
    def analyze_dataset(self, collect_intensityproperties=True):
        # get all spacings and sizes
				# 以list的形式返回所有数据集的sizes和spacings
        sizes, spacings = self.get_sizes_and_spacings_after_cropping()

        # get all classes and what classes are in what patients
        # class min size
        # region size per class
				# 读取json文件,返回json["labels"],这里使用字典进行存储
        classes = self.get_classes()  
        all_classes = [int(i) for i in classes.keys() if int(i) > 0]

        # modalities
				# 返回数据模态信息
        modalities = self.get_modalities()   

        # collect intensity information
        if collect_intensityproperties:
            intensityproperties = self.collect_intensity_properties(len(modalities))  # 不同模态数据的强度分布情况
        else:
            intensityproperties = None

        # size reduction by cropping
				# crop前后的大小比例
        size_reductions = self.get_size_reduction_by_cropping()

        dataset_properties = dict()
        dataset_properties['all_sizes'] = sizes
        dataset_properties['all_spacings'] = spacings
        dataset_properties['all_classes'] = all_classes
        dataset_properties['modalities'] = modalities  # {idx: modality name}
        dataset_properties['intensityproperties'] = intensityproperties
        dataset_properties['size_reductions'] = size_reductions  # {patient_id: size_reduction}
        # 这里统计了整个数据size、spacing、类别信息、模态信息、模态强度信息、以及crop区域所占比例,这里还只是根据数据进行统计,未根据mask统计信息
        save_pickle(dataset_properties, join(self.folder_with_cropped_data, "dataset_properties.pkl"))
        return dataset_properties

你可能感兴趣的:(nnUNet,python,人工智能,图像处理)