作者信息:
Robin Brügger, CV Lab,ETH Zürich
代码:https://github.com/RobinBruegger/RevTorch
https://github.com/RobinBruegger/PartiallyReversibleUnet
医疗影像常用3D网络,显存占用经常制约了网络结构与深度,从而对最终精度产生影响。文章主要借鉴了reversible block 的思路来解决上述问题。
该block设计很巧妙。输入x 按通道数先分成两组,x1, x2。利用如下公式(1),得到y1,y2,由于特殊的结构设计,x1,x2反过来又可以由公式(2) 通过y1,y2计算得到。
网络训练时显存占用很大一部分是储存前向传播的中间结果(因为反向传播时需要用到),使用 reversible block 后,中间结果无需保存,只要保存最后输出的结果,中间结果都可以反推得到。
文章基于MICCAI Brats18挑战赛第二名 No-New-Net 的结构进行改进,引入reversible block后的网络结构如下:
结果很好,第一二行比较可以看到使用reversible block后,显存节约2.5G,使得在12G显存下使用full volume 训练成为可能,与No-New-Net的单模型比也要强。
reversible block模块部分的代码如下,反向传播的代码花了一定时间才大致了解。f.backward(dy)
是链式法则的意思:把f.backward()
得到的梯度乘上之前层反传得到的梯度dy
,可以参考这个资料
import torch
import torch.nn as nn
#import torch.autograd.function as func
class ReversibleBlock(nn.Module):
'''
Elementary building block for building (partially) reversible architectures
Implementation of the Reversible block described in the RevNet paper
(https://arxiv.org/abs/1707.04585). Must be used inside a :class:`revtorch.ReversibleSequence`
for autograd support.
Arguments:
f_block (nn.Module): arbitrary subnetwork whos output shape is equal to its input shape
g_block (nn.Module): arbitrary subnetwork whos output shape is equal to its input shape
'''
def __init__(self, f_block, g_block):
super(ReversibleBlock, self).__init__()
self.f_block = f_block
self.g_block = g_block
def forward(self, x):
"""
Performs the forward pass of the reversible block. Does not record any gradients.
:param x: Input tensor. Must be splittable along dimension 1.
:return: Output tensor of the same shape as the input tensor
"""
x1, x2 = torch.chunk(x, 2, dim=1)
y1, y2 = None, None
with torch.no_grad():
y1 = x1 + self.f_block(x2)
y2 = x2 + self.g_block(y1)
return torch.cat([y1, y2], dim=1)
def backward_pass(self, y, dy):
"""
Performs the backward pass of the reversible block.
Calculates the derivatives of the block's parameters in f_block and g_block, as well as the inputs of the
forward pass and its gradients.
:param y: Outputs of the reversible block
:param dy: Derivatives of the outputs
:return: A tuple of (block input, block input derivatives). The block inputs are the same shape as the block outptus.
"""
# Split the arguments channel-wise
y1, y2 = torch.chunk(y, 2, dim=1)
del y
assert (not y1.requires_grad), "y1 must already be detached"
assert (not y2.requires_grad), "y2 must already be detached"
dy1, dy2 = torch.chunk(dy, 2, dim=1)
del dy
assert (not dy1.requires_grad), "dy1 must not require grad"
assert (not dy2.requires_grad), "dy2 must not require grad"
# Enable autograd for y1 and y2. This ensures that PyTorch
# keeps track of ops. that use y1 and y2 as inputs in a DAG
y1.requires_grad = True
y2.requires_grad = True
# Ensures that PyTorch tracks the operations in a DAG
with torch.enable_grad():
gy1 = self.g_block(y1)
# Use autograd framework to differentiate the calculation. The
# derivatives of the parameters of G are set as a side effect
gy1.backward(dy2)
with torch.no_grad():
x2 = y2 - gy1 # Restore first input of forward()
del y2, gy1
# The gradient of x1 is the sum of the gradient of the output
# y1 as well as the gradient that flows back through G
# (The gradient that flows back through G is stored in y1.grad)
dx1 = dy1 + y1.grad
del dy1
y1.grad = None
with torch.enable_grad():
x2.requires_grad = True
fx2 = self.f_block(x2)
# Use autograd framework to differentiate the calculation. The
# derivatives of the parameters of F are set as a side effec
fx2.backward(dx1)
with torch.no_grad():
x1 = y1 - fx2 # Restore second input of forward()
del y1, fx2
# The gradient of x2 is the sum of the gradient of the output
# y2 as well as the gradient that flows back through F
# (The gradient that flows back through F is stored in x2.grad)
dx2 = dy2 + x2.grad
del dy2
x2.grad = None
# Undo the channelwise split
x = torch.cat([x1, x2.detach()], dim=1)
dx = torch.cat([dx1, dx2], dim=1)
return x, dx
我觉得这篇文章思路很棒,一是本文针对到了医疗影像处理的一个痛点,即显存占用。大部分研究者显存受限,12G为最常用的设备。二是他引入了其他领域的reversible block的思路,该问题提出了一个解决思路,并且最终的实验结果也很好。本文对我的研究思路有很好的启发。
当然,节约的显存是以更长的训练时间为代价的。