nnUnetV2:使用自定义网络

前言

2023年3月17日,nnUnet迎来重大更新。紧接着不久,Facebook推出大一统多模态分割模型Segment Anything。喜忧参半,喜的是一直关注的医学图像分割仓库更新了,忧的是以后分割的赛道变了,小打小闹的堆模块水文章估计不行了,各种微雕大模型的工作会逐渐应用到医学图像分割领域。

闲话少说,回到本文的主题:怎么在新版nnUnetV2使用自定义网络。

基本知识

nnUnetV2默认使用深监督,意味着自定义网络输出应为一个列表形式。然而,在网络推理时,我们只需要最高分辨率的输出,不需要多层次输出。在nnUnetV2官方实现中,使用deep_supervision参数控制是否多层次输出。综上所述,自定义网络需要满足两个条件:

  • 支持多层次输出

  • 使用变量deep_supervision控制输出类型

实战

这里提供一种对已有网络包装的方法,仅供参考

import torch.nn as nn

class custom_net(nn.Module):

    def __init__(self,):
        super(custom_net, self).__init__()
        self.deep_supervision = True
        # 使用你自己的网络
        self.model = None

    def forward(self, x):
        output = self.model(x)
        if self.deep_supervision:
            return [output, ]
        else:
            return output

将自定义网络嵌套进主框架。打开文件 nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py

替换函数 build_network_architecture

    def build_network_architecture(self, plans_manager: PlansManager,
                                   dataset_json,
                                   configuration_manager: ConfigurationManager,
                                   num_input_channels,
                                   enable_deep_supervision: bool = True) -> nn.Module:
        from dynamic_network_architectures.initialization.weight_init import InitWeights_He
        model = custom_net()
        model.apply(InitWeights_He(1e-2))
        return model

你可能感兴趣的:(医学图像,深度学习,人工智能)