【语义分割系列】PointRend源码注释

小白一个,理解错误欢迎大佬指正。下面的流程按语义分割框架deeplabv3 + PointRend做的注释。deeplabv3 的主干网络是xception65

原理图:代码主要流程看完下面的介绍再返回头来看看这张图应该就更清晰了.不过这个原理图和代码对应存在点问题。

代码中的fine-grained是原图的1/4大小,不像下面这个是与原图大小一致的。后面的就都一样了

      【语义分割系列】PointRend源码注释_第1张图片

                                                                                                        图1

1.PointRend提出原因:

    传统语义分割网络,在进行一系列卷积池化后。会得到一定分辨率的featuremap图。这个featuremap图一般大小为原图的  1/8    1/16或者1/32 等等吧,其上的点就有了类别标签了,知道了某个像素归属于某类。然后通过一定的上采样方法将其恢复到原图大小,这样就得到原图的语义分割结果了,可以想象,上采样后的物体边缘会有不准确情况。这个PointRend就是要修正下边缘。将featuremap上的点按照一定规则做了个不稳定性排序,然后找出最不稳定的N个点(认为其归属不明,边界混乱)对其精修。可见,这个方法是在某种语义分割的结果之上做的工作

2.PointRend训练流程:

a.对featuremap上的点做不稳定排序,选取N个点出来。代码中N是8096。

