最先进的语义分割方法几乎只针对固定分辨率范围内的图像进行训练。这些分割对于非常高分辨率的图像来说是不准确的,因为使用 bicubic 上采样的低分辨率分割方法不能充分捕捉沿物体边界的高分辨率细节。在本文中,提出了一种无需使用任何高分辨率训练数据即可解决高分辨率分割问题的新颖方法——Cascade PSP网络,该网络会在可能的情况下完善和纠正局部边界。尽管该网络是用低分辨率分割数据训练的,但即使对于大于4K的高分辨率图像也同样适用。经过在不同的数据集上进行定量和定性研究,实验表明CascadePSP可以在不进行任何微调的情况下分割出像素精确的分割边界。
随着4K UHD(3840×2160)成为新的行业标准,商用相机和显示器的分辨率已显着提高。尽管对高分辨率媒体有需求,但是许多先进的计算机视觉算法在具有高像素数量的图像方面仍面临各种挑战。图像语义分割是这些计算机视觉任务之一,针对低分辨率图像(例如PASCAL或COCO数据集)设计的深度学习语义分割模型通常无法推广到更高分辨率的场景。具体来说,这些模型通常使用与像素数成线性关系的GPU内存,因此实际上不可能直接训练4K UHD高分辨率图像。同时,由于高分辨率图像的标注困难,导致训练所需的数据难以获取。目前,解决4K分辨率的图像的分割主要手段有下采样和剪裁两种,但是下采样消除了细节信息,而剪裁则破坏了图像的上下文信息。
为了对非常高分辨率的图像进行评估,首先为高分辨率数据集添加了标注,该数据集具有50个验证对象和100个测试对象,并具有与PASCAL中相同的语义类别,称为BIG数据集。最终在PASCALVOC 2012,BIG和ADE20K数据集上测试了模型。实验证明,不必针对特定数据集或特定模型的输出来训练CascadePSP的模型,相反,通过干扰ground truth来执行数据扩充就足够了。CascadePSP模型还可以扩展为具有直接适应的密集多类语义分割的场景解析任务。
1、Refinement Module
如图2所示,优化模块用不同的比例拍摄图像和多个不完美的分割mask来生成精确的分割。 多尺度输入使模型能够捕获不同层次的结构和边界信息,使网络能够学习自适应地融合不同尺度的mask特征,在最精细的层次上完善分割。
所有较低分辨率的输入分割均被双线性上采样为相同大小,并与RGB图像连接在一起。使用ResNet-50作为主干网络的PSPNet来从输入中提取步长为8的特征图。其中金字塔池化大小为[1,2,3,6],这有助于捕获全局上下文。除了最终的OS 1输出之外,模型还生成了OS 8和OS 4的分割结果,并跳过OS 2来提供纠正局部错误边界的灵活性。(OS指输出特征图的分辨率与输入图像分辨率的比值)
为了重建在提取过程中丢失的像素级图像细节,采用了来自主干网络的skip connection,并使用上采样模块融合了特征。将skip分支的特征和来自主干分支的双线性上采样的特征连接起来,并使用两个ResNet块对其进行处理。使用2层1×1conv生成分割输出,然后进行sigmoid函数激活。
对于较粗糙的OS 8,使用交叉熵损失,对于较精细的OS 1,使用L1 + L2损失,对于OS 4,使用交叉熵和L1 + L2损失的平均值,可以得出最佳结果。为了进行更好的边界细化,在OS 1上还采用了分段梯度幅度上的L1损失。分割梯度由一个3×3均值滤波器和Sobel算子组成。
为了强调边界精度的感知重要性,提出了一种新的评价指标mean Boundary Accuracy (mBA)。
class RefinementModule(nn.Module):
def __init__(self):
self.feats = extractors.resnet50()
self.psp = PSPModule(2048, 1024, (1, 2, 3, 6))
self.up_1 = PSPUpsample(1024, 1024+256, 512)
self.up_2 = PSPUpsample(512, 512+64, 256)
self.up_3 = PSPUpsample(256, 256+3, 32)
self.final_28 = nn.Sequential(
nn.Conv2d(1024, 32, kernel_size=1),
nn.Conv2d(32, 1, kernel_size=1),
self.final_56 = nn.Sequential(
nn.Conv2d(512, 32, kernel_size=1),
nn.Conv2d(32, 1, kernel_size=1),
self.final_11 = nn.Conv2d(32+3, 32, kernel_size=1)
self.final_21 = nn.Conv2d(32, 1, kernel_size=1)
def forward(self, x, seg, inter_s8=None, inter_s4=None):
images = {}
First iteration, s8 output
if inter_s8 is None:
p = torch.cat((x, seg, seg, seg), 1)
f, f_1, f_2 = self.feats(p)
p = self.psp(f)
inter_s8 = self.final_28(p)
r_inter_s8 = F.interpolate(inter_s8, scale_factor=8, mode='bilinear', align_corners=False)
r_inter_tanh_s8 = torch.tanh(r_inter_s8)
images['pred_28'] = torch.sigmoid(r_inter_s8)
images['out_28'] = r_inter_s8
r_inter_tanh_s8 = inter_s8
Second iteration, s8 output
if inter_s4 is None:
p = torch.cat((x, seg, r_inter_tanh_s8, r_inter_tanh_s8), 1)
f, f_1, f_2 = self.feats(p)
p = self.psp(f)
inter_s8_2 = self.final_28(p)
r_inter_s8_2 = F.interpolate(inter_s8_2, scale_factor=8, mode='bilinear', align_corners=False)
r_inter_tanh_s8_2 = torch.tanh(r_inter_s8_2)
p = self.up_1(p, f_2)
inter_s4 = self.final_56(p)
r_inter_s4 = F.interpolate(inter_s4, scale_factor=4, mode='bilinear', align_corners=False)
r_inter_tanh_s4 = torch.tanh(r_inter_s4)
images['pred_28_2'] = torch.sigmoid(r_inter_s8_2)
images['out_28_2'] = r_inter_s8_2
images['pred_56'] = torch.sigmoid(r_inter_s4)
images['out_56'] = r_inter_s4
r_inter_tanh_s8_2 = inter_s8
r_inter_tanh_s4 = inter_s4
Third iteration, s1 output
p = torch.cat((x, seg, r_inter_tanh_s8_2, r_inter_tanh_s4), 1)
f, f_1, f_2 = self.feats(p)
p = self.psp(f)
inter_s8_3 = self.final_28(p)
r_inter_s8_3 = F.interpolate(inter_s8_3, scale_factor=8, mode='bilinear', align_corners=False)
p = self.up_1(p, f_2)
inter_s4_2 = self.final_56(p)
r_inter_s4_2 = F.interpolate(inter_s4_2, scale_factor=4, mode='bilinear', align_corners=False)
p = self.up_2(p, f_1)
p = self.up_3(p, x)
Final output
p = F.relu(self.final_11(torch.cat([p, x], 1)), inplace=True)
p = self.final_21(p)
pred_224 = torch.sigmoid(p)
images['pred_224'] = pred_224
images['out_224'] = p
images['pred_28_3'] = torch.sigmoid(r_inter_s8_3)
images['pred_56_2'] = torch.sigmoid(r_inter_s4_2)
images['out_28_3'] = r_inter_s8_3
images['out_56_2'] = r_inter_s4_2
return images
2、Global and Local Cascade Refinement
Global Step
级联的输入使用input segmentation进行初始化,并进行复制以保持输入通道尺寸固定。在级联的第一级之后,其中一个输入通道将被双线性向上采样的粗略输出代替。重复此操作直到最后一级,其中输入既包含初始的分割结果,也包含来自先前级别的所有输出。
Local Step
在局部步骤中,模型采用全局步骤最后一级的两个输出,两个输出均被线性调整为图像的原始尺寸W×H。模型对图像进行尺寸为L×L的裁剪,裁剪输出的每边将被削去16个像素,以避免边界伪影,但图像边界处有例外。裁剪的步幅统一为L/2-32,这样大部分像素被4个裁剪覆盖,超出图像边界的无效裁剪被移位到与图像的最后一行/一列对齐。然后将图像裁剪送入2级级联,步幅分别为4和1。在融合过程中,由于图像上下文不同,不同补丁的输出可能会有差异,我们通过对所有输出值进行平均来解决这个问题。 对于分辨率更高的图像,可以采用从粗到细的方式递归应用局部步骤。
def resize_max_side(im, size, method):
h, w = im.shape[-2:]
max_side = max(h, w)
ratio = size / max_side
if method in ['bilinear', 'bicubic']:
return F.interpolate(im, scale_factor=ratio, mode=method, align_corners=False)
return F.interpolate(im, scale_factor=ratio, mode=method)
def safe_forward(model, im, seg, inter_s8=None, inter_s4=None):
Slightly pads the input image such that its length is a multiple of 8
b, _, ph, pw = seg.shape
if (ph % 8 != 0) or (pw % 8 != 0):
newH = ((ph//8+1)*8)
newW = ((pw//8+1)*8)
p_im = torch.zeros(b, 3, newH, newW, device=im.device)
p_seg = torch.zeros(b, 1, newH, newW, device=im.device) - 1
p_im[:,:,0:ph,0:pw] = im
p_seg[:,:,0:ph,0:pw] = seg
im = p_im
seg = p_seg
if inter_s8 is not None:
p_inter_s8 = torch.zeros(b, 1, newH, newW, device=im.device) - 1
p_inter_s8[:,:,0:ph,0:pw] = inter_s8
inter_s8 = p_inter_s8
if inter_s4 is not None:
p_inter_s4 = torch.zeros(b, 1, newH, newW, device=im.device) - 1
p_inter_s4[:,:,0:ph,0:pw] = inter_s4
inter_s4 = p_inter_s4
images = model(im, seg, inter_s8, inter_s4)
return_im = {}
for key in ['pred_224', 'pred_28_3', 'pred_56_2']:
return_im[key] = images[key][:,:,0:ph,0:pw]
del images
return return_im
def process_high_res_im(model, im, seg, L=900):
stride = L//2
_, _, h, w = seg.shape
Global Step
if max(h, w) > L:
im_small = resize_max_side(im, L, 'area')
seg_small = resize_max_side(seg, L, 'area')
elif max(h, w) < L:
im_small = resize_max_side(im, L, 'bicubic')
seg_small = resize_max_side(seg, L, 'bilinear')
im_small = im
seg_small = seg
images = safe_forward(model, im_small, seg_small)
pred_224 = images['pred_224']
pred_56 = images['pred_56_2']
Local step
for new_size in [max(h, w)]:
im_small = resize_max_side(im, new_size, 'area')
seg_small = resize_max_side(seg, new_size, 'area')
_, _, h, w = seg_small.shape
combined_224 = torch.zeros_like(seg_small)
combined_weight = torch.zeros_like(seg_small)
r_pred_224 = (F.interpolate(pred_224, size=(h, w), mode='bilinear', align_corners=False)>0.5).float()*2-1
r_pred_56 = F.interpolate(pred_56, size=(h, w), mode='bilinear', align_corners=False)*2-1
padding = 16
step_size = stride - padding*2
step_len = L
used_start_idx = {}
for x_idx in range((w)//step_size+1):
for y_idx in range((h)//step_size+1):
start_x = x_idx * step_size
start_y = y_idx * step_size
end_x = start_x + step_len
end_y = start_y + step_len
# Shift when required
if end_y > h:
end_y = h
start_y = h - step_len
if end_x > w:
end_x = w
start_x = w - step_len
# Bound x/y range
start_x = max(0, start_x)
start_y = max(0, start_y)
end_x = min(w, end_x)
end_y = min(h, end_y)
# The same crop might appear twice due to bounding/shifting
start_idx = start_y*w + start_x
if start_idx in used_start_idx:
used_start_idx[start_idx] = True
# Take crop
im_part = im_small[:,:,start_y:end_y, start_x:end_x]
seg_224_part = r_pred_224[:,:,start_y:end_y, start_x:end_x]
seg_56_part = r_pred_56[:,:,start_y:end_y, start_x:end_x]
# Skip when it is not an interesting crop anyway
seg_part_norm = (seg_224_part>0).float()
high_thres = 0.9
low_thres = 0.1
if (seg_part_norm.mean() > high_thres) or (seg_part_norm.mean() < low_thres):
grid_images = safe_forward(model, im_part, seg_224_part, seg_56_part)
grid_pred_224 = grid_images['pred_224']
# Padding
pred_sx = pred_sy = 0
pred_ex = step_len
pred_ey = step_len
if start_x != 0:
start_x += padding
pred_sx += padding
if start_y != 0:
start_y += padding
pred_sy += padding
if end_x != w:
end_x -= padding
pred_ex -= padding
if end_y != h:
end_y -= padding
pred_ey -= padding
combined_224[:,:,start_y:end_y, start_x:end_x] += grid_pred_224[:,:,pred_sy:pred_ey,pred_sx:pred_ex]
del grid_pred_224
# Used for averaging
combined_weight[:,:,start_y:end_y, start_x:end_x] += 1
# Final full resolution output
seg_norm = (r_pred_224/2+0.5)
pred_224 = combined_224 / combined_weight
pred_224 = torch.where(combined_weight==0, seg_norm, pred_224)
_, _, h, w = seg.shape
images = {}
images['pred_224'] = F.interpolate(pred_224, size=(h, w), mode='bilinear', align_corners=True)
return images['pred_224']
def process_im_single_pass(model, im, seg, L=900):
A single pass version, aka global step only.
_, _, h, w = im.shape
if max(h, w) < L:
im = resize_max_side(im, L, 'bicubic')
seg = resize_max_side(seg, L, 'bilinear')
if max(h, w) > L:
im = resize_max_side(im, L, 'area')
seg = resize_max_side(seg, L, 'area')
images = safe_forward(model, im, seg)
if max(h, w) < L:
images['pred_224'] = F.interpolate(images['pred_224'], size=(h, w), mode='area')
elif max(h, w) > L:
images['pred_224'] = F.interpolate(images['pred_224'], size=(h, w), mode='bilinear', align_corners=True)
return images['pred_224']
数据集: PASCAL VOC 2012 , BIG (our high-resolutiondata set), ADE20K
Baseline:PSPNet with ResNet-50
训练方法:在训练过程中,随机抽取224×224个图像crop,并通过扰动ground truth来生成输入分割。具体扰动方法如下图所示: