[SOLO ]SOLO: Segmenting Objects by Locations代码解读笔记(ECCV. 2020)

Segmenting Objects by Locations

如果对你帮助的话,希望给我个赞~

文章目录

  • SOLO head网络结构
  • 损失函数
  • 正样本的选取
  • 1. SOLO/mmdect/models/detectors/single_stage_ins.py
  • 2. SOLO/mmdet/models/anchor_heads/solo_head.py
  • 3. SOLO/mmdetect/core/post_processing/matrix_nms.py
  • 4. SOLO/configs/solo/solo_r50_fpn_8gpu_1x.py
  • 5. SOLO/mmdet/models/anchor_heads/_ _init_ _.py

SOLO head网络结构

[SOLO ]SOLO: Segmenting Objects by Locations代码解读笔记(ECCV. 2020)_第1张图片

损失函数

[SOLO ]SOLO: Segmenting Objects by Locations代码解读笔记(ECCV. 2020)_第2张图片

正样本的选取

论文原话:
[SOLO ]SOLO: Segmenting Objects by Locations代码解读笔记(ECCV. 2020)_第3张图片
起初看完后,并不是很理解。但我认为看完代码后,是我对于正样本选取的一个新的领悟与体会,如何与全卷积网络结合,很好的一个实践与理论相结合,通过代码来反思与加深与论文思想的理解。
其中FCOS、polarmask也是采用了一种中心采样的结构。这些文中都有提到,全卷积网络可以采用gt_box内的所有点为positive example,但是这样子计算量肯定很大,并且其他靠近bbox的点回归的效果肯定是很差的,因此围绕质心(solo以质心为中心)进行正样本采样是非常合理的。
引用一篇特别棒的转载博客里的图片:博客链接
如图所示,在原图中,蓝色框表示图片等分的格子,这里设置分为5X5个格子。绿色框为目标物体的gt box,黄色框表示缩小到0.2倍数的box,红色框表示负责预测该实例的格子。
下方黑白图为mask分支的target可视化,为了便于显示,这里对不同通道进行了拼接。左边的第一幅图,图中有一个实例,其gt box缩小到0.2倍占据两个格子,因此这两个格子负责预测该实例。
下方的mask分支,只有两个FPN的输出匹配到了该实例,因此在红色格子对应的channel负责预测该实例的mask。第二幅图,图中分布大小不同的实例,可见在FPN输出的mask分支上,从小到大负责不同尺度的实例。
[SOLO ]SOLO: Segmenting Objects by Locations代码解读笔记(ECCV. 2020)_第4张图片

下图是原图的,也很清晰的表达了FPN如何根据不同的gt_areas 以及 实例所处在的网格位置放入对于的channel上预测。首先根据gt_areas将不同的gt放入不同的FPN层。然后再相同层中,如果有多个实例,就会根据设置好的网格,按照某个GT的质心的0.2 * gt_areas(这时候的gt_areas缩小到对应的FPN层输出的feature map的大小)的大小缩放。
[SOLO ]SOLO: Segmenting Objects by Locations代码解读笔记(ECCV. 2020)_第5张图片

1. SOLO/mmdect/models/detectors/single_stage_ins.py

single_stage_ins中实现了backbone(resnet),neck(fpn)以及head(solo_head)的连接以及forward。

import torch.nn as nn

from mmdet.core import bbox2result
from .. import builder
from ..registry import DETECTORS
from .base import BaseDetector
import pdb

