如果对你帮助的话,希望给我个赞~
论文原话:
起初看完后,并不是很理解。但我认为看完代码后,是我对于正样本选取的一个新的领悟与体会,如何与全卷积网络结合,很好的一个实践与理论相结合,通过代码来反思与加深与论文思想的理解。
其中FCOS、polarmask也是采用了一种中心采样的结构。这些文中都有提到,全卷积网络可以采用gt_box内的所有点为positive example,但是这样子计算量肯定很大,并且其他靠近bbox的点回归的效果肯定是很差的,因此围绕质心(solo以质心为中心)进行正样本采样是非常合理的。
引用一篇特别棒的转载博客里的图片:博客链接
如图所示,在原图中,蓝色框表示图片等分的格子,这里设置分为5X5个格子。绿色框为目标物体的gt box,黄色框表示缩小到0.2倍数的box,红色框表示负责预测该实例的格子。
下方黑白图为mask分支的target可视化,为了便于显示,这里对不同通道进行了拼接。左边的第一幅图,图中有一个实例,其gt box缩小到0.2倍占据两个格子,因此这两个格子负责预测该实例。
下方的mask分支,只有两个FPN的输出匹配到了该实例,因此在红色格子对应的channel负责预测该实例的mask。第二幅图,图中分布大小不同的实例,可见在FPN输出的mask分支上,从小到大负责不同尺度的实例。
下图是原图的,也很清晰的表达了FPN如何根据不同的gt_areas 以及 实例所处在的网格位置放入对于的channel上预测。首先根据gt_areas将不同的gt放入不同的FPN层。然后再相同层中,如果有多个实例,就会根据设置好的网格,按照某个GT的质心的0.2 * gt_areas(这时候的gt_areas缩小到对应的FPN层输出的feature map的大小)的大小缩放。
single_stage_ins中实现了backbone(resnet),neck(fpn)以及head(solo_head)的连接以及forward。
import torch.nn as nn
from mmdet.core import bbox2result
from .. import builder
from ..registry import DETECTORS
from .base import BaseDetector
import pdb
@DETECTORS.register_module
class SingleStageInsDetector(BaseDetector):
def __init__(self,
backbone,
neck=None,
bbox_head=None,
mask_feat_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None):
super(SingleStageInsDetector, self).__init__()
self.backbone = builder.build_backbone(backbone) # 1.build_backbone --> resnet
if neck is not None:
self.neck = builder.build_neck(neck) # 2.build_neck --> fpn
if mask_feat_head is not None:
self.mask_feat_head = builder.build_head(mask_feat_head)
#pdb.set_trace()
self.bbox_head = builder.build_head(bbox_head) # 3.build_head --> solo head
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.init_weights(pretrained=pretrained) # 'torchvision://resnet50'
def init_weights(self, pretrained=None):
super(SingleStageInsDetector, self).init_weights(pretrained)
self.backbone.init_weights(pretrained=pretrained)
if self.with_neck:
if isinstance(self.neck, nn.Sequential):
for m in self.neck:
m.init_weights()
else:
self.neck.init_weights()
if self.with_mask_feat_head:
if isinstance(self.mask_feat_head, nn.Sequential):
for m in self.mask_feat_head:
m.init_weights()
else:
self.mask_feat_head.init_weights()
#pdb.set_trace()
self.bbox_head.init_weights()
# forward提取 backbone 和 neck的特征
def extract_feat(self, img):
x = self.backbone(img) # resnet forward
if self.with_neck:
x = self.neck(x) # fpn forward
return x
'''
after neck feature map:x
(Pdb) x[0].shape
torch.Size([2, 256, 200, 304])
(Pdb) x[1].shape
torch.Size([2, 256, 100, 152])
(Pdb) x[2].shape
torch.Size([2, 256, 50, 76])
(Pdb) x[3].shape
torch.Size([2, 256, 25, 38])
(Pdb) x[4].shape
torch.Size([2, 256, 13, 19])
'''
def forward_dummy(self, img):
x = self.extract_feat(img)
outs = self.bbox_head(x)
return outs
def forward_train(self,
img,
img_metas,
gt_bboxes,
gt_labels,
gt_bboxes_ignore=None,
gt_masks=None):
# 1. img
# eg. [torch.Size([2, 3, 800, 1216]) represents the max size of h and w in the img batch_size
# 2. img_metas
# eg.
#[
# {'filename': 'data/coco2017/train2017/000000559012.jpg',
# 'ori_shape': (508, 640, 3),
# 'img_shape': (800, 1008, 3),
# 'pad_shape': (800, 1216, 3),
# 'scale_factor': 1.8823529411764706,
# 'flip': False,
# 'img_norm_cfg': {'mean': array([123.675, 116.28 , 103.53 ], dtype=float32),
# 'std': array([58.395, 57.12 , 57.375], dtype=float32),
# 'to_rgb': True}},
#
# {'filename': 'data/coco2017/train2017/000000532426.jpg',
# 'ori_shape': (333, 640, 3), 'img_shape': (753, 1333, 3),
# 'pad_shape': (800, 1088, 3), 'scale_factor': 2.4024024024024024,
# 'flip': False,
# 'img_norm_cfg': {'mean': array([123.675, 116.28 , 103.53 ], dtype=float32),
# 'std': array([58.395, 57.12 , 57.375], dtype=float32),
# 'to_rgb': True}}
# ]
# 3. gt_bboxes
# eg.
# gt_bboxes represents 'bbox' of coco datasets
# type(gt_bboxes) --> list
# len(gt_bboxes) --> batch_size(ie. img per gpu) eg. 2
# type(gt_bboxes[idx]) --> tensor
# gt_bboxes[idx].size() --> [instances, 4] '4' represents [x1, y1, x2, y2]
# [6, 4] [9, 4]
# 4. gt_labels
# eg.
# gt_labels represents 'category_id' of coco datasets
# type(gt_labels) --> list
# len(gt_labels) --> batch_size(img per gpu) eg. 2
# type(gt_labels[idx]) --> tensor
# gt_labels[idx].size() --> instances eg. how many categories gt_bboxes[7 or 13, 4] --> gt_labels[7 or 13]
# 6 , 9
# 5. gt_masks
# eg.
# type(gt_masks) --> list
# len(gt_masks) --> batch_size(img per gpu) eg. 2
# type(gt_bboxes[idx]) --> list
# (6, 800, 1216) (9, 800, 1088) represents (instances of pad_shape, w, h)
x = self.extract_feat(img) # forward backbone and fpn
# solo_head forward
outs = self.bbox_head(x) # forward solo_head
# outs eg. 各五层
# 1.ins_pred:
# outs[0][0].size() --> torch.Size([2, 1600, 200, 336])
# outs[0][1].size() --> torch.Size([2, 1296, 200, 336])
# outs[0][2].size() --> torch.Size([2, 1024, 100, 168])
# outs[0][3].size() --> torch.Size([2, 256, 50, 84])
# outs[0][4].size() --> torch.Size([2, 144, 50, 84])
#
# 2.cate_pred:
# outs[1][0].size() --> torch.Size([2, 80, 40, 40])
# outs[1][1].size() --> torch.Size([2, 80, 36, 36])
# outs[1][2].size() --> torch.Size([2, 80, 24, 24])
# outs[1][3].size() --> torch.Size([2, 80, 24, 24])
# outs[1][4].size() --> torch.Size([2, 80, 12, 12])
#
if self.with_mask_feat_head:
mask_feat_pred = self.mask_feat_head(
x[self.mask_feat_head.
start_level:self.mask_feat_head.end_level + 1])
loss_inputs = outs + (mask_feat_pred, gt_bboxes, gt_labels, gt_masks, img_metas, self.train_cfg)
else:
loss_inputs = outs + (gt_bboxes, gt_labels, gt_masks, img_metas, self.train_cfg)
# tuple len(outs) = 2 len(loss_inputs) = 7
# compute SOLO loss
losses = self.bbox_head.loss(
*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
return losses
def simple_test(self, img, img_meta, rescale=False):
x = self.extract_feat(img)
outs = self.bbox_head(x, eval=True) # when testing , eval = True rescale=True
if self.with_mask_feat_head: # False
mask_feat_pred = self.mask_feat_head(
x[self.mask_feat_head.
start_level:self.mask_feat_head.end_level + 1])
seg_inputs = outs + (mask_feat_pred, img_meta, self.test_cfg, rescale)
else:
seg_inputs = outs + (img_meta, self.test_cfg, rescale) # forward backbone fpn and solo_head
seg_result = self.bbox_head.get_seg(*seg_inputs) # get_seg()
return seg_result
def aug_test(self, imgs, img_metas, rescale=False):
raise NotImplementedError
注:一次输入的数据打印在最下方。
import mmcv
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import normal_init
from mmdet.ops import DeformConv, roi_align
from mmdet.core import multi_apply, bbox2roi, matrix_nms
from ..builder import build_loss
from ..registry import HEADS
from ..utils import bias_init_with_prob, ConvModule
import pdb
import math
INF = 1e8
from scipy import ndimage
def points_nms(heat, kernel=2):
# kernel must be 2
hmax = nn.functional.max_pool2d(
heat, (kernel, kernel), stride=1, padding=1)
keep = (hmax[:, :, :-1, :-1] == heat).float() # 在tensor相等(a==b) 是返回一个bool类型的矩阵,T or F; 如果加上float(),则返回1 or 0。 可以使用(hmax[:, :, :-1, :-1] == heat).bool()修正回去。
return heat * keep # 通过max_pool2d操作后, 返回一个 2*2 中只有一个值非0
def dice_loss(input, target):
input = input.contiguous().view(input.size()[0], -1) # [instances , w * h]
target = target.contiguous().view(target.size()[0], -1).float() # [instances , w * h]
a = torch.sum(input * target, 1)
b = torch.sum(input * input, 1) + 0.001
c = torch.sum(target * target, 1) + 0.001
e = (2 * a) / (b + c)
print('dice_loss:', 1-e)
#pdb.set_trace() # [24]
return 1-e
@HEADS.register_module
class SOLOHead(nn.Module):
def __init__(self,
num_classes,
in_channels,
seg_feat_channels=256,
stacked_convs=4,
strides=(4, 8, 16, 32, 64),
base_edge_list=(16, 32, 64, 128, 256),
scale_ranges=((8, 32), (16, 64), (32, 128), (64, 256), (128, 512)),
sigma=0.4,
num_grids=None,
cate_down_pos=0,
with_deform=False,
loss_ins=None,
loss_cate=None,
conv_cfg=None,
norm_cfg=None):
super(SOLOHead, self).__init__()
self.num_classes = num_classes # 81
self.seg_num_grids = num_grids # [40, 36, 24, 16, 12]
self.cate_out_channels = self.num_classes - 1 # 80
self.in_channels = in_channels #256
self.seg_feat_channels = seg_feat_channels # 256
self.stacked_convs = stacked_convs # 7
self.strides = strides # [8, 8, 16, 32, 32]
self.sigma = sigma # 0.2
self.cate_down_pos = cate_down_pos # 0
self.base_edge_list = base_edge_list # (16, 32, 64, 128, 256)
self.scale_ranges = scale_ranges # ((1, 96), (48, 192), (96, 384), (192, 768), (384, 2048))
self.with_deform = with_deform #False
#loss_cate: {'type': 'FocalLoss', 'use_sigmoid': True, 'gamma': 2.0, 'alpha': 0.25, 'loss_weight': 1.0}
self.loss_cate = build_loss(loss_cate) # FocalLoss()
self.ins_loss_weight = loss_ins['loss_weight'] # 3
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self._init_layers()
#pdb.set_trace()
# init ins_convs, cate_convs, solo_ins_list, solo_cate
def _init_layers(self):
norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)
self.ins_convs = nn.ModuleList()
self.cate_convs = nn.ModuleList()
for i in range(self.stacked_convs):
# coorconv要加x y 2维
chn = self.in_channels + 2 if i == 0 else self.seg_feat_channels
self.ins_convs.append(
ConvModule(
chn,
self.seg_feat_channels,
3,
stride=1,
padding=1,
norm_cfg=norm_cfg,
bias=norm_cfg is None))
chn = self.in_channels if i == 0 else self.seg_feat_channels
self.cate_convs.append(
ConvModule(
chn,
self.seg_feat_channels,
3,
stride=1,
padding=1,
norm_cfg=norm_cfg,
bias=norm_cfg is None))
self.solo_ins_list = nn.ModuleList()
# 修改 [h, w, 256] --> [h, w, min(h/s, w/s)^2]
self.solo_sa_module = nn.ModuleList()
# [h, w , 256] ---> [h, w, s*s]
# 修改
'''
for seg_num_grid in self.seg_num_grids:
self.solo_ins_list.append(
nn.Conv2d(
self.seg_feat_channels, seg_num_grid**2, 1))
'''
for seg_num_grid in self.seg_num_grids:
self.solo_ins_list.append(
nn.Conv2d(
seg_num_grid**2, seg_num_grid**2, 1))
# [h, w, 256] --> [h, w, s]
self.solo_cate = nn.Conv2d(
self.seg_feat_channels, self.cate_out_channels, 3, padding=1)
#pdb.set_trace()
#初始化权重
def init_weights(self):
for m in self.ins_convs:
normal_init(m.conv, std=0.01)
for m in self.cate_convs:
normal_init(m.conv, std=0.01)
bias_ins = bias_init_with_prob(0.01) # bias_ins
for m in self.solo_ins_list:
normal_init(m, std=0.01, bias=bias_ins)
bias_cate = bias_init_with_prob(0.01) # -4.59511985013459
normal_init(self.solo_cate, std=0.01, bias=bias_cate)
#pdb.set_trace()
def forward(self, feats, eval=False):
new_feats = self.split_feats(feats) # 先对feats[0] 以及 feats[4]进行插值 进行缩放
# feats:
# (Pdb) feats[0].size()
# torch.Size([2, 256, 200, 304]) ---> new_feats[0] [2, 256, 100, 152] 缩小
# (Pdb) feats[1].size()
# torch.Size([2, 256, 100, 152])
# (Pdb) feats[3].size()
# torch.Size([2, 256, 25, 38])
# (Pdb) feats[4].size()
# torch.Size([2, 256, 13, 19]) ---> new_feats[4] [2, 256, 25, 38] 放大
featmap_sizes = [featmap.size()[-2:] for featmap in new_feats] # h, w
# featmap_sizes = [
# torch.Size([100, 152]),
# torch.Size([100, 152]),
# torch.Size([50, 76]),
# torch.Size([25, 38]),
# torch.Size([25, 38]
# )]
upsampled_size = (featmap_sizes[0][0] * 2, featmap_sizes[0][1] * 2) # upsampled_size表示原来的最大的fpn层上的 feature map的siz: eg. [320, 200]
ins_pred, cate_pred = multi_apply(self.forward_single, new_feats,
list(range(len(self.seg_num_grids))),
eval=eval, upsampled_size=upsampled_size)
return ins_pred, cate_pred
def split_feats(self, feats):
#len(feats) = 5 (tuple)
#pdb.set_trace()
# 缩小的插值 scale_factor=0.5
# {'P2': 8, 'P3': 8, 'P4': 16, 'P5': 32, 'P6': 32} ---> 可以推出这次输入的图片 [, ] --> fpn缩放
return (F.interpolate(feats[0], scale_factor=0.5, mode='bilinear'), # torch.Size([2, 256, 160, 100])
feats[1], # torch.Size([2, 256, 160, 100])
feats[2], # torch.Size([2, 256, 80, 50])
feats[3], # torch.Size([2, 256, 40, 25])
F.interpolate(feats[4], size=feats[3].shape[-2:], mode='bilinear'))# torch.Size([2, 256, 40, 25])
def forward_single(self, x, idx, eval=False, upsampled_size=None):
# 执行5次 对应FPN的5层 分别构造head
# x = torch.Size([2, 256, 160, 100])
# idx = 0
# upsampled_size = (320, 200)
#pdb.set_trace()
ins_feat = x
device = ins_feat.device
print(device)
cate_feat = x
# ins branch
# concat CoordConv
x_range = torch.linspace(-1, 1, ins_feat.shape[-1], device=ins_feat.device)
y_range = torch.linspace(-1, 1, ins_feat.shape[-2], device=ins_feat.device)
y, x = torch.meshgrid(y_range, x_range)
y = y.expand([ins_feat.shape[0], 1, -1, -1]) # N, 1, h/strides, w/strides
x = x.expand([ins_feat.shape[0], 1, -1, -1]) # N, 1, h/strides, w/strides
coord_feat = torch.cat([x, y], 1) # [N, 2, w, h]
# channels: 256 --> 258 [N, 256, w, h] --> [N, 258, w, h]
ins_feat = torch.cat([ins_feat, coord_feat], 1)
# in_convs 7个conv forward
for i, ins_layer in enumerate(self.ins_convs):
ins_feat = ins_layer(ins_feat)
#pdb.set_trace()
# 第一次修改
sa_feat = []
# [152, 100] --> [160, 120]
sa_h = math.ceil(ins_feat.size()[2] / self.seg_num_grids[idx])
#if (ins_feat.size()[2] % self.seg_num_grids[idx]) != 0:
# sa_h = sa_h + 1
sa_w = math.ceil(ins_feat.size()[3] / self.seg_num_grids[idx])
#if (ins_feat.size()[3] % self.seg_num_grids[idx]) != 0:
#sa_w = sa_w + 1
# interpolate
# 插值后: ins_feat [2, 256, 160, 120]
ins_feat = F.interpolate(ins_feat, size=(self.seg_num_grids[idx] * sa_h, self.seg_num_grids[idx] * sa_w), mode='bilinear')
# ins_sa_feat [2, 40*40, 160, 120]
#ins_sa_feat = torch.zeros(ins_feat.size()[0], self.seg_num_grids[idx] * self.seg_num_grids[idx], ins_feat.size()[2], ins_feat.size()[3],device=device)
seg_num_grids = self.seg_num_grids[idx]
abc = []
for i in range(seg_num_grids):
for j in range(seg_num_grids):
weight = ins_feat[:, :, i * sa_h : (i + 1) * sa_h, j * sa_w : (j + 1) * sa_w].repeat(1, 1, seg_num_grids, seg_num_grids)
abc.append((weight * ins_feat).sum(1))
ins_pred = torch.stack(abc, dim=1)
#print(ins_pred.shape)
'''
基于boss方法的改进,此部分可以直接跳过~
# 第一次修改
速度太慢
for i in range(seg_num_grids * seg_num_grids):
grid_in_row = i % seg_num_grids
row = i // seg_num_grids
sa = ins_feat[:, :, row*sa_h : row*sa_h + sa_h, grid_in_row*sa_w : grid_in_row*sa_w + sa_w].cuda()
for j in range(seg_num_grids):
for k in range(seg_num_grids):
ins_sa_feat[:, i, j*sa_h : j*sa_h + sa_h, k*sa_w : k*sa_w + sa_w] = (sa * ins_feat[:, :, j*sa_h : j*sa_h + sa_h, k*sa_w : k*sa_w + sa_w]).sum(dim = 1)
ins_sa_feat = ins_sa_feat.cuda()
'''
'''
# 第二次修改
# --------------------------------------------------------------------------------------------------------------------#
# 1. 分成 sa_h * sa_w 个 seg_num_grids * seg_num_grids的mask特征图
# --------------------------------------------------------------------------------------------------------------------#
mask_list =[]
for i in range(sa_h):
for j in range(sa_w):
mask_list.append(ins_feat[:, :, i::sa_h, j::sa_w]) # mask_list[i].size() = [n, 256, seg_num_grids, seg_num_grids]
#print(len(mask_list)) # len = sa_h * sa_w
#pdb.set_trace()
# --------------------------------------------------------------------------------------------------------------------#
# 2. sa_h * sa_w 的self-attention
# --------------------------------------------------------------------------------------------------------------------#
all_sa_feat = []
per_sa_feat = []
for i in range(sa_h * sa_w):
ori_n = mask_list[i].size()[0]
ori_c = mask_list[i].size()[1]
n_c_hw = mask_list[i].reshape(ori_n, ori_c, -1) # [n, c, hw]
#tmp_n_c_hw = n_c_hw.clone()
n_c_hw_T = n_c_hw.permute(0, 2, 1) #[n, hw, c]
tmp = torch.matmul(n_c_hw_T, n_c_hw) # [n, hw, c] x [n, c, hw] == [n, hw, hw]
stack_sa_feat = tmp.reshape(ori_n, seg_num_grids * seg_num_grids, seg_num_grids, -1) # [n, s*s, s, s]
all_sa_feat.append(stack_sa_feat)
# --------------------------------------------------------------------------------------------------------------------#
# 3. 将同一行的seg_num_grids个元素矩阵先拼接 eg: xxxxyyyyzzzzccccc --> xyzc xyzc xyzc
# --------------------------------------------------------------------------------------------------------------------#
cat_all_row_feat = []
for i in range(0, sa_w * sa_h, sa_w):
cat_row_feat = torch.cat([feat for feat in all_sa_feat[i : i + sa_w]], dim = 3)
cat_all_row_feat.append(cat_row_feat)
#print(len(cat_all_row_feat))
#pdb.set_trace()
# --------------------------------------------------------------------------------------------------------------------#
# 4. 先交换cat_all_row_feat中的每一列
# --------------------------------------------------------------------------------------------------------------------#
all_new_row_feat_list = [] #交换好后的4个tensor的新行 xyxy abab cdcd fgfg
for i in range(0, len(cat_all_row_feat)):
per_new_row_feat_list = [] # eg. xyxy or abab
for j in range(0, seg_num_grids):
per_row_feat = cat_all_row_feat[i][:, :, :, j::seg_num_grids] # Tensor
per_new_row_feat_list.append(per_row_feat)
all_new_row_feat_list.append(torch.cat(per_new_row_feat_list, dim = 3)) # 交换好后
#print('len(all_new_row_feat_list):', len(all_new_row_feat_list))
#pdb.set_trace()
# --------------------------------------------------------------------------------------------------------------------#
# 5. 在此基础上继续在列上拼接
# --------------------------------------------------------------------------------------------------------------------#
#for feat in all_new_row_feat_list:
#print(feat.size())
cat_all_col_feat = torch.cat([feat for feat in all_new_row_feat_list], dim = 2)
#print('cat_all_col_feat.size():', cat_all_col_feat.size())
#pdb.set_trace()
# --------------------------------------------------------------------------------------------------------------------#
# 6. 交换行
# --------------------------------------------------------------------------------------------------------------------#
per_new_col_feat_list = [] #交换好后的4个tensor的新行 xyxy abab cdcd fgfg
for i in range(0, seg_num_grids):
# eg. xyxy
# abab
per_col_feat = cat_all_col_feat[:, :, i::seg_num_grids, :] # Tensor
per_new_col_feat_list.append(per_col_feat)
all_new_col_feat = torch.cat(per_new_col_feat_list, dim = 2) # 交换好后
ins_sa_feat = all_new_col_feat.to(device)
#print('ins_sa_feat.size(): ', ins_sa_feat.size())
#print(ins_sa_feat)
#pdb.set_trace()
'''
# --------------------------------------------------------------------------------------------------------------------#
# 修改截止
# --------------------------------------------------------------------------------------------------------------------#
# w x h x 256 --> 2w x 2h x 256
#ins_feat = F.interpolate(ins_feat, scale_factor=2, mode='bilinear')
ins_pred = F.interpolate(ins_pred, scale_factor=2, mode='bilinear')
# eg. torch.Size([2, 1600 or 1296 or 576 or 256 or 144, 2H/strides, 2W/strides])
# 新的修改
ins_pred = self.solo_ins_list[idx](ins_pred) # [N, 256, 2w, 2h] --> [N, S*S, 2w, 2h] eg. torch.Size([2, 1600, 200, 304])
# cate branch
for i, cate_layer in enumerate(self.cate_convs):
if i == self.cate_down_pos: # when i == 0
seg_num_grid = self.seg_num_grids[idx] # [40, 36, 24, 16, 12]
cate_feat = F.interpolate(cate_feat, size=seg_num_grid, mode='bilinear') # 缩放
cate_feat = cate_layer(cate_feat)
# channels: 256 --> 80
cate_pred = self.solo_cate(cate_feat)
if eval:
ins_pred = F.interpolate(ins_pred.sigmoid(), size=upsampled_size, mode='bilinear') # 注意:把5个fpn层全部插值成同一个尺寸!根据upsampled_size, eval时放大到原图的1/4 eg. [1, 1600, 200, 304]
cate_pred = points_nms(cate_pred.sigmoid(), kernel=2).permute(0, 2, 3, 1) # [N, h, w, c] eg. [1, 40, 40, 80]
# 返回 分类和实例的最后一层结果。
return ins_pred, cate_pred
def loss(self,
ins_preds,
cate_preds,
gt_bbox_list,
gt_label_list,
gt_mask_list,
img_metas,
cfg,
gt_bboxes_ignore=None):
featmap_sizes = [featmap.size()[-2:] for featmap in
ins_preds]
ins_label_list, cate_label_list, ins_ind_label_list = multi_apply(
self.solo_target_single,
gt_bbox_list,
gt_label_list,
gt_mask_list,
featmap_sizes=featmap_sizes)
#test
ins_labels = []
temp_2 = []
#ins_labels_2 =[]
# 循环 5次
# ins_labels_level :
# eg. ins_labels_level[0].size() torch.Size([1296, 200, 272])
# ins_labels_level[1].size() torch.Size([1296, 200, 272])
for ins_labels_level, ins_ind_labels_level in zip(zip(*ins_label_list),zip(*ins_ind_label_list)):
temp = []
#pdb.set_trace()
for ins_labels_level_img, ins_ind_labels_level_img in zip(ins_labels_level, ins_ind_labels_level):
temp.append(ins_labels_level_img[ins_ind_labels_level_img, ...]) # [instances, 200, 304]
#pdb.set_trace()
temp_2 = torch.cat(temp, 0) # batch_size的每个图片的每一层
ins_labels.append(temp_2)
# ins
'''
# zip() 与 zip(*)相反
ins_labels = [torch.cat([ins_labels_level_img[ins_ind_labels_level_img, ...]
for ins_labels_level_img, ins_ind_labels_level_img in
zip(ins_labels_level, ins_ind_labels_level)], 0)
for ins_labels_level, ins_ind_labels_level in zip(zip(*ins_label_list), zip(*ins_ind_label_list))] # len(ins_label_list) = batchsize
'''
'''
temp_2 = []
for ins_preds_level, ins_ind_labels_level in zip(ins_preds, zip(*ins_ind_label_list)):
temp = []
for ins_preds_level_img, ins_ind_labels_level_img in zip(ins_preds_level, ins_ind_labels_level):
temp.append(ins_preds_level_img[ins_ind_labels_level_img, ...])
temp_2 = torch.cat(temp, 0)
ins_preds.append(temp_2)
pdb.set_trace()
'''
ins_preds = [torch.cat([ins_preds_level_img[ins_ind_labels_level_img, ...]
for ins_preds_level_img, ins_ind_labels_level_img in
zip(ins_preds_level, ins_ind_labels_level)], 0)
for ins_preds_level, ins_ind_labels_level in zip(ins_preds, zip(*ins_ind_label_list))]
#pdb.set_trace()
ins_ind_labels = []
temp_2 = []
for ins_ind_labels_level in zip(*ins_ind_label_list):
temp = []
for ins_ind_labels_level_img in ins_ind_labels_level:
temp.append(ins_ind_labels_level_img.flatten())
temp_2 = torch.cat(temp)
ins_ind_labels.append(temp_2)
#pdb.set_trace()
'''
ins_ind_labels = [
torch.cat([ins_ind_labels_level_img.flatten()
for ins_ind_labels_level_img in ins_ind_labels_level])
for ins_ind_labels_level in zip(*ins_ind_label_list)
]
'''
flatten_ins_ind_labels = torch.cat(ins_ind_labels) # 3872 * batch_size
num_ins = flatten_ins_ind_labels.sum() # 计算有多少正样本 相当于把元素是True的加起来
#pdb.set_trace()
# dice loss
loss_ins = []
# 对于ins 使用 gt ins_labels 与 pre ins_preds 求loss
for input, target in zip(ins_preds, ins_labels): # ins_preds 与 ins_labels维度一样, ins_preds[0]数值, ins_labels[0]是0,1
if input.size()[0] == 0: # no ins
continue
input = torch.sigmoid(input) # sigmoid
loss_ins.append(dice_loss(input, target))
loss_ins = torch.cat(loss_ins).mean()
loss_ins = loss_ins * self.ins_loss_weight
print('loss_ins: ', loss_ins)
# cate
cate_labels = [
torch.cat([cate_labels_level_img.flatten()
for cate_labels_level_img in cate_labels_level])
for cate_labels_level in zip(*cate_label_list)
]
flatten_cate_labels = torch.cat(cate_labels) # 3872 * batch_size
# 对于cate 同样使用gt cate_labels 与 pre cate_preds求loss
cate_preds = [
cate_pred.permute(0, 2, 3, 1).reshape(-1, self.cate_out_channels) # [s*s , C]
for cate_pred in cate_preds
]
'''
(Pdb) cate_preds[0].size()
torch.Size([3200, 80]) 3200 = 1600 *2 --> [40, 40, 80]
(Pdb) cate_preds[1].size()
torch.Size([2592, 80])
(Pdb) cate_preds[2].size()
torch.Size([1152, 80])
(Pdb) cate_preds[3].size()
torch.Size([512, 80])
(Pdb) cate_preds[4].size()
torch.Size([288, 80])
'''
flatten_cate_preds = torch.cat(cate_preds) # [3782 * instance, 80] 5个fpn最后的feature map的channel相加
loss_cate = self.loss_cate(flatten_cate_preds, flatten_cate_labels, avg_factor=num_ins + 1) # num_ins表示的是ins_preds[0:4]上的的第一维度相加, 表示一共实例的个数。
return dict(
loss_ins=loss_ins,
loss_cate=loss_cate)
def solo_target_single(self,
gt_bboxes_raw,
gt_labels_raw,
gt_masks_raw,
featmap_sizes=None):
# 每次读取一张图片,根据gt_areas算图中的每一个实例在FPN的哪一层
# gt_bboxes_raw.size() --> [7, 4]
# gt_labels_raw --> 7
# gt_masks_raw --> [7, 800, 1024]
# featmap_sizes --> [torch.Size([200, 336]), torch.Size([200, 336]), torch.Size([100, 168]), torch.Size([50, 84]), torch.Size([50, 84])]
device = gt_labels_raw[0].device # cuda
# ins
# compute the gt_areas of per gt in one img.
# gt_areas.size() --> [instances]
gt_areas = torch.sqrt((gt_bboxes_raw[:, 2] - gt_bboxes_raw[:, 0]) * (
gt_bboxes_raw[:, 3] - gt_bboxes_raw[:, 1]))
ins_label_list = []
cate_label_list = []
ins_ind_label_list = []
for (lower_bound, upper_bound), stride, featmap_size, num_grid \
in zip(self.scale_ranges, self.strides, featmap_sizes, self.seg_num_grids):
ins_label = torch.zeros([num_grid ** 2, featmap_size[0], featmap_size[1]], dtype=torch.uint8, device=device) # eg. [40 * 40, 200, 336]
cate_label = torch.zeros([num_grid, num_grid], dtype=torch.int64, device=device) # [40, 40]
ins_ind_label = torch.zeros([num_grid ** 2], dtype=torch.bool, device=device) # [1600]
# nonzero()返回非0索引的位置。
# flatten()展平操作
hit_indices = ((gt_areas >= lower_bound) & (gt_areas <= upper_bound)).nonzero().flatten() # 代表在这一层 预测的实例的gt索引 也就是哪一个示例会出现在这层
#pdb.set_trace()
if len(hit_indices) == 0:
ins_label_list.append(ins_label)
cate_label_list.append(cate_label)
ins_ind_label_list.append(ins_ind_label)
continue
gt_bboxes = gt_bboxes_raw[hit_indices] # store gt_bboxes[x1,y1,x2,y2] when gt_areas belong to [lower_bound , upper_bound] ---> eg.[1, 4]
gt_labels = gt_labels_raw[hit_indices] # [instances] when gt_areas belong to [lower_bound , upper_bound ---> eg.[57]
gt_masks = gt_masks_raw[hit_indices.cpu().numpy(), ...] # [instances , w, h] --> eg. [1, 800, 1216]
half_ws = 0.5 * (gt_bboxes[:, 2] - gt_bboxes[:, 0]) * self.sigma # self.sigma = 0.2
half_hs = 0.5 * (gt_bboxes[:, 3] - gt_bboxes[:, 1]) * self.sigma
output_stride = stride / 2
# 每次只挑出一个instance
for seg_mask, gt_label, half_h, half_w in zip(gt_masks, gt_labels, half_hs, half_ws):
if seg_mask.sum() < 10:
continue
# mass center
upsampled_size = (featmap_sizes[0][0] * 4, featmap_sizes[0][1] * 4)
center_h, center_w = ndimage.measurements.center_of_mass(seg_mask) # 算质心
coord_w = int((center_w / upsampled_size[1]) // (1. / num_grid)) # 将质心 转化为 num_grid的坐标 eg. [659, 398] --> [29, 11] when num_grid = 36
coord_h = int((center_h / upsampled_size[0]) // (1. / num_grid))
# left, top, right, down
top_box = max(0, int(((center_h - half_h) / upsampled_size[0]) // (1. / num_grid)))
down_box = min(num_grid - 1, int(((center_h + half_h) / upsampled_size[0]) // (1. / num_grid)))
left_box = max(0, int(((center_w - half_w) / upsampled_size[1]) // (1. / num_grid)))
right_box = min(num_grid - 1, int(((center_w + half_w) / upsampled_size[1]) // (1. / num_grid)))
top = max(top_box, coord_h-1) # 6
down = min(down_box, coord_h+1) # 8
left = max(coord_w-1, left_box) # 6
right = min(right_box, coord_w+1) # 8
# cate
cate_label[top:(down+1), left:(right+1)] = gt_label # eg. 将[6,8]
# ins
seg_mask = mmcv.imrescale(seg_mask, scale=1. / output_stride) # [800, 1088] --> [50, 68] 因为是[2h, 2w] 因此少缩小2倍
seg_mask = torch.Tensor(seg_mask)
for i in range(top, down+1):
for j in range(left, right+1):
label = int(i * num_grid + j)
ins_label[label, :seg_mask.shape[0], :seg_mask.shape[1]] = seg_mask # 存储在 s*s的某个通道上
ins_ind_label[label] = True # s*s 中哪一个网格有实例
ins_label_list.append(ins_label)
cate_label_list.append(cate_label)
ins_ind_label_list.append(ins_ind_label)
#pdb.set_trace()
return ins_label_list, cate_label_list, ins_ind_label_list
def get_seg(self, seg_preds, cate_preds, img_metas, cfg, rescale=None): # len(seg_preds):5 len(cate_preds):5
#pdb.set_trace()
assert len(seg_preds) == len(cate_preds)
num_levels = len(cate_preds) # 5
featmap_size = seg_preds[0].size()[-2:] # max fpn feature map size : [200, 304]
result_list = []
for img_id in range(len(img_metas)):
cate_pred_list = [
cate_preds[i][img_id].view(-1, self.cate_out_channels).detach() for i in range(num_levels)
]
seg_pred_list = [
seg_preds[i][img_id].detach() for i in range(num_levels)
]
img_shape = img_metas[img_id]['img_shape']
scale_factor = img_metas[img_id]['scale_factor']
ori_shape = img_metas[img_id]['ori_shape']
cate_pred_list = torch.cat(cate_pred_list, dim=0) #每次读取one img, 因此cate_pred_list.size() --> [3872, 80]
seg_pred_list = torch.cat(seg_pred_list, dim=0)
result = self.get_seg_single(cate_pred_list, seg_pred_list,
featmap_size, img_shape, ori_shape, scale_factor, cfg, rescale)
result_list.append(result)
#pdb.set_trace()
#pdb.set_trace()
return result_list
# 对于每一个图片。
def get_seg_single(self,
cate_preds, # [3872, 80]
seg_preds, # eg. [3872, 200, 304]
featmap_size, # eg. [200, 304] max feature map in FPN
img_shape, # eg. [800, 1199, 3]
ori_shape, # eg. [427, 640, 3]
scale_factor,
cfg,
rescale=False, debug=False):
assert len(cate_preds) == len(seg_preds)
#pdb.set_trace()
# overall info.
h, w, _ = img_shape
upsampled_size_out = (featmap_size[0] * 4, featmap_size[1] * 4) # eg. [800, 1216]
# process.
inds = (cate_preds > cfg.score_thr) # 第一次筛选 eg. [3872, 80] score_thr = 0.1 inds 是 bool类型
# category scores.
cate_scores = cate_preds[inds] # eg.[507] cate_scores是数值,维度是[num[True]](我认为还降维了), 根据cate_preds[inds] 在对于true的地方输出
if len(cate_scores) == 0:
return None
# category labels.
inds = inds.nonzero() # 返回inds[i]为True的索引 inds.nonzero().size() --> [507, 2]
cate_labels = inds[:, 1] # inds的第二列是代表的[80]中的类别。 cate_labels --> [507]
# strides.
size_trans = cate_labels.new_tensor(self.seg_num_grids).pow(2).cumsum(0) # tensor([1600, 2896, 3472, 3728, 3872], device='cuda:0')
strides = cate_scores.new_ones(size_trans[-1]) # [3872] 全为1
n_stage = len(self.seg_num_grids) # 5
strides[:size_trans[0]] *= self.strides[0] # 前1600个元素由 1 变成 8
for ind_ in range(1, n_stage):
strides[size_trans[ind_ - 1]:size_trans[ind_]] *= self.strides[ind_] # eg. 为1600 ~ 2896的1296个元素 赋值
strides = strides[inds[:, 0]] # strides.size() --> [507] inds[:, 0] 表示第几个grid_cell
# masks.
seg_preds = seg_preds[inds[:, 0]] # [3872, 200, 304] --> [507, 200, 304]
seg_masks = seg_preds > cfg.mask_thr # mask_thr = 0.5 bool [507, 200, 304] --> binary mask 二值化的作用!
sum_masks = seg_masks.sum((1, 2)).float() # [507, 200, 304] ---> [507] sum(1,2)表示对每一个channcel内的[H * W]的每个元素求和
# filter.
keep = sum_masks > strides #bool [507]
if keep.sum() == 0:
return None
#过滤
seg_masks = seg_masks[keep, ...] # bool [keep.size(), 200, 304] seg_mask[True]的位置保持原来的seg_mask的值(T or F), seg_mask[False]的位置直接取舍不记录。
seg_preds = seg_preds[keep, ...]
sum_masks = sum_masks[keep]
cate_scores = cate_scores[keep]
cate_labels = cate_labels[keep]
# mask scoring.
seg_scores = (seg_preds * seg_masks.float()).sum((1, 2)) / sum_masks # eg, [507] 每一个channel上的对应元素相乘再求和最后除以
cate_scores *= seg_scores # why?
# sort and keep top nms_pre
sort_inds = torch.argsort(cate_scores, descending=True)
if len(sort_inds) > cfg.nms_pre: # 筛选前500
sort_inds = sort_inds[:cfg.nms_pre]
seg_masks = seg_masks[sort_inds, :, :]
seg_preds = seg_preds[sort_inds, :, :]
sum_masks = sum_masks[sort_inds]
cate_scores = cate_scores[sort_inds]
cate_labels = cate_labels[sort_inds]
#pdb.set_trace()
# Matrix NMS
cate_scores = matrix_nms(seg_masks, cate_labels, cate_scores,
kernel=cfg.kernel, sigma=cfg.sigma, sum_masks=sum_masks)
# filter.
keep = cate_scores >= cfg.update_thr
if keep.sum() == 0:
return None
seg_preds = seg_preds[keep, :, :]
cate_scores = cate_scores[keep]
cate_labels = cate_labels[keep]
# sort and keep top_k
sort_inds = torch.argsort(cate_scores, descending=True)
if len(sort_inds) > cfg.max_per_img:
sort_inds = sort_inds[:cfg.max_per_img]
seg_preds = seg_preds[sort_inds, :, :]
cate_scores = cate_scores[sort_inds]
cate_labels = cate_labels[sort_inds]
seg_preds = F.interpolate(seg_preds.unsqueeze(0),
size=upsampled_size_out,
mode='bilinear')[:, :, :h, :w]
seg_masks = F.interpolate(seg_preds,
size=ori_shape[:2],
mode='bilinear').squeeze(0)
seg_masks = seg_masks > cfg.mask_thr
#pdb.set_trace()
return seg_masks, cate_labels, cate_scores
#----------------------------------------------------------------------------------------#
#self.ins_convs:
'''
ModuleList(
(0): ConvModule(
(conv): Conv2d(258, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(gn): GroupNorm(32, 256, eps=1e-05, affine=True)
(activate): ReLU(inplace=True)
)
(1): ConvModule(
(conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(gn): GroupNorm(32, 256, eps=1e-05, affine=True)
(activate): ReLU(inplace=True)
)
(2): ConvModule(
(conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(gn): GroupNorm(32, 256, eps=1e-05, affine=True)
(activate): ReLU(inplace=True)
)
(3): ConvModule(
(conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(gn): GroupNorm(32, 256, eps=1e-05, affine=True)
(activate): ReLU(inplace=True)
)
(4): ConvModule(
(conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(gn): GroupNorm(32, 256, eps=1e-05, affine=True)
(activate): ReLU(inplace=True)
)
(5): ConvModule(
(conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(gn): GroupNorm(32, 256, eps=1e-05, affine=True)
(activate): ReLU(inplace=True)
)
(6): ConvModule(
(conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(gn): GroupNorm(32, 256, eps=1e-05, affine=True)
(activate): ReLU(inplace=True)
)
)
'''
#----------------------------------------------------------------------------------------#
#self.cate_convs
'''
ModuleList(
(0): ConvModule(
(conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(gn): GroupNorm(32, 256, eps=1e-05, affine=True)
(activate): ReLU(inplace=True)
)
(1): ConvModule(
(conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(gn): GroupNorm(32, 256, eps=1e-05, affine=True)
(activate): ReLU(inplace=True)
)
(2): ConvModule(
(conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(gn): GroupNorm(32, 256, eps=1e-05, affine=True)
(activate): ReLU(inplace=True)
)
(3): ConvModule(
(conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(gn): GroupNorm(32, 256, eps=1e-05, affine=True)
(activate): ReLU(inplace=True)
)
(4): ConvModule(
(conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(gn): GroupNorm(32, 256, eps=1e-05, affine=True)
(activate): ReLU(inplace=True)
)
(5): ConvModule(
(conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(gn): GroupNorm(32, 256, eps=1e-05, affine=True)
(activate): ReLU(inplace=True)
)
(6): ConvModule(
(conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(gn): GroupNorm(32, 256, eps=1e-05, affine=True)
(activate): ReLU(inplace=True)
)
'''
#----------------------------------------------------------------------------------------#
# self.solo_ins_list
'''
ModuleList(
(0): Conv2d(256, 1600, kernel_size=(1, 1), stride=(1, 1))
(1): Conv2d(256, 1296, kernel_size=(1, 1), stride=(1, 1))
(2): Conv2d(256, 576, kernel_size=(1, 1), stride=(1, 1))
(3): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
(4): Conv2d(256, 144, kernel_size=(1, 1), stride=(1, 1))
)
'''
'''
ins_pred:
(Pdb) ins_pred[0].size()
torch.Size([2, 1600, 200, 304])
(Pdb) ins_pred[1].size()
torch.Size([2, 1296, 200, 304])
(Pdb) ins_pred[2].size()
torch.Size([2, 576, 100, 152])
(Pdb) ins_pred[3].size()
torch.Size([2, 256, 50, 76])
(Pdb) ins_pred[4].size()
torch.Size([2, 144, 50, 76])
'''
'''
cate_pred:
(Pdb) cate_pred[0].size()
torch.Size([2, 80, 40, 40])
(Pdb) cate_pred[1].size()
torch.Size([2, 80, 36, 36])
(Pdb) cate_pred[2].size()
torch.Size([2, 80, 24, 24])
(Pdb) cate_pred[3].size()
torch.Size([2, 80, 16, 16])
(Pdb) cate_pred[4].size()
torch.Size([2, 80, 12, 12])
'''
#----------------------------------------------------------------------------------------#
#def loss
'''
ins_labels
(Pdb) ins_labels[0].size()
torch.Size([1, 200, 272])
(Pdb) ins_labels[1].size()
torch.Size([0, 200, 272])
(Pdb) ins_labels[2].size()
torch.Size([16, 100, 136])
(Pdb) ins_labels[3].size()
torch.Size([39, 50, 68])
(Pdb) ins_labels[4].size()
torch.Size([18, 50, 68])
'''
'''
ins_preds:
(Pdb) ins_preds[0].size()
torch.Size([1, 200, 272])
(Pdb) ins_preds[1].size()
torch.Size([0, 200, 272])
(Pdb) ins_preds[2].size()
torch.Size([6, 100, 136])
(Pdb) ins_preds[3].size()
torch.Size([10, 50, 68])
(Pdb) ins_preds[4].size()
torch.Size([6, 50, 68])
'''
'''
ins_ind_labels:
(Pdb) ins_ind_labels[0].size()
torch.Size([1600])
(Pdb) ins_ind_labels[1].size()
torch.Size([1296])
(Pdb) ins_ind_labels[2].size()
torch.Size([576])
(Pdb) ins_ind_labels[3].size()
torch.Size([256])
(Pdb) ins_ind_labels[4].size()
torch.Size([144])
'''
'''
cate_labels:
(Pdb) cate_labels[0].size()
torch.Size([1600])
(Pdb) cate_labels[1].size()
torch.Size([1296])
(Pdb) cate_labels[2].size()
torch.Size([576])
(Pdb) cate_labels[3].size()
torch.Size([256])
(Pdb) cate_labels[4].size()
torch.Size([144])
'''
'''
get_seg
cfg:
{
'nms_pre': 500,
'score_thr': 0.1,
'mask_thr': 0.5,
'update_thr': 0.05,
'kernel': 'gaussian',
'sigma': 2.0,
'max_per_img': 100}
'''
'''
sum_masks
tensor([ 96., 96., 82., 82., 82., 108., 108., 108., 86.,
86., 86., 208., 227., 227., 227., 134., 134., 88.,
28., 79., 79., 231., 231., 231., 189., 189., 31.,
31., 125., 125., 125., 158., 158., 194., 99., 99.,
74., 159., 37., 37., 37., 39., 39., 275., 50.,
31., 64., 64., 64., 64., 66., 66., 66., 66.,
91., 91., 91., 93., 192., 192., 192., 46., 46.,
46., 39., 39., 51., 51., 87., 140., 181., 199.,
50., 50., 50., 50., 76., 20., 88., 88., 84.,
84., 84., 236., 236., 94., 211., 211., 252., 85.,
98., 56., 96., 96., 60., 60., 60., 53., 84.,
84., 84., 84., 258., 267., 304., 90., 105., 105.,
105., 75., 75., 75., 53., 53., 84., 84., 132.,
274., 274., 259., 259., 296., 296., 296., 272., 272.,
272., 272., 112., 117., 50., 87., 143., 143., 80.,
88., 88., 273., 273., 320., 320., 294., 364., 313.,
355., 302., 353., 353., 67., 67., 42., 32., 32.,
61., 61., 61., 61., 68., 68., 68., 68., 168.,
168., 168., 28., 28., 28., 67., 71., 139., 282.,
304., 94., 169., 135., 135., 286., 331., 100., 100.,
100., 95., 95., 172., 277., 277., 277., 371., 380.,
92., 92., 160., 394., 394., 395., 132., 132., 157.,
295., 282., 452., 468., 66., 66., 209., 73., 73.,
73., 352., 360., 333., 25., 205., 229., 229., 229.,
491., 491., 488., 488., 488., 449., 449., 234., 255.,
255., 255., 255., 630., 514., 514., 514., 481., 481.,
481., 871., 1029., 260., 260., 260., 260., 639., 514.,
484., 168., 168., 415., 81., 1120., 1232., 418., 418.,
128., 141., 242., 242., 91., 57., 57., 80., 80.,
80., 621., 1248., 1315., 199., 304., 210., 78., 54.,
54., 62., 62., 62., 622., 697., 697., 663., 663.,
149., 118., 108., 109., 109., 202., 218., 218., 275.,
275., 357., 357., 357., 361., 361., 102., 111., 111.,
448., 279., 356., 347., 347., 271., 293., 288., 288.,
288., 277., 277., 271., 271., 131., 131., 162., 162.,
162., 132., 132., 107., 362., 452., 452., 571., 361.,
360., 438., 714., 404., 427., 613., 395., 411., 438.,
438., 471., 529., 546., 52., 52., 85., 85., 85.,
181., 181., 336., 359., 183., 353., 370., 98., 98.,
98., 191., 191., 268., 268., 340., 340., 736., 346.,
380., 94., 94., 94., 179., 179., 412., 437., 437.,
437., 1087., 560., 398., 925., 925., 802., 802., 802.,
375., 834., 847., 512., 944., 508., 48., 274., 82.,
82., 82., 482., 444., 491., 491., 1281., 679., 679.,
571., 571., 571., 1403., 583., 647., 1429., 940., 721.,
721., 313., 1953., 3322., 3694., 3694., 2245., 2187., 1180.,
3924., 3924., 3963., 1622., 2566., 3506., 1246., 2082., 4032.,
4067., 474., 567., 567., 1675., 2513., 3013., 1489., 709.,
900., 900., 769., 2537., 689., 1485., 2476., 416., 1449.,
706., 2477., 3185., 3221., 413., 2756., 3230., 3230., 3156.,
424., 465., 2933., 2846., 474., 474., 940., 940., 851.,
851., 851., 553., 1572., 5856., 3666., 4373., 3937., 2129.,
4194., 4586., 2788., 2683., 4081., 3171., 3171., 3894., 4206.,
1353., 1984., 3575., 3303., 2040., 3688., 3688., 7555., 8147.,
9637., 10042., 7735., 9848., 10357., 6124., 10311., 10753., 5137.,
4384., 6858., 4768., 4397., 6499., 10237., 10237., 9333., 9333.,
9033., 9723., 9955.], device='cuda:0')
'''
'''strides
tensor([ 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,
8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,
8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,
8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,
8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,
8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,
8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,
8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,
8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,
8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,
8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,
8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,
8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,
8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,
8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,
8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,
8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,
8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,
8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,
8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,
8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,
8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,
8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,
8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,
8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,
8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,
8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,
8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 16., 16.,
16., 16., 16., 16., 16., 16., 16., 16., 16., 16., 16., 16., 16., 16.,
16., 16., 16., 16., 16., 16., 16., 16., 16., 16., 16., 16., 16., 16.,
16., 16., 16., 16., 16., 16., 16., 16., 16., 16., 16., 16., 16., 16.,
16., 16., 16., 16., 16., 16., 16., 16., 16., 16., 16., 16., 16., 16.,
16., 16., 16., 16., 16., 16., 16., 16., 16., 16., 16., 16., 16., 16.,
32., 32., 32., 32., 32., 32., 32., 32., 32., 32., 32., 32., 32., 32.,
32., 32., 32., 32., 32., 32., 32., 32., 32., 32., 32., 32., 32., 32.,
32., 32., 32., 32., 32., 32., 32., 32., 32., 32., 32., 32., 32., 32.,
32., 32., 32.], device='cuda:0')
'''
import torch
import pdb
def matrix_nms(seg_masks, cate_labels, cate_scores, kernel='gaussian', sigma=2.0, sum_masks=None):
"""Matrix NMS for multi-class masks.
Args:
seg_masks (Tensor): shape (n, h, w) bool
cate_labels (Tensor): shape (n), mask labels in descending order
cate_scores (Tensor): shape (n), mask scores in descending order
kernel (str): 'linear' or 'gauss'
sigma (float): std in gaussian method
sum_masks (Tensor): The sum of seg_masks
Returns:
Tensor: cate_scores_update, tensors of shape (n)
"""
pdb.set_trace()
n_samples = len(cate_labels) # 最多 500
if n_samples == 0:
return []
if sum_masks is None:
sum_masks = seg_masks.sum((1, 2)).float()
seg_masks = seg_masks.reshape(n_samples, -1).float() # [500, 60800] 相当于把同一个实例的特征展平
# inter. 注: 矩阵相乘就表示了每一个channel上某一个实例的掩码所在所在位置上的值(1or0)与其他通道的mask所在位置的值相乘
# 2个特例:
# 就算相同类别,如果位置不同,那么他们inter也是0,如果位置相同,就涉及到了NMS筛选的范畴
# (1)如果他们位置不同,那么就必定是为0的,不能仅仅考虑类别相同!
# (2)并且可能不同的实例一大一小,但是他们位置有相交,那么也有交集!不同实例相同位置的IOU排除方法见下面的label_matrix的使用。
inter_matrix = torch.mm(seg_masks, seg_masks.transpose(1, 0)) # [500 , 60800] @ [60800 , 500] = [500, 500]
# union.
sum_masks_x = sum_masks.expand(n_samples, n_samples) # [500, 500]
# iou.
# 掩码值相加代表了union 取上三角(转置肯定有重复。)
iou_matrix = (inter_matrix / (sum_masks_x + sum_masks_x.transpose(1, 0) - inter_matrix)).triu(diagonal=1)
# label_specific matrix.
cate_labels_x = cate_labels.expand(n_samples, n_samples) # [500, 500]
# 每i行的元素(1 or 0),1表示和第i个mask类别一样的。 并且使用了triu方法,进一步的得到分数比他低的的mask(triu方法的妙用)
# 因此在已经排除了同一种label不同位置的情况,这一步就是排除同一个位置,不同label,它们的iou也要置于0
label_matrix = (cate_labels_x == cate_labels_x.transpose(1, 0)).float().triu(diagonal=1) # [500, 500]
# IoU decay
# iou_matrix * label_matrix是为了保留同一种小于最大scores的label的iou。
# 因为之前算的iou的inter部分有可能一大一小的实例,但是他们位置上有重叠,因此还有iou并不等于0,要进行惩罚
# 而消除不同label的iou(因为nms就是对同一个类别的scores高低的mask/box进行筛选最后剩下一个)
# 第一个式子排除结束。得到同种mask同一位置的IOU,每i行表示与第i个mask的iou。
decay_iou = iou_matrix * label_matrix
'''
(Pdb) decay_iou = (iou_matrix * label_matrix) 上三角。
tensor([[0.0000, 0.8036, 0.5017, ..., 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.4816, ..., 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
...,
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0127],
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
device='cuda:0')
'''
# IoU compensation
compensate_iou, _ = (iou_matrix * label_matrix).max(0) # fast-nms 按列取最大值(不是同一类的mask直接就0不考虑了),第i列表示第i个mask与跟它同种mask最大的scores最大的iou值
# 分析:
# eg.
# 前3列都是第一个mask的预测,按照scores排列第一个是最大的,所以第一列的max就是0;
# 注意看第三列,max是0.5017,这个0.5是和第一个mask相比的,而不取0.47(如果thr是0.5就不会被排除)。
# 这就是**fast-nms尽可能去掉更多的框的核心思想**。
'''
compensate_iou
tensor([[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[0.8036, 0.8036, 0.8036, ..., 0.8036, 0.8036, 0.8036],
[0.5017, 0.5017, 0.5017, ..., 0.5017, 0.5017, 0.5017],
...,
[0.0021, 0.0021, 0.0021, ..., 0.0021, 0.0021, 0.0021],
[0.0054, 0.0054, 0.0054, ..., 0.0054, 0.0054, 0.0054],
[0.0193, 0.0193, 0.0193, ..., 0.0193, 0.0193, 0.0193]],
'''
compensate_iou = compensate_iou.expand(n_samples, n_samples).transpose(1, 0)
# matrix nms
if kernel == 'gaussian':
decay_matrix = torch.exp(-1 * sigma * (decay_iou ** 2))
compensate_matrix = torch.exp(-1 * sigma * (compensate_iou ** 2))
decay_coefficient, _ = (decay_matrix / compensate_matrix).min(0)
# 分析min(0)按列取最小的作用:
# 如下面的eg, 因为经过了指数函数e, 原来为0的表示最大的score或者无iou缩减的值就要变为1。原来对于每一个mask,次大的得分的scores就会变小。
# 按列取最小应该算出对每一个mask的scores抑制的大小。(这里的decay_iou只会算同label的mask了。)
'''
(Pdb) decay_matrix / compensate_matrix
tensor([[1.0000, 0.2748, 0.6044, ..., 1.0000, 1.0000, 1.0000],
[3.6388, 3.6388, 2.2883, ..., 3.6388, 3.6388, 3.6388],
[1.6545, 1.6545, 1.6545, ..., 1.6545, 1.6545, 1.6545],
...,
[1.0000, 1.0000, 1.0000, ..., 1.0000, 1.0000, 0.9997],
[1.0001, 1.0001, 1.0001, ..., 1.0001, 1.0001, 1.0001],
[1.0007, 1.0007, 1.0007, ..., 1.0007, 1.0007, 1.0007]],
device='cuda:0')
'''
pdb.set_trace
elif kernel == 'linear':
decay_matrix = (1-decay_iou)/(1-compensate_iou)
decay_coefficient, _ = decay_matrix.min(0)
else:
raise NotImplementedError
# update the score.
cate_scores_update = cate_scores * decay_coefficient # soft-nms的方法 让相同的label但是scores低与max的变小。
pdb.set_trace()
return cate_scores_update
def multiclass_nms(multi_bboxes,
multi_scores,
score_thr,
nms_cfg,
max_num=-1,
score_factors=None):
"""NMS for multi-class bboxes.
Args:
multi_bboxes (Tensor): shape (n, #class*4) or (n, 4)
multi_scores (Tensor): shape (n, #class), where the 0th column
contains scores of the background class, but this will be ignored.
score_thr (float): bbox threshold, bboxes with scores lower than it
will not be considered.
nms_thr (float): NMS IoU threshold
max_num (int): if there are more than max_num bboxes after NMS,
only top max_num will be kept.
score_factors (Tensor): The factors multiplied to scores before
applying NMS
Returns:
tuple: (bboxes, labels), tensors of shape (k, 5) and (k, 1). Labels
are 0-based.
"""
num_classes = multi_scores.shape[1]
bboxes, labels = [], []
nms_cfg_ = nms_cfg.copy()
nms_type = nms_cfg_.pop('type', 'nms')
nms_op = getattr(nms_wrapper, nms_type)
for i in range(1, num_classes):
cls_inds = multi_scores[:, i] > score_thr
if not cls_inds.any():
continue
# get bboxes and scores of this class
if multi_bboxes.shape[1] == 4:
_bboxes = multi_bboxes[cls_inds, :]
else:
_bboxes = multi_bboxes[cls_inds, i * 4:(i + 1) * 4]
_scores = multi_scores[cls_inds, i]
if score_factors is not None:
_scores *= score_factors[cls_inds]
cls_dets = torch.cat([_bboxes, _scores[:, None]], dim=1)
cls_dets, _ = nms_op(cls_dets, **nms_cfg_)
cls_labels = multi_bboxes.new_full((cls_dets.shape[0], ),
i - 1,
dtype=torch.long)
bboxes.append(cls_dets)
labels.append(cls_labels)
if bboxes:
bboxes = torch.cat(bboxes)
labels = torch.cat(labels)
if bboxes.shape[0] > max_num:
_, inds = bboxes[:, -1].sort(descending=True)
inds = inds[:max_num]
bboxes = bboxes[inds]
labels = labels[inds]
else:
bboxes = multi_bboxes.new_zeros((0, 5))
labels = multi_bboxes.new_zeros((0, ), dtype=torch.long)
return bboxes, labels
'''
(Pdb) cate_scores * decay_coefficient
tensor([0.7593, 0.1010, 0.1081, 0.5393, 0.0926, 0.0885, 0.4901, 0.4664, 0.4540,
0.0755, 0.4385, 0.3944, 0.0726, 0.0748, 0.0986, 0.0551, 0.0835, 0.0694,
0.0822, 0.3194, 0.0600, 0.3141, 0.0594, 0.3115, 0.3114, 0.3086, 0.0771,
0.0792, 0.0597, 0.0512, 0.0569, 0.3018, 0.0461, 0.0550, 0.0537, 0.0662,
0.0580, 0.0644, 0.0503, 0.2881, 0.2839, 0.2830, 0.0561, 0.1310, 0.2692,
0.0652, 0.0694, 0.0505, 0.0410, 0.0464, 0.0665, 0.0409, 0.2440, 0.0407,
0.0464, 0.0410, 0.2291, 0.0447, 0.1051, 0.2260, 0.2241, 0.2236, 0.2233,
0.0529, 0.1370, 0.2200, 0.0540, 0.0532, 0.0473, 0.0530, 0.2168, 0.2134,
0.0678, 0.0478, 0.0384, 0.0407, 0.1161, 0.0320, 0.0619, 0.2025, 0.0388,
0.0331, 0.0493, 0.0866, 0.0849, 0.0413, 0.0593, 0.0593, 0.0388, 0.0389,
0.0738, 0.1875, 0.0674, 0.1145, 0.0588, 0.0806, 0.1797, 0.0382, 0.1776,
0.1751, 0.0489, 0.0511, 0.1743, 0.0815, 0.1741, 0.0582, 0.0925, 0.0317,
0.0318, 0.1661, 0.1645, 0.0297, 0.1634, 0.1629, 0.0446, 0.0389, 0.0318,
0.1611, 0.1445, 0.0564, 0.0337, 0.1564, 0.1563, 0.0331, 0.1556, 0.0605,
0.1533, 0.1526, 0.0254, 0.1477, 0.0477, 0.1507, 0.0379, 0.1504, 0.0312,
0.0492, 0.1478, 0.0248, 0.1466, 0.0412, 0.0278, 0.0301, 0.0973, 0.0297,
0.1449, 0.0219, 0.0616, 0.0348, 0.0274, 0.0721, 0.0425, 0.1388, 0.0409,
0.0231, 0.0848, 0.1382, 0.0488, 0.0265, 0.0326, 0.1361, 0.0220, 0.0898,
0.0259, 0.0259, 0.0268, 0.0563, 0.1345, 0.1344, 0.0220, 0.0319, 0.0512,
0.1330, 0.0265, 0.0458, 0.0277, 0.0257, 0.0245, 0.0280, 0.1300, 0.0402,
0.0307, 0.0460, 0.0315, 0.0277, 0.0173, 0.0657, 0.0251, 0.0230, 0.1267,
0.1263, 0.0789, 0.0680, 0.0559, 0.0196, 0.0247, 0.0987, 0.1243, 0.0254,
0.1033, 0.1235, 0.1234, 0.1233, 0.1232, 0.0211, 0.0351, 0.1230, 0.1225,
0.0211, 0.1211, 0.0752, 0.1207, 0.0759, 0.1200, 0.0432, 0.1198, 0.1191,
0.0215, 0.0458, 0.1184, 0.0221, 0.1175, 0.0706, 0.0312, 0.1170, 0.1169,
0.0257, 0.1167, 0.1166, 0.0193, 0.0641, 0.1151, 0.0692, 0.0873, 0.0289,
0.0330, 0.1137, 0.0447, 0.0257, 0.0675, 0.1123, 0.0252, 0.0519, 0.0219,
0.0188, 0.0327, 0.1117, 0.1117, 0.0921, 0.0403, 0.0270, 0.0230, 0.0641,
0.0273, 0.1099, 0.0201, 0.0322, 0.1091, 0.1090, 0.0229, 0.1089, 0.0187,
0.0216, 0.0307, 0.0513, 0.1080, 0.0260, 0.0855, 0.0441, 0.0188, 0.0972,
0.1068, 0.0417, 0.0206, 0.0394, 0.0214, 0.0427, 0.0170, 0.0311, 0.0481,
0.0196, 0.1049, 0.1051, 0.1049, 0.0295, 0.0347, 0.0226, 0.0667, 0.0199,
0.1041, 0.0246, 0.1038, 0.0241, 0.1033, 0.1028, 0.0212, 0.1021, 0.1022,
0.1019, 0.0413, 0.0388, 0.0343, 0.0967, 0.0925, 0.0654, 0.1009, 0.0301,
0.1007, 0.0986, 0.0474, 0.0583, 0.0990, 0.0273, 0.0989, 0.0737, 0.0689,
0.0187, 0.0231, 0.0982, 0.0522, 0.0132, 0.0973, 0.0387, 0.0971, 0.0937,
0.0968, 0.0189, 0.0218, 0.0933, 0.0219, 0.0199, 0.0957, 0.0475, 0.0266,
0.0950, 0.0389, 0.0454, 0.0262, 0.0641, 0.0870, 0.0212, 0.0187, 0.0834,
0.0931, 0.0431, 0.0929, 0.0929, 0.0703, 0.0193, 0.0459, 0.0211, 0.0926,
0.0925, 0.0923, 0.0371, 0.0420, 0.0224, 0.0196, 0.0919, 0.0336, 0.0917,
0.0894, 0.0569, 0.0832, 0.0328, 0.0249, 0.0263, 0.0181, 0.0410, 0.0906,
0.0159, 0.0402, 0.0183, 0.0168, 0.0171, 0.0204, 0.0160, 0.0897, 0.0323,
0.0173, 0.0240, 0.0708, 0.0894, 0.0892, 0.0892, 0.0283, 0.0186, 0.0172,
0.0882, 0.0160, 0.0179, 0.0522, 0.0511, 0.0177, 0.0877, 0.0418, 0.0155,
0.0606, 0.0868, 0.0867, 0.0485, 0.0258, 0.0143, 0.0359, 0.0804, 0.0457,
0.0835, 0.0678, 0.0177, 0.0193, 0.0250, 0.0477, 0.0289, 0.0247, 0.0839,
0.0836, 0.0680, 0.0423, 0.0147, 0.0649, 0.0824, 0.0178, 0.0299, 0.0219,
0.0161, 0.0152, 0.0422, 0.0242, 0.0266, 0.0808, 0.0453, 0.0557, 0.0807,
0.0222, 0.0154, 0.0217, 0.0134, 0.0600, 0.0447, 0.0231, 0.0162, 0.0759,
0.0292, 0.0229, 0.0790, 0.0380, 0.0216, 0.0505, 0.0786, 0.0556, 0.0281,
0.0469, 0.0556, 0.0233, 0.0726, 0.0175, 0.0303, 0.0774, 0.0770, 0.0462,
0.0285, 0.0731, 0.0333, 0.0712, 0.0232, 0.0318, 0.0756, 0.0361, 0.0382,
0.0751, 0.0627, 0.0749, 0.0565, 0.0470, 0.0228, 0.0193, 0.0294, 0.0442,
0.0434, 0.0538, 0.0726, 0.0562, 0.0260, 0.0227, 0.0721, 0.0325, 0.0717,
0.0604, 0.0696, 0.0700, 0.0588, 0.0234, 0.0229, 0.0195, 0.0683, 0.0350,
0.0359, 0.0378, 0.0688, 0.0407, 0.0671], device='cuda:0')
(Pdb) cate_scores
tensor([0.7593, 0.6425, 0.6012, 0.5393, 0.5195, 0.4914, 0.4901, 0.4664, 0.4540,
0.4468, 0.4385, 0.3944, 0.3913, 0.3701, 0.3569, 0.3558, 0.3473, 0.3448,
0.3417, 0.3194, 0.3147, 0.3141, 0.3134, 0.3115, 0.3114, 0.3086, 0.3071,
0.3065, 0.3050, 0.3035, 0.3025, 0.3018, 0.3017, 0.3003, 0.2977, 0.2969,
0.2946, 0.2934, 0.2930, 0.2881, 0.2839, 0.2830, 0.2733, 0.2713, 0.2692,
0.2640, 0.2634, 0.2615, 0.2544, 0.2502, 0.2472, 0.2443, 0.2440, 0.2430,
0.2317, 0.2306, 0.2291, 0.2265, 0.2262, 0.2260, 0.2241, 0.2236, 0.2233,
0.2222, 0.2207, 0.2202, 0.2192, 0.2191, 0.2173, 0.2169, 0.2168, 0.2134,
0.2130, 0.2112, 0.2105, 0.2093, 0.2073, 0.2070, 0.2031, 0.2025, 0.2007,
0.1998, 0.1989, 0.1978, 0.1951, 0.1939, 0.1920, 0.1917, 0.1895, 0.1893,
0.1876, 0.1875, 0.1847, 0.1839, 0.1827, 0.1817, 0.1797, 0.1786, 0.1776,
0.1751, 0.1748, 0.1746, 0.1743, 0.1743, 0.1741, 0.1723, 0.1704, 0.1701,
0.1675, 0.1661, 0.1645, 0.1642, 0.1634, 0.1629, 0.1625, 0.1623, 0.1618,
0.1611, 0.1607, 0.1599, 0.1583, 0.1564, 0.1563, 0.1557, 0.1556, 0.1541,
0.1533, 0.1526, 0.1518, 0.1514, 0.1512, 0.1507, 0.1505, 0.1504, 0.1503,
0.1499, 0.1478, 0.1476, 0.1466, 0.1461, 0.1458, 0.1453, 0.1452, 0.1452,
0.1449, 0.1438, 0.1419, 0.1411, 0.1405, 0.1392, 0.1391, 0.1388, 0.1386,
0.1385, 0.1383, 0.1382, 0.1379, 0.1367, 0.1363, 0.1361, 0.1357, 0.1355,
0.1352, 0.1352, 0.1349, 0.1348, 0.1345, 0.1344, 0.1335, 0.1335, 0.1333,
0.1330, 0.1326, 0.1323, 0.1317, 0.1313, 0.1312, 0.1309, 0.1301, 0.1298,
0.1293, 0.1283, 0.1282, 0.1282, 0.1280, 0.1280, 0.1277, 0.1268, 0.1267,
0.1263, 0.1261, 0.1259, 0.1259, 0.1255, 0.1253, 0.1245, 0.1243, 0.1238,
0.1237, 0.1235, 0.1234, 0.1233, 0.1233, 0.1231, 0.1230, 0.1230, 0.1226,
0.1215, 0.1211, 0.1211, 0.1207, 0.1201, 0.1200, 0.1198, 0.1198, 0.1191,
0.1187, 0.1186, 0.1184, 0.1183, 0.1175, 0.1173, 0.1172, 0.1170, 0.1169,
0.1168, 0.1167, 0.1166, 0.1164, 0.1153, 0.1151, 0.1150, 0.1145, 0.1141,
0.1140, 0.1137, 0.1133, 0.1131, 0.1128, 0.1123, 0.1123, 0.1123, 0.1120,
0.1119, 0.1117, 0.1117, 0.1117, 0.1112, 0.1111, 0.1111, 0.1107, 0.1104,
0.1103, 0.1099, 0.1097, 0.1093, 0.1091, 0.1090, 0.1089, 0.1089, 0.1086,
0.1082, 0.1082, 0.1082, 0.1080, 0.1080, 0.1076, 0.1074, 0.1074, 0.1071,
0.1068, 0.1068, 0.1066, 0.1065, 0.1063, 0.1062, 0.1060, 0.1056, 0.1056,
0.1054, 0.1053, 0.1051, 0.1049, 0.1049, 0.1044, 0.1044, 0.1041, 0.1041,
0.1041, 0.1038, 0.1038, 0.1034, 0.1033, 0.1028, 0.1028, 0.1023, 0.1022,
0.1022, 0.1021, 0.1019, 0.1017, 0.1015, 0.1015, 0.1011, 0.1009, 0.1007,
0.1007, 0.0996, 0.0996, 0.0993, 0.0990, 0.0990, 0.0989, 0.0988, 0.0988,
0.0987, 0.0983, 0.0982, 0.0978, 0.0978, 0.0973, 0.0972, 0.0971, 0.0969,
0.0968, 0.0965, 0.0963, 0.0958, 0.0958, 0.0958, 0.0957, 0.0957, 0.0955,
0.0950, 0.0947, 0.0946, 0.0942, 0.0940, 0.0940, 0.0938, 0.0935, 0.0933,
0.0931, 0.0930, 0.0929, 0.0929, 0.0928, 0.0928, 0.0928, 0.0927, 0.0926,
0.0925, 0.0923, 0.0923, 0.0923, 0.0922, 0.0919, 0.0919, 0.0919, 0.0917,
0.0916, 0.0913, 0.0912, 0.0911, 0.0908, 0.0907, 0.0907, 0.0906, 0.0906,
0.0905, 0.0905, 0.0904, 0.0902, 0.0902, 0.0901, 0.0901, 0.0897, 0.0896,
0.0895, 0.0895, 0.0894, 0.0894, 0.0892, 0.0892, 0.0889, 0.0884, 0.0883,
0.0882, 0.0881, 0.0879, 0.0878, 0.0878, 0.0877, 0.0877, 0.0875, 0.0873,
0.0870, 0.0868, 0.0867, 0.0867, 0.0865, 0.0862, 0.0859, 0.0856, 0.0856,
0.0856, 0.0856, 0.0852, 0.0852, 0.0852, 0.0849, 0.0847, 0.0843, 0.0839,
0.0836, 0.0834, 0.0833, 0.0831, 0.0830, 0.0824, 0.0824, 0.0822, 0.0818,
0.0818, 0.0815, 0.0815, 0.0813, 0.0811, 0.0808, 0.0808, 0.0807, 0.0807,
0.0807, 0.0806, 0.0802, 0.0800, 0.0800, 0.0798, 0.0796, 0.0793, 0.0792,
0.0791, 0.0790, 0.0790, 0.0790, 0.0789, 0.0787, 0.0786, 0.0784, 0.0783,
0.0779, 0.0778, 0.0778, 0.0777, 0.0775, 0.0774, 0.0774, 0.0770, 0.0768,
0.0767, 0.0766, 0.0763, 0.0760, 0.0759, 0.0756, 0.0756, 0.0755, 0.0754,
0.0752, 0.0751, 0.0749, 0.0749, 0.0740, 0.0738, 0.0736, 0.0736, 0.0733,
0.0729, 0.0729, 0.0726, 0.0723, 0.0722, 0.0721, 0.0721, 0.0721, 0.0717,
0.0715, 0.0714, 0.0714, 0.0713, 0.0713, 0.0712, 0.0707, 0.0704, 0.0694,
0.0690, 0.0689, 0.0688, 0.0678, 0.0671], device='cuda:0')
'''
'''
(Pdb) ans[0].sum()
tensor(26, device='cuda:0')
(Pdb) ans[1].sum()
tensor(26, device='cuda:0')
(Pdb) ans[2].sum()
tensor(26, device='cuda:0')
(Pdb) label_matrix[0].sum()
tensor(25., device='cuda:0')
(Pdb) label_matrix[1].sum()
tensor(24., device='cuda:0')
(Pdb) label_matrix[2].sum()
tensor(23., device='cuda:0')
'''
'''
(Pdb) iou_matrix
tensor([[0.0000, 0.9618, 0.9262, ..., 0.5556, 0.0000, 0.0000],
[0.0000, 0.0000, 0.9157, ..., 0.5608, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, ..., 0.5495, 0.0000, 0.0000],
...,
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0082],
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
device='cuda:0')
(Pdb) label_matrix
tensor([[0., 1., 1., ..., 0., 0., 0.],
[0., 0., 1., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]], device='cuda:0')
(Pdb) iou_matrix * label_matrix
tensor([[0.0000, 0.9618, 0.9262, ..., 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.9157, ..., 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
...,
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
device='cuda:0')
'''
'''
(Pdb) decay_iou
tensor([[0.0000, 0.9618, 0.9262, ..., 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.9157, ..., 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
...,
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
device='cuda:0')
(Pdb) compensate_iou
tensor([[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[0.9618, 0.9618, 0.9618, ..., 0.9618, 0.9618, 0.9618],
[0.9262, 0.9262, 0.9262, ..., 0.9262, 0.9262, 0.9262],
...,
[0.1814, 0.1814, 0.1814, ..., 0.1814, 0.1814, 0.1814],
[0.5750, 0.5750, 0.5750, ..., 0.5750, 0.5750, 0.5750],
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
device='cuda:0')
'''
'''
(Pdb) decay_matrix
tensor([[1.0000, 0.1572, 0.1798, ..., 1.0000, 1.0000, 1.0000],
[1.0000, 1.0000, 0.1869, ..., 1.0000, 1.0000, 1.0000],
[1.0000, 1.0000, 1.0000, ..., 1.0000, 1.0000, 1.0000],
...,
[1.0000, 1.0000, 1.0000, ..., 1.0000, 1.0000, 1.0000],
[1.0000, 1.0000, 1.0000, ..., 1.0000, 1.0000, 1.0000],
[1.0000, 1.0000, 1.0000, ..., 1.0000, 1.0000, 1.0000]],
device='cuda:0')
(Pdb) compensate_matrix
tensor([[1.0000, 1.0000, 1.0000, ..., 1.0000, 1.0000, 1.0000],
[0.1572, 0.1572, 0.1572, ..., 0.1572, 0.1572, 0.1572],
[0.1798, 0.1798, 0.1798, ..., 0.1798, 0.1798, 0.1798],
...,
[0.9363, 0.9363, 0.9363, ..., 0.9363, 0.9363, 0.9363],
[0.5162, 0.5162, 0.5162, ..., 0.5162, 0.5162, 0.5162],
[1.0000, 1.0000, 1.0000, ..., 1.0000, 1.0000, 1.0000]],
device='cuda:0')
(Pdb) decay_coefficient
tensor([1.0000, 0.1572, 0.1798, 1.0000, 0.1783, 0.1802, 1.0000, 1.0000, 1.0000,
0.1689, 1.0000, 1.0000, 0.1855, 0.2022, 0.2764, 0.1547, 0.2404, 0.2013,
0.2406, 1.0000, 0.1905, 1.0000, 0.1897, 1.0000, 1.0000, 1.0000, 0.2510,
0.2585, 0.1958, 0.1687, 0.1882, 1.0000, 0.1527, 0.1832, 0.1805, 0.2230,
0.1968, 0.2195, 0.1718, 1.0000, 1.0000, 1.0000, 0.2053, 0.4830, 1.0000,
0.2468, 0.2634, 0.1930, 0.1613, 0.1854, 0.2688, 0.1672, 1.0000, 0.1676,
0.2001, 0.1778, 0.9999, 0.1972, 0.4646, 0.9999, 1.0000, 1.0000, 1.0000,
0.2382, 0.6206, 0.9990, 0.2464, 0.2426, 0.2175, 0.2442, 1.0000, 1.0000,
0.3181, 0.2265, 0.1825, 0.1945, 0.5600, 0.1545, 0.3049, 1.0000, 0.1934,
0.1655, 0.2479, 0.4379, 0.4354, 0.2128, 0.3090, 0.3096, 0.2048, 0.2055,
0.3932, 0.9999, 0.3649, 0.6227, 0.3221, 0.4436, 1.0000, 0.2139, 1.0000,
1.0000, 0.2798, 0.2927, 1.0000, 0.4676, 1.0000, 0.3377, 0.5426, 0.1867,
0.1898, 1.0000, 1.0000, 0.1810, 1.0000, 1.0000, 0.2745, 0.2394, 0.1964,
1.0000, 0.8990, 0.3527, 0.2131, 0.9999, 1.0000, 0.2128, 1.0000, 0.3925,
1.0000, 1.0000, 0.1672, 0.9756, 0.3158, 1.0000, 0.2516, 1.0000, 0.2073,
0.3283, 1.0000, 0.1679, 1.0000, 0.2823, 0.1908, 0.2075, 0.6698, 0.2049,
1.0000, 0.1524, 0.4344, 0.2468, 0.1950, 0.5180, 0.3053, 1.0000, 0.2947,
0.1668, 0.6135, 1.0000, 0.3543, 0.1941, 0.2391, 1.0000, 0.1622, 0.6628,
0.1915, 0.1915, 0.1984, 0.4174, 1.0000, 1.0000, 0.1648, 0.2391, 0.3841,
1.0000, 0.2000, 0.3462, 0.2100, 0.1955, 0.1864, 0.2139, 0.9999, 0.3096,
0.2372, 0.3589, 0.2454, 0.2163, 0.1353, 0.5128, 0.1966, 0.1813, 1.0000,
1.0000, 0.6258, 0.5399, 0.4436, 0.1565, 0.1968, 0.7929, 1.0000, 0.2052,
0.8349, 1.0000, 1.0000, 1.0000, 0.9999, 0.1713, 0.2855, 1.0000, 0.9990,
0.1734, 1.0000, 0.6206, 1.0000, 0.6314, 1.0000, 0.3608, 1.0000, 1.0000,
0.1814, 0.3864, 0.9998, 0.1868, 1.0000, 0.6018, 0.2664, 1.0000, 1.0000,
0.2204, 1.0000, 1.0000, 0.1655, 0.5560, 1.0000, 0.6018, 0.7625, 0.2531,
0.2891, 1.0000, 0.3947, 0.2269, 0.5983, 1.0000, 0.2240, 0.4622, 0.1954,
0.1679, 0.2926, 0.9999, 1.0000, 0.8288, 0.3631, 0.2429, 0.2077, 0.5807,
0.2477, 1.0000, 0.1835, 0.2947, 1.0000, 1.0000, 0.2103, 1.0000, 0.1724,
0.2000, 0.2840, 0.4740, 1.0000, 0.2410, 0.7940, 0.4111, 0.1751, 0.9077,
1.0000, 0.3905, 0.1929, 0.3705, 0.2016, 0.4020, 0.1601, 0.2947, 0.4558,
0.1863, 0.9965, 1.0000, 1.0000, 0.2813, 0.3324, 0.2163, 0.6408, 0.1911,
1.0000, 0.2366, 1.0000, 0.2331, 1.0000, 1.0000, 0.2059, 0.9982, 1.0000,
0.9972, 0.4041, 0.3809, 0.3376, 0.9527, 0.9110, 0.6471, 1.0000, 0.2990,
1.0000, 0.9895, 0.4757, 0.5870, 1.0000, 0.2757, 1.0000, 0.7463, 0.6971,
0.1898, 0.2349, 1.0000, 0.5335, 0.1353, 0.9996, 0.3981, 1.0000, 0.9676,
1.0000, 0.1954, 0.2259, 0.9738, 0.2285, 0.2074, 1.0000, 0.4963, 0.2780,
1.0000, 0.4111, 0.4801, 0.2780, 0.6819, 0.9255, 0.2259, 0.2002, 0.8939,
1.0000, 0.4634, 1.0000, 1.0000, 0.7577, 0.2078, 0.4951, 0.2280, 1.0000,
1.0000, 1.0000, 0.4014, 0.4548, 0.2429, 0.2128, 1.0000, 0.3658, 1.0000,
0.9756, 0.6232, 0.9124, 0.3601, 0.2744, 0.2895, 0.2001, 0.4525, 1.0000,
0.1758, 0.4439, 0.2022, 0.1865, 0.1894, 0.2269, 0.1781, 1.0000, 0.3609,
0.1929, 0.2681, 0.7913, 0.9999, 1.0000, 1.0000, 0.3181, 0.2103, 0.1950,
1.0000, 0.1819, 0.2036, 0.5941, 0.5819, 0.2022, 1.0000, 0.4777, 0.1774,
0.6963, 1.0000, 1.0000, 0.5600, 0.2989, 0.1664, 0.4174, 0.9394, 0.5335,
0.9756, 0.7929, 0.2073, 0.2270, 0.2930, 0.5621, 0.3410, 0.2926, 1.0000,
1.0000, 0.8152, 0.5078, 0.1772, 0.7817, 1.0000, 0.2154, 0.3641, 0.2681,
0.1963, 0.1870, 0.5180, 0.2982, 0.3277, 0.9999, 0.5600, 0.6903, 1.0000,
0.2754, 0.1911, 0.2704, 0.1668, 0.7497, 0.5600, 0.2895, 0.2049, 0.9588,
0.3695, 0.2894, 1.0000, 0.4810, 0.2742, 0.6411, 1.0000, 0.7090, 0.3589,
0.6018, 0.7151, 0.3002, 0.9344, 0.2259, 0.3921, 1.0000, 1.0000, 0.6018,
0.3710, 0.9543, 0.4373, 0.9361, 0.3053, 0.4208, 1.0000, 0.4777, 0.5065,
0.9999, 0.8339, 1.0000, 0.7549, 0.6350, 0.3088, 0.2617, 0.3994, 0.6034,
0.5947, 0.7377, 0.9990, 0.7771, 0.3594, 0.3155, 1.0000, 0.4505, 1.0000,
0.8445, 0.9756, 0.9810, 0.8240, 0.3274, 0.3215, 0.2753, 0.9701, 0.5041,
0.5205, 0.5485, 1.0000, 0.5994, 1.0000], device='cuda:0')
'''
# model settings
model = dict(
type='SOLO',
pretrained='torchvision://resnet50',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3), # C2, C3, C4, C5
frozen_stages=1,
style='pytorch'),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
start_level=0,
num_outs=5),
bbox_head=dict(
type='SOLOHead', # SOLOHead对应同名 SOLOHead.py, 因此可以修改type对应相应自己修改的SOLOHead_xx.py
num_classes=81,
in_channels=256,
stacked_convs=7,
seg_feat_channels=256,
strides=[8, 8, 16, 32, 32],
scale_ranges=((1, 96), (48, 192), (96, 384), (192, 768), (384, 2048)),
sigma=0.2,
num_grids=[40, 36, 24, 16, 12],
cate_down_pos=0,
with_deform=False,
loss_ins=dict(
type='DiceLoss',
use_sigmoid=True,
loss_weight=3.0),
loss_cate=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
))
# training and testing settings
train_cfg = dict()
test_cfg = dict(
nms_pre=500,
score_thr=0.1,
mask_thr=0.5,
update_thr=0.05,
kernel='gaussian', # gaussian/linear
sigma=2.0,
max_per_img=100)
# dataset settings
dataset_type = 'CocoDataset'
data_root = 'data/coco2017/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
imgs_per_gpu=2,
workers_per_gpu=2,
train=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_train2017.json',
img_prefix=data_root + 'train2017/',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline))
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=1.0 / 3,
step=[9, 11])
#save
checkpoint_config = dict(interval=1) # log文件里面
# yapf:disable
log_config = dict(
interval=1, # 每interval次iter打印一次
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
# runtime settings
total_epochs = 12
device_ids = range(8)
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/solo_release_r50_fpn_8gpu_1x'
load_from = None
resume_from = None
workflow = [('train', 1)]
from .anchor_head import AnchorHead
from .atss_head import ATSSHead
from .fcos_head import FCOSHead
from .fovea_head import FoveaHead
from .free_anchor_retina_head import FreeAnchorRetinaHead
from .ga_retina_head import GARetinaHead
from .ga_rpn_head import GARPNHead
from .guided_anchor_head import FeatureAdaption, GuidedAnchorHead
from .reppoints_head import RepPointsHead
from .retina_head import RetinaHead
from .retina_sepbn_head import RetinaSepBNHead
from .rpn_head import RPNHead
from .ssd_head import SSDHead
from .solo_head import SOLOHead
from .solov2_head import SOLOv2Head
from .solov2_light_head import SOLOv2LightHead
from .decoupled_solo_head import DecoupledSOLOHead
from .decoupled_solo_light_head import DecoupledSOLOLightHead
from .solo_head_xx improt SOLOHead_xx # 注册文件名
__all__ = [
'AnchorHead', 'GuidedAnchorHead', 'FeatureAdaption', 'RPNHead',
'GARPNHead', 'RetinaHead', 'RetinaSepBNHead', 'GARetinaHead', 'SSDHead',
'FCOSHead', 'RepPointsHead', 'FoveaHead', 'FreeAnchorRetinaHead',
'ATSSHead', 'SOLOHead','SOLOv2Head', 'SOLOv2LightHead', 'DecoupledSOLOHead', 'DecoupledSOLOLightHead'
'SOLOHead_xx'
]
# 然后把 SOLOHead_xx.py实现以下, 对应的super函数更改下,就可以保留官方文件的同时,进行小更改了。