具体代码为:points = sampling_points(out, x.shape[-1] // 16, self.k, self.beta)

b.在xception65的第一层上对应的N个点的特征提出来。

例如用的主干网络为xception65,那就以它为例。这个网络输出c1,c2,c3,c4。其中c1是较高分辨率下的featuremap(1/4),c4是最终的featuremap(1/16).将上面N个点在这两个图上的对应特征提出来。

具体代码为: coarse = point_sample(out, points, align_corners=False)
                         fine = point_sample(res2, points, align_corners=False)

c.将N个点的对应位置的特征粘合到一起。torch.cat函数实现 例如 C1的特征是[1, 19, 8096]  C2的特征是[1, 1256 8096] 那结果就是[1, 275, 8096]大小呗。

具体代码为:  eature_representation = torch.cat([coarse, fine], dim=1)

d.使用MLP进行细分预测。

具体代码为:  rend = self.mlp(feature_representation)

3.PointRend预测流程:

 与训练部分代码不同,在下面关键代码注释部分写了。

4.PointRend关键代码注释:

class PointHead(nn.Module):
    def __init__(self, in_c=275, num_classes=19, k=3, beta=0.75):
        super().__init__()
        self.mlp = nn.Conv1d(in_c, num_classes, 1)
        self.k = k
        self.beta = beta

    def forward(self, x, res2, out):
        """
        1. Fine-grained features are interpolated from res2 for DeeplabV3
        2. During training we sample as many points as there are on a stride 16 feature map of the input
        3. To measure prediction uncertainty
           we use the same strategy during training and inference: the difference between the most
           confident and second most confident class probabilities.
        """
        if not self.training:
            return self.inference(x, res2, out)

        points = sampling_points(out, x.shape[-1] // 16, self.k, self.beta)#提取点的位置

        coarse = point_sample(out, points, align_corners=False)#提C4特征位置 提取的是高级特征(深度深)
        fine = point_sample(res2, points, align_corners=False)#提C1特征位置  提取的是低级级特征(深度浅)

        feature_representation = torch.cat([coarse, fine], dim=1)#特征粘合

        rend = self.mlp(feature_representation)#mlp预测识别  这些个点就被归属到不同类了

        return {"rend": rend, "points": points}

    @torch.no_grad()
    def inference(self, x, res2, out):
        """
        During inference, subdivision uses N=8096
        (i.e., the number of points in the stride 16 map of a 1024×2048 image)
        """
        num_points = 8096
        #这块代码  输入的数据out是粗糙分类的结果,其是高层特征经过最终的21类的卷积得到的结果,可以看成是粗糙的语义分割结果,out 的shape 是类似[1, 21 , w, h ]形态  21 是类别数   w, h  是原图池化次后的大小,下面的代码就是不断对out上采样并且选其中的不稳定点做mlp预测,将预测结果替换out中的不稳定值。不断重复直到out尺寸与原图大小一致。
        while out.shape[-1] != x.shape[-1]:#直到将小图out插值到与原图x大小一致while循环结束
            out = F.interpolate(out, scale_factor=2, mode="bilinear", align_corners=True)#先将高级特征out 做插值  乘以2

            points_idx, points = sampling_points(out, num_points, training=self.training)#在out 上提取不稳定点

            coarse = point_sample(out, points, align_corners=False)#同训练部分  提取不稳定点特征 在高级特征上做
            fine = point_sample(res2, points, align_corners=False)#同训练部分  提取不稳定点特征  在低级特征上做

            feature_representation = torch.cat([coarse, fine], dim=1)#特征粘合

            rend = self.mlp(feature_representation)#同训练部分  rend的size是 [1, 21, 8096]    21是类别数  8096是点个数

            B, C, H, W = out.shape
            points_idx = points_idx.unsqueeze(1).expand(-1, C, -1)

            #这个函数的用法没弄太明白  但是功能不外乎就是将不确定点的新类别值去替换out中老类别的值
            out = (out.reshape(B, C, -1)  
                      .scatter_(2, points_idx, rend) #scatter_函数将rend中的数据根据points_idx索引填入out中
                      .view(B, C, H, W))

            
        return {"fine": out}

5.可运行的PointRend完整源码:

代码一共包含3个文件。

运行命令是:

python pointrend.py

1.这段代码放在deeplab.py文件中。

from collections import OrderedDict

from torchvision.models._utils import IntermediateLayerGetter
from torchvision.models.utils import load_state_dict_from_url
from torchvision.models.segmentation._utils import _SimpleSegmentationModel
from torchvision.models.segmentation.deeplabv3 import DeepLabHead
from torchvision.models.segmentation.fcn import FCNHead
#from .resnet import resnet103, resnet53
from torchvision.models import resnet50, resnet101

from torchvision.models.resnet import ResNet, Bottleneck
import torch.nn as nn


class ResNetXX3(ResNet):
    def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
                 groups=1, width_per_group=64, replace_stride_with_dilation=None,
                 norm_layer=None):
        super().__init__(block, layers, num_classes, zero_init_residual,
                         groups, width_per_group, replace_stride_with_dilation,
                         norm_layer)
        self.conv1 = nn.Conv2d(3, 64, 3, 1, 1, bias=False)
        nn.init.kaiming_normal_(self.conv1.weight, mode='fan_out', nonlinearity='relu')


def resnet53(pretrained=False, progress=True, **kwargs):
    r"""ResNet-50 model from
    `"Deep Residual Learning for Image Recognition" `_
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return ResNetXX3(Bottleneck, [3, 4, 6, 3], **kwargs)


def resnet103(pretrained=False, progress=True, **kwargs):
    r"""ResNet-101 model from
    `"Deep Residual Learning for Image Recognition" `_
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return ResNetXX3(Bottleneck, [3, 4, 23, 3], **kwargs)
    
    
    
    

class SmallDeepLab(_SimpleSegmentationModel):
    def forward(self, input_):
        result = self.backbone(input_)
        result["coarse"] = self.classifier(result["out"])
        return result


def deeplabv3(pretrained=False, resnet="res103", head_in_ch=2048, num_classes=21):
    resnet = {
        "res53":  resnet53,
        "res103": resnet103,
        "res50":  resnet50,
        "res101": resnet101
    }[resnet]

    net = SmallDeepLab(#IntermediateLayerGetter返回了resnet中的layer2和layer4,并封装成了新的名字'res2'和'out'
        backbone=IntermediateLayerGetter(
            resnet(pretrained=False, replace_stride_with_dilation=[False, True, True]),
            return_layers={'layer2': 'res2', 'layer4': 'out'}
        ),
        classifier=DeepLabHead(head_in_ch, num_classes)
        
    )
    return net


if __name__ == "__main__":
    import torch
    x = torch.randn(3, 3, 512, 1024).cuda()
    net = deeplabv3(False).cuda()
    result = net(x)
    for k, v in result.items():
        print(k, v.shape)

2.这段代码放在sampling_points.py中

import torch
import torch.nn.functional as F


def point_sample(input, point_coords, **kwargs):
    """
    From Detectron2, point_features.py#19

    A wrapper around :function:`torch.nn.functional.grid_sample` to support 3D point_coords tensors.
    Unlike :function:`torch.nn.functional.grid_sample` it assumes `point_coords` to lie inside
    [0, 1] x [0, 1] square.

    Args:
        input (Tensor): A tensor of shape (N, C, H, W) that contains features map on a H x W grid.
        point_coords (Tensor): A tensor of shape (N, P, 2) or (N, Hgrid, Wgrid, 2) that contains
        [0, 1] x [0, 1] normalized point coordinates.

    Returns:
        output (Tensor): A tensor of shape (N, C, P) or (N, C, Hgrid, Wgrid) that contains
            features for points in `point_coords`. The features are obtained via bilinear
            interplation from `input` the same way as :function:`torch.nn.functional.grid_sample`.
    """
    add_dim = False
    if point_coords.dim() == 3:
        add_dim = True
        point_coords = point_coords.unsqueeze(2)
    output = F.grid_sample(input, 2.0 * point_coords - 1.0, **kwargs)
    if add_dim:
        output = output.squeeze(3)
    return output


@torch.no_grad()
def sampling_points(mask, N, k=3, beta=0.75, training=True):
    """
    Follows 3.1. Point Selection for Inference and Training

    In Train:, `The sampling strategy selects N points on a feature map to train on.`

    In Inference, `then selects the N most uncertain points`

    Args:
        mask(Tensor): [B, C, H, W]
        N(int): `During training we sample as many points as there are on a stride 16 feature map of the input`
        k(int): Over generation multiplier
        beta(float): ratio of importance points
        training(bool): flag

    Return:
        selected_point(Tensor) : flattened indexing points [B, num_points, 2]
    """
    assert mask.dim() == 4, "Dim must be N(Batch)CHW"
    device = mask.device
    B, _, H, W = mask.shape
    mask, _ = mask.sort(1, descending=True)

    if not training:
        H_step, W_step = 1 / H, 1 / W
        N = min(H * W, N)
        uncertainty_map = -1 * (mask[:, 0] - mask[:, 1])
        _, idx = uncertainty_map.view(B, -1).topk(N, dim=1)

        points = torch.zeros(B, N, 2, dtype=torch.float, device=device)
        points[:, :, 0] = W_step / 2.0 + (idx  % W).to(torch.float) * W_step
        points[:, :, 1] = H_step / 2.0 + (idx // W).to(torch.float) * H_step
        return idx, points

    # Official Comment : point_features.py#92
    # It is crucial to calculate uncertanty based on the sampled prediction value for the points.
    # Calculating uncertainties of the coarse predictions first and sampling them for points leads
    # to worse results. To illustrate the difference: a sampled point between two coarse predictions
    # with -1 and 1 logits has 0 logit prediction and therefore 0 uncertainty value, however, if one
    # calculates uncertainties for the coarse predictions first (-1 and -1) and sampe it for the
    # center point, they will get -1 unceratinty.

    over_generation = torch.rand(B, k * N, 2, device=device)
    over_generation_map = point_sample(mask, over_generation, align_corners=False)

    uncertainty_map = -1 * (over_generation_map[:, 0] - over_generation_map[:, 1])
    _, idx = uncertainty_map.topk(int(beta * N), -1)

    shift = (k * N) * torch.arange(B, dtype=torch.long, device=device)

    idx += shift[:, None]

    importance = over_generation.view(-1, 2)[idx.view(-1), :].view(B, int(beta * N), 2)
    coverage = torch.rand(B, N - int(beta * N), 2, device=device)
    return torch.cat([importance, coverage], 1).to(device)

3.这段代码放在pointrend.py中


import torch
import torch.nn as nn
import torch.nn.functional as F

from sampling_points import sampling_points, point_sample


class PointHead(nn.Module):
    def __init__(self, in_c=533, num_classes=21, k=3, beta=0.75):
        super().__init__()
        self.mlp = nn.Conv1d(in_c, num_classes, 1)
        self.k = k
        self.beta = beta

    def forward(self, x, res2, out):
        """
        1. Fine-grained features are interpolated from res2 for DeeplabV3
        2. During training we sample as many points as there are on a stride 16 feature map of the input
        3. To measure prediction uncertainty
           we use the same strategy during training and inference: the difference between the most
           confident and second most confident class probabilities.
        """
        self.training = False
        if not self.training:
        	return self.inference(x, res2, out)

        points = sampling_points(out, x.shape[-1] // 16, self.k, self.beta)
        #print("points", points.shape)  [3, 32, 2]   32 points

        coarse = point_sample(out, points, align_corners=False)
        fine = point_sample(res2, points, align_corners=False)
        
       
        feature_representation = torch.cat([coarse, fine], dim=1)
        print("feature_representation = ", feature_representation.shape)
        rend = self.mlp(feature_representation)#input shape  533 * 32  output shape 21 * 32

        return {"rend": rend, "points": points}

    @torch.no_grad()
    def inference(self, x, res2, out):
        """
        During inference, subdivision uses N=8096
        (i.e., the number of points in the stride 16 map of a 1024×2048 image)
        """
        num_points = 8096
        
        print("x = ", x.shape)
        print(" res2 = ", res2.shape)

        while out.shape[-1] != x.shape[-1]:
            out = F.interpolate(out, scale_factor=2, mode="bilinear", align_corners=True)
            print("out old = ", out.shape)
            points_idx, points = sampling_points(out, num_points, training=self.training)

            coarse = point_sample(out, points, align_corners=False)
            fine = point_sample(res2, points, align_corners=False)

            feature_representation = torch.cat([coarse, fine], dim=1)

            rend = self.mlp(feature_representation)

            B, C, H, W = out.shape
            points_idx = points_idx.unsqueeze(1).expand(-1, C, -1)
            out = (out.reshape(B, C, -1)
                      .scatter_(2, points_idx, rend)
                      .view(B, C, H, W))
            print("out new = ", out.shape)

        return {"fine": out}


class PointRend(nn.Module):
    def __init__(self, backbone, head):
        super().__init__()
        self.backbone = backbone
        self.head = head

    def forward(self, x):
        result = self.backbone(x)
        print("x = ", x.shape)

        #print("result : %s" %  result)
        result.update(self.head(x, result["res2"], result["coarse"]))
        return result


if __name__ == "__main__":
    x = torch.randn(3, 3, 256, 512)
    from deeplab import deeplabv3
    print("6666666666666")
    net = PointRend(deeplabv3(False), PointHead())
    
    #print("net = ", net)
    
    out = net(x)
   
    for k, v in out.items():
    	print("=========")
    	print(k, v.shape)

你可能感兴趣的:(【语义分割系列】PointRend源码注释)