Spatial Path
绿色部分表示空间路径,每一层包括一个stride = 2的卷积,接着是批处理归一化和ReLU激活函数,总共三层,故而提取的特征图尺寸是原始图像的1/8。
'''code'''
class SpatialPath(nn.Module):
def __init__(self, *args, **kwargs):
super(SpatialPath, self).__init__()
self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3)
self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0)
self.init_weight()
def forward(self, x):
feat = self.conv1(x) # (N, 3, H, W)
feat = self.conv2(feat) # (N, 64, H/2, W/2)
feat = self.conv3(feat) # (N, 64, H/4, W/4)
feat = self.conv_out(feat) # (N, 128, H/8, W/8)
return feat
Context Path
第二个虚线框部分是上下文路径,用于提取上下文信息,利用轻量级模型和全局平均池进行下采样。作者在轻量级模型的尾部添加一个全局平均池,提供具有全局上下文信息的最大接收字段, 并且使用U型结构来融合最后两个阶段的特征,这是一种不完整的U型结构。作者使用了Xception作为上下文路径的主干。
'''code'''
class ContextPath(nn.Module):
def __init__(self, *args, **kwargs):
super(ContextPath, self).__init__()
self.resnet = Resnet18()
self.arm16 = AttentionRefinementModule(256, 128) # 先看下面的ARM的代码
self.arm32 = AttentionRefinementModule(512, 128)
self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
self.up32 = nn.Upsample(scale_factor=2.) # 上采样 X2
self.up16 = nn.Upsample(scale_factor=2.)
self.init_weight()
def forward(self, x):
feat8, feat16, feat32 = self.resnet(x)
'''
feat8 : (N, 128, H/8, W/8)
feat16 : (N, 256, H/16, W/16)
feat32 : (N, 512, H/32, W/32)
'''
avg = torch.mean(feat32, dim=(2, 3), keepdim=True) # 全局平均池化
avg = self.conv_avg(avg) # (N, 128, 1, 1)
feat32_arm = self.arm32(feat32) # (N, 128, 1, 1)
feat32_sum = feat32_arm + avg # (N, 128, H/32, W/32)
feat32_up = self.up32(feat32_sum) # (N, 128, H/16, W/16)
feat32_up = self.conv_head32(feat32_up)
feat16_arm = self.arm16(feat16) # (N, 256, H/16, W/16)
feat16_sum = feat16_arm + feat32_up # (N, 256, H/16, W/16)
feat16_up = self.up16(feat16_sum) # (N, 256, H/8, W/8)
feat16_up = self.conv_head16(feat16_up)
return feat16_up, feat32_up # x8, x16
AttentionRefinementModule
'''ARM Code'''
class AttentionRefinementModule(nn.Module): # ARM
def __init__(self, in_chan, out_chan, *args, **kwargs):
super(AttentionRefinementModule, self).__init__()
self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1) # 一个拥有Conv,BN,RELU的block
self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False)
self.bn_atten = BatchNorm2d(out_chan)
self.sigmoid_atten = nn.Sigmoid()
self.init_weight()
def forward(self, x): # 假设输入为(N, 3, 14, 14)
feat = self.conv(x) # (N, 64, 14, 14)
atten = torch.mean(feat, dim=(2, 3), keepdim=True) # (N, 64, 1, 1)
atten = self.conv_atten(atten) # (N, 64, 1, 1)
atten = self.bn_atten(atten) # (N, 64, 1, 1)
atten = self.sigmoid_atten(atten) # (N, 64, 1, 1)
out = torch.mul(feat, atten) # (N, 64, 14, 14)
return out
FeatureFusionModule
FFM:作者认为空间路径可以编码丰富的空间信息和细节信息,而上下文路径提供大的接受场,主要对上下文信息进行编码,也就是说空间路径的输出是低水平而上下文路径的输出是高水平的,2条路径的特征在特征表示的层次上是不同的。因此提出了一个特征融合模块用于融合这些特征。作者仿照SENet(通道注意力机制对特征进行重新加权),即特征的选择和组合。
'''code'''
class FeatureFusionModule(nn.Module):
def __init__(self, in_chan, out_chan, *args, **kwargs):
super(FeatureFusionModule, self).__init__()
self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
self.conv1 = nn.Conv2d( out_chan,
out_chan//4,
kernel_size = 1,
stride = 1,
padding = 0,
bias = False)
self.conv2 = nn.Conv2d( out_chan//4,
out_chan,
kernel_size = 1,
stride = 1,
padding = 0,
bias = False)
self.relu = nn.ReLU(inplace=True)
self.sigmoid = nn.Sigmoid()
self.init_weight()
def forward(self, fsp, fcp):
fcat = torch.cat([fsp, fcp], dim=1) # concat the feature from Spatial Path and Context Path
feat = self.convblk(fcat) # (N, C, H, W)
atten = torch.mean(feat, dim=(2, 3), keepdim=True) # (N, C, 1, 1)
atten = self.conv1(atten) # (N, C/4, 1, 1)
atten = self.relu(atten)
atten = self.conv2(atten) # (N, C, 1, 1)
atten = self.sigmoid(atten)
feat_atten = torch.mul(feat, atten) # (N, C, H, W)
feat_out = feat_atten + feat
return feat_out
BiSeNet V1
class BiSeNetV1(nn.Module):
def __init__(self, n_classes, output_aux=True, *args, **kwargs):
super(BiSeNetV1, self).__init__()
self.cp = ContextPath()
self.sp = SpatialPath()
self.ffm = FeatureFusionModule(256, 256)
self.conv_out = BiSeNetOutput(256, 256, n_classes, up_factor=8)
''' BiSeNetOutput: input: (c, h, w) output:(n_classes, 8h, 8w) '''
self.output_aux = output_aux
if self.output_aux:
self.conv_out16 = BiSeNetOutput(128, 64, n_classes, up_factor=8)
self.conv_out32 = BiSeNetOutput(128, 64, n_classes, up_factor=16)
self.init_weight()
def forward(self, x):
H, W = x.size()[2:]
feat_cp8, feat_cp16 = self.cp(x) # (N, 256, H/8, W/8), (N, 256, H/16, W/16)
feat_sp = self.sp(x) # (N, 256, H/8, W/8)
feat_fuse = self.ffm(feat_sp, feat_cp8) # (N, 256, H/8, W/8)
feat_out = self.conv_out(feat_fuse) # (N, 19, H, W)
if self.output_aux:
feat_out16 = self.conv_out16(feat_cp8) # (N, 19, H, W)
feat_out32 = self.conv_out32(feat_cp16) # (N, 19, H, W)
return feat_out, feat_out16, feat_out32
# feat_out = feat_out.argmax(dim=1)
return feat_out
class OhemCELoss(nn.Module):
def __init__(self, thresh, ignore_lb=255):
super(OhemCELoss, self).__init__()
self.thresh = -torch.log(torch.tensor(thresh, requires_grad=False, dtype=torch.float)) # .cuda()
self.ignore_lb = ignore_lb
self.criteria = nn.CrossEntropyLoss(ignore_index=ignore_lb, reduction='none')
def forward(self, logits, labels):
n_min = labels[labels != self.ignore_lb].numel() // 16
loss = self.criteria(logits, labels).view(-1)
loss_hard = loss[loss > self.thresh]
if loss_hard.numel() < n_min:
loss_hard, _ = loss.topk(n_min)
return torch.mean(loss_hard)
criteria_pre = OhemCELoss(0.7)
criteria_aux = [OhemCELoss(0.7) for _ in range(2)] # to cal the loss of feat_out16 and feature out 32