nnUNet基础分割网络框架

在分析完nnUNet的基础训练框架之后,再接着分析一下nnUNet的基础分割网络,2D,3D的分割网络都是在此基础上进行构建的。此部分的代码在nnunet/network_architecture/neural_network.py中。

#基础的神经网络框架
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()

    def get_device(self):
        if next(self.parameters()).device.type == "cpu":
            return "cpu"
        else:
            return next(self.parameters()).device.index

    def set_device(self, device):
        if device == "cpu":
            self.cpu()
        else:
            self.cuda(device)

    def forward(self, x):
        raise NotImplementedError

分割网络框架:

def predict_3D(self, x: np.ndarray, do_mirroring: bool, mirror_axes: Tuple[int, ...] = (0, 1, 2),
                   use_sliding_window: bool = False,
                   step_size: float = 0.5, patch_size: Tuple[int, ...] = None, regions_class_order: Tuple[int, ...] = None,
                   use_gaussian: bool = False, pad_border_mode: str = "constant",
                   pad_kwargs: dict = None, all_in_gpu: bool = False,
                   verbose: bool = True, mixed_precision: bool = True) -> Tuple[np.ndarray, np.ndarray]:
        """
        在3D数据上进行推理,不管是2D模型还是3D模型均适用
        推荐使用滑窗的形式进行推理,可以得到较高的精度但是速度会慢一些
        模型需要处于eval模式
        
        Use this function to predict a 3D image. It does not matter whether the network is a 2D or 3D U-Net, it will
        detect that automatically and run the appropriate code.

        When running predictions, you need to specify whether you want to run fully convolutional of sliding window
        based inference. We very strongly recommend you use sliding window with the default settings.

        It is the responsibility of the user to make sure the network is in the proper mode (eval for inference!). If
        the network is not in eval mode it will print a warning.

        :param x: Your input data. Must be a nd.ndarray of shape (c, x, y, z).
        :param do_mirroring: If True, use test time data augmentation in the form of mirroring
        :param mirror_axes: Determines which axes to use for mirroing. Per default, mirroring is done along all three
        axes
        :param use_sliding_window: if True, run sliding window prediction. Heavily recommended! This is also the default
        :param step_size: When running sliding window prediction, the step size determines the distance between adjacent
        predictions. The smaller the step size, the denser the predictions (and the longer it takes!). Step size is given
        as a fraction of the patch_size. 0.5 is the default and means that wen advance by patch_size * 0.5 between
        predictions. step_size cannot be larger than 1!
        :param patch_size: The patch size that was used for training the network. Do not use different patch sizes here,
        this will either crash or give potentially less accurate segmentations
        :param regions_class_order: Fabian only
        :param use_gaussian: (Only applies to sliding window prediction) If True, uses a Gaussian importance weighting
         to weigh predictions closer to the center of the current patch higher than those at the borders. The reason
         behind this is that the segmentation accuracy decreases towards the borders. Default (and recommended): True
        :param pad_border_mode: leave this alone
        :param pad_kwargs: leave this alone
        :param all_in_gpu: experimental. You probably want to leave this as is it
        :param verbose: Do you want a wall of text? If yes then set this to True
        :param mixed_precision: if True, will run inference in mixed precision with autocast()
        :return:
        """
				# 每次推理前清空显存
        torch.cuda.empty_cache()

        assert step_size <= 1, 'step_size must be smaller than 1. Otherwise there will be a gap between consecutive ' \
                               'predictions'

        if verbose: print("debug: mirroring", do_mirroring, "mirror_axes", mirror_axes)

        assert self.get_device() != "cpu", "CPU not implemented"

        if pad_kwargs is None:
            pad_kwargs = {'constant_values': 0}

        # A very long time ago the mirror axes were (2, 3, 4) for a 3d network. This is just to intercept any old
        # code that uses this convention
        if len(mirror_axes):
            if self.conv_op == nn.Conv2d:
                if max(mirror_axes) > 1:
                    raise ValueError("mirror axes. duh")
            if self.conv_op == nn.Conv3d:
                if max(mirror_axes) > 2:
                    raise ValueError("mirror axes. duh")

        if self.training:
            print('WARNING! Network is in train mode during inference. This may be intended, or not...')

        assert len(x.shape) == 4, "data must have shape (c,x,y,z)"

        if mixed_precision:
            context = autocast
        else:
            context = no_op

        with context():       # 是否使用半精度
            with torch.no_grad():
                if self.conv_op == nn.Conv3d:
										# 使用滑窗进行推理
                    if use_sliding_window:
                        res = self._internal_predict_3D_3Dconv_tiled(x, step_size, do_mirroring, mirror_axes, patch_size, regions_class_order, use_gaussian, pad_border_mode,pad_kwargs=pad_kwargs, all_in_gpu=all_in_gpu, verbose=verbose)
                    else:
												# 直接进行推理
                        res = self._internal_predict_3D_3Dconv(x, patch_size, do_mirroring, mirror_axes, regions_class_order, pad_border_mode, pad_kwargs=pad_kwargs, verbose=verbose)
                elif self.conv_op == nn.Conv2d:
                    if use_sliding_window:
												# 使用2D网络以slide的形式对3D数据进行推理
                        res = self._internal_predict_3D_2Dconv_tiled(x, patch_size, do_mirroring, mirror_axes, step_size,regions_class_order, use_gaussian, pad_border_mode,pad_kwargs, all_in_gpu, False)
                    else:
                        res = self._internal_predict_3D_2Dconv(x, patch_size, do_mirroring, mirror_axes, regions_class_order, pad_border_mode, pad_kwargs, all_in_gpu, False)
                else:
                    raise RuntimeError("Invalid conv op, cannot determine what dimensionality (2d/3d) the network is")
        return res
