小白一个,理解错误欢迎大佬指正。下面的流程按语义分割框架deeplabv3 + PointRend做的注释。deeplabv3 的主干网络是xception65。
传统语义分割网络,在进行一系列卷积池化后。会得到一定分辨率的featuremap图。这个featuremap图一般大小为原图的 1/8 1/16或者1/32 等等吧,其上的点就有了类别标签了,知道了某个像素归属于某类。然后通过一定的上采样方法将其恢复到原图大小,这样就得到原图的语义分割结果了,可以想象,上采样后的物体边缘会有不准确情况。这个PointRend就是要修正下边缘。将featuremap上的点按照一定规则做了个不稳定性排序,然后找出最不稳定的N个点(认为其归属不明,边界混乱)对其精修。可见,这个方法是在某种语义分割的结果之上做的工作。
具体代码为:points = sampling_points(out, x.shape[-1] // 16, self.k, self.beta)
具体代码为: coarse = point_sample(out, points, align_corners=False)
fine = point_sample(res2, points, align_corners=False)
c.将N个点的对应位置的特征粘合到一起。torch.cat函数实现 例如 C1的特征是[1, 19, 8096] C2的特征是[1, 1256 8096] 那结果就是[1, 275, 8096]大小呗。
具体代码为: eature_representation = torch.cat([coarse, fine], dim=1)
具体代码为: rend = self.mlp(feature_representation)
class PointHead(nn.Module):
def __init__(self, in_c=275, num_classes=19, k=3, beta=0.75):
self.mlp = nn.Conv1d(in_c, num_classes, 1)
self.k = k
self.beta = beta
def forward(self, x, res2, out):
1. Fine-grained features are interpolated from res2 for DeeplabV3
2. During training we sample as many points as there are on a stride 16 feature map of the input
3. To measure prediction uncertainty
we use the same strategy during training and inference: the difference between the most
confident and second most confident class probabilities.
if not self.training:
return self.inference(x, res2, out)
points = sampling_points(out, x.shape[-1] // 16, self.k, self.beta)#提取点的位置
coarse = point_sample(out, points, align_corners=False)#提C4特征位置 提取的是高级特征(深度深)
fine = point_sample(res2, points, align_corners=False)#提C1特征位置 提取的是低级级特征(深度浅)
feature_representation = torch.cat([coarse, fine], dim=1)#特征粘合
rend = self.mlp(feature_representation)#mlp预测识别 这些个点就被归属到不同类了
return {"rend": rend, "points": points}
def inference(self, x, res2, out):
During inference, subdivision uses N=8096
(i.e., the number of points in the stride 16 map of a 1024×2048 image)
num_points = 8096
#这块代码 输入的数据out是粗糙分类的结果,其是高层特征经过最终的21类的卷积得到的结果,可以看成是粗糙的语义分割结果,out 的shape 是类似[1, 21 , w, h ]形态 21 是类别数 w, h 是原图池化次后的大小,下面的代码就是不断对out上采样并且选其中的不稳定点做mlp预测,将预测结果替换out中的不稳定值。不断重复直到out尺寸与原图大小一致。
while out.shape[-1] != x.shape[-1]:#直到将小图out插值到与原图x大小一致while循环结束
out = F.interpolate(out, scale_factor=2, mode="bilinear", align_corners=True)#先将高级特征out 做插值 乘以2
points_idx, points = sampling_points(out, num_points, training=self.training)#在out 上提取不稳定点
coarse = point_sample(out, points, align_corners=False)#同训练部分 提取不稳定点特征 在高级特征上做
fine = point_sample(res2, points, align_corners=False)#同训练部分 提取不稳定点特征 在低级特征上做
feature_representation = torch.cat([coarse, fine], dim=1)#特征粘合
rend = self.mlp(feature_representation)#同训练部分 rend的size是 [1, 21, 8096] 21是类别数 8096是点个数
B, C, H, W = out.shape
points_idx = points_idx.unsqueeze(1).expand(-1, C, -1)
#这个函数的用法没弄太明白 但是功能不外乎就是将不确定点的新类别值去替换out中老类别的值
out = (out.reshape(B, C, -1)
.scatter_(2, points_idx, rend) #scatter_函数将rend中的数据根据points_idx索引填入out中
.view(B, C, H, W))
return {"fine": out}
python pointrend.py
from collections import OrderedDict
from torchvision.models._utils import IntermediateLayerGetter
from torchvision.models.utils import load_state_dict_from_url
from torchvision.models.segmentation._utils import _SimpleSegmentationModel
from torchvision.models.segmentation.deeplabv3 import DeepLabHead
from torchvision.models.segmentation.fcn import FCNHead
#from .resnet import resnet103, resnet53
from torchvision.models import resnet50, resnet101
from torchvision.models.resnet import ResNet, Bottleneck
import torch.nn as nn
class ResNetXX3(ResNet):
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None,
super().__init__(block, layers, num_classes, zero_init_residual,
groups, width_per_group, replace_stride_with_dilation,
self.conv1 = nn.Conv2d(3, 64, 3, 1, 1, bias=False)
nn.init.kaiming_normal_(self.conv1.weight, mode='fan_out', nonlinearity='relu')
def resnet53(pretrained=False, progress=True, **kwargs):
r"""ResNet-50 model from
`"Deep Residual Learning for Image Recognition" `_
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
return ResNetXX3(Bottleneck, [3, 4, 6, 3], **kwargs)
def resnet103(pretrained=False, progress=True, **kwargs):
r"""ResNet-101 model from
`"Deep Residual Learning for Image Recognition" `_
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
return ResNetXX3(Bottleneck, [3, 4, 23, 3], **kwargs)
class SmallDeepLab(_SimpleSegmentationModel):
def forward(self, input_):
result = self.backbone(input_)
result["coarse"] = self.classifier(result["out"])
return result
def deeplabv3(pretrained=False, resnet="res103", head_in_ch=2048, num_classes=21):
resnet = {
"res53": resnet53,
"res103": resnet103,
"res50": resnet50,
"res101": resnet101
net = SmallDeepLab(#IntermediateLayerGetter返回了resnet中的layer2和layer4,并封装成了新的名字'res2'和'out'
resnet(pretrained=False, replace_stride_with_dilation=[False, True, True]),
return_layers={'layer2': 'res2', 'layer4': 'out'}
classifier=DeepLabHead(head_in_ch, num_classes)
return net
if __name__ == "__main__":
import torch
x = torch.randn(3, 3, 512, 1024).cuda()
net = deeplabv3(False).cuda()
result = net(x)
for k, v in result.items():
print(k, v.shape)
import torch
import torch.nn.functional as F
def point_sample(input, point_coords, **kwargs):
From Detectron2, point_features.py#19
A wrapper around :function:`torch.nn.functional.grid_sample` to support 3D point_coords tensors.
Unlike :function:`torch.nn.functional.grid_sample` it assumes `point_coords` to lie inside
[0, 1] x [0, 1] square.
input (Tensor): A tensor of shape (N, C, H, W) that contains features map on a H x W grid.
point_coords (Tensor): A tensor of shape (N, P, 2) or (N, Hgrid, Wgrid, 2) that contains
[0, 1] x [0, 1] normalized point coordinates.
output (Tensor): A tensor of shape (N, C, P) or (N, C, Hgrid, Wgrid) that contains
features for points in `point_coords`. The features are obtained via bilinear
interplation from `input` the same way as :function:`torch.nn.functional.grid_sample`.
add_dim = False
if point_coords.dim() == 3:
add_dim = True
point_coords = point_coords.unsqueeze(2)
output = F.grid_sample(input, 2.0 * point_coords - 1.0, **kwargs)
if add_dim:
output = output.squeeze(3)
return output
def sampling_points(mask, N, k=3, beta=0.75, training=True):
Follows 3.1. Point Selection for Inference and Training
In Train:, `The sampling strategy selects N points on a feature map to train on.`
In Inference, `then selects the N most uncertain points`
mask(Tensor): [B, C, H, W]
N(int): `During training we sample as many points as there are on a stride 16 feature map of the input`
k(int): Over generation multiplier
beta(float): ratio of importance points
training(bool): flag
selected_point(Tensor) : flattened indexing points [B, num_points, 2]
assert mask.dim() == 4, "Dim must be N(Batch)CHW"
device = mask.device
B, _, H, W = mask.shape
mask, _ = mask.sort(1, descending=True)
if not training:
H_step, W_step = 1 / H, 1 / W
N = min(H * W, N)
uncertainty_map = -1 * (mask[:, 0] - mask[:, 1])
_, idx = uncertainty_map.view(B, -1).topk(N, dim=1)
points = torch.zeros(B, N, 2, dtype=torch.float, device=device)
points[:, :, 0] = W_step / 2.0 + (idx % W).to(torch.float) * W_step
points[:, :, 1] = H_step / 2.0 + (idx // W).to(torch.float) * H_step
return idx, points
# Official Comment : point_features.py#92
# It is crucial to calculate uncertanty based on the sampled prediction value for the points.
# Calculating uncertainties of the coarse predictions first and sampling them for points leads
# to worse results. To illustrate the difference: a sampled point between two coarse predictions
# with -1 and 1 logits has 0 logit prediction and therefore 0 uncertainty value, however, if one
# calculates uncertainties for the coarse predictions first (-1 and -1) and sampe it for the
# center point, they will get -1 unceratinty.
over_generation = torch.rand(B, k * N, 2, device=device)
over_generation_map = point_sample(mask, over_generation, align_corners=False)
uncertainty_map = -1 * (over_generation_map[:, 0] - over_generation_map[:, 1])
_, idx = uncertainty_map.topk(int(beta * N), -1)
shift = (k * N) * torch.arange(B, dtype=torch.long, device=device)
idx += shift[:, None]
importance = over_generation.view(-1, 2)[idx.view(-1), :].view(B, int(beta * N), 2)
coverage = torch.rand(B, N - int(beta * N), 2, device=device)
return torch.cat([importance, coverage], 1).to(device)
import torch
import torch.nn as nn
import torch.nn.functional as F
from sampling_points import sampling_points, point_sample
class PointHead(nn.Module):
def __init__(self, in_c=533, num_classes=21, k=3, beta=0.75):
self.mlp = nn.Conv1d(in_c, num_classes, 1)
self.k = k
self.beta = beta
def forward(self, x, res2, out):
1. Fine-grained features are interpolated from res2 for DeeplabV3
2. During training we sample as many points as there are on a stride 16 feature map of the input
3. To measure prediction uncertainty
we use the same strategy during training and inference: the difference between the most
confident and second most confident class probabilities.
self.training = False
if not self.training:
return self.inference(x, res2, out)
points = sampling_points(out, x.shape[-1] // 16, self.k, self.beta)
#print("points", points.shape) [3, 32, 2] 32 points
coarse = point_sample(out, points, align_corners=False)
fine = point_sample(res2, points, align_corners=False)
feature_representation = torch.cat([coarse, fine], dim=1)
print("feature_representation = ", feature_representation.shape)
rend = self.mlp(feature_representation)#input shape 533 * 32 output shape 21 * 32
return {"rend": rend, "points": points}
def inference(self, x, res2, out):
During inference, subdivision uses N=8096
(i.e., the number of points in the stride 16 map of a 1024×2048 image)
num_points = 8096
print("x = ", x.shape)
print(" res2 = ", res2.shape)
while out.shape[-1] != x.shape[-1]:
out = F.interpolate(out, scale_factor=2, mode="bilinear", align_corners=True)
print("out old = ", out.shape)
points_idx, points = sampling_points(out, num_points, training=self.training)
coarse = point_sample(out, points, align_corners=False)
fine = point_sample(res2, points, align_corners=False)
feature_representation = torch.cat([coarse, fine], dim=1)
rend = self.mlp(feature_representation)
B, C, H, W = out.shape
points_idx = points_idx.unsqueeze(1).expand(-1, C, -1)
out = (out.reshape(B, C, -1)
.scatter_(2, points_idx, rend)
.view(B, C, H, W))
print("out new = ", out.shape)
return {"fine": out}
class PointRend(nn.Module):
def __init__(self, backbone, head):
self.backbone = backbone
self.head = head
def forward(self, x):
result = self.backbone(x)
print("x = ", x.shape)
#print("result : %s" % result)
result.update(self.head(x, result["res2"], result["coarse"]))
return result
if __name__ == "__main__":
x = torch.randn(3, 3, 256, 512)
from deeplab import deeplabv3
net = PointRend(deeplabv3(False), PointHead())
#print("net = ", net)
out = net(x)
for k, v in out.items():
print(k, v.shape)