nnUnet(代码)-训练部分

学习目标:逐步分析nnunet训练部分

学习内容:training部分

· 拿到训练plans(计划)
· 初始化数据增强参数
· 采用五折交叉验证
· dataset与dataloader/数据加载过程
· 初始化网络
· 初始化优化器与学习率函数

1.nnUNetTrainer(版本一的训练方法)

··· 损失函数:

self.loss = DC_and_CE_loss({'batch_dice': self.batch_dice, 'smooth': 1e-5, 'do_bg': False}, {})

··· 优化器与学习率函数:
优化器用adam
学习率的调整是用的损失函数的加权平均值来判断是否变动的方法

    def initialize_optimizer_and_scheduler(self):
        assert self.network is not None, "self.initialize_network must be called first"
        self.optimizer = torch.optim.Adam(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay,
                                          amsgrad=True)
        self.lr_scheduler = lr_scheduler.ReduceLROnPlateau(self.optimizer, mode='min', factor=0.2,
                                                           patience=self.lr_scheduler_patience,
                                                           verbose=True, threshold=1e-3,
                                                           threshold_mode="abs")
# 学习率函数设置
self.train_loss_MA_alpha = 0.93  # alpha * old + (1-alpha) * new

    def update_train_loss_MA(self):
        if self.train_loss_MA is None:
            self.train_loss_MA = self.all_tr_losses[-1]
        else:
            self.train_loss_MA = self.train_loss_MA_alpha * self.train_loss_MA + (1 - self.train_loss_MA_alpha) * \
                                 self.all_tr_losses[-1]
# lr scheduler is updated with moving average val loss. should be more robust
self.lr_scheduler.step(self.train_loss_MA)

2.nnUNetTrainerV2(版本二的训练方法)

··· 加强了损失函数(深监督):
还是原来损失,但是添加了一个策略:给每层的损失加一个权重,分辨率越高的权重越大,简单说就是针对中间隐藏层特征透明度不高以及深层网络中浅层以及中间网络难以训练的问题。

################# 封装损失函数进入深度学习(深监督) ############
        # 需要知道网络深度
        # net_numpool = len(self.plans['pool_op_kernel_sizes'])

        # 我们给每个输出一个权重,该权重随着分辨率的降低呈指数递减(除以2)
        # 这使得更高的分辨率输出在损失中有更大的权重
        weights = np.array([1 / (2 ** i) for i in range(self.net_numpool)])

        # 我们不使用最低的2个输出。标准化权重,使其总和为1
        mask = np.array([True] + [True if i < self.net_numpool - 1 else False for i in range(1, self.net_numpool)])
        weights[~mask] = 0
        weights = weights / weights.sum()
        self.ds_loss_weights = weights

        # 封装损失函数
        self.loss = MultipleOutputLoss2(self.loss, self.ds_loss_weights)

··· 重写了优化器与学习率函数
采用SGD与自定义的学习率下降函数

    def initialize_optimizer_and_scheduler(self):
        assert self.network is not None, "self.initialize_network must be called first"
        self.optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay,
                                         momentum=0.99, nesterov=True)
        self.lr_scheduler = None
 def maybe_update_lr(self, epoch=None):

        if epoch is None:
            ep = self.epoch + 1
        else:
            ep = epoch
        self.optimizer.param_groups[0]['lr'] = poly_lr(ep, self.max_num_epochs, self.initial_lr, 0.9)
        
def poly_lr(epoch, max_epochs, initial_lr, exponent=0.9):
    return initial_lr * (1 - epoch / max_epochs)**exponent

··· 重写了数据增强参数

3.后面还有DP等等三四个版本,是基于版本二改变的,主要是通过混合精度进行训练增加训练速度

你可能感兴趣的:(nnUNet,pytorch+Unet)