[MICCAI2019] A Partially Reversible U-Net for Memory-Efficient Volumetric Image Segmentation

作者信息:
Robin Brügger, CV Lab,ETH Zürich
代码:https://github.com/RobinBruegger/RevTorch
https://github.com/RobinBruegger/PartiallyReversibleUnet


医疗影像常用3D网络,显存占用经常制约了网络结构与深度,从而对最终精度产生影响。文章主要借鉴了reversible block 的思路来解决上述问题。

reversible block

该block设计很巧妙。输入x 按通道数先分成两组,x1, x2。利用如下公式(1),得到y1,y2,由于特殊的结构设计,x1,x2反过来又可以由公式(2) 通过y1,y2计算得到。
[MICCAI2019] A Partially Reversible U-Net for Memory-Efficient Volumetric Image Segmentation_第1张图片在这里插入图片描述在这里插入图片描述
网络训练时显存占用很大一部分是储存前向传播的中间结果(因为反向传播时需要用到),使用 reversible block 后,中间结果无需保存,只要保存最后输出的结果,中间结果都可以反推得到。
[MICCAI2019] A Partially Reversible U-Net for Memory-Efficient Volumetric Image Segmentation_第2张图片

Method

文章基于MICCAI Brats18挑战赛第二名 No-New-Net 的结构进行改进,引入reversible block后的网络结构如下:
[MICCAI2019] A Partially Reversible U-Net for Memory-Efficient Volumetric Image Segmentation_第3张图片

Results

结果很好,第一二行比较可以看到使用reversible block后,显存节约2.5G,使得在12G显存下使用full volume 训练成为可能,与No-New-Net的单模型比也要强。
[MICCAI2019] A Partially Reversible U-Net for Memory-Efficient Volumetric Image Segmentation_第4张图片

代码

reversible block模块部分的代码如下,反向传播的代码花了一定时间才大致了解。f.backward(dy) 是链式法则的意思:把f.backward()得到的梯度乘上之前层反传得到的梯度dy,可以参考这个资料

import torch
import torch.nn as nn
#import torch.autograd.function as func

class ReversibleBlock(nn.Module):
    '''
    Elementary building block for building (partially) reversible architectures
    Implementation of the Reversible block described in the RevNet paper
    (https://arxiv.org/abs/1707.04585). Must be used inside a :class:`revtorch.ReversibleSequence`
    for autograd support.
    Arguments:
        f_block (nn.Module): arbitrary subnetwork whos output shape is equal to its input shape
        g_block (nn.Module): arbitrary subnetwork whos output shape is equal to its input shape
    '''

    def __init__(self, f_block, g_block):
        super(ReversibleBlock, self).__init__()
        self.f_block = f_block
        self.g_block = g_block

    def forward(self, x):
        """
        Performs the forward pass of the reversible block. Does not record any gradients.
        :param x: Input tensor. Must be splittable along dimension 1.
        :return: Output tensor of the same shape as the input tensor
        """
        x1, x2 = torch.chunk(x, 2, dim=1)
        y1, y2 = None, None
        with torch.no_grad():
            y1 = x1 + self.f_block(x2)
            y2 = x2 + self.g_block(y1)

        return torch.cat([y1, y2], dim=1)

    def backward_pass(self, y, dy):
        """
        Performs the backward pass of the reversible block.
        Calculates the derivatives of the block's parameters in f_block and g_block, as well as the inputs of the
        forward pass and its gradients.
        :param y: Outputs of the reversible block
        :param dy: Derivatives of the outputs
        :return: A tuple of (block input, block input derivatives). The block inputs are the same shape as the block outptus.
        """
        
        # Split the arguments channel-wise
        y1, y2 = torch.chunk(y, 2, dim=1)
        del y
        assert (not y1.requires_grad), "y1 must already be detached"
        assert (not y2.requires_grad), "y2 must already be detached"
        dy1, dy2 = torch.chunk(dy, 2, dim=1)
        del dy
        assert (not dy1.requires_grad), "dy1 must not require grad"
        assert (not dy2.requires_grad), "dy2 must not require grad"

        # Enable autograd for y1 and y2. This ensures that PyTorch
        # keeps track of ops. that use y1 and y2 as inputs in a DAG
        y1.requires_grad = True
        y2.requires_grad = True

        # Ensures that PyTorch tracks the operations in a DAG
        with torch.enable_grad():
            gy1 = self.g_block(y1)

            # Use autograd framework to differentiate the calculation. The
            # derivatives of the parameters of G are set as a side effect
            gy1.backward(dy2)

        with torch.no_grad():
            x2 = y2 - gy1 # Restore first input of forward()
            del y2, gy1

            # The gradient of x1 is the sum of the gradient of the output
            # y1 as well as the gradient that flows back through G
            # (The gradient that flows back through G is stored in y1.grad)
            dx1 = dy1 + y1.grad
            del dy1
            y1.grad = None

        with torch.enable_grad():
            x2.requires_grad = True
            fx2 = self.f_block(x2)

            # Use autograd framework to differentiate the calculation. The
            # derivatives of the parameters of F are set as a side effec
            fx2.backward(dx1)

        with torch.no_grad():
            x1 = y1 - fx2 # Restore second input of forward()
            del y1, fx2

            # The gradient of x2 is the sum of the gradient of the output
            # y2 as well as the gradient that flows back through F
            # (The gradient that flows back through F is stored in x2.grad)
            dx2 = dy2 + x2.grad
            del dy2
            x2.grad = None

            # Undo the channelwise split
            x = torch.cat([x1, x2.detach()], dim=1)
            dx = torch.cat([dx1, dx2], dim=1)

        return x, dx

我的笔记

我觉得这篇文章思路很棒,一是本文针对到了医疗影像处理的一个痛点,即显存占用。大部分研究者显存受限,12G为最常用的设备。二是他引入了其他领域的reversible block的思路,该问题提出了一个解决思路,并且最终的实验结果也很好。本文对我的研究思路有很好的启发。
当然,节约的显存是以更长的训练时间为代价的。

你可能感兴趣的:(MICCAI2019)