解读:IPMI 2023 港科大陈浩团队新作 | CTO: 重新思考边界检测在医学图像分割中的作用 (qq.com)
论文: https://arxiv.org/abs/2305.00678
代码: https://github.com/xiaofang007/CTO
, Transformer
和 Operator
,通过结合卷积神经网络、视觉 Transformer 和显式边界检测操作,实现高精度的图像分割,并在准确性和效率之间保持最佳平衡。
CTO 遵循标准的编码器-解码器分割范式,其中编码器网络采用流行的 CNN 骨干结构来捕捉局部语义信息,并使用轻量级的 ViT 辅助网络来整合远距离依赖关系。为了增强边界的学习能力,本文进一步提出了一种基于边界引导的解码器网络,利用专用边界检测操作得到的边界掩模作为显式监督,引导解码学习过程。
CTO 遵循编码器-解码器范式,并采用跳跃连接将来自编码器的低级特征聚合到解码器中。其中编码器网络由主流的 CNN 和辅助 ViT 组成。解码器网络则采用边界检测运算符来指导其学习过程。
)通过生成的边界掩模来指导学习过程,整个模型以端到端的方式进行训练。CTO 首先构建一个卷积流,选择Res2Net
CTO使用一个基于轻量级Vision Transformer 的辅助流,捕捉不同图像块间的远程依赖关系。具体而言,它由多个并行的轻量级 Transformer 块组成,这些块接收不同尺度的特征块作为输入。所有的 Transformer 块共享相似的结构,包括块嵌入层和 Transformer 编码层。
LightViT 的块嵌入层用于将输入的特征块转换为嵌入向量,将空间维度转换为序列维度。Transformer 编码层用于对特征块进行自注意力机制的建模,以捕捉不同特征块之间的长程依赖关系。通过在 Transformer 模块中引入自注意力机制,LightViT 可以有效地对特征块之间的相互作用进行建模,从而提取图像的全局上下文信息。
Boundary Enhanced Module (BEM)
然后,将这两个卷积应用于输入特征图,得到梯度图Mx和My。接下来,梯度图通过 sigmoid
其中,圈号表示逐元素相乘, 表示 sigmoid 函数,Mxy是将Mx和My沿通道维度进行拼接。然后,我们便可以直接使用简单的堆叠卷积层将边缘增强特征图进行融合。最后,输出特征图受到GT 边界图的监督,从而消除了物体内部的边缘特征,产生边界增强特征。
Boundary Inject Module (BIM)
通过 BEM 得到的边界增强特征可以作为先验知识,改善编码器生成的特征的图像表示能力。BIM,引入了双路径边界融合方案,促进前景和背景特征的表示能力。具体而言,BIM 接收两个输入:边界增强特征与来自编码器网络的对应特征的通道级连接,以及前一解码器层的特征。然后,这两个输入被馈送到 BIM 中,其中包含两个独立的路径,分别用于促进前景和背景的特征表示。
前景注意力图,由前一层解码器的特征图经过sigmoid得到;背景注意力图,由1减去前景注意力图得来。 最后,将前景特征Ffg、背景特征Fbg、前一层解码器特征拼接,得到本层输出。
整体损失由主要的内部分割损失L_seg和边界损失L_bnd组成。在边界检测损失中,仅考虑来自 BEM 的预测结果,该模块将编码器的高层特征图和低层特征图作为输入。
Interior Segmentation Loss
L_seg是交叉熵损失L_CE和平均交并比 mIoU 损失L_mIoU的加权和:
Boundary Loss
边界损失 L_bnd考虑到边界检测中前景和背景像素之间的类别不平衡问题,因此采用Dice损失:
# https://github.com/xiaofang007/CTO/blob/main/CTOTrainer/network/CTO_net.py
class ConvBNR(nn.Module):
def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False):
super(ConvBNR, self).__init__()
self.block = nn.Sequential(
nn.Conv2d(inplanes, planes, kernel_size, stride=stride, padding=dilation, dilation=dilation, bias=bias),
def forward(self, x):
return self.block(x)
class Conv1x1(nn.Module):
def __init__(self, inplanes, planes):
super(Conv1x1, self).__init__()
self.conv = nn.Conv2d(inplanes, planes, 1)
self.bn = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
class EAM(nn.Module):
def __init__(self):
super(EAM, self).__init__()
self.reduce1 = Conv1x1(256, 64)
self.reduce4 = Conv1x1(512, 256)
self.block = nn.Sequential(
ConvBNR(320 + 64, 256, 3),
ConvBNR(256, 256, 3),
nn.Conv2d(256, 1, 1))
def forward(self, x1, x11, p2):
size = x1.size()[2:]
x1 = self.reduce1(x1)
x11 = self.reduce1(x11)
p2 = self.reduce4(p2)
p2 = F.interpolate(p2, size, mode='bilinear', align_corners=False)
out = torch.cat((x1, x11), dim=1)
out = torch.cat((out, p2), dim=1)
out = self.block(out)
return out
class EFM(nn.Module):
def __init__(self, channel):
super(EFM, self).__init__()
t = int(abs((log(channel, 2) + 1) / 2))
k = t if t % 2 else t + 1
self.conv2d = ConvBNR(channel, channel, 3)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv1d = nn.Conv1d(1, 1, kernel_size=k, padding=(k - 1) // 2, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, c, att):
if c.size() != att.size():
att = F.interpolate(att, c.size()[2:], mode='bilinear', align_corners=False)
x = c * att + c
x = self.conv2d(x)
wei = self.avg_pool(x)
wei = self.conv1d(wei.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
wei = self.sigmoid(wei)
x = x * wei
return x
class BasicConv2d(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1):
super(BasicConv2d, self).__init__()
self.conv = nn.Conv2d(in_planes, out_planes,
kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=False)
self.bn = nn.BatchNorm2d(out_planes)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
class DM(nn.Module):
def __init__(self):
super(DM, self).__init__()
self.predict3 = nn.Sequential(
nn.Conv2d(128, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU(),
nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU(),
nn.Conv2d(64, 1, kernel_size=1)
self.ra2_conv2 = BasicConv2d(64, 64, kernel_size=3, padding=1)
self.ra2_conv3 = BasicConv2d(64, 64, kernel_size=3, padding=1)
self.ra2_conv4 = BasicConv2d(64, 1, kernel_size=3, padding=1)
def forward(self, xr, dualattention):
crop_3 = F.interpolate(dualattention, xr.size()[2:], mode='bilinear', align_corners=False)
re3_feat = self.predict3(torch.cat([xr, crop_3], dim=1))
x = -1*(torch.sigmoid(crop_3)) + 1
x = x.expand(-1, 64, -1, -1).mul(xr)
x = F.relu(self.ra2_conv2(x))
x = F.relu(self.ra2_conv3(x))
ra3_feat = self.ra2_conv4(x)
x = ra3_feat + crop_3 + re3_feat
return x
class _DAHead(nn.Module):
def __init__(self, in_channels, nclass, aux=True, norm_layer=nn.BatchNorm2d, norm_kwargs=None, **kwargs):
super(_DAHead, self).__init__()
self.aux = aux
inter_channels = in_channels // 4
self.conv_p1 = nn.Sequential(
nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
norm_layer(inter_channels, **({} if norm_kwargs is None else norm_kwargs)),
self.conv_c1 = nn.Sequential(
nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
norm_layer(inter_channels, **({} if norm_kwargs is None else norm_kwargs)),
self.pam = _PositionAttentionModule(inter_channels, **kwargs)
self.cam = _ChannelAttentionModule(**kwargs)
self.conv_p2 = nn.Sequential(
nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False),
norm_layer(inter_channels, **({} if norm_kwargs is None else norm_kwargs)),
self.conv_c2 = nn.Sequential(
nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False),
norm_layer(inter_channels, **({} if norm_kwargs is None else norm_kwargs)),
self.out = nn.Sequential(
nn.Conv2d(inter_channels, nclass, 1)
if aux:
self.conv_p3 = nn.Sequential(
nn.Conv2d(inter_channels, nclass, 1)
self.conv_c3 = nn.Sequential(
nn.Conv2d(inter_channels, nclass, 1)
def forward(self, x):
feat_p = self.conv_p1(x)
feat_p = self.pam(feat_p)
feat_p = self.conv_p2(feat_p)
feat_c = self.conv_c1(x)
feat_c = self.cam(feat_c)
feat_c = self.conv_c2(feat_c)
feat_fusion = feat_p + feat_c
outputs = []
fusion_out = self.out(feat_fusion)
if self.aux:
p_out = self.conv_p3(feat_p)
c_out = self.conv_c3(feat_c)
return tuple(outputs)
def run_sobel(conv_x, conv_y, input):
g_x = conv_x(input)
g_y = conv_y(input)
g = torch.sqrt(torch.pow(g_x, 2) + torch.pow(g_y, 2))
return torch.sigmoid(g) * input
def get_sobel(in_chan, out_chan):
filter_x = np.array([
[3, 0, -3],
[10, 0, -10],
[3, 0, -3],
filter_y = np.array([
[3, 10, 3],
[0, 0, 0],
[-3, -10, -3],
filter_x = np.array([
[1, 0, -1],
[2, 0, -2],
[1, 0, -1],
filter_y = np.array([
[1, 2, 1],
[0, 0, 0],
[-1, -2, -1],
filter_x = filter_x.reshape((1, 1, 3, 3))
filter_x = np.repeat(filter_x, in_chan, axis=1)
filter_x = np.repeat(filter_x, out_chan, axis=0)
filter_y = filter_y.reshape((1, 1, 3, 3))
filter_y = np.repeat(filter_y, in_chan, axis=1)
filter_y = np.repeat(filter_y, out_chan, axis=0)
filter_x = torch.from_numpy(filter_x)
filter_y = torch.from_numpy(filter_y)
filter_x = nn.Parameter(filter_x, requires_grad=False)
filter_y = nn.Parameter(filter_y, requires_grad=False)
conv_x = nn.Conv2d(in_chan, out_chan, kernel_size=3, stride=1, padding=1, bias=False)
conv_x.weight = filter_x
conv_y = nn.Conv2d(in_chan, out_chan, kernel_size=3, stride=1, padding=1, bias=False)
conv_y.weight = filter_y
sobel_x = nn.Sequential(conv_x, nn.BatchNorm2d(out_chan))
sobel_y = nn.Sequential(conv_y, nn.BatchNorm2d(out_chan))
return sobel_x, sobel_y
class GlobalFilter(nn.Module):
def __init__(self, dim=32, h=64, w=33, fp32fft=True):
self.complex_weight = nn.Parameter(
torch.randn(h, w, dim, 2, dtype=torch.float32) * 0.02
self.w = w
self.h = h
self.fp32fft = fp32fft
def forward(self, x):
b, _, a, b = x.size()
x = x.permute(0, 2, 3, 1).contiguous()
if self.fp32fft:
dtype = x.dtype
x = x.to(torch.float32)
x = torch.fft.rfft2(x, dim=(1, 2), norm="ortho")
weight = torch.view_as_complex(self.complex_weight)
# print(x.shape)
x = x * weight
x = torch.fft.irfft2(x, s=(a, b), dim=(1, 2), norm="ortho")
if self.fp32fft:
x = x.to(dtype)
x = x.permute(0, 3, 1, 2).contiguous()
return x
class ERB(nn.Module):
def __init__(self, in_channels, out_channels):
super(ERB, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.relu = nn.ReLU()
self.bn = nn.BatchNorm2d(out_channels)
self.conv3 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x, relu=True):
x = self.conv1(x)
res = self.conv2(x)
res = self.bn(res)
res = self.relu(res)
res = self.conv3(res)
if relu:
return self.relu(x + res)
return x+res
class _PositionAttentionModule(nn.Module):
""" Position attention module"""
def __init__(self, in_channels, **kwargs):
super(_PositionAttentionModule, self).__init__()
self.conv_b = nn.Conv2d(in_channels, in_channels // 8, 1)
self.conv_c = nn.Conv2d(in_channels, in_channels // 8, 1)
self.conv_d = nn.Conv2d(in_channels, in_channels, 1)
self.alpha = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
batch_size, _, height, width = x.size()
feat_b = self.conv_b(x).view(batch_size, -1, height * width).permute(0, 2, 1)
feat_c = self.conv_c(x).view(batch_size, -1, height * width)
attention_s = self.softmax(torch.bmm(feat_b, feat_c))
feat_d = self.conv_d(x).view(batch_size, -1, height * width)
feat_e = torch.bmm(feat_d, attention_s.permute(0, 2, 1)).view(batch_size, -1, height, width)
out = self.alpha * feat_e + x
return out
class _ChannelAttentionModule(nn.Module):
"""Channel attention module"""
def __init__(self, **kwargs):
super(_ChannelAttentionModule, self).__init__()
self.beta = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
batch_size, _, height, width = x.size()
feat_a = x.view(batch_size, -1, height * width)
feat_a_transpose = x.view(batch_size, -1, height * width).permute(0, 2, 1)
attention = torch.bmm(feat_a, feat_a_transpose)
attention_new = torch.max(attention, dim=-1, keepdim=True)[0].expand_as(attention) - attention
attention = self.softmax(attention_new)
feat_e = torch.bmm(attention, feat_a).view(batch_size, -1, height, width)
out = self.beta * feat_e + x
return out
class EAM(nn.Module):
def __init__(self):
super(EAM, self).__init__()
self.reduce1 = Conv1x1(256, 64)
self.reduce4 = Conv1x1(2048, 256)
self.block = nn.Sequential(
ConvBNR(256 + 64, 256, 3),
ConvBNR(256, 256, 3),
nn.Conv2d(256, 1, 1))
def forward(self, x4, x1):
size = x1.size()[2:]
x1 = self.reduce1(x1)
x4 = self.reduce4(x4)
x4 = F.interpolate(x4, size, mode='bilinear', align_corners=False)
out = torch.cat((x4, x1), dim=1)
out = self.block(out)
return out
def attention(query, key, value):
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(
p_attn = F.softmax(scores, dim=-1)
p_val = torch.matmul(p_attn, value)
return p_val, p_attn
class MultiHeadedAttention(nn.Module):
Take in model size and number of heads.
def __init__(self, patchsize, d_model):
self.patchsize = patchsize
self.query_embedding = nn.Conv2d(
d_model, d_model, kernel_size=1, padding=0
self.value_embedding = nn.Conv2d(
d_model, d_model, kernel_size=1, padding=0
self.key_embedding = nn.Conv2d(
d_model, d_model, kernel_size=1, padding=0
self.output_linear = nn.Sequential(
nn.Conv2d(d_model, d_model, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, inplace=True),
def forward(self, x):
b, c, h, w = x.size()#8,255,64,64
d_k = c // len(self.patchsize)
output = []
_query = self.query_embedding(x)#8,32,80,80
_key = self.key_embedding(x)#8,32,80,80
_value = self.value_embedding(x)#8,32,80,80
attentions = []
for (width, height), query, key, value in zip(
torch.chunk(_query, len(self.patchsize), dim=1),
torch.chunk(_key, len(self.patchsize), dim=1),
torch.chunk(_value, len(self.patchsize), dim=1),
#print('-----------width, height):',x.size())
# print('-----------x.size()):',x.size())
#print('-----------len(self.patchsize):',len(self.patchsize)) # 4
#print('-----------_query):',_query.shape) #8,256,64,64
#print('-----------query):',query.shape) #8,64,64,64
out_w, out_h = w // width, h // height#
## 1) embedding and reshape
query = query.view(b, d_k, out_h, height, out_w, width)
# print('-----------query):',query.shape)
# print('%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%')
query = (
query.permute(0, 2, 4, 1, 3, 5)
.view(b, out_h * out_w, d_k * height * width)
key = key.view(b, d_k, out_h, height, out_w, width)
key = (
key.permute(0, 2, 4, 1, 3, 5)
.view(b, out_h * out_w, d_k * height * width)
value = value.view(b, d_k, out_h, height, out_w, width)
value = (
value.permute(0, 2, 4, 1, 3, 5)
.view(b, out_h * out_w, d_k * height * width)
y, _ = attention(query, key, value)
# 3) "Concat" using a view and apply a final linear.
y = y.view(b, out_h, out_w, d_k, height, width)
y = y.permute(0, 3, 1, 4, 2, 5).contiguous().view(b, d_k, h, w)
output = torch.cat(output, 1)
self_attention = self.output_linear(output)
return self_attention
class TransformerBlock(nn.Module):
Transformer = MultiHead_Attention + Feed_Forward with sublayer connection
def __init__(self, patchsize, in_channel=256):
self.attention = MultiHeadedAttention(patchsize, d_model=in_channel)
self.feed_forward = FeedForward2D(
in_channel=in_channel, out_channel=in_channel
def forward(self, rgb):
self_attention = self.attention(rgb)
output = rgb + self_attention
output = output + self.feed_forward(output)
return output
class PatchTrans(BaseNetwork):
def __init__(self, in_channel, in_size):#32,80
super(PatchTrans, self).__init__()
self.in_size = in_size#80
patchsize = [
self.t = TransformerBlock(patchsize, in_channel=in_channel)
def forward(self, enc_feat):
output = self.t(enc_feat)
return output
class multi(nn.Module):
def __init__(self, channel):
super(EFM, self).__init__()
t = int(abs((log(channel, 2) + 1) / 2))
k = t if t % 2 else t + 1
self.conv2d = ConvBNR(channel, channel, 3)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv1d = nn.Conv1d(1, 1, kernel_size=k, padding=(k - 1) // 2, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, c, att):
if c.size() != att.size():
att = F.interpolate(att, c.size()[2:], mode='bilinear', align_corners=False)
x = c * att
#x = self.conv2d(x)
#wei = self.avg_pool(x)
#wei = self.conv1d(wei.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
#wei = self.sigmoid(wei)
#x = x * wei
return x
class CTO(nn.Module):
def __init__(self,seg_classes):
super(CTO, self).__init__()
self.resnet = res2net50_v1b_26w_4s(pretrained=True)
# if self.training:
# self.initialize_weights()
self.fft = GlobalFilter(dim = 3 , h=256, w=129, fp32fft= True)
self.multi_trans = PatchTrans(in_channel=256,in_size=64)
self.num_class = seg_classes
self.eam = EAM()
self.sobel_x1, self.sobel_y1 = get_sobel(256, 1)
self.sobel_x2, self.sobel_y2 = get_sobel(512, 1)
self.sobel_x3, self.sobel_y3 = get_sobel(1024, 1)
self.sobel_x4, self.sobel_y4 = get_sobel(2048, 1)
self.upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
self.upsample_4 = nn.Upsample(scale_factor=4, mode="bilinear", align_corners=True)
self.upsample_3 = nn.Upsample(scale_factor=8, mode="bilinear", align_corners=True)
self.erb_db_1 = ERB(256, self.num_class)
self.erb_db_2 = ERB(512, self.num_class)
self.erb_db_3 = ERB(1024, self.num_class)
self.erb_db_4 = ERB(2048, self.num_class)
self.head = _DAHead(2048+256, 2048, aux=False)
self.reduce1 = Conv1x1(256, 64)
self.reduce2 = Conv1x1(512, 64)
self.reduce3 = Conv1x1(1024, 64)
self.reduce4 = Conv1x1(2048, 64)
self.reduce5 = Conv1x1(2048, 1)
self.dm1 = DM()
self.dm2 = DM()
self.dm3 = DM()
self.dm4 = DM()
self.predictor1 = nn.Conv2d(64, self.num_class, 1)
self.predictor2 = nn.Conv2d(64, self.num_class, 1)
self.predictor3 = nn.Conv2d(64, self.num_class, 1)
self.predictor4 = nn.Conv2d(64, self.num_class, 1)
# def initialize_weights(self):
# model_state = torch.load('./models/resnet50-19c8e357.pth')
# self.resnet.load_state_dict(model_state, strict=False)
def forward(self, x):
fft_fea = self.fft(x)#3,256,256
x1, x2, x3 ,x4= self.resnet(x)#[16, 256, 64, 64] [16, 512, 32, 32] [16, 1024, 16, 16] [16, 2048, 8, 8]
trans = self.multi_trans(x1)#16,256,64,64
s1 = run_sobel(self.sobel_x1, self.sobel_y1, x1)
s4 = run_sobel(self.sobel_x4, self.sobel_y4, x4)
edge = self.eam(s4, s1)
edge_att = torch.sigmoid(edge)#[16, 1, 64, 64]
trans = F.interpolate(trans,x4.size()[2:], mode='bilinear', align_corners=False)#256,8,8
dual_attention = self.head(torch.cat([trans, x4], dim=1))[0] #2048,8,8
x1a = x1*edge_att
edge_att2 = F.interpolate(edge_att, x2.size()[2:], mode='bilinear', align_corners=False)
x2a = x2*edge_att2
edge_att3 = F.interpolate(edge_att, x3.size()[2:], mode='bilinear', align_corners=False)
x3a = x3*edge_att3
#x1a = self.efm1(x1, edge_att)
#x2a = self.efm2(x2, edge_att)
# x3a = self.efm3(x3, edge_att)
# x4a = self.efm4(x4, edge_att)
x1r = self.reduce1(x1a)
x2r = self.reduce2(x2a)#128,32,32
x3r = self.reduce3(x3a)#256,16,16
dual_attention = self.reduce4(dual_attention)
c3 = self.dm3(x3r, dual_attention) #256 16 16
c2 = self.dm2(x2r, c3) #128 32 32
c1 = self.dm1(x1r, c2) #64 64 64
o3 = self.predictor3(c3)
o3 = F.interpolate(o3, scale_factor=16, mode='bilinear', align_corners=False)
o2 = self.predictor2(c2)
o2 = F.interpolate(o2, scale_factor=8, mode='bilinear', align_corners=False)
o1 = self.predictor1(c1)
o1 = F.interpolate(o1, scale_factor=4, mode='bilinear', align_corners=False)
oe = F.interpolate(edge_att, scale_factor=4, mode='bilinear', align_corners=False)
return o3, o2, o1, oe