# 生成与patch大小相同GaussianMap,patch中心点居于更大的权重,边缘区域则权重较低。
@staticmethod
def _get_gaussian(patch_size, sigma_scale=1. / 8) -> np.ndarray:
    tmp = np.zeros(patch_size)
    center_coords = [i // 2 for i in patch_size]
    sigmas = [i * sigma_scale for i in patch_size]
    tmp[tuple(center_coords)] = 1
    gaussian_importance_map = gaussian_filter(tmp, sigmas, 0, mode='constant', cval=0)
		# from scipy.ndimage.filters import gaussian_filter 
    gaussian_importance_map = gaussian_importance_map / np.max(gaussian_importance_map) * 1
    gaussian_importance_map = gaussian_importance_map.astype(np.float32)

    # gaussian_importance_map cannot be 0, otherwise we may end up with nans!
    gaussian_importance_map[gaussian_importance_map == 0] = np.min(
        gaussian_importance_map[gaussian_importance_map != 0])

    return gaussian_importance_map
# 计算每个3D数据根据给定的patch_size
@staticmethod
def _compute_steps_for_sliding_window(patch_size: Tuple[int, ...], image_size: Tuple[int, ...], step_size: float) -> List[List[int]]:
    assert [i >= j for i, j in zip(image_size, patch_size)], "image size must be as large or larger than patch_size"
    assert 0 < step_size <= 1, 'step_size must be larger than 0 and smaller or equal to 1'  # 这里的step_size是patch size的比例

    # our step width is patch_size*step_size at most, but can be narrower. For example if we have image size of
    # 110, patch size of 32 and step_size of 0.5, then we want to make 4 steps starting at coordinate 0, 27, 55, 78
		# 转换到实际大小中,各边每个patch之间的间隔
    target_step_sizes_in_voxels = [i * step_size for i in patch_size]   
		# 每个边需要多少个step
    num_steps = [int(np.ceil((i - k) / j)) + 1 for i, j, k in zip(image_size, target_step_sizes_in_voxels, patch_size)]

    steps = []
    for dim in range(len(patch_size)):
        # the highest step value for this dimension is
        max_step_value = image_size[dim] - patch_size[dim]   
        if num_steps[dim] > 1:
            actual_step_size = max_step_value / (num_steps[dim] - 1) # 在具体实施过程中的step size
        else:
            actual_step_size = 99999999999  # does not matter because there is only one step at 0
				# 每个patch的起点
        steps_here = [int(np.round(actual_step_size * i)) for i in range(num_steps[dim])]

        steps.append(steps_here)

    return steps
# 如果在推理是设置了对数据进行镜像,则会根据指定的axes,对数据进行镜像后再推理,并将推理在一次镜# 像回到原始坐标轴顺序,一个3D数据最多可能镜像8次,最后在将多次镜像结果取平均,真丧心病狂
def _internal_maybe_mirror_and_pred_3D(self, x: Union[np.ndarray, torch.tensor], mirror_axes: tuple, do_mirroring: bool = True, mult: np.ndarray or torch.tensor = None) -> torch.tensor:
    assert len(x.shape) == 5, 'x must be (b, c, x, y, z)'
    # everything in here takes place on the GPU. If x and mult are not yet on GPU this will be taken care of here
    # we now return a cuda tensor! Not numpy array!

    x = to_cuda(maybe_to_torch(x), gpu_id=self.get_device())
    result_torch = torch.zeros([1, self.num_classes] + list(x.shape[2:]),
                               dtype=torch.float).cuda(self.get_device(), non_blocking=True)

    if mult is not None:
        mult = to_cuda(maybe_to_torch(mult), gpu_id=self.get_device())

    if do_mirroring:
        mirror_idx = 8
        num_results = 2 ** len(mirror_axes)
    else:
        mirror_idx = 1
        num_results = 1

    for m in range(mirror_idx):
        if m == 0:
            pred = self.inference_apply_nonlin(self(x))
            result_torch += 1 / num_results * pred

        if m == 1 and (2 in mirror_axes):
            pred = self.inference_apply_nonlin(self(torch.flip(x, (4, ))))
            result_torch += 1 / num_results * torch.flip(pred, (4,))

        if m == 2 and (1 in mirror_axes):
            pred = self.inference_apply_nonlin(self(torch.flip(x, (3, ))))
            result_torch += 1 / num_results * torch.flip(pred, (3,))

        if m == 3 and (2 in mirror_axes) and (1 in mirror_axes):
            pred = self.inference_apply_nonlin(self(torch.flip(x, (4, 3))))
            result_torch += 1 / num_results * torch.flip(pred, (4, 3))

        if m == 4 and (0 in mirror_axes):
            pred = self.inference_apply_nonlin(self(torch.flip(x, (2, ))))
            result_torch += 1 / num_results * torch.flip(pred, (2,))

        if m == 5 and (0 in mirror_axes) and (2 in mirror_axes):
            pred = self.inference_apply_nonlin(self(torch.flip(x, (4, 2))))
            result_torch += 1 / num_results * torch.flip(pred, (4, 2))

        if m == 6 and (0 in mirror_axes) and (1 in mirror_axes):
            pred = self.inference_apply_nonlin(self(torch.flip(x, (3, 2))))
            result_torch += 1 / num_results * torch.flip(pred, (3, 2))

        if m == 7 and (0 in mirror_axes) and (1 in mirror_axes) and (2 in mirror_axes):
            pred = self.inference_apply_nonlin(self(torch.flip(x, (4, 3, 2))))
            result_torch += 1 / num_results * torch.flip(pred, (4, 3, 2))

    if mult is not None:
        result_torch[:, :] *= mult

    return result_torch
# 以滑窗的形式,对3D数据进行推理,需要用到参数是patch_size, step_size
def _internal_predict_3D_3Dconv_tiled(self, x: np.ndarray, step_size: float, do_mirroring: bool, mirror_axes: tuple, patch_size: tuple, regions_class_order: tuple, use_gaussian: bool, pad_border_mode: str, pad_kwargs: dict, all_in_gpu: bool,verbose: bool) -> Tuple[np.ndarray, np.ndarray]:
   
    assert len(x.shape) == 4, "x must be (c, x, y, z)"  
    assert self.get_device() != "cpu"
    if verbose: print("step_size:", step_size)
    if verbose: print("do mirror:", do_mirroring)

    assert patch_size is not None, "patch_size cannot be None for tiled prediction"

    # for sliding window inference the image must at least be as large as the patch
		# size. It does not matter whether the shape is divisible by 2**num_pool as long as the patch size is
		# 数据必须大于patch size,如果小于patch size,则需进行pad
    data, slicer = pad_nd_image(x, patch_size, pad_border_mode, pad_kwargs, True, None)
        data_shape = data.shape  # still c, x, y, z
		
    # compute the steps for sliding window
    steps = self._compute_steps_for_sliding_window(patch_size, data_shape[1:], step_size)
		num_tiles = len(steps[0]) * len(steps[1]) * len(steps[2])

    if verbose:
        print("data shape:", data_shape)
        print("patch size:", patch_size)
        print("steps (x, y, and z):", steps)
        print("number of tiles:", num_tiles)

        # we only need to compute that once. It can take a while to compute this due to the large sigma in gaussian_filter
    if use_gaussian and num_tiles > 1:
        if self._gaussian_3d is None or not all(
                [i == j for i, j in zip(patch_size, self._patch_size_for_gaussian_3d)]):
            if verbose: print('computing Gaussian')
            gaussian_importance_map = self._get_gaussian(patch_size, sigma_scale=1. / 8)

            self._gaussian_3d = gaussian_importance_map
            self._patch_size_for_gaussian_3d = patch_size
        else:
            if verbose: print("using precomputed Gaussian")
            gaussian_importance_map = self._gaussian_3d

        gaussian_importance_map = torch.from_numpy(gaussian_importance_map).cuda(self.get_device(), non_blocking=True)

    else:
        gausian_importance_map = None

    if all_in_gpu:
        # If we run the inference in GPU only (meaning all tensors are allocated on the GPU, this reduces CPU-GPU communication but required more GPU memory) we need to preallocate a few things on GPU

		    if use_gaussian and num_tiles > 1:
		        # half precision for the outputs should be good enough. If the outputs here are half, the gaussian_importance_map should be as well
			      gaussian_importance_map = gaussian_importance_map.half()
		
		        # make sure we did not round anything to 0
		        gaussian_importance_map[gaussian_importance_map == 0] = gaussian_importance_map[gaussian_importance_map != 0].min()
		
		        add_for_nb_of_preds = gaussian_importance_map
		    else:
		        add_for_nb_of_preds = torch.ones(data.shape[1:], device=self.get_device())
		
		        if verbose: print("initializing result array (on GPU)")
		        aggregated_results = torch.zeros([self.num_classes] + list(data.shape[1:]), dtype=torch.half,device=self.get_device())
	
            if verbose: print("moving data to GPU")
            data = torch.from_numpy(data).cuda(self.get_device(), non_blocking=True)

            if verbose: print("initializing result_numsamples (on GPU)")
            aggregated_nb_of_predictions = torch.zeros([self.num_classes] + list(data.shape[1:]), dtype=torch.half, device=self.get_device())
      else:
          if use_gaussian and num_tiles > 1:
              add_for_nb_of_preds = self._gaussian_3d
          else:
              add_for_nb_of_preds = np.ones(data.shape[1:], dtype=np.float32)
          aggregated_results = np.zeros([self.num_classes] + list(data.shape[1:]), dtype=np.float32)
          aggregated_nb_of_predictions = np.zeros([self.num_classes] + list(data.shape[1:]), dtype=np.float32)

      for x in steps[0]:
          lb_x = x
          ub_x = x + patch_size[0]
          for y in steps[1]:
              lb_y = y
              ub_y = y + patch_size[1]
              for z in steps[2]:
                  lb_z = z
                  ub_z = z + patch_size[2]
							
                  predicted_patch = self._internal_maybe_mirror_and_pred_3D(
                      data[None, :, lb_x:ub_x, lb_y:ub_y, lb_z:ub_z], mirror_axes, do_mirroring, gaussian_importance_map)[0]

                  if all_in_gpu:
                      predicted_patch = predicted_patch.half()
                  else:
                      predicted_patch = predicted_patch.cpu().numpy()

                    aggregated_results[:, lb_x:ub_x, lb_y:ub_y, lb_z:ub_z] += predicted_patch
                    aggregated_nb_of_predictions[:, lb_x:ub_x, lb_y:ub_y, lb_z:ub_z] += add_for_nb_of_preds

        # we reverse the padding here (remeber that we padded the input to be at least as large as the patch size
        slicer = tuple(
            [slice(0, aggregated_results.shape[i]) for i in
             range(len(aggregated_results.shape) - (len(slicer) - 1))] + slicer[1:])
        aggregated_results = aggregated_results[slicer]
        aggregated_nb_of_predictions = aggregated_nb_of_predictions[slicer]

        # computing the class_probabilities by dividing the aggregated result with result_numsamples
        class_probabilities = aggregated_results / aggregated_nb_of_predictions

        if regions_class_order is None:
            predicted_segmentation = class_probabilities.argmax(0)
        else:
            if all_in_gpu:
                class_probabilities_here = class_probabilities.detach().cpu().numpy()
            else:
                class_probabilities_here = class_probabilities
            predicted_segmentation = np.zeros(class_probabilities_here.shape[1:], dtype=np.float32)
            for i, c in enumerate(regions_class_order):
                predicted_segmentation[class_probabilities_here[i] > 0.5] = c

        if all_in_gpu:
            if verbose: print("copying results to CPU")

            if regions_class_order is None:
                predicted_segmentation = predicted_segmentation.detach().cpu().numpy()

            class_probabilities = class_probabilities.detach().cpu().numpy()

        if verbose: print("prediction done")
        return predicted_segmentation, class_probabilities

以patch形式进行推理是最重要也是最难的,其他的很简单这里就不过多记录了

你可能感兴趣的:(nnUNet,深度学习,神经网络,自然语言处理)