@DETECTORS.register_module
class SingleStageInsDetector(BaseDetector):

    def __init__(self,
                 backbone,
                 neck=None,
                 bbox_head=None,
                 mask_feat_head=None,
                 train_cfg=None,
                 test_cfg=None,
                 pretrained=None):
        super(SingleStageInsDetector, self).__init__()
        self.backbone = builder.build_backbone(backbone) # 1.build_backbone --> resnet
        if neck is not None:
            self.neck = builder.build_neck(neck) # 2.build_neck --> fpn
        if mask_feat_head is not None:
            self.mask_feat_head = builder.build_head(mask_feat_head)
        #pdb.set_trace()

        self.bbox_head = builder.build_head(bbox_head) # 3.build_head --> solo head

        self.train_cfg = train_cfg
        self.test_cfg = test_cfg
        self.init_weights(pretrained=pretrained) # 'torchvision://resnet50'

    def init_weights(self, pretrained=None):
        super(SingleStageInsDetector, self).init_weights(pretrained)
        self.backbone.init_weights(pretrained=pretrained)
        if self.with_neck:
            if isinstance(self.neck, nn.Sequential):
                for m in self.neck:
                    m.init_weights()
            else:
                self.neck.init_weights()
        if self.with_mask_feat_head:
            if isinstance(self.mask_feat_head, nn.Sequential):
                for m in self.mask_feat_head:
                    m.init_weights()
            else:
                self.mask_feat_head.init_weights()
        #pdb.set_trace()
        self.bbox_head.init_weights()

    # forward提取 backbone 和 neck的特征 
    def extract_feat(self, img):
        x = self.backbone(img) # resnet forward        
        if self.with_neck:
            x = self.neck(x) # fpn forward
        return x
    '''
    after neck feature map:x
        (Pdb) x[0].shape
        torch.Size([2, 256, 200, 304])
        (Pdb) x[1].shape
        torch.Size([2, 256, 100, 152])
        (Pdb) x[2].shape
        torch.Size([2, 256, 50, 76])
        (Pdb) x[3].shape
        torch.Size([2, 256, 25, 38])
        (Pdb) x[4].shape
        torch.Size([2, 256, 13, 19])

    '''
    def forward_dummy(self, img):
        x = self.extract_feat(img)
        outs = self.bbox_head(x)
        return outs

    def forward_train(self,
                      img,
                      img_metas,
                      gt_bboxes,
                      gt_labels,
                      gt_bboxes_ignore=None,
                      gt_masks=None):
        # 1. img 
            # eg. [torch.Size([2, 3, 800, 1216]) represents the max size of h and w in the img batch_size

        # 2. img_metas
            # eg.
            #[
            # {'filename': 'data/coco2017/train2017/000000559012.jpg', 
            #   'ori_shape': (508, 640, 3), 
            #   'img_shape': (800, 1008, 3), 
            #   'pad_shape': (800, 1216, 3), 
            #   'scale_factor': 1.8823529411764706, 
            #   'flip': False, 
            #   'img_norm_cfg': {'mean': array([123.675, 116.28 , 103.53 ], dtype=float32), 
            #   'std': array([58.395, 57.12 , 57.375], dtype=float32), 
            #   'to_rgb': True}}, 
            #
            # {'filename': 'data/coco2017/train2017/000000532426.jpg', 
            #   'ori_shape': (333, 640, 3), 'img_shape': (753, 1333, 3), 
            #   'pad_shape': (800, 1088, 3), 'scale_factor': 2.4024024024024024,
            #   'flip': False, 
            #   'img_norm_cfg': {'mean': array([123.675, 116.28 , 103.53 ], dtype=float32), 
            #   'std': array([58.395, 57.12 , 57.375], dtype=float32), 
            #   'to_rgb': True}}
            # ]
  
        # 3. gt_bboxes
            # eg.
            # gt_bboxes represents  'bbox' of coco datasets
            # type(gt_bboxes) --> list 
            # len(gt_bboxes) --> batch_size(ie. img per gpu) eg. 2
            # type(gt_bboxes[idx]) --> tensor
            # gt_bboxes[idx].size() --> [instances, 4]  '4' represents [x1, y1, x2, y2]
            # [6, 4] [9, 4]


        # 4. gt_labels
            # eg.
            # gt_labels represents 'category_id' of coco datasets
            # type(gt_labels) --> list 
            # len(gt_labels) --> batch_size(img per gpu) eg. 2
            # type(gt_labels[idx]) --> tensor
            # gt_labels[idx].size() --> instances eg. how many categories  gt_bboxes[7 or 13, 4] --> gt_labels[7 or 13]
            # 6 , 9

        # 5. gt_masks
            # eg.
            # type(gt_masks) --> list
            # len(gt_masks) --> batch_size(img per gpu) eg. 2
            # type(gt_bboxes[idx]) --> list
            # (6, 800, 1216)  (9, 800, 1088) represents (instances of pad_shape, w, h)


        x = self.extract_feat(img) #    forward backbone and  fpn
        # solo_head forward
        outs = self.bbox_head(x) # forward solo_head
        # outs eg. 各五层
        # 1.ins_pred:
        # outs[0][0].size() --> torch.Size([2, 1600, 200, 336])
        # outs[0][1].size() --> torch.Size([2, 1296, 200, 336]) 
        # outs[0][2].size() --> torch.Size([2, 1024, 100, 168])
        # outs[0][3].size() --> torch.Size([2, 256, 50, 84])
        # outs[0][4].size() --> torch.Size([2, 144, 50, 84])
        # 

        # 2.cate_pred:
        # outs[1][0].size() --> torch.Size([2, 80, 40, 40])
        # outs[1][1].size() --> torch.Size([2, 80, 36, 36])
        # outs[1][2].size() --> torch.Size([2, 80, 24, 24])
        # outs[1][3].size() --> torch.Size([2, 80, 24, 24])
        # outs[1][4].size() --> torch.Size([2, 80, 12, 12])
        # 

        if self.with_mask_feat_head:
            mask_feat_pred = self.mask_feat_head(
                x[self.mask_feat_head.
                  start_level:self.mask_feat_head.end_level + 1])
            loss_inputs = outs + (mask_feat_pred, gt_bboxes, gt_labels, gt_masks, img_metas, self.train_cfg)
        else:
            loss_inputs = outs + (gt_bboxes, gt_labels, gt_masks, img_metas, self.train_cfg) 
            # tuple len(outs) = 2  len(loss_inputs) = 7

        # compute SOLO loss
        losses = self.bbox_head.loss(
            *loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
        return losses

    def simple_test(self, img, img_meta, rescale=False):
        x = self.extract_feat(img)
        outs = self.bbox_head(x, eval=True) # when testing , eval = True rescale=True
        if self.with_mask_feat_head: # False
            mask_feat_pred = self.mask_feat_head(
                x[self.mask_feat_head.
                  start_level:self.mask_feat_head.end_level + 1])
            seg_inputs = outs + (mask_feat_pred, img_meta, self.test_cfg, rescale)
        else:
            seg_inputs = outs + (img_meta, self.test_cfg, rescale) # forward backbone fpn and solo_head 
        seg_result = self.bbox_head.get_seg(*seg_inputs) # get_seg()
        return seg_result  

    def aug_test(self, imgs, img_metas, rescale=False):
        raise NotImplementedError

2. SOLO/mmdet/models/anchor_heads/solo_head.py

注:一次输入的数据打印在最下方。

import mmcv
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import normal_init
from mmdet.ops import DeformConv, roi_align
from mmdet.core import multi_apply, bbox2roi, matrix_nms
from ..builder import build_loss
from ..registry import HEADS
from ..utils import bias_init_with_prob, ConvModule
import pdb
import math
INF = 1e8

from scipy import ndimage

def points_nms(heat, kernel=2):
    # kernel must be 2
    hmax = nn.functional.max_pool2d(
        heat, (kernel, kernel), stride=1, padding=1)
    keep = (hmax[:, :, :-1, :-1] == heat).float() # 在tensor相等(a==b) 是返回一个bool类型的矩阵,T or F; 如果加上float(),则返回1 or 0。 可以使用(hmax[:, :, :-1, :-1] == heat).bool()修正回去。
    return heat * keep # 通过max_pool2d操作后, 返回一个 2*2 中只有一个值非0

def dice_loss(input, target):
    input = input.contiguous().view(input.size()[0], -1) # [instances , w * h]
    target = target.contiguous().view(target.size()[0], -1).float() # [instances , w * h]

    a = torch.sum(input * target, 1)
    b = torch.sum(input * input, 1) + 0.001
    c = torch.sum(target * target, 1) + 0.001
    e = (2 * a) / (b + c)
    print('dice_loss:', 1-e)
    #pdb.set_trace() # [24]
    return 1-e

@HEADS.register_module
class SOLOHead(nn.Module):

    def __init__(self,
                 num_classes,
                 in_channels,
                 seg_feat_channels=256,
                 stacked_convs=4,
                 strides=(4, 8, 16, 32, 64),
                 base_edge_list=(16, 32, 64, 128, 256),
                 scale_ranges=((8, 32), (16, 64), (32, 128), (64, 256), (128, 512)),
                 sigma=0.4,
                 num_grids=None,
                 cate_down_pos=0,
                 with_deform=False,
                 loss_ins=None,
                 loss_cate=None,
                 conv_cfg=None,
                 norm_cfg=None):
        super(SOLOHead, self).__init__()
        self.num_classes = num_classes # 81
        self.seg_num_grids = num_grids # [40, 36, 24, 16, 12]
        self.cate_out_channels = self.num_classes - 1 # 80
        self.in_channels = in_channels #256
        self.seg_feat_channels = seg_feat_channels # 256
        self.stacked_convs = stacked_convs # 7
        self.strides = strides # [8, 8, 16, 32, 32]
        self.sigma = sigma # 0.2
        self.cate_down_pos = cate_down_pos # 0
        self.base_edge_list = base_edge_list # (16, 32, 64, 128, 256)
        self.scale_ranges = scale_ranges # ((1, 96), (48, 192), (96, 384), (192, 768), (384, 2048))
        self.with_deform = with_deform #False
        #loss_cate: {'type': 'FocalLoss', 'use_sigmoid': True, 'gamma': 2.0, 'alpha': 0.25, 'loss_weight': 1.0}

        self.loss_cate = build_loss(loss_cate) # FocalLoss() 
        self.ins_loss_weight = loss_ins['loss_weight'] # 3
        self.conv_cfg = conv_cfg
        self.norm_cfg = norm_cfg
        self._init_layers()
        #pdb.set_trace()

    # init  ins_convs, cate_convs, solo_ins_list, solo_cate
    def _init_layers(self):
        norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)
        self.ins_convs = nn.ModuleList()
        self.cate_convs = nn.ModuleList()
        for i in range(self.stacked_convs):
            # coorconv要加x y 2维
            chn = self.in_channels + 2 if i == 0 else self.seg_feat_channels
            self.ins_convs.append(
                ConvModule(
                    chn,
                    self.seg_feat_channels,
                    3,
                    stride=1,
                    padding=1,
                    norm_cfg=norm_cfg,
                    bias=norm_cfg is None))

            chn = self.in_channels if i == 0 else self.seg_feat_channels
            self.cate_convs.append(
                ConvModule(
                    chn,
                    self.seg_feat_channels,
                    3,
                    stride=1,
                    padding=1,
                    norm_cfg=norm_cfg,
                    bias=norm_cfg is None))

        self.solo_ins_list = nn.ModuleList()

        # 修改 [h, w, 256] --> [h, w, min(h/s, w/s)^2]   
        self.solo_sa_module = nn.ModuleList()


        # [h, w , 256] ---> [h, w, s*s]

        # 修改
        '''
        for seg_num_grid in self.seg_num_grids:
            self.solo_ins_list.append(
                nn.Conv2d(
                    self.seg_feat_channels, seg_num_grid**2, 1))
        '''
        for seg_num_grid in self.seg_num_grids:
            self.solo_ins_list.append(
                nn.Conv2d(
                seg_num_grid**2, seg_num_grid**2, 1))
        # [h, w, 256] --> [h, w, s]
        self.solo_cate = nn.Conv2d(
            self.seg_feat_channels, self.cate_out_channels, 3, padding=1)
        #pdb.set_trace()
    #初始化权重
    def init_weights(self):
        for m in self.ins_convs:
            normal_init(m.conv, std=0.01)
        for m in self.cate_convs:
            normal_init(m.conv, std=0.01)
        bias_ins = bias_init_with_prob(0.01) # bias_ins
        for m in self.solo_ins_list: 
            normal_init(m, std=0.01, bias=bias_ins)
        bias_cate = bias_init_with_prob(0.01) # -4.59511985013459
        normal_init(self.solo_cate, std=0.01, bias=bias_cate)
        #pdb.set_trace()
    
    def forward(self, feats, eval=False):
        new_feats = self.split_feats(feats) # 先对feats[0] 以及 feats[4]进行插值 进行缩放
        # feats:
        # (Pdb) feats[0].size()
        # torch.Size([2, 256, 200, 304])  --->  new_feats[0] [2, 256, 100, 152] 缩小
        # (Pdb) feats[1].size()
        # torch.Size([2, 256, 100, 152])
        # (Pdb) feats[3].size()
        # torch.Size([2, 256, 25, 38])
        # (Pdb) feats[4].size()
        # torch.Size([2, 256, 13, 19]) --->  new_feats[4] [2, 256, 25, 38] 放大


        featmap_sizes = [featmap.size()[-2:] for featmap in new_feats] # h, w
         
        # featmap_sizes = [
        #   torch.Size([100, 152]), 
        #   torch.Size([100, 152]), 
        #   torch.Size([50, 76]), 
        #   torch.Size([25, 38]), 
        #   torch.Size([25, 38]
        # )]


        upsampled_size = (featmap_sizes[0][0] * 2, featmap_sizes[0][1] * 2) # upsampled_size表示原来的最大的fpn层上的 feature map的siz: eg. [320, 200]
        ins_pred, cate_pred = multi_apply(self.forward_single, new_feats, 
                                          list(range(len(self.seg_num_grids))),
                                          eval=eval, upsampled_size=upsampled_size)

        return ins_pred, cate_pred

    def split_feats(self, feats):
        #len(feats) = 5 (tuple)

        #pdb.set_trace()
        # 缩小的插值 scale_factor=0.5
        # {'P2': 8, 'P3': 8, 'P4': 16, 'P5': 32, 'P6': 32} ---> 可以推出这次输入的图片 [, ] --> fpn缩放
        return (F.interpolate(feats[0], scale_factor=0.5, mode='bilinear'), # torch.Size([2, 256, 160, 100])
                feats[1],  # torch.Size([2, 256, 160, 100])
                feats[2],  # torch.Size([2, 256, 80, 50])
                feats[3],  # torch.Size([2, 256, 40, 25])
                F.interpolate(feats[4], size=feats[3].shape[-2:], mode='bilinear'))# torch.Size([2, 256, 40, 25])


    def forward_single(self, x, idx, eval=False, upsampled_size=None):
        # 执行5次 对应FPN的5层 分别构造head
        # x = torch.Size([2, 256, 160, 100]) 
        # idx = 0
        # upsampled_size = (320, 200)
        
        #pdb.set_trace()
        ins_feat = x
        device = ins_feat.device
        print(device)
        cate_feat = x
        # ins branch
        # concat CoordConv
        x_range = torch.linspace(-1, 1, ins_feat.shape[-1], device=ins_feat.device)
        y_range = torch.linspace(-1, 1, ins_feat.shape[-2], device=ins_feat.device)
        y, x = torch.meshgrid(y_range, x_range)
        y = y.expand([ins_feat.shape[0], 1, -1, -1]) # N, 1, h/strides, w/strides
        x = x.expand([ins_feat.shape[0], 1, -1, -1]) # N, 1, h/strides, w/strides
        coord_feat = torch.cat([x, y], 1) # [N, 2, w, h]

        # channels: 256 --> 258 [N, 256, w, h] --> [N, 258, w, h]
        ins_feat = torch.cat([ins_feat, coord_feat], 1)
        # in_convs 7个conv forward
        for i, ins_layer in enumerate(self.ins_convs):
            ins_feat = ins_layer(ins_feat)
        #pdb.set_trace()
        # 第一次修改
        
        sa_feat = []
        
        # [152, 100]  --> [160, 120]
        sa_h = math.ceil(ins_feat.size()[2] / self.seg_num_grids[idx])    
        #if (ins_feat.size()[2] % self.seg_num_grids[idx]) != 0:
        #   sa_h = sa_h + 1
        sa_w = math.ceil(ins_feat.size()[3] / self.seg_num_grids[idx])
        #if (ins_feat.size()[3] % self.seg_num_grids[idx]) != 0:
            #sa_w = sa_w + 1
        
        # interpolate
        # 插值后: ins_feat [2, 256, 160, 120]
        ins_feat = F.interpolate(ins_feat, size=(self.seg_num_grids[idx] * sa_h, self.seg_num_grids[idx] * sa_w), mode='bilinear') 
        # ins_sa_feat [2, 40*40, 160, 120]
        #ins_sa_feat = torch.zeros(ins_feat.size()[0],  self.seg_num_grids[idx] *  self.seg_num_grids[idx], ins_feat.size()[2], ins_feat.size()[3],device=device)
        seg_num_grids = self.seg_num_grids[idx]

        
        abc = []
        for i in range(seg_num_grids):
            for j in range(seg_num_grids):
                weight = ins_feat[:, :, i * sa_h : (i + 1) * sa_h, j * sa_w : (j + 1) * sa_w].repeat(1, 1, seg_num_grids, seg_num_grids)
                abc.append((weight * ins_feat).sum(1))
        ins_pred = torch.stack(abc, dim=1)
        #print(ins_pred.shape)
        
        '''
        基于boss方法的改进,此部分可以直接跳过~
        # 第一次修改
        速度太慢
        for i in range(seg_num_grids * seg_num_grids):
            grid_in_row = i % seg_num_grids
            row = i // seg_num_grids 
            sa = ins_feat[:, :, row*sa_h : row*sa_h + sa_h, grid_in_row*sa_w : grid_in_row*sa_w + sa_w].cuda()
            for j in range(seg_num_grids):
                for k in range(seg_num_grids):
                    ins_sa_feat[:, i, j*sa_h : j*sa_h + sa_h, k*sa_w : k*sa_w + sa_w] = (sa * ins_feat[:, :, j*sa_h : j*sa_h + sa_h, k*sa_w : k*sa_w + sa_w]).sum(dim = 1)             
        ins_sa_feat = ins_sa_feat.cuda()
        '''

        '''
        # 第二次修改
        # --------------------------------------------------------------------------------------------------------------------#
        # 1. 分成  sa_h * sa_w 个 seg_num_grids * seg_num_grids的mask特征图 
        # --------------------------------------------------------------------------------------------------------------------#
        mask_list =[]
        for i in range(sa_h):
            for j in range(sa_w):
                mask_list.append(ins_feat[:, :, i::sa_h, j::sa_w]) # mask_list[i].size() = [n, 256, seg_num_grids, seg_num_grids]

		
        #print(len(mask_list)) # len = sa_h * sa_w      
        #pdb.set_trace()

        # --------------------------------------------------------------------------------------------------------------------#
        # 2. sa_h * sa_w 的self-attention
        # --------------------------------------------------------------------------------------------------------------------#
		all_sa_feat = []
        per_sa_feat = []
        for i in range(sa_h * sa_w):
            ori_n = mask_list[i].size()[0]
            ori_c = mask_list[i].size()[1]
            n_c_hw = mask_list[i].reshape(ori_n, ori_c, -1) # [n, c, hw]
            #tmp_n_c_hw = n_c_hw.clone()
            n_c_hw_T = n_c_hw.permute(0, 2, 1) #[n, hw, c]
            tmp = torch.matmul(n_c_hw_T, n_c_hw) # [n, hw, c] x [n, c, hw] == [n, hw, hw]
            stack_sa_feat = tmp.reshape(ori_n, seg_num_grids * seg_num_grids, seg_num_grids, -1) # [n, s*s, s, s] 
            all_sa_feat.append(stack_sa_feat)

           
        # --------------------------------------------------------------------------------------------------------------------#
        # 3. 将同一行的seg_num_grids个元素矩阵先拼接 eg: xxxxyyyyzzzzccccc  --> xyzc xyzc xyzc
        # --------------------------------------------------------------------------------------------------------------------#

        cat_all_row_feat = []
        for i in range(0, sa_w * sa_h, sa_w):
            cat_row_feat = torch.cat([feat for feat in all_sa_feat[i : i + sa_w]], dim = 3)
            cat_all_row_feat.append(cat_row_feat) 
        #print(len(cat_all_row_feat))
        #pdb.set_trace()

        # --------------------------------------------------------------------------------------------------------------------#
        # 4. 先交换cat_all_row_feat中的每一列
        # --------------------------------------------------------------------------------------------------------------------#

        all_new_row_feat_list = [] #交换好后的4个tensor的新行 xyxy abab cdcd fgfg
        for i in range(0, len(cat_all_row_feat)):
            per_new_row_feat_list = [] # eg. xyxy or abab 
            for j in range(0, seg_num_grids):
                per_row_feat = cat_all_row_feat[i][:, :, :, j::seg_num_grids] # Tensor
                per_new_row_feat_list.append(per_row_feat)
            all_new_row_feat_list.append(torch.cat(per_new_row_feat_list, dim = 3)) # 交换好后
        #print('len(all_new_row_feat_list):', len(all_new_row_feat_list))
        #pdb.set_trace()

        # --------------------------------------------------------------------------------------------------------------------#
        # 5. 在此基础上继续在列上拼接   
        # --------------------------------------------------------------------------------------------------------------------#

        #for feat in all_new_row_feat_list:
            #print(feat.size())
        cat_all_col_feat = torch.cat([feat for feat in all_new_row_feat_list], dim = 2)
        #print('cat_all_col_feat.size():', cat_all_col_feat.size())
        #pdb.set_trace()

        # --------------------------------------------------------------------------------------------------------------------#
        # 6. 交换行
        # --------------------------------------------------------------------------------------------------------------------#

        per_new_col_feat_list = [] #交换好后的4个tensor的新行 xyxy abab cdcd fgfg
        for i in range(0, seg_num_grids):
            # eg. xyxy 
            #     abab                   
            per_col_feat = cat_all_col_feat[:, :, i::seg_num_grids, :] # Tensor
            per_new_col_feat_list.append(per_col_feat)
        all_new_col_feat = torch.cat(per_new_col_feat_list, dim = 2) # 交换好后
        ins_sa_feat = all_new_col_feat.to(device)
        #print('ins_sa_feat.size(): ', ins_sa_feat.size())
        #print(ins_sa_feat)
        #pdb.set_trace()
        '''
        # --------------------------------------------------------------------------------------------------------------------#
        # 修改截止
        # --------------------------------------------------------------------------------------------------------------------#

        # w x h x 256 --> 2w x 2h x 256
        #ins_feat = F.interpolate(ins_feat, scale_factor=2, mode='bilinear')
        ins_pred = F.interpolate(ins_pred, scale_factor=2, mode='bilinear')
        # eg. torch.Size([2, 1600 or 1296 or 576 or 256 or 144, 2H/strides, 2W/strides])

        # 新的修改
        ins_pred = self.solo_ins_list[idx](ins_pred) #  [N, 256, 2w, 2h] --> [N, S*S, 2w, 2h]  eg. torch.Size([2, 1600, 200, 304])

        # cate branch
        for i, cate_layer in enumerate(self.cate_convs):
            if i == self.cate_down_pos: # when i == 0
                seg_num_grid = self.seg_num_grids[idx] # [40, 36, 24, 16, 12]
                cate_feat = F.interpolate(cate_feat, size=seg_num_grid, mode='bilinear') # 缩放
            cate_feat = cate_layer(cate_feat)

        # channels: 256 --> 80
        cate_pred = self.solo_cate(cate_feat)
        if eval:
            ins_pred = F.interpolate(ins_pred.sigmoid(), size=upsampled_size, mode='bilinear') # 注意:把5个fpn层全部插值成同一个尺寸!根据upsampled_size, eval时放大到原图的1/4 eg. [1, 1600, 200, 304]
            cate_pred = points_nms(cate_pred.sigmoid(), kernel=2).permute(0, 2, 3, 1) # [N, h, w, c] eg. [1, 40, 40, 80]

        # 返回 分类和实例的最后一层结果。
        return ins_pred, cate_pred

    def loss(self,
             ins_preds,
             cate_preds,
             gt_bbox_list,
             gt_label_list,
             gt_mask_list,
             img_metas,
             cfg,
             gt_bboxes_ignore=None):
             
        featmap_sizes = [featmap.size()[-2:] for featmap in
                         ins_preds]
        ins_label_list, cate_label_list, ins_ind_label_list = multi_apply(
            self.solo_target_single,
            gt_bbox_list,
            gt_label_list,
            gt_mask_list,
            featmap_sizes=featmap_sizes)


        #test
        ins_labels = []
        temp_2 = []
        #ins_labels_2  =[]      
        # 循环 5次
        # ins_labels_level :
        # eg.  ins_labels_level[0].size() torch.Size([1296, 200, 272]) 
        #      ins_labels_level[1].size() torch.Size([1296, 200, 272])

        for ins_labels_level, ins_ind_labels_level in zip(zip(*ins_label_list),zip(*ins_ind_label_list)):
            temp = []
            #pdb.set_trace()
            for ins_labels_level_img, ins_ind_labels_level_img in zip(ins_labels_level, ins_ind_labels_level):
                temp.append(ins_labels_level_img[ins_ind_labels_level_img, ...]) # [instances, 200, 304]
                #pdb.set_trace()
            temp_2 = torch.cat(temp, 0) # batch_size的每个图片的每一层
            ins_labels.append(temp_2)

        # ins
        '''
        # zip() 与 zip(*)相反
        ins_labels = [torch.cat([ins_labels_level_img[ins_ind_labels_level_img, ...]
                                 for ins_labels_level_img, ins_ind_labels_level_img in
                                 zip(ins_labels_level, ins_ind_labels_level)], 0)
                      for ins_labels_level, ins_ind_labels_level in zip(zip(*ins_label_list), zip(*ins_ind_label_list))] # len(ins_label_list) = batchsize
        '''

        '''
        temp_2 = [] 
        for ins_preds_level, ins_ind_labels_level in zip(ins_preds, zip(*ins_ind_label_list)):
            temp = []
            for ins_preds_level_img, ins_ind_labels_level_img in zip(ins_preds_level, ins_ind_labels_level):
                temp.append(ins_preds_level_img[ins_ind_labels_level_img, ...])
            temp_2 = torch.cat(temp, 0)
            ins_preds.append(temp_2)
        pdb.set_trace()
    
        '''
        ins_preds = [torch.cat([ins_preds_level_img[ins_ind_labels_level_img, ...]
                                for ins_preds_level_img, ins_ind_labels_level_img in
                                zip(ins_preds_level, ins_ind_labels_level)], 0)
                     for ins_preds_level, ins_ind_labels_level in zip(ins_preds, zip(*ins_ind_label_list))]
        
        #pdb.set_trace()

        ins_ind_labels = []    
        
        temp_2 = []
        for ins_ind_labels_level in zip(*ins_ind_label_list):
            temp = [] 
            for ins_ind_labels_level_img in ins_ind_labels_level:
                temp.append(ins_ind_labels_level_img.flatten())
            temp_2 = torch.cat(temp)
            ins_ind_labels.append(temp_2)
        #pdb.set_trace()

        '''
        ins_ind_labels = [
            torch.cat([ins_ind_labels_level_img.flatten()
                       for ins_ind_labels_level_img in ins_ind_labels_level])
            for ins_ind_labels_level in zip(*ins_ind_label_list)
        ]
        '''

        flatten_ins_ind_labels = torch.cat(ins_ind_labels) # 3872 * batch_size

        num_ins = flatten_ins_ind_labels.sum() # 计算有多少正样本 相当于把元素是True的加起来
        #pdb.set_trace()
        
        # dice loss
        loss_ins = []
        # 对于ins 使用 gt ins_labels 与 pre ins_preds 求loss
        for input, target in zip(ins_preds, ins_labels): # ins_preds 与 ins_labels维度一样, ins_preds[0]数值,  ins_labels[0]是0,1
            if input.size()[0] == 0: # no ins
                continue
            input = torch.sigmoid(input) # sigmoid
            loss_ins.append(dice_loss(input, target))
        loss_ins = torch.cat(loss_ins).mean()
        loss_ins = loss_ins * self.ins_loss_weight
        print('loss_ins: ', loss_ins)

        # cate
        cate_labels = [
            torch.cat([cate_labels_level_img.flatten()
                       for cate_labels_level_img in cate_labels_level])
            for cate_labels_level in zip(*cate_label_list)
        ]
        flatten_cate_labels = torch.cat(cate_labels) # 3872 * batch_size
        # 对于cate 同样使用gt cate_labels 与 pre cate_preds求loss
        cate_preds = [
            cate_pred.permute(0, 2, 3, 1).reshape(-1, self.cate_out_channels) # [s*s , C]
            for cate_pred in cate_preds
        ]
        '''
            (Pdb) cate_preds[0].size()
                torch.Size([3200, 80]) 3200 = 1600 *2  --> [40, 40, 80]
            (Pdb) cate_preds[1].size()
                torch.Size([2592, 80])
            (Pdb) cate_preds[2].size()
                torch.Size([1152, 80])
            (Pdb) cate_preds[3].size()
                torch.Size([512, 80])
            (Pdb) cate_preds[4].size()
                torch.Size([288, 80])
        '''
        flatten_cate_preds = torch.cat(cate_preds) # [3782 * instance, 80]  5个fpn最后的feature map的channel相加

        loss_cate = self.loss_cate(flatten_cate_preds, flatten_cate_labels, avg_factor=num_ins + 1)  # num_ins表示的是ins_preds[0:4]上的的第一维度相加, 表示一共实例的个数。

        return dict(
            loss_ins=loss_ins,
            loss_cate=loss_cate)

    def solo_target_single(self,
                               gt_bboxes_raw,
                               gt_labels_raw,
                               gt_masks_raw,
                               featmap_sizes=None):

        # 每次读取一张图片,根据gt_areas算图中的每一个实例在FPN的哪一层
        # gt_bboxes_raw.size() --> [7, 4]
        # gt_labels_raw --> 7
        # gt_masks_raw --> [7, 800, 1024]
        # featmap_sizes --> [torch.Size([200, 336]), torch.Size([200, 336]), torch.Size([100, 168]), torch.Size([50, 84]), torch.Size([50, 84])]

        
        device = gt_labels_raw[0].device # cuda

        # ins
        # compute the gt_areas of per gt in one img.
        # gt_areas.size() --> [instances]
        gt_areas = torch.sqrt((gt_bboxes_raw[:, 2] - gt_bboxes_raw[:, 0]) * (
                gt_bboxes_raw[:, 3] - gt_bboxes_raw[:, 1]))

        ins_label_list = []
        cate_label_list = []
        ins_ind_label_list = []
        for (lower_bound, upper_bound), stride, featmap_size, num_grid \
                in zip(self.scale_ranges, self.strides, featmap_sizes, self.seg_num_grids):

            ins_label = torch.zeros([num_grid ** 2, featmap_size[0], featmap_size[1]], dtype=torch.uint8, device=device) # eg. [40 * 40, 200, 336]
            cate_label = torch.zeros([num_grid, num_grid], dtype=torch.int64, device=device) # [40, 40]
            ins_ind_label = torch.zeros([num_grid ** 2], dtype=torch.bool, device=device) # [1600]
            # nonzero()返回非0索引的位置。
            # flatten()展平操作
            hit_indices = ((gt_areas >= lower_bound) & (gt_areas <= upper_bound)).nonzero().flatten() # 代表在这一层 预测的实例的gt索引 也就是哪一个示例会出现在这层
            #pdb.set_trace()
            
            if len(hit_indices) == 0:
                ins_label_list.append(ins_label)
                cate_label_list.append(cate_label)
                ins_ind_label_list.append(ins_ind_label)
                continue
            gt_bboxes = gt_bboxes_raw[hit_indices] # store  gt_bboxes[x1,y1,x2,y2] when gt_areas belong to [lower_bound , upper_bound] ---> eg.[1, 4]
            gt_labels = gt_labels_raw[hit_indices] #  [instances] when gt_areas belong to [lower_bound , upper_bound ---> eg.[57]
            gt_masks = gt_masks_raw[hit_indices.cpu().numpy(), ...] # [instances , w, h] --> eg. [1, 800, 1216]

            half_ws = 0.5 * (gt_bboxes[:, 2] - gt_bboxes[:, 0]) * self.sigma # self.sigma = 0.2
            half_hs = 0.5 * (gt_bboxes[:, 3] - gt_bboxes[:, 1]) * self.sigma

            output_stride = stride / 2 

            # 每次只挑出一个instance
            for seg_mask, gt_label, half_h, half_w in zip(gt_masks, gt_labels, half_hs, half_ws):
                if seg_mask.sum() < 10:
                   continue
                # mass center
                upsampled_size = (featmap_sizes[0][0] * 4, featmap_sizes[0][1] * 4)
                center_h, center_w = ndimage.measurements.center_of_mass(seg_mask) # 算质心
                coord_w = int((center_w / upsampled_size[1]) // (1. / num_grid)) # 将质心 转化为 num_grid的坐标 eg. [659, 398] --> [29, 11]  when num_grid = 36
                coord_h = int((center_h / upsampled_size[0]) // (1. / num_grid))

                # left, top, right, down
                top_box = max(0, int(((center_h - half_h) / upsampled_size[0]) // (1. / num_grid)))
                down_box = min(num_grid - 1, int(((center_h + half_h) / upsampled_size[0]) // (1. / num_grid)))
                left_box = max(0, int(((center_w - half_w) / upsampled_size[1]) // (1. / num_grid)))
                right_box = min(num_grid - 1, int(((center_w + half_w) / upsampled_size[1]) // (1. / num_grid)))

                top = max(top_box, coord_h-1) # 6
                down = min(down_box, coord_h+1) # 8
                left = max(coord_w-1, left_box) # 6
                right = min(right_box, coord_w+1) # 8

                
                # cate
                cate_label[top:(down+1), left:(right+1)] = gt_label # eg. 将[6,8]
                # ins
                seg_mask = mmcv.imrescale(seg_mask, scale=1. / output_stride) # [800, 1088] --> [50, 68]  因为是[2h, 2w] 因此少缩小2倍
                seg_mask = torch.Tensor(seg_mask)
                for i in range(top, down+1):
                    for j in range(left, right+1):
                        label = int(i * num_grid + j)
                        ins_label[label, :seg_mask.shape[0], :seg_mask.shape[1]] = seg_mask # 存储在 s*s的某个通道上
                        ins_ind_label[label] = True # s*s 中哪一个网格有实例
            ins_label_list.append(ins_label)
            cate_label_list.append(cate_label)
            ins_ind_label_list.append(ins_ind_label)
        #pdb.set_trace()

        return ins_label_list, cate_label_list, ins_ind_label_list

    def get_seg(self, seg_preds, cate_preds, img_metas, cfg, rescale=None): # len(seg_preds):5   len(cate_preds):5
        #pdb.set_trace()
        assert len(seg_preds) == len(cate_preds)
        num_levels = len(cate_preds) # 5
        featmap_size = seg_preds[0].size()[-2:] # max fpn feature map size : [200, 304]

        result_list = []
        for img_id in range(len(img_metas)):
            cate_pred_list = [
                cate_preds[i][img_id].view(-1, self.cate_out_channels).detach() for i in range(num_levels)
            ]
            seg_pred_list = [
                seg_preds[i][img_id].detach() for i in range(num_levels)
            ]
            img_shape = img_metas[img_id]['img_shape']
            scale_factor = img_metas[img_id]['scale_factor']
            ori_shape = img_metas[img_id]['ori_shape']

            cate_pred_list = torch.cat(cate_pred_list, dim=0) #每次读取one img, 因此cate_pred_list.size() --> [3872, 80]
            seg_pred_list = torch.cat(seg_pred_list, dim=0)

            result = self.get_seg_single(cate_pred_list, seg_pred_list,
                                         featmap_size, img_shape, ori_shape, scale_factor, cfg, rescale)
            result_list.append(result)
            #pdb.set_trace()
        #pdb.set_trace()

        return result_list

    # 对于每一个图片。
    def get_seg_single(self,
                       cate_preds, # [3872, 80]
                       seg_preds, # eg. [3872, 200, 304]
                       featmap_size, # eg. [200, 304] max feature map in FPN
                       img_shape, # eg. [800, 1199, 3]
                       ori_shape, # eg. [427, 640, 3]
                       scale_factor,
                       cfg,
                       rescale=False, debug=False):
        assert len(cate_preds) == len(seg_preds)
        #pdb.set_trace()
        # overall info.
        h, w, _ = img_shape
        upsampled_size_out = (featmap_size[0] * 4, featmap_size[1] * 4) # eg. [800, 1216]

        # process.
        inds = (cate_preds > cfg.score_thr) # 第一次筛选 eg. [3872, 80]  score_thr = 0.1 inds 是 bool类型  
        # category scores.
        cate_scores = cate_preds[inds] #  eg.[507]   cate_scores是数值,维度是[num[True]](我认为还降维了), 根据cate_preds[inds] 在对于true的地方输出
        if len(cate_scores) == 0: 
            return None
        # category labels.
        inds = inds.nonzero() # 返回inds[i]为True的索引   inds.nonzero().size() --> [507, 2]
        cate_labels = inds[:, 1] # inds的第二列是代表的[80]中的类别。 cate_labels --> [507]

        # strides.
        size_trans = cate_labels.new_tensor(self.seg_num_grids).pow(2).cumsum(0)  # tensor([1600, 2896, 3472, 3728, 3872], device='cuda:0')
        strides = cate_scores.new_ones(size_trans[-1])  # [3872] 全为1
        n_stage = len(self.seg_num_grids)  # 5
        strides[:size_trans[0]] *= self.strides[0] # 前1600个元素由 1 变成 8
        for ind_ in range(1, n_stage):
            strides[size_trans[ind_ - 1]:size_trans[ind_]] *= self.strides[ind_] # eg. 为1600 ~ 2896的1296个元素 赋值
        strides = strides[inds[:, 0]] # strides.size() --> [507]    inds[:, 0] 表示第几个grid_cell

        # masks. 
        seg_preds = seg_preds[inds[:, 0]]  # [3872, 200, 304] --> [507, 200, 304]
        seg_masks = seg_preds > cfg.mask_thr # mask_thr = 0.5 bool [507, 200, 304] --> binary mask 二值化的作用!
        sum_masks = seg_masks.sum((1, 2)).float() # [507, 200, 304] ---> [507]  sum(1,2)表示对每一个channcel内的[H * W]的每个元素求和

        # filter.
        keep = sum_masks > strides #bool  [507]
        if keep.sum() == 0:
            return None
        #过滤
        seg_masks = seg_masks[keep, ...]  #  bool [keep.size(), 200, 304]  seg_mask[True]的位置保持原来的seg_mask的值(T or F), seg_mask[False]的位置直接取舍不记录。
        seg_preds = seg_preds[keep, ...]
        sum_masks = sum_masks[keep]
        cate_scores = cate_scores[keep]
        cate_labels = cate_labels[keep]

        # mask scoring.
        seg_scores = (seg_preds * seg_masks.float()).sum((1, 2)) / sum_masks # eg, [507] 每一个channel上的对应元素相乘再求和最后除以
        cate_scores *= seg_scores # why?

        # sort and keep top nms_pre
        sort_inds = torch.argsort(cate_scores, descending=True)
        if len(sort_inds) > cfg.nms_pre: # 筛选前500
            sort_inds = sort_inds[:cfg.nms_pre]
        seg_masks = seg_masks[sort_inds, :, :]
        seg_preds = seg_preds[sort_inds, :, :]
        sum_masks = sum_masks[sort_inds]
        cate_scores = cate_scores[sort_inds]
        cate_labels = cate_labels[sort_inds]
        #pdb.set_trace()
        # Matrix NMS
        cate_scores = matrix_nms(seg_masks, cate_labels, cate_scores,
                                 kernel=cfg.kernel, sigma=cfg.sigma, sum_masks=sum_masks)

        # filter.
        keep = cate_scores >= cfg.update_thr
        if keep.sum() == 0:
            return None
        seg_preds = seg_preds[keep, :, :]
        cate_scores = cate_scores[keep]
        cate_labels = cate_labels[keep]

        # sort and keep top_k
        sort_inds = torch.argsort(cate_scores, descending=True)
        if len(sort_inds) > cfg.max_per_img:
            sort_inds = sort_inds[:cfg.max_per_img]
        seg_preds = seg_preds[sort_inds, :, :]
        cate_scores = cate_scores[sort_inds]
        cate_labels = cate_labels[sort_inds]

        seg_preds = F.interpolate(seg_preds.unsqueeze(0),
                                  size=upsampled_size_out,
                                  mode='bilinear')[:, :, :h, :w]
        seg_masks = F.interpolate(seg_preds,
                                  size=ori_shape[:2],
                                  mode='bilinear').squeeze(0)
        seg_masks = seg_masks > cfg.mask_thr
        #pdb.set_trace()
        return seg_masks, cate_labels, cate_scores

#----------------------------------------------------------------------------------------#
#self.ins_convs:
'''
ModuleList(
  (0): ConvModule(
    (conv): Conv2d(258, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (gn): GroupNorm(32, 256, eps=1e-05, affine=True)
    (activate): ReLU(inplace=True)
  )
  (1): ConvModule(
    (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (gn): GroupNorm(32, 256, eps=1e-05, affine=True)
    (activate): ReLU(inplace=True)
  )
  (2): ConvModule(
    (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (gn): GroupNorm(32, 256, eps=1e-05, affine=True)
    (activate): ReLU(inplace=True)
  )
  (3): ConvModule(
    (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (gn): GroupNorm(32, 256, eps=1e-05, affine=True)
    (activate): ReLU(inplace=True)
  )
  (4): ConvModule(
    (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (gn): GroupNorm(32, 256, eps=1e-05, affine=True)
    (activate): ReLU(inplace=True)
  )
  (5): ConvModule(
    (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (gn): GroupNorm(32, 256, eps=1e-05, affine=True)
    (activate): ReLU(inplace=True)
  )
  (6): ConvModule(
    (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (gn): GroupNorm(32, 256, eps=1e-05, affine=True)
    (activate): ReLU(inplace=True)
  )
)

'''



#----------------------------------------------------------------------------------------#
#self.cate_convs
'''

ModuleList(
  (0): ConvModule(
    (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (gn): GroupNorm(32, 256, eps=1e-05, affine=True)
    (activate): ReLU(inplace=True)
  )
  (1): ConvModule(
    (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (gn): GroupNorm(32, 256, eps=1e-05, affine=True)
    (activate): ReLU(inplace=True)
  )
  (2): ConvModule(
    (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (gn): GroupNorm(32, 256, eps=1e-05, affine=True)
    (activate): ReLU(inplace=True)
  )
  (3): ConvModule(
    (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (gn): GroupNorm(32, 256, eps=1e-05, affine=True)
    (activate): ReLU(inplace=True)
  )
  (4): ConvModule(
    (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (gn): GroupNorm(32, 256, eps=1e-05, affine=True)
    (activate): ReLU(inplace=True)
  )
  (5): ConvModule(
    (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (gn): GroupNorm(32, 256, eps=1e-05, affine=True)
    (activate): ReLU(inplace=True)
  )
  (6): ConvModule(
    (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (gn): GroupNorm(32, 256, eps=1e-05, affine=True)
    (activate): ReLU(inplace=True)
  )

'''
#----------------------------------------------------------------------------------------#
# self.solo_ins_list
'''
ModuleList(
  (0): Conv2d(256, 1600, kernel_size=(1, 1), stride=(1, 1))
  (1): Conv2d(256, 1296, kernel_size=(1, 1), stride=(1, 1))
  (2): Conv2d(256, 576, kernel_size=(1, 1), stride=(1, 1))
  (3): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
  (4): Conv2d(256, 144, kernel_size=(1, 1), stride=(1, 1))
)
'''


'''
ins_pred:
    (Pdb) ins_pred[0].size()
    torch.Size([2, 1600, 200, 304])
    (Pdb) ins_pred[1].size()
    torch.Size([2, 1296, 200, 304])
    (Pdb) ins_pred[2].size()
    torch.Size([2, 576, 100, 152])
    (Pdb) ins_pred[3].size()
    torch.Size([2, 256, 50, 76])
    (Pdb) ins_pred[4].size()
    torch.Size([2, 144, 50, 76])

'''

'''
cate_pred:
    (Pdb) cate_pred[0].size()
    torch.Size([2, 80, 40, 40])
    (Pdb) cate_pred[1].size()
    torch.Size([2, 80, 36, 36])
    (Pdb) cate_pred[2].size()
    torch.Size([2, 80, 24, 24])
    (Pdb) cate_pred[3].size()
    torch.Size([2, 80, 16, 16])
    (Pdb) cate_pred[4].size()
    torch.Size([2, 80, 12, 12])

'''
#----------------------------------------------------------------------------------------#

#def loss


'''
ins_labels
    (Pdb) ins_labels[0].size()
    torch.Size([1, 200, 272])
    (Pdb) ins_labels[1].size()
    torch.Size([0, 200, 272])
    (Pdb) ins_labels[2].size()
    torch.Size([16, 100, 136])
    (Pdb) ins_labels[3].size()
    torch.Size([39, 50, 68])
    (Pdb) ins_labels[4].size()
    torch.Size([18, 50, 68])
'''

'''
ins_preds:
    (Pdb) ins_preds[0].size()
    torch.Size([1, 200, 272])
    (Pdb) ins_preds[1].size()
    torch.Size([0, 200, 272])
    (Pdb) ins_preds[2].size()
    torch.Size([6, 100, 136])
    (Pdb) ins_preds[3].size()
    torch.Size([10, 50, 68])
    (Pdb) ins_preds[4].size()
    torch.Size([6, 50, 68])
'''

'''
ins_ind_labels:
    (Pdb) ins_ind_labels[0].size()
    torch.Size([1600])
    (Pdb) ins_ind_labels[1].size()
    torch.Size([1296])
    (Pdb) ins_ind_labels[2].size()
    torch.Size([576])
    (Pdb) ins_ind_labels[3].size()
    torch.Size([256])
    (Pdb) ins_ind_labels[4].size()
    torch.Size([144])

'''


'''
cate_labels:
    (Pdb) cate_labels[0].size()
    torch.Size([1600])
    (Pdb) cate_labels[1].size()
    torch.Size([1296])
    (Pdb) cate_labels[2].size()
    torch.Size([576])
    (Pdb) cate_labels[3].size()
    torch.Size([256])
    (Pdb) cate_labels[4].size()
    torch.Size([144])

'''


'''
get_seg
cfg:
    {
        'nms_pre': 500, 
        'score_thr': 0.1,
        'mask_thr': 0.5, 
        'update_thr': 0.05, 
        'kernel': 'gaussian', 
        'sigma': 2.0, 
        'max_per_img': 100}

'''


'''
sum_masks
tensor([   96.,    96.,    82.,    82.,    82.,   108.,   108.,   108.,    86.,
           86.,    86.,   208.,   227.,   227.,   227.,   134.,   134.,    88.,
           28.,    79.,    79.,   231.,   231.,   231.,   189.,   189.,    31.,
           31.,   125.,   125.,   125.,   158.,   158.,   194.,    99.,    99.,
           74.,   159.,    37.,    37.,    37.,    39.,    39.,   275.,    50.,
           31.,    64.,    64.,    64.,    64.,    66.,    66.,    66.,    66.,
           91.,    91.,    91.,    93.,   192.,   192.,   192.,    46.,    46.,
           46.,    39.,    39.,    51.,    51.,    87.,   140.,   181.,   199.,
           50.,    50.,    50.,    50.,    76.,    20.,    88.,    88.,    84.,
           84.,    84.,   236.,   236.,    94.,   211.,   211.,   252.,    85.,
           98.,    56.,    96.,    96.,    60.,    60.,    60.,    53.,    84.,
           84.,    84.,    84.,   258.,   267.,   304.,    90.,   105.,   105.,
          105.,    75.,    75.,    75.,    53.,    53.,    84.,    84.,   132.,
          274.,   274.,   259.,   259.,   296.,   296.,   296.,   272.,   272.,
          272.,   272.,   112.,   117.,    50.,    87.,   143.,   143.,    80.,
           88.,    88.,   273.,   273.,   320.,   320.,   294.,   364.,   313.,
          355.,   302.,   353.,   353.,    67.,    67.,    42.,    32.,    32.,
           61.,    61.,    61.,    61.,    68.,    68.,    68.,    68.,   168.,
          168.,   168.,    28.,    28.,    28.,    67.,    71.,   139.,   282.,
          304.,    94.,   169.,   135.,   135.,   286.,   331.,   100.,   100.,
          100.,    95.,    95.,   172.,   277.,   277.,   277.,   371.,   380.,
           92.,    92.,   160.,   394.,   394.,   395.,   132.,   132.,   157.,
          295.,   282.,   452.,   468.,    66.,    66.,   209.,    73.,    73.,
           73.,   352.,   360.,   333.,    25.,   205.,   229.,   229.,   229.,
          491.,   491.,   488.,   488.,   488.,   449.,   449.,   234.,   255.,
          255.,   255.,   255.,   630.,   514.,   514.,   514.,   481.,   481.,
          481.,   871.,  1029.,   260.,   260.,   260.,   260.,   639.,   514.,
          484.,   168.,   168.,   415.,    81.,  1120.,  1232.,   418.,   418.,
          128.,   141.,   242.,   242.,    91.,    57.,    57.,    80.,    80.,
           80.,   621.,  1248.,  1315.,   199.,   304.,   210.,    78.,    54.,
           54.,    62.,    62.,    62.,   622.,   697.,   697.,   663.,   663.,
          149.,   118.,   108.,   109.,   109.,   202.,   218.,   218.,   275.,
          275.,   357.,   357.,   357.,   361.,   361.,   102.,   111.,   111.,
          448.,   279.,   356.,   347.,   347.,   271.,   293.,   288.,   288.,
          288.,   277.,   277.,   271.,   271.,   131.,   131.,   162.,   162.,
          162.,   132.,   132.,   107.,   362.,   452.,   452.,   571.,   361.,
          360.,   438.,   714.,   404.,   427.,   613.,   395.,   411.,   438.,
          438.,   471.,   529.,   546.,    52.,    52.,    85.,    85.,    85.,
          181.,   181.,   336.,   359.,   183.,   353.,   370.,    98.,    98.,
           98.,   191.,   191.,   268.,   268.,   340.,   340.,   736.,   346.,
          380.,    94.,    94.,    94.,   179.,   179.,   412.,   437.,   437.,
          437.,  1087.,   560.,   398.,   925.,   925.,   802.,   802.,   802.,
          375.,   834.,   847.,   512.,   944.,   508.,    48.,   274.,    82.,
           82.,    82.,   482.,   444.,   491.,   491.,  1281.,   679.,   679.,
          571.,   571.,   571.,  1403.,   583.,   647.,  1429.,   940.,   721.,
          721.,   313.,  1953.,  3322.,  3694.,  3694.,  2245.,  2187.,  1180.,
         3924.,  3924.,  3963.,  1622.,  2566.,  3506.,  1246.,  2082.,  4032.,
         4067.,   474.,   567.,   567.,  1675.,  2513.,  3013.,  1489.,   709.,
          900.,   900.,   769.,  2537.,   689.,  1485.,  2476.,   416.,  1449.,
          706.,  2477.,  3185.,  3221.,   413.,  2756.,  3230.,  3230.,  3156.,
          424.,   465.,  2933.,  2846.,   474.,   474.,   940.,   940.,   851.,
          851.,   851.,   553.,  1572.,  5856.,  3666.,  4373.,  3937.,  2129.,
         4194.,  4586.,  2788.,  2683.,  4081.,  3171.,  3171.,  3894.,  4206.,
         1353.,  1984.,  3575.,  3303.,  2040.,  3688.,  3688.,  7555.,  8147.,
         9637., 10042.,  7735.,  9848., 10357.,  6124., 10311., 10753.,  5137.,
         4384.,  6858.,  4768.,  4397.,  6499., 10237., 10237.,  9333.,  9333.,
         9033.,  9723.,  9955.], device='cuda:0')

'''

'''strides
tensor([ 8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,
         8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,
         8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,
         8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,
         8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,
         8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,
         8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,
         8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,
         8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,
         8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,
         8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,
         8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,
         8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,
         8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,
         8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,
         8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,
         8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,
         8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,
         8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,
         8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,
         8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,
         8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,
         8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,
         8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,
         8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,
         8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,
         8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,
         8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8., 16., 16.,
        16., 16., 16., 16., 16., 16., 16., 16., 16., 16., 16., 16., 16., 16.,
        16., 16., 16., 16., 16., 16., 16., 16., 16., 16., 16., 16., 16., 16.,
        16., 16., 16., 16., 16., 16., 16., 16., 16., 16., 16., 16., 16., 16.,
        16., 16., 16., 16., 16., 16., 16., 16., 16., 16., 16., 16., 16., 16.,
        16., 16., 16., 16., 16., 16., 16., 16., 16., 16., 16., 16., 16., 16.,
        32., 32., 32., 32., 32., 32., 32., 32., 32., 32., 32., 32., 32., 32.,
        32., 32., 32., 32., 32., 32., 32., 32., 32., 32., 32., 32., 32., 32.,
        32., 32., 32., 32., 32., 32., 32., 32., 32., 32., 32., 32., 32., 32.,
        32., 32., 32.], device='cuda:0')

'''

3. SOLO/mmdetect/core/post_processing/matrix_nms.py

import torch
import pdb

def matrix_nms(seg_masks, cate_labels, cate_scores, kernel='gaussian', sigma=2.0, sum_masks=None):
    """Matrix NMS for multi-class masks.

    Args:
        seg_masks (Tensor): shape (n, h, w) bool
        cate_labels (Tensor): shape (n), mask labels in descending order
        cate_scores (Tensor): shape (n), mask scores in descending order
        kernel (str):  'linear' or 'gauss' 
        sigma (float): std in gaussian method
        sum_masks (Tensor): The sum of seg_masks

    Returns:
        Tensor: cate_scores_update, tensors of shape (n)
    """
    pdb.set_trace()
    n_samples = len(cate_labels) # 最多 500
    if n_samples == 0:
        return []
    if sum_masks is None:
        sum_masks = seg_masks.sum((1, 2)).float()
    seg_masks = seg_masks.reshape(n_samples, -1).float() # [500, 60800] 相当于把同一个实例的特征展平
    # inter.   注: 矩阵相乘就表示了每一个channel上某一个实例的掩码所在所在位置上的值(1or0)与其他通道的mask所在位置的值相乘
    # 2个特例:
    # 就算相同类别,如果位置不同,那么他们inter也是0,如果位置相同,就涉及到了NMS筛选的范畴
    # (1)如果他们位置不同,那么就必定是为0的,不能仅仅考虑类别相同! 
    # (2)并且可能不同的实例一大一小,但是他们位置有相交,那么也有交集!不同实例相同位置的IOU排除方法见下面的label_matrix的使用。
    inter_matrix = torch.mm(seg_masks, seg_masks.transpose(1, 0)) # [500 , 60800] @ [60800 , 500] = [500, 500]  
    # union.
    sum_masks_x = sum_masks.expand(n_samples, n_samples) # [500, 500]
    # iou.
    # 掩码值相加代表了union 取上三角(转置肯定有重复。)
    iou_matrix = (inter_matrix / (sum_masks_x + sum_masks_x.transpose(1, 0) - inter_matrix)).triu(diagonal=1)
    # label_specific matrix.
    cate_labels_x = cate_labels.expand(n_samples, n_samples) # [500, 500]
    # 每i行的元素(1 or 0),1表示和第i个mask类别一样的。 并且使用了triu方法,进一步的得到分数比他低的的mask(triu方法的妙用)
    # 因此在已经排除了同一种label不同位置的情况,这一步就是排除同一个位置,不同label,它们的iou也要置于0
    label_matrix = (cate_labels_x == cate_labels_x.transpose(1, 0)).float().triu(diagonal=1) # [500, 500]  
    
    # IoU decay 
    
    # iou_matrix * label_matrix是为了保留同一种小于最大scores的label的iou。
    # 因为之前算的iou的inter部分有可能一大一小的实例,但是他们位置上有重叠,因此还有iou并不等于0,要进行惩罚
    # 而消除不同label的iou(因为nms就是对同一个类别的scores高低的mask/box进行筛选最后剩下一个)
    # 第一个式子排除结束。得到同种mask同一位置的IOU,每i行表示与第i个mask的iou。
    decay_iou = iou_matrix * label_matrix
    '''
    (Pdb) decay_iou = (iou_matrix * label_matrix) 上三角。
    tensor([[0.0000, 0.8036, 0.5017,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.4816,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        ...,
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0127],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],
       device='cuda:0')

    '''


    # IoU compensation
    compensate_iou, _ = (iou_matrix * label_matrix).max(0)  # fast-nms 按列取最大值(不是同一类的mask直接就0不考虑了),第i列表示第i个mask与跟它同种mask最大的scores最大的iou值
    # 分析:
    # eg.  
    # 前3列都是第一个mask的预测,按照scores排列第一个是最大的,所以第一列的max就是0;
    # 注意看第三列,max是0.5017,这个0.5是和第一个mask相比的,而不取0.47(如果thr是0.5就不会被排除)。
    # 这就是**fast-nms尽可能去掉更多的框的核心思想**。
    '''
    compensate_iou
    tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.8036, 0.8036, 0.8036,  ..., 0.8036, 0.8036, 0.8036],
        [0.5017, 0.5017, 0.5017,  ..., 0.5017, 0.5017, 0.5017],
        ...,
        [0.0021, 0.0021, 0.0021,  ..., 0.0021, 0.0021, 0.0021],
        [0.0054, 0.0054, 0.0054,  ..., 0.0054, 0.0054, 0.0054],
        [0.0193, 0.0193, 0.0193,  ..., 0.0193, 0.0193, 0.0193]],

    '''

    compensate_iou = compensate_iou.expand(n_samples, n_samples).transpose(1, 0)

   

    # matrix nms
    if kernel == 'gaussian': 
        decay_matrix = torch.exp(-1 * sigma * (decay_iou ** 2))
        compensate_matrix = torch.exp(-1 * sigma * (compensate_iou ** 2))
        decay_coefficient, _ = (decay_matrix / compensate_matrix).min(0) 
        # 分析min(0)按列取最小的作用:
        # 如下面的eg, 因为经过了指数函数e, 原来为0的表示最大的score或者无iou缩减的值就要变为1。原来对于每一个mask,次大的得分的scores就会变小。
        # 按列取最小应该算出对每一个mask的scores抑制的大小。(这里的decay_iou只会算同label的mask了。)
        '''
        (Pdb) decay_matrix / compensate_matrix
        tensor([[1.0000, 0.2748, 0.6044,  ..., 1.0000, 1.0000, 1.0000],
        [3.6388, 3.6388, 2.2883,  ..., 3.6388, 3.6388, 3.6388],
        [1.6545, 1.6545, 1.6545,  ..., 1.6545, 1.6545, 1.6545],
        ...,
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 0.9997],
        [1.0001, 1.0001, 1.0001,  ..., 1.0001, 1.0001, 1.0001],
        [1.0007, 1.0007, 1.0007,  ..., 1.0007, 1.0007, 1.0007]],
       device='cuda:0')

        '''
        pdb.set_trace
    elif kernel == 'linear':
        decay_matrix = (1-decay_iou)/(1-compensate_iou)
        decay_coefficient, _ = decay_matrix.min(0)
    else:
        raise NotImplementedError

    # update the score.
    cate_scores_update = cate_scores * decay_coefficient #  soft-nms的方法 让相同的label但是scores低与max的变小。
    pdb.set_trace()
    return cate_scores_update


def multiclass_nms(multi_bboxes,
                   multi_scores,
                   score_thr,
                   nms_cfg,
                   max_num=-1,
                   score_factors=None):
    """NMS for multi-class bboxes.

    Args:
        multi_bboxes (Tensor): shape (n, #class*4) or (n, 4)
        multi_scores (Tensor): shape (n, #class), where the 0th column
            contains scores of the background class, but this will be ignored.
        score_thr (float): bbox threshold, bboxes with scores lower than it
            will not be considered.
        nms_thr (float): NMS IoU threshold
        max_num (int): if there are more than max_num bboxes after NMS,
            only top max_num will be kept.
        score_factors (Tensor): The factors multiplied to scores before
            applying NMS

    Returns:
        tuple: (bboxes, labels), tensors of shape (k, 5) and (k, 1). Labels
            are 0-based.
    """
    num_classes = multi_scores.shape[1]
    bboxes, labels = [], []
    nms_cfg_ = nms_cfg.copy()
    nms_type = nms_cfg_.pop('type', 'nms')
    nms_op = getattr(nms_wrapper, nms_type)
    for i in range(1, num_classes):
        cls_inds = multi_scores[:, i] > score_thr
        if not cls_inds.any():
            continue
        # get bboxes and scores of this class
        if multi_bboxes.shape[1] == 4:
            _bboxes = multi_bboxes[cls_inds, :]
        else:
            _bboxes = multi_bboxes[cls_inds, i * 4:(i + 1) * 4]
        _scores = multi_scores[cls_inds, i]
        if score_factors is not None:
            _scores *= score_factors[cls_inds]
        cls_dets = torch.cat([_bboxes, _scores[:, None]], dim=1)
        cls_dets, _ = nms_op(cls_dets, **nms_cfg_)
        cls_labels = multi_bboxes.new_full((cls_dets.shape[0], ),
                                           i - 1,
                                           dtype=torch.long)
        bboxes.append(cls_dets)
        labels.append(cls_labels)
    if bboxes:
        bboxes = torch.cat(bboxes)
        labels = torch.cat(labels)
        if bboxes.shape[0] > max_num:
            _, inds = bboxes[:, -1].sort(descending=True)
            inds = inds[:max_num]
            bboxes = bboxes[inds]
            labels = labels[inds]
    else:
        bboxes = multi_bboxes.new_zeros((0, 5))
        labels = multi_bboxes.new_zeros((0, ), dtype=torch.long)

    return bboxes, labels


'''
(Pdb) cate_scores * decay_coefficient
tensor([0.7593, 0.1010, 0.1081, 0.5393, 0.0926, 0.0885, 0.4901, 0.4664, 0.4540,
        0.0755, 0.4385, 0.3944, 0.0726, 0.0748, 0.0986, 0.0551, 0.0835, 0.0694,
        0.0822, 0.3194, 0.0600, 0.3141, 0.0594, 0.3115, 0.3114, 0.3086, 0.0771,
        0.0792, 0.0597, 0.0512, 0.0569, 0.3018, 0.0461, 0.0550, 0.0537, 0.0662,
        0.0580, 0.0644, 0.0503, 0.2881, 0.2839, 0.2830, 0.0561, 0.1310, 0.2692,
        0.0652, 0.0694, 0.0505, 0.0410, 0.0464, 0.0665, 0.0409, 0.2440, 0.0407,
        0.0464, 0.0410, 0.2291, 0.0447, 0.1051, 0.2260, 0.2241, 0.2236, 0.2233,
        0.0529, 0.1370, 0.2200, 0.0540, 0.0532, 0.0473, 0.0530, 0.2168, 0.2134,
        0.0678, 0.0478, 0.0384, 0.0407, 0.1161, 0.0320, 0.0619, 0.2025, 0.0388,
        0.0331, 0.0493, 0.0866, 0.0849, 0.0413, 0.0593, 0.0593, 0.0388, 0.0389,
        0.0738, 0.1875, 0.0674, 0.1145, 0.0588, 0.0806, 0.1797, 0.0382, 0.1776,
        0.1751, 0.0489, 0.0511, 0.1743, 0.0815, 0.1741, 0.0582, 0.0925, 0.0317,
        0.0318, 0.1661, 0.1645, 0.0297, 0.1634, 0.1629, 0.0446, 0.0389, 0.0318,
        0.1611, 0.1445, 0.0564, 0.0337, 0.1564, 0.1563, 0.0331, 0.1556, 0.0605,
        0.1533, 0.1526, 0.0254, 0.1477, 0.0477, 0.1507, 0.0379, 0.1504, 0.0312,
        0.0492, 0.1478, 0.0248, 0.1466, 0.0412, 0.0278, 0.0301, 0.0973, 0.0297,
        0.1449, 0.0219, 0.0616, 0.0348, 0.0274, 0.0721, 0.0425, 0.1388, 0.0409,
        0.0231, 0.0848, 0.1382, 0.0488, 0.0265, 0.0326, 0.1361, 0.0220, 0.0898,
        0.0259, 0.0259, 0.0268, 0.0563, 0.1345, 0.1344, 0.0220, 0.0319, 0.0512,
        0.1330, 0.0265, 0.0458, 0.0277, 0.0257, 0.0245, 0.0280, 0.1300, 0.0402,
        0.0307, 0.0460, 0.0315, 0.0277, 0.0173, 0.0657, 0.0251, 0.0230, 0.1267,
        0.1263, 0.0789, 0.0680, 0.0559, 0.0196, 0.0247, 0.0987, 0.1243, 0.0254,
        0.1033, 0.1235, 0.1234, 0.1233, 0.1232, 0.0211, 0.0351, 0.1230, 0.1225,
        0.0211, 0.1211, 0.0752, 0.1207, 0.0759, 0.1200, 0.0432, 0.1198, 0.1191,
        0.0215, 0.0458, 0.1184, 0.0221, 0.1175, 0.0706, 0.0312, 0.1170, 0.1169,
        0.0257, 0.1167, 0.1166, 0.0193, 0.0641, 0.1151, 0.0692, 0.0873, 0.0289,
        0.0330, 0.1137, 0.0447, 0.0257, 0.0675, 0.1123, 0.0252, 0.0519, 0.0219,
        0.0188, 0.0327, 0.1117, 0.1117, 0.0921, 0.0403, 0.0270, 0.0230, 0.0641,
        0.0273, 0.1099, 0.0201, 0.0322, 0.1091, 0.1090, 0.0229, 0.1089, 0.0187,
        0.0216, 0.0307, 0.0513, 0.1080, 0.0260, 0.0855, 0.0441, 0.0188, 0.0972,
        0.1068, 0.0417, 0.0206, 0.0394, 0.0214, 0.0427, 0.0170, 0.0311, 0.0481,
        0.0196, 0.1049, 0.1051, 0.1049, 0.0295, 0.0347, 0.0226, 0.0667, 0.0199,
        0.1041, 0.0246, 0.1038, 0.0241, 0.1033, 0.1028, 0.0212, 0.1021, 0.1022,
        0.1019, 0.0413, 0.0388, 0.0343, 0.0967, 0.0925, 0.0654, 0.1009, 0.0301,
        0.1007, 0.0986, 0.0474, 0.0583, 0.0990, 0.0273, 0.0989, 0.0737, 0.0689,
        0.0187, 0.0231, 0.0982, 0.0522, 0.0132, 0.0973, 0.0387, 0.0971, 0.0937,
        0.0968, 0.0189, 0.0218, 0.0933, 0.0219, 0.0199, 0.0957, 0.0475, 0.0266,
        0.0950, 0.0389, 0.0454, 0.0262, 0.0641, 0.0870, 0.0212, 0.0187, 0.0834,
        0.0931, 0.0431, 0.0929, 0.0929, 0.0703, 0.0193, 0.0459, 0.0211, 0.0926,
        0.0925, 0.0923, 0.0371, 0.0420, 0.0224, 0.0196, 0.0919, 0.0336, 0.0917,
        0.0894, 0.0569, 0.0832, 0.0328, 0.0249, 0.0263, 0.0181, 0.0410, 0.0906,
        0.0159, 0.0402, 0.0183, 0.0168, 0.0171, 0.0204, 0.0160, 0.0897, 0.0323,
        0.0173, 0.0240, 0.0708, 0.0894, 0.0892, 0.0892, 0.0283, 0.0186, 0.0172,
        0.0882, 0.0160, 0.0179, 0.0522, 0.0511, 0.0177, 0.0877, 0.0418, 0.0155,
        0.0606, 0.0868, 0.0867, 0.0485, 0.0258, 0.0143, 0.0359, 0.0804, 0.0457,
        0.0835, 0.0678, 0.0177, 0.0193, 0.0250, 0.0477, 0.0289, 0.0247, 0.0839,
        0.0836, 0.0680, 0.0423, 0.0147, 0.0649, 0.0824, 0.0178, 0.0299, 0.0219,
        0.0161, 0.0152, 0.0422, 0.0242, 0.0266, 0.0808, 0.0453, 0.0557, 0.0807,
        0.0222, 0.0154, 0.0217, 0.0134, 0.0600, 0.0447, 0.0231, 0.0162, 0.0759,
        0.0292, 0.0229, 0.0790, 0.0380, 0.0216, 0.0505, 0.0786, 0.0556, 0.0281,
        0.0469, 0.0556, 0.0233, 0.0726, 0.0175, 0.0303, 0.0774, 0.0770, 0.0462,
        0.0285, 0.0731, 0.0333, 0.0712, 0.0232, 0.0318, 0.0756, 0.0361, 0.0382,
        0.0751, 0.0627, 0.0749, 0.0565, 0.0470, 0.0228, 0.0193, 0.0294, 0.0442,
        0.0434, 0.0538, 0.0726, 0.0562, 0.0260, 0.0227, 0.0721, 0.0325, 0.0717,
        0.0604, 0.0696, 0.0700, 0.0588, 0.0234, 0.0229, 0.0195, 0.0683, 0.0350,
        0.0359, 0.0378, 0.0688, 0.0407, 0.0671], device='cuda:0')
(Pdb) cate_scores
tensor([0.7593, 0.6425, 0.6012, 0.5393, 0.5195, 0.4914, 0.4901, 0.4664, 0.4540,
        0.4468, 0.4385, 0.3944, 0.3913, 0.3701, 0.3569, 0.3558, 0.3473, 0.3448,
        0.3417, 0.3194, 0.3147, 0.3141, 0.3134, 0.3115, 0.3114, 0.3086, 0.3071,
        0.3065, 0.3050, 0.3035, 0.3025, 0.3018, 0.3017, 0.3003, 0.2977, 0.2969,
        0.2946, 0.2934, 0.2930, 0.2881, 0.2839, 0.2830, 0.2733, 0.2713, 0.2692,
        0.2640, 0.2634, 0.2615, 0.2544, 0.2502, 0.2472, 0.2443, 0.2440, 0.2430,
        0.2317, 0.2306, 0.2291, 0.2265, 0.2262, 0.2260, 0.2241, 0.2236, 0.2233,
        0.2222, 0.2207, 0.2202, 0.2192, 0.2191, 0.2173, 0.2169, 0.2168, 0.2134,
        0.2130, 0.2112, 0.2105, 0.2093, 0.2073, 0.2070, 0.2031, 0.2025, 0.2007,
        0.1998, 0.1989, 0.1978, 0.1951, 0.1939, 0.1920, 0.1917, 0.1895, 0.1893,
        0.1876, 0.1875, 0.1847, 0.1839, 0.1827, 0.1817, 0.1797, 0.1786, 0.1776,
        0.1751, 0.1748, 0.1746, 0.1743, 0.1743, 0.1741, 0.1723, 0.1704, 0.1701,
        0.1675, 0.1661, 0.1645, 0.1642, 0.1634, 0.1629, 0.1625, 0.1623, 0.1618,
        0.1611, 0.1607, 0.1599, 0.1583, 0.1564, 0.1563, 0.1557, 0.1556, 0.1541,
        0.1533, 0.1526, 0.1518, 0.1514, 0.1512, 0.1507, 0.1505, 0.1504, 0.1503,
        0.1499, 0.1478, 0.1476, 0.1466, 0.1461, 0.1458, 0.1453, 0.1452, 0.1452,
        0.1449, 0.1438, 0.1419, 0.1411, 0.1405, 0.1392, 0.1391, 0.1388, 0.1386,
        0.1385, 0.1383, 0.1382, 0.1379, 0.1367, 0.1363, 0.1361, 0.1357, 0.1355,
        0.1352, 0.1352, 0.1349, 0.1348, 0.1345, 0.1344, 0.1335, 0.1335, 0.1333,
        0.1330, 0.1326, 0.1323, 0.1317, 0.1313, 0.1312, 0.1309, 0.1301, 0.1298,
        0.1293, 0.1283, 0.1282, 0.1282, 0.1280, 0.1280, 0.1277, 0.1268, 0.1267,
        0.1263, 0.1261, 0.1259, 0.1259, 0.1255, 0.1253, 0.1245, 0.1243, 0.1238,
        0.1237, 0.1235, 0.1234, 0.1233, 0.1233, 0.1231, 0.1230, 0.1230, 0.1226,
        0.1215, 0.1211, 0.1211, 0.1207, 0.1201, 0.1200, 0.1198, 0.1198, 0.1191,
        0.1187, 0.1186, 0.1184, 0.1183, 0.1175, 0.1173, 0.1172, 0.1170, 0.1169,
        0.1168, 0.1167, 0.1166, 0.1164, 0.1153, 0.1151, 0.1150, 0.1145, 0.1141,
        0.1140, 0.1137, 0.1133, 0.1131, 0.1128, 0.1123, 0.1123, 0.1123, 0.1120,
        0.1119, 0.1117, 0.1117, 0.1117, 0.1112, 0.1111, 0.1111, 0.1107, 0.1104,
        0.1103, 0.1099, 0.1097, 0.1093, 0.1091, 0.1090, 0.1089, 0.1089, 0.1086,
        0.1082, 0.1082, 0.1082, 0.1080, 0.1080, 0.1076, 0.1074, 0.1074, 0.1071,
        0.1068, 0.1068, 0.1066, 0.1065, 0.1063, 0.1062, 0.1060, 0.1056, 0.1056,
        0.1054, 0.1053, 0.1051, 0.1049, 0.1049, 0.1044, 0.1044, 0.1041, 0.1041,
        0.1041, 0.1038, 0.1038, 0.1034, 0.1033, 0.1028, 0.1028, 0.1023, 0.1022,
        0.1022, 0.1021, 0.1019, 0.1017, 0.1015, 0.1015, 0.1011, 0.1009, 0.1007,
        0.1007, 0.0996, 0.0996, 0.0993, 0.0990, 0.0990, 0.0989, 0.0988, 0.0988,
        0.0987, 0.0983, 0.0982, 0.0978, 0.0978, 0.0973, 0.0972, 0.0971, 0.0969,
        0.0968, 0.0965, 0.0963, 0.0958, 0.0958, 0.0958, 0.0957, 0.0957, 0.0955,
        0.0950, 0.0947, 0.0946, 0.0942, 0.0940, 0.0940, 0.0938, 0.0935, 0.0933,
        0.0931, 0.0930, 0.0929, 0.0929, 0.0928, 0.0928, 0.0928, 0.0927, 0.0926,
        0.0925, 0.0923, 0.0923, 0.0923, 0.0922, 0.0919, 0.0919, 0.0919, 0.0917,
        0.0916, 0.0913, 0.0912, 0.0911, 0.0908, 0.0907, 0.0907, 0.0906, 0.0906,
        0.0905, 0.0905, 0.0904, 0.0902, 0.0902, 0.0901, 0.0901, 0.0897, 0.0896,
        0.0895, 0.0895, 0.0894, 0.0894, 0.0892, 0.0892, 0.0889, 0.0884, 0.0883,
        0.0882, 0.0881, 0.0879, 0.0878, 0.0878, 0.0877, 0.0877, 0.0875, 0.0873,
        0.0870, 0.0868, 0.0867, 0.0867, 0.0865, 0.0862, 0.0859, 0.0856, 0.0856,
        0.0856, 0.0856, 0.0852, 0.0852, 0.0852, 0.0849, 0.0847, 0.0843, 0.0839,
        0.0836, 0.0834, 0.0833, 0.0831, 0.0830, 0.0824, 0.0824, 0.0822, 0.0818,
        0.0818, 0.0815, 0.0815, 0.0813, 0.0811, 0.0808, 0.0808, 0.0807, 0.0807,
        0.0807, 0.0806, 0.0802, 0.0800, 0.0800, 0.0798, 0.0796, 0.0793, 0.0792,
        0.0791, 0.0790, 0.0790, 0.0790, 0.0789, 0.0787, 0.0786, 0.0784, 0.0783,
        0.0779, 0.0778, 0.0778, 0.0777, 0.0775, 0.0774, 0.0774, 0.0770, 0.0768,
        0.0767, 0.0766, 0.0763, 0.0760, 0.0759, 0.0756, 0.0756, 0.0755, 0.0754,
        0.0752, 0.0751, 0.0749, 0.0749, 0.0740, 0.0738, 0.0736, 0.0736, 0.0733,
        0.0729, 0.0729, 0.0726, 0.0723, 0.0722, 0.0721, 0.0721, 0.0721, 0.0717,
        0.0715, 0.0714, 0.0714, 0.0713, 0.0713, 0.0712, 0.0707, 0.0704, 0.0694,
        0.0690, 0.0689, 0.0688, 0.0678, 0.0671], device='cuda:0')

'''


'''
(Pdb) ans[0].sum()
tensor(26, device='cuda:0')
(Pdb) ans[1].sum()
tensor(26, device='cuda:0')
(Pdb) ans[2].sum()
tensor(26, device='cuda:0')
(Pdb) label_matrix[0].sum()
tensor(25., device='cuda:0')
(Pdb) label_matrix[1].sum()
tensor(24., device='cuda:0')
(Pdb) label_matrix[2].sum()
tensor(23., device='cuda:0')

'''


'''

(Pdb) iou_matrix
tensor([[0.0000, 0.9618, 0.9262,  ..., 0.5556, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.9157,  ..., 0.5608, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.5495, 0.0000, 0.0000],
        ...,
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0082],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],
       device='cuda:0')
(Pdb) label_matrix
tensor([[0., 1., 1.,  ..., 0., 0., 0.],
        [0., 0., 1.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0')
(Pdb) iou_matrix * label_matrix
tensor([[0.0000, 0.9618, 0.9262,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.9157,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        ...,
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],
       device='cuda:0')

'''




'''
(Pdb) decay_iou
tensor([[0.0000, 0.9618, 0.9262,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.9157,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        ...,
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],
       device='cuda:0')
(Pdb) compensate_iou
tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.9618, 0.9618, 0.9618,  ..., 0.9618, 0.9618, 0.9618],
        [0.9262, 0.9262, 0.9262,  ..., 0.9262, 0.9262, 0.9262],
        ...,
        [0.1814, 0.1814, 0.1814,  ..., 0.1814, 0.1814, 0.1814],
        [0.5750, 0.5750, 0.5750,  ..., 0.5750, 0.5750, 0.5750],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],
       device='cuda:0')

'''

'''
(Pdb) decay_matrix 
tensor([[1.0000, 0.1572, 0.1798,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 0.1869,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')
(Pdb) compensate_matrix 
tensor([[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.1572, 0.1572, 0.1572,  ..., 0.1572, 0.1572, 0.1572],
        [0.1798, 0.1798, 0.1798,  ..., 0.1798, 0.1798, 0.1798],
        ...,
        [0.9363, 0.9363, 0.9363,  ..., 0.9363, 0.9363, 0.9363],
        [0.5162, 0.5162, 0.5162,  ..., 0.5162, 0.5162, 0.5162],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')
(Pdb) decay_coefficient
tensor([1.0000, 0.1572, 0.1798, 1.0000, 0.1783, 0.1802, 1.0000, 1.0000, 1.0000,
        0.1689, 1.0000, 1.0000, 0.1855, 0.2022, 0.2764, 0.1547, 0.2404, 0.2013,
        0.2406, 1.0000, 0.1905, 1.0000, 0.1897, 1.0000, 1.0000, 1.0000, 0.2510,
        0.2585, 0.1958, 0.1687, 0.1882, 1.0000, 0.1527, 0.1832, 0.1805, 0.2230,
        0.1968, 0.2195, 0.1718, 1.0000, 1.0000, 1.0000, 0.2053, 0.4830, 1.0000,
        0.2468, 0.2634, 0.1930, 0.1613, 0.1854, 0.2688, 0.1672, 1.0000, 0.1676,
        0.2001, 0.1778, 0.9999, 0.1972, 0.4646, 0.9999, 1.0000, 1.0000, 1.0000,
        0.2382, 0.6206, 0.9990, 0.2464, 0.2426, 0.2175, 0.2442, 1.0000, 1.0000,
        0.3181, 0.2265, 0.1825, 0.1945, 0.5600, 0.1545, 0.3049, 1.0000, 0.1934,
        0.1655, 0.2479, 0.4379, 0.4354, 0.2128, 0.3090, 0.3096, 0.2048, 0.2055,
        0.3932, 0.9999, 0.3649, 0.6227, 0.3221, 0.4436, 1.0000, 0.2139, 1.0000,
        1.0000, 0.2798, 0.2927, 1.0000, 0.4676, 1.0000, 0.3377, 0.5426, 0.1867,
        0.1898, 1.0000, 1.0000, 0.1810, 1.0000, 1.0000, 0.2745, 0.2394, 0.1964,
        1.0000, 0.8990, 0.3527, 0.2131, 0.9999, 1.0000, 0.2128, 1.0000, 0.3925,
        1.0000, 1.0000, 0.1672, 0.9756, 0.3158, 1.0000, 0.2516, 1.0000, 0.2073,
        0.3283, 1.0000, 0.1679, 1.0000, 0.2823, 0.1908, 0.2075, 0.6698, 0.2049,
        1.0000, 0.1524, 0.4344, 0.2468, 0.1950, 0.5180, 0.3053, 1.0000, 0.2947,
        0.1668, 0.6135, 1.0000, 0.3543, 0.1941, 0.2391, 1.0000, 0.1622, 0.6628,
        0.1915, 0.1915, 0.1984, 0.4174, 1.0000, 1.0000, 0.1648, 0.2391, 0.3841,
        1.0000, 0.2000, 0.3462, 0.2100, 0.1955, 0.1864, 0.2139, 0.9999, 0.3096,
        0.2372, 0.3589, 0.2454, 0.2163, 0.1353, 0.5128, 0.1966, 0.1813, 1.0000,
        1.0000, 0.6258, 0.5399, 0.4436, 0.1565, 0.1968, 0.7929, 1.0000, 0.2052,
        0.8349, 1.0000, 1.0000, 1.0000, 0.9999, 0.1713, 0.2855, 1.0000, 0.9990,
        0.1734, 1.0000, 0.6206, 1.0000, 0.6314, 1.0000, 0.3608, 1.0000, 1.0000,
        0.1814, 0.3864, 0.9998, 0.1868, 1.0000, 0.6018, 0.2664, 1.0000, 1.0000,
        0.2204, 1.0000, 1.0000, 0.1655, 0.5560, 1.0000, 0.6018, 0.7625, 0.2531,
        0.2891, 1.0000, 0.3947, 0.2269, 0.5983, 1.0000, 0.2240, 0.4622, 0.1954,
        0.1679, 0.2926, 0.9999, 1.0000, 0.8288, 0.3631, 0.2429, 0.2077, 0.5807,
        0.2477, 1.0000, 0.1835, 0.2947, 1.0000, 1.0000, 0.2103, 1.0000, 0.1724,
        0.2000, 0.2840, 0.4740, 1.0000, 0.2410, 0.7940, 0.4111, 0.1751, 0.9077,
        1.0000, 0.3905, 0.1929, 0.3705, 0.2016, 0.4020, 0.1601, 0.2947, 0.4558,
        0.1863, 0.9965, 1.0000, 1.0000, 0.2813, 0.3324, 0.2163, 0.6408, 0.1911,
        1.0000, 0.2366, 1.0000, 0.2331, 1.0000, 1.0000, 0.2059, 0.9982, 1.0000,
        0.9972, 0.4041, 0.3809, 0.3376, 0.9527, 0.9110, 0.6471, 1.0000, 0.2990,
        1.0000, 0.9895, 0.4757, 0.5870, 1.0000, 0.2757, 1.0000, 0.7463, 0.6971,
        0.1898, 0.2349, 1.0000, 0.5335, 0.1353, 0.9996, 0.3981, 1.0000, 0.9676,
        1.0000, 0.1954, 0.2259, 0.9738, 0.2285, 0.2074, 1.0000, 0.4963, 0.2780,
        1.0000, 0.4111, 0.4801, 0.2780, 0.6819, 0.9255, 0.2259, 0.2002, 0.8939,
        1.0000, 0.4634, 1.0000, 1.0000, 0.7577, 0.2078, 0.4951, 0.2280, 1.0000,
        1.0000, 1.0000, 0.4014, 0.4548, 0.2429, 0.2128, 1.0000, 0.3658, 1.0000,
        0.9756, 0.6232, 0.9124, 0.3601, 0.2744, 0.2895, 0.2001, 0.4525, 1.0000,
        0.1758, 0.4439, 0.2022, 0.1865, 0.1894, 0.2269, 0.1781, 1.0000, 0.3609,
        0.1929, 0.2681, 0.7913, 0.9999, 1.0000, 1.0000, 0.3181, 0.2103, 0.1950,
        1.0000, 0.1819, 0.2036, 0.5941, 0.5819, 0.2022, 1.0000, 0.4777, 0.1774,
        0.6963, 1.0000, 1.0000, 0.5600, 0.2989, 0.1664, 0.4174, 0.9394, 0.5335,
        0.9756, 0.7929, 0.2073, 0.2270, 0.2930, 0.5621, 0.3410, 0.2926, 1.0000,
        1.0000, 0.8152, 0.5078, 0.1772, 0.7817, 1.0000, 0.2154, 0.3641, 0.2681,
        0.1963, 0.1870, 0.5180, 0.2982, 0.3277, 0.9999, 0.5600, 0.6903, 1.0000,
        0.2754, 0.1911, 0.2704, 0.1668, 0.7497, 0.5600, 0.2895, 0.2049, 0.9588,
        0.3695, 0.2894, 1.0000, 0.4810, 0.2742, 0.6411, 1.0000, 0.7090, 0.3589,
        0.6018, 0.7151, 0.3002, 0.9344, 0.2259, 0.3921, 1.0000, 1.0000, 0.6018,
        0.3710, 0.9543, 0.4373, 0.9361, 0.3053, 0.4208, 1.0000, 0.4777, 0.5065,
        0.9999, 0.8339, 1.0000, 0.7549, 0.6350, 0.3088, 0.2617, 0.3994, 0.6034,
        0.5947, 0.7377, 0.9990, 0.7771, 0.3594, 0.3155, 1.0000, 0.4505, 1.0000,
        0.8445, 0.9756, 0.9810, 0.8240, 0.3274, 0.3215, 0.2753, 0.9701, 0.5041,
        0.5205, 0.5485, 1.0000, 0.5994, 1.0000], device='cuda:0')

'''

4. SOLO/configs/solo/solo_r50_fpn_8gpu_1x.py

# model settings
model = dict(
    type='SOLO',
    pretrained='torchvision://resnet50',
    backbone=dict(
        type='ResNet',
        depth=50,
        num_stages=4,
        out_indices=(0, 1, 2, 3), # C2, C3, C4, C5
        frozen_stages=1,
        style='pytorch'),
    neck=dict(
        type='FPN',
        in_channels=[256, 512, 1024, 2048],
        out_channels=256,
        start_level=0,
        num_outs=5),
    bbox_head=dict(
        type='SOLOHead', # SOLOHead对应同名 SOLOHead.py, 因此可以修改type对应相应自己修改的SOLOHead_xx.py
        num_classes=81,
        in_channels=256,
        stacked_convs=7,
        seg_feat_channels=256,
        strides=[8, 8, 16, 32, 32],
        scale_ranges=((1, 96), (48, 192), (96, 384), (192, 768), (384, 2048)),
        sigma=0.2,
        num_grids=[40, 36, 24, 16, 12],
        cate_down_pos=0,
        with_deform=False,
        loss_ins=dict(
            type='DiceLoss',
            use_sigmoid=True,
            loss_weight=3.0),
        loss_cate=dict(
            type='FocalLoss',
            use_sigmoid=True,
            gamma=2.0,
            alpha=0.25,
            loss_weight=1.0),
    ))
# training and testing settings
train_cfg = dict()
test_cfg = dict(
    nms_pre=500,
    score_thr=0.1,
    mask_thr=0.5,
    update_thr=0.05,
    kernel='gaussian',  # gaussian/linear
    sigma=2.0,
    max_per_img=100)
# dataset settings
dataset_type = 'CocoDataset'
data_root = 'data/coco2017/'
img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
    dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
    dict(type='RandomFlip', flip_ratio=0.5),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='Pad', size_divisor=32),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
]
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='MultiScaleFlipAug',
        img_scale=(1333, 800),
        flip=False,
        transforms=[
            dict(type='Resize', keep_ratio=True),
            dict(type='RandomFlip'),
            dict(type='Normalize', **img_norm_cfg),
            dict(type='Pad', size_divisor=32),
            dict(type='ImageToTensor', keys=['img']),
            dict(type='Collect', keys=['img']),
        ])
]
data = dict(
    imgs_per_gpu=2,
    workers_per_gpu=2,
    train=dict(
        type=dataset_type,
        ann_file=data_root + 'annotations/instances_train2017.json',
        img_prefix=data_root + 'train2017/',
        pipeline=train_pipeline),
    val=dict(
        type=dataset_type,
        ann_file=data_root + 'annotations/instances_val2017.json',
        img_prefix=data_root + 'val2017/',
        pipeline=test_pipeline),
    test=dict(
        type=dataset_type,
        ann_file=data_root + 'annotations/instances_val2017.json',
        img_prefix=data_root + 'val2017/',
        pipeline=test_pipeline))
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
# learning policy
lr_config = dict(
    policy='step',
    warmup='linear',
    warmup_iters=500,
    warmup_ratio=1.0 / 3,
    step=[9, 11])
#save
checkpoint_config = dict(interval=1) # log文件里面
# yapf:disable
log_config = dict(
    interval=1, # 每interval次iter打印一次
    hooks=[
        dict(type='TextLoggerHook'),
        # dict(type='TensorboardLoggerHook')
    ])
# yapf:enable
# runtime settings
total_epochs = 12
device_ids = range(8)
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/solo_release_r50_fpn_8gpu_1x'
load_from = None
resume_from = None
workflow = [('train', 1)]

5. SOLO/mmdet/models/anchor_heads/_ init _.py

from .anchor_head import AnchorHead
from .atss_head import ATSSHead
from .fcos_head import FCOSHead
from .fovea_head import FoveaHead
from .free_anchor_retina_head import FreeAnchorRetinaHead
from .ga_retina_head import GARetinaHead
from .ga_rpn_head import GARPNHead
from .guided_anchor_head import FeatureAdaption, GuidedAnchorHead
from .reppoints_head import RepPointsHead
from .retina_head import RetinaHead
from .retina_sepbn_head import RetinaSepBNHead
from .rpn_head import RPNHead
from .ssd_head import SSDHead
from .solo_head import SOLOHead
from .solov2_head import SOLOv2Head
from .solov2_light_head import SOLOv2LightHead
from .decoupled_solo_head import DecoupledSOLOHead
from .decoupled_solo_light_head import DecoupledSOLOLightHead
from .solo_head_xx improt SOLOHead_xx # 注册文件名
__all__ = [
    'AnchorHead', 'GuidedAnchorHead', 'FeatureAdaption', 'RPNHead',
    'GARPNHead', 'RetinaHead', 'RetinaSepBNHead', 'GARetinaHead', 'SSDHead',
    'FCOSHead', 'RepPointsHead', 'FoveaHead', 'FreeAnchorRetinaHead',
    'ATSSHead', 'SOLOHead','SOLOv2Head', 'SOLOv2LightHead', 'DecoupledSOLOHead', 'DecoupledSOLOLightHead'
'SOLOHead_xx'
]


# 然后把 SOLOHead_xx.py实现以下, 对应的super函数更改下,就可以保留官方文件的同时,进行小更改了。

你可能感兴趣的:(笔记)