具体为先根据粗糙预测出来的mask,将mask按类别预测分数排序,选出分数高的前2 类别的mask,计算出在2个类别mask上均有较高得分的Top K个像素点作为K 个不确定点【1个像素点只能对应1个类别,如果它对应2个类别的分数都很高,说明它很可能是边界点,也是不确定的】
def sampling_points(mask, N, k=3, beta=0.75, training=True):
:param mask: 粗糙的预测结果(out) eg.[2, 19, 48, 48]
:param N: 不确定点个数(train:N = 图片的尺寸/16, test: N = 8096) eg. N=48
:param k: 超参
:param beta: 超参
:param training:
:return: 不确定点的位置坐标 eg.[2, 48, 2]
assert mask.dim() == 4, "Dim must be N(Batch)CHW" #this mask is out(coarse)
device = mask.device
B, _, H, W = mask.shape #first: mask[1, 19, 48, 48]
mask, _ = mask.sort(1, descending=True) #_ : [1, 19, 48, 48],按照每一类的总体得分排序
if not training:
H_step, W_step = 1 / H, 1 / W
N = min(H * W, N)
uncertainty_map = -1 * (mask[:, 0] - mask[:, 1])
#mask[:, 0]表示每个像素最有可能的分类,mask[:, 1]表示每个像素次有可能的分类,当一个像素
_, idx = uncertainty_map.view(B, -1).topk(N, dim=1) #id选出最不好预测的N个点
points = torch.zeros(B, N, 2, dtype=torch.float, device=device)
points[:, :, 0] = W_step / 2.0 + (idx % W).to(torch.float) * W_step #点的横坐标
points[:, :, 1] = H_step / 2.0 + (idx // W).to(torch.float) * H_step #点的纵坐标
return idx, points #idx:48 || points:[1, 48, 2]
def forward(self, x, res2, out):
通过 out(粗糙预测)计算出top N 个不稳定的像素点,针对每个不稳定像素点得到在res2(fine)
:param x: 表示输入图片的特征 eg.[2, 3, 768, 768]
:param res2: 表示xception的第一层特征输出 eg.[2, 256, 192, 192]
:param out: 表示经过级联空洞卷积提取的特征的粗糙预测 eg.[2, 19, 48, 48]
:return: rend:更准确的预测,points:不确定像素点的位置
1. Fine-grained features are interpolated from res2 for DeeplabV3
2. During training we sample as many points as there are on a stride 16 feature map of the input
3. To measure prediction uncertainty
we use the same strategy during training and inference: the difference between the most
confident and second most confident class probabilities.
if not self.training:
return self.inference(x, res2, out)
points = sampling_points(out, x.shape[-1] // 16, self.k, self.beta) #out:[2, 19, 48, 48] || x:[2, 3, 768, 768] || points:[2, 48, 2]
coarse = point_sample(out, points, align_corners=False) #[2, 19, 48]
fine = point_sample(res2, points, align_corners=False) #[2, 256, 48]
feature_representation = torch.cat([coarse, fine], dim=1) #[2, 275, 48]
rend = self.mlp(feature_representation) #[2, 19, 48]
return {"rend": rend, "points": points}
def inference(self, x, res2, out):
x:[1, 3, 768, 768],表示输入图片的特征
res2:[1, 256, 192, 192],表示xception的第一层特征输出
out:[1, 19, 48, 48],表示经过级联空洞卷积提取的特征的粗糙预测
通过 out计算出top N = 8096 个不稳定的像素点,针对每个不稳定像素点得到在res2(fine)
During inference, subdivision uses N=8096
(i.e., the number of points in the stride 16 map of a 1024×2048 image)
num_points = 8096
while out.shape[-1] != x.shape[-1]: #out:[1, 19, 48, 48], x:[1, 3, 768, 768]
out = F.interpolate(out, scale_factor=2, mode="bilinear", align_corners=True) #out[1, 19, 48, 48]
points_idx, points = sampling_points(out, num_points, training=self.training) #points_idx:8096 || points:[1, 8096, 2]
coarse = point_sample(out, points, align_corners=False) #coarse:[1, 19, 8096] 表示8096个不稳定像素点根据高级特征得出的对应的类别
fine = point_sample(res2, points, align_corners=False) #fine:[1, 256, 8096] 表示8096个不稳定像素点根据低级特征得出的对应类别
feature_representation = torch.cat([coarse, fine], dim=1) #[1, 275, 8096] 表示8096个不稳定像素点合并fine和coarse的特征
rend = self.mlp(feature_representation) #[1, 19, 8096]
B, C, H, W = out.shape #first:[1, 19, 128, 256]
points_idx = points_idx.unsqueeze(1).expand(-1, C, -1) #[1, 19, 8096]
out = (out.reshape(B, C, -1)
.scatter_(2, points_idx, rend) #[1, 19, 32768]
.view(B, C, H, W)) #[1, 19, 128, 256]
return {"fine": out}
class PointRendLoss(nn.CrossEntropyLoss):
def __init__(self, aux=True, aux_weight=0.2, ignore_index=-1, **kwargs):
super(PointRendLoss, self).__init__(ignore_index=ignore_index)
self.aux = aux
self.aux_weight = aux_weight
self.ignore_index = ignore_index
def forward(self, *inputs, **kwargs):
result, gt = tuple(inputs)
#result['res2']: [2, 256, 192, 192], 即xception的c1层提取到的特征
#result['coarse']: [2, 19, 48, 48]
#result['rend']: [2, 19, 48]
#result['points']:[2, 48, 2]
#gt:[2, 768, 768], 即图片对应的label
#pred:[2, 19, 768, 768],将粗糙预测的插值到label大小
pred = F.interpolate(result["coarse"], gt.shape[-2:], mode="bilinear", align_corners=True)
seg_loss = F.cross_entropy(pred, gt, ignore_index=self.ignore_index)
gt_points = point_sample(
points_loss = F.cross_entropy(result["rend"], gt_points, ignore_index=self.ignore_index)
loss = seg_loss + points_loss
return dict(loss=loss)