解读:SPM: 一种即插即用的形状先验模块,可轻松嵌入任意编解码架构,助力涨点并显著改善分割效果! (qq.com)
论文:https://arxiv.org/abs/2303.17967
代码:https://github.com/AlexYouXin/Explicit-Shape-Priors
基于UNet的网络在医学图像分割领域逐步占据主导地位。然而,卷积神经网络(CNNs)面临两个限制:
现有的方法不能很好地同时解决这两个限制。因此,本文提出了一种新的形状先验模块(SPM),它可以引入形状先验来提高基于UNet的模型的分割性能。显式形状先验由全局形状先验和局部形状先验组成。
为了评估SPM的有效性,在三个具有挑战性的公共数据集上进行了实验。SPM性能优异。此外,SPM在经典的细胞神经网络和最近的基于Transformer的主干上表现出了突出的泛化能力,可以作为不同数据集分割任务的即插即用结构。
如何解决CNN感受野有限的问题呢?本文开始探索形状先验(shape priors
)对分割性能的影响。
在医学图像中,不同的器官或病灶通常具有特定的形状和结构,这些形状和结构信息对于分割模型来说非常关键,因此先前的许多工作尝试利用形状先验来设计分割模型,以获得具有解剖形状信息的更好掩模(mask
)。就是引入形状先验可以帮助分割模型在分割过程中更好地考虑和利用目标物体的形状信息,从而提高分割性能。
为此,本文集中探讨了三种带有形状先验的分割模型:
atlas-based models
)statistical-based models
)UNet-based models
)论文认为,前两种方法的泛化能力较差,而 UNet-based 模型由于相比于前两者泛化性能要好,但由于它是倾向于使用隐式形状先验,这在不同形状的器官上缺乏良好的可解释性和泛化能力。综上所述,本文提出了一种新的形状先验模块(Shape Prior Module, SPM
),它可以显示地引入形状先验,以促进 UNet-based 模型的分割性能。(具体分析见论文)
论文在三个具有挑战性的公共数据集上进行实验,验证了SPM的有效性。SPM也表现出很强的泛化性,可作为不同数据集分割任务的即插即用结构。
来源:
SPM: 一种即插即用的形状先验模块,可轻松嵌入任意编解码架构,助力涨点并显著改善分割效果!
隐式形状先验通常是通过在模型中加入先验信息,例如特定的损失函数或正则化项来实现的。这些隐式的形状先验通常难以解释,因为它们是通过一些特殊的方式集成到模型中的,而不是直接考虑目标物体的形状信息。例如,在基于 UNet 的模型中,可以通过使用 Dice 损失函数来强制模型更加注重目标物体的轮廓信息,从而隐式地考虑了形状先验信息。
相反,显式形状先验则直接将形状先验信息作为输入提供给模型。例如,在本文中,作者提出了一个新的形状先验模块,它明确地将形状先验信息作为输入,并利用这些信息来引导模型更好地分割目标物体。这种显式的形状先验可以更好地解释和调整,因为它们直接考虑了目标物体的形状和结构信息。
将可学习的重复形状先验S引入U形神经网络。具体地,S被用作与图像组合的网络的输入。网络的输出是由S生成的预测掩码和注意力图。然后注意力图的通道可以提供真实标签区域的丰富形状信息。显式形状先验模型可以描述如下:
其中,F表示推理期间的前向传播,S表示构造图像空间I和标签空间L之间的映射的连续形状先验。这里,S在训练过程中随着图像GT对的变化而更新。一旦训练完成,可学习的形状先验就被固定,这可以随着推理过程中输入补丁的变化而动态地生成精细的形状先验。精细形状先验作为注意力图,可以定位感兴趣的区域,并抑制背景区域。此外,一小部分不准确的基本事实不会显著影响S的学习,显示了该范式的稳健性。
图1所示,本文模型是一个分层的U形网络,它由类ResNet编码器、基于Resblock的解码器和形状先验模块(SPM)组成。SPM通过引入可学习形状先验,为每个类别施加解剖形状约束来增强网络的表示能力。SPM是一个即插即用模块,可以灵活地插入其他网络结构,以提高分割性能。
图2所示,SPM的输入包括原始跳跃特征Fo和原始形状先验So,经过“特征提纯”后会生成对应的增强跳跃特征Fe和增强形状先验Se 。最终,通过这些增强后的特征和先验,模型会生成更加精准的分割掩膜。与DETR不同,SPM会与多尺度特征进行交互,而不仅仅是来自编码器最深层的特征。因此,在跳跃连接之前的分层编码特征在经过SPM处理后将获得更多的形状信息。增强形状先验由两个部分组成:
旨在引入能够定位目标区域的显式形状先验的基础上,形状先验的大小So是N×空间维度。N表示类的数量,空间维度与补丁大小有关。为了缓解感受野有限的缺点,本工作考虑了形状先验内的长程依赖性。具体而言,提出了自更新块(SUB)来对类之间的关系进行建模,并生成具有N个通道之间相互作用的全局形状先验。受自注意机制的启发,构建了N类之间的自注意Smap的亲和图,以描述形状先验的每个通道之间的相似性和依赖性关系。再采用Smap加权Vs,随后经过多层感知机MLP和层归一化处理,得到全局形状先验。
引入显式形状先验給SUB带来了全局上下文信息,但不具有精确的形状和轮廓信息。因为SUB缺乏归纳偏置,无法建模局部视觉结构和定位各种不同尺度的对象。
为了解决这个限制,论文提出交叉更新块CUB。受到卷积核固有的局部性和尺度不变性的归纳偏置的启发,基于卷积的 CUB 为 SPM 注入归纳偏置,以获得更精确的局部形状信息。此外,基于编码器中卷积特征具有定位区分性区域的显著潜力的事实,论文在原始跳跃特征Fo和形状先验So之间进行交互。
具体来说,
综上所述,原始形状先验通过引入全局和局部特征进行增强。
上图展示了跳跃特征对明确形状先验的影响。其中:
将形状先验分解为来自 SUB 和 CUB 的两个组成部分,即全局形状先验和局部形状先验:从图7可以观察到,得益于自注意力模块,全局形状先验具有全局的感受野,包含上下文和纹理。然而,SUB 的结构缺乏归纳偏差来模拟局部视觉结构。全局形状先验负责对 GT 区域进行粗定位。而由 CUB 生成的局部形状先验可以通过引入卷积核提供更精细的形状信息,这些卷积核具有局部归纳偏差。
# https://github.com/AlexYouXin/Explicit-Shape-Priors/blob/main/networks/ACDC/SPM.py
class self_update_block(nn.Module):
def __init__(self, config):
super(self_update_block, self).__init__()
num_layers = 2
self.layer = nn.ModuleList()
self.encoder_norm = LayerNorm(config.n_patches, eps=1e-6)
for _ in range(num_layers):
layer = Block(config)
self.layer.append(copy.deepcopy(layer))
def forward(self, refined_shape_prior):
for layer_block in self.layer:
refined_shape_prior = layer_block(refined_shape_prior)
encoded = self.encoder_norm(refined_shape_prior)
return encoded
class cross_update_block(nn.Module):
def __init__(self, n_class):
super(cross_update_block, self).__init__()
self.n_class = n_class
self.softmax = Softmax(dim=-1)
def forward(self, refined_shape_prior, feature):
class_feature = torch.matmul(feature.flatten(2), refined_shape_prior.flatten(2).transpose(-1, -2))
# scale
class_feature = class_feature / math.sqrt(self.n_class)
class_feature = self.softmax(class_feature)
class_feature = torch.einsum("ijk, iklhw->ijlhw", class_feature, refined_shape_prior)
class_feature = feature + class_feature
return class_feature
class Attention(nn.Module):
def __init__(self, config):
super(Attention, self).__init__()
self.num_attention_heads = config.transformer.num_heads
self.attention_head_size = int(config.n_patches / self.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = Linear(config.n_patches, config.n_patches)
self.key = Linear(config.n_patches, config.n_patches)
self.value = Linear(config.n_patches, config.n_patches)
self.out = Linear(config.n_patches, config.n_patches)
self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"])
self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"])
self.softmax = Softmax(dim=-1)
self.position_embeddings = nn.Parameter(torch.randn(1, self.num_attention_heads, config.n_classes, config.n_classes))
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(self, hidden_states):
mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)
# print(mixed_query_layer.shape)
query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(mixed_key_layer)
value_layer = self.transpose_for_scores(mixed_value_layer)
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores + self.position_embeddings # RPE
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
attention_probs = self.softmax(attention_scores)
# weights = attention_probs if self.vis else None
attention_probs = self.attn_dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
attention_output = self.out(context_layer)
attention_output = self.proj_dropout(attention_output)
return attention_output
class Mlp(nn.Module):
def __init__(self, config):
super(Mlp, self).__init__()
self.fc1 = Linear(config.n_patches, config.hidden_size)
self.fc2 = Linear(config.hidden_size, config.n_patches)
self.act_fn = ACT2FN["gelu"]
self.dropout = Dropout(config.transformer["dropout_rate"])
self._init_weights()
def _init_weights(self):
nn.init.xavier_uniform_(self.fc1.weight)
nn.init.xavier_uniform_(self.fc2.weight)
nn.init.normal_(self.fc1.bias, std=1e-6)
nn.init.normal_(self.fc2.bias, std=1e-6)
def forward(self, x):
x = self.fc1(x)
x = self.act_fn(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.dropout(x)
return x
class Block(nn.Module):
def __init__(self, config):
super(Block, self).__init__()
self.attention_norm = LayerNorm(config.n_patches, eps=1e-6)
self.ffn_norm = LayerNorm(config.n_patches, eps=1e-6)
self.ffn = Mlp(config)
self.attn = Attention(config)
def forward(self, x):
h = x
x = self.attention_norm(x)
x = self.attn(x)
x = x + h
h = x
x = self.ffn_norm(x)
x = self.ffn(x)
x = x + h
return x
class SPM(nn.Module):
def __init__(self, config, in_channel, scale):
super(SPM, self).__init__()
self.scale = scale
self.SUB = self_update_block(config)
self.CUB = cross_update_block(config.n_classes)
self.resblock1 = DecoderResBlock(in_channel, in_channel)
self.resblock2 = DecoderResBlock(in_channel, in_channel)
self.resblock3 = DecoderResBlock(in_channel, config.n_classes)
self.h = config.h
self.w = config.w
self.l = config.l
self.dim = in_channel
def forward(self, feature, refined_shape_prior):
# print(refined_shape_prior.size())
b, n_class, _ = refined_shape_prior.size()
B = feature.size()[0]
refined_shape_prior = self.SUB(refined_shape_prior)
previous_class_center = refined_shape_prior
refined_shape_prior = F.interpolate(refined_shape_prior.contiguous().view(b, n_class, self.h, self.w, self.l), scale_factor=self.scale, mode="trilinear")
feature = self.resblock1(feature)
feature = self.resblock2(feature)
class_feature = self.CUB(refined_shape_prior, feature)
# b * N * H/i * W/i * L/i
refined_shape_prior = F.interpolate(self.resblock3(class_feature), scale_factor=(1.0 / self.scale[0], 1.0 / self.scale[1], 1.0 / self.scale[2]), mode="trilinear").flatten(2) + previous_class_center
return class_feature, refined_shape_prior
class Conv3dbn(nn.Sequential):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
padding=0,
stride=1,
use_batchnorm=True,
):
conv = nn.Conv3d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
bias=not (use_batchnorm),
)
bn = nn.BatchNorm3d(out_channels)
super(Conv3dbn, self).__init__(conv, bn)
class Conv3dReLU(nn.Sequential):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
padding=0,
stride=1,
use_batchnorm=True,
):
conv = nn.Conv3d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
bias=not (use_batchnorm),
)
relu = nn.ReLU(inplace=True)
bn = nn.BatchNorm3d(out_channels)
super(Conv3dReLU, self).__init__(conv, bn, relu)
class DecoderResBlock(nn.Module):
def __init__(
self,
in_channels,
out_channels,
use_batchnorm=True,
):
super().__init__()
self.conv1 = Conv3dReLU(
in_channels,
out_channels,
kernel_size=1,
padding=0,
use_batchnorm=use_batchnorm,
)
self.conv2 = Conv3dReLU(
out_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
)
self.conv3 = Conv3dbn(
in_channels,
out_channels,
kernel_size=1,
padding=0,
use_batchnorm=use_batchnorm,
)
self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
self.relu = nn.ReLU(inplace=True)
def forward(self, x, skip=None):
feature_in = self.conv3(x)
x = self.conv1(x)
x = self.conv2(x)
x = x + feature_in
x = self.relu(x)
# x = self.se_block(x)
return x