之前写过对SOLO demo的代码解析,今天来梳理一下training过程。
首先是tools/train.py
,这个文件是训练的开始,命令行运行的就是该文件:
from mmdet.apis import set_random_seed, train_detector
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
# update configs according to CLI args
if args.work_dir is not None:
cfg.work_dir = args.work_dir
if args.resume_from is not None:
cfg.resume_from = args.resume_from
cfg.gpus = args.gpus
if args.autoscale_lr:
# apply the linear scaling rule (https://arxiv.org/abs/1706.02677)
cfg.optimizer['lr'] = cfg.optimizer['lr'] * cfg.gpus / 8
# init distributed env first, since logger depends on the dist info.
if args.launcher == 'none':
distributed = False
else:
distributed = True
init_dist(args.launcher, **cfg.dist_params)
# create work_dir
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
# init the logger before other steps
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
log_file = osp.join(cfg.work_dir, '{}.log'.format(timestamp))
logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
# log some basic info
logger.info('Distributed training: {}'.format(distributed))
logger.info('MMDetection Version: {}'.format(__version__))
logger.info('Config:\n{}'.format(cfg.text))
# set random seeds
if args.seed is not None:
logger.info('Set random seed to {}, deterministic: {}'.format(
args.seed, args.deterministic))
set_random_seed(args.seed, deterministic=args.deterministic)
model = build_detector(
cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
datasets = [build_dataset(cfg.data.train)]
if len(cfg.workflow) == 2:
datasets.append(build_dataset(cfg.data.val))
if cfg.checkpoint_config is not None:
# save mmdet version, config file content and class names in
# checkpoints as meta data
cfg.checkpoint_config.meta = dict(
mmdet_version=__version__,
config=cfg.text,
CLASSES=datasets[0].CLASSES)
# add an attribute for visualization convenience
model.CLASSES = datasets[0].CLASSES
train_detector(
model,
datasets,
cfg,
distributed=distributed,
validate=args.validate,
timestamp=timestamp)
文件的核心部分就是这个main函数,主要是加载配置文件,从配置中创建SOLO模型(build_detector),加载数据集(build_dataset),以及进入训练SOLO的模块(train_detector)。该部分代码没有太多细节,创建模型的部分与demo的一样,加载数据集部分可以暂不深究,所以核心是train_detector模块。
下面进入mmdet/apis/train.py
文件,来对train_detector函数一探究竟:
def train_detector(model,
dataset,
cfg,
distributed=False,
validate=False,
timestamp=None):
logger = get_root_logger(cfg.log_level)
# start training
if distributed:
_dist_train(
model,
dataset,
cfg,
validate=validate,
logger=logger,
timestamp=timestamp)
else:
_non_dist_train(
model,
dataset,
cfg,
validate=validate,
logger=logger,
timestamp=timestamp)
def _non_dist_train(model,
dataset,
cfg,
validate=False,
logger=None,
timestamp=None):
if validate:
raise NotImplementedError('Built-in validation is not implemented '
'yet in not-distributed training. Use '
'distributed training or test.py and '
'*eval.py scripts instead.')
# prepare data loaders
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
data_loaders = [
build_dataloader(
ds,
cfg.data.imgs_per_gpu,
cfg.data.workers_per_gpu,
cfg.gpus,
dist=False) for ds in dataset
]
# put model on gpus
model = MMDataParallel(model, device_ids=range(cfg.gpus)).cuda()
# build runner
optimizer = build_optimizer(model, cfg.optimizer)
runner = Runner(
model, batch_processor, optimizer, cfg.work_dir, logger=logger)
# an ugly walkaround to make the .log and .log.json filenames the same
runner.timestamp = timestamp
# fp16 setting
fp16_cfg = cfg.get('fp16', None)
if fp16_cfg is not None:
optimizer_config = Fp16OptimizerHook(
**cfg.optimizer_config, **fp16_cfg, distributed=False)
else:
optimizer_config = cfg.optimizer_config
runner.register_training_hooks(cfg.lr_config, optimizer_config,
cfg.checkpoint_config, cfg.log_config)
if cfg.resume_from:
runner.resume(cfg.resume_from)
elif cfg.load_from:
runner.load_checkpoint(cfg.load_from)
runner.run(data_loaders, cfg.workflow, cfg.total_epochs)
首先根据是否要分布式训练做出判断。假设这里不采用distributed train,所以调用 _non_dist_train函数。该函数的作用是创建了dataloader,将model放到GPU上,创建了一个Runner类的实例(Runner有兴趣可以搜一下,他是MMdetection中比较重要的一个容器。不过他比较底层,可以不深究,只要了解它是将“数据、模型、训练策略、评估、推理”融合在一起的工具,在training的时候要用它将几部分模块进行统一)。我们发现在Runner实例化的时候,model、批处理工具、优化器等一些跟模型相关的模块就传递进去了;而在最后一行run函数执行时,又把data_loader,工作流以及训练时的epoch数传递进去。
所以接下来代码进入到runner.run中。要了解run函数,就需要简单看一下Runner这个类,Runner存在于mmcv中(不在MMdetection项目下,但是必须有mmcv才可运行MMDetection)。
Runner类存在于mmcv/runner/epoch_based_runner.py
文件中:
@RUNNERS.register_module()
class EpochBasedRunner(BaseRunner):
"""Epoch-based Runner.
This runner train models epoch by epoch.
"""
def run_iter(self, data_batch, train_mode, **kwargs):
if self.batch_processor is not None: #走该分支
outputs = self.batch_processor(
self.model, data_batch, train_mode=train_mode, **kwargs)
elif train_mode:
outputs = self.model.train_step(data_batch, self.optimizer,
**kwargs)
else:
outputs = self.model.val_step(data_batch, self.optimizer, **kwargs)
if not isinstance(outputs, dict):
raise TypeError('"batch_processor()" or "model.train_step()"'
'and "model.val_step()" must return a dict')
if 'log_vars' in outputs:
self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
self.outputs = outputs
def train(self, data_loader, **kwargs):
self.model.train()
self.mode = 'train'
self.data_loader = data_loader
self._max_iters = self._max_epochs * len(self.data_loader)
self.call_hook('before_train_epoch')
time.sleep(2) # Prevent possible deadlock during epoch transition
#训练一个epoch
for i, data_batch in enumerate(self.data_loader):
self._inner_iter = i
self.call_hook('before_train_iter')
###核心###
self.run_iter(data_batch, train_mode=True)
self.call_hook('after_train_iter')
self._iter += 1
self.call_hook('after_train_epoch')
self._epoch += 1
@torch.no_grad()
def val(self, data_loader, **kwargs):
self.model.eval()
self.mode = 'val'
self.data_loader = data_loader
self.call_hook('before_val_epoch')
time.sleep(2) # Prevent possible deadlock during epoch transition
for i, data_batch in enumerate(self.data_loader):
self._inner_iter = i
self.call_hook('before_val_iter')
self.run_iter(data_batch, train_mode=False)
self.call_hook('after_val_iter')
self.call_hook('after_val_epoch')
def run(self, data_loaders, workflow, max_epochs=None, **kwargs):
"""Start running.
Args:
data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
and validation.
workflow (list[tuple]): A list of (phase, epochs) to specify the
running order and epochs. E.g, [('train', 2), ('val', 1)] means
running 2 epochs for training and 1 epoch for validation,
iteratively.
"""
assert isinstance(data_loaders, list)
assert mmcv.is_list_of(workflow, tuple)
assert len(data_loaders) == len(workflow)
if max_epochs is not None:
warnings.warn(
'setting max_epochs in run is deprecated, '
'please set max_epochs in runner_config', DeprecationWarning)
self._max_epochs = max_epochs
assert self._max_epochs is not None, (
'max_epochs must be specified during instantiation')
for i, flow in enumerate(workflow):
mode, epochs = flow # 如果有train,epochs为train的epoch数(为了应对train-val交错的情况,如train2个epoch,在val1个epoch)
if mode == 'train':
self._max_iters = self._max_epochs * len(data_loaders[i])
break
work_dir = self.work_dir if self.work_dir is not None else 'NONE'
self.logger.info('Start running, host: %s, work_dir: %s',
get_host_info(), work_dir)
self.logger.info('workflow: %s, max: %d epochs', workflow,
self._max_epochs)
self.call_hook('before_run')
#############下面开始为重要代码#################
while self.epoch < self._max_epochs: # 训练至最大epoch
for i, flow in enumerate(workflow):
mode, epochs = flow
# data_loaders[i]对应的是train部分数据集或者val数据集
if isinstance(mode, str): # self.train()
if not hasattr(self, mode):
raise ValueError(
f'runner has no method named "{mode}" to run an '
'epoch')
epoch_runner = getattr(self, mode) # 返回train或者val
else:
raise TypeError(
'mode in workflow must be a str, but got {}'.format(
type(mode)))
for _ in range(epochs):
# 训练完毕后break,同时满足下述两个条件
if mode == 'train' and self.epoch >= self._max_epochs:
break
epoch_runner(data_loaders[i], **kwargs)
time.sleep(1) # wait for some hooks like loggers to finish
self.call_hook('after_run')
@RUNNERS.register_module()
class Runner(EpochBasedRunner):
"""Deprecated name of EpochBasedRunner."""
def __init__(self, *args, **kwargs):
warnings.warn(
'Runner was deprecated, please use EpochBasedRunner instead')
super().__init__(*args, **kwargs)
可以发现Runner类是完全继承自EpochBasedRunner类的,所以关注的重点来到了EpochBasedRunner类。该类继承自BaseRunner(基本只做初始化时有用)。重点来看run函数,首先根据workflow来确定train和val的情况,然后开始训练。假设这里的workflow没有val,此时epoch_runner的返回值一定是train,所以直接调用train函数。上面的train函数中,通过一个循环进行一个epoch的训练,而循环中的最重要的一行调用了run_iter函数。我们在配置文件中的设定是要做批处理的,因此调用batch_processor,而该函数是在一开始实例化Runner的时候就传入的,所以一会回头看下这个函数。先剧透一波,其实这里得到的output就是loss,有了loss的值以后,就可以根据hook的一系列操作进行梯度下降方向传播,来做网络参数的更新。
所以回头看下mmdet/apis/train.py
文件的 batch_processor 函数:
def batch_processor(model, data, train_mode):
"""Process a data batch.
This method is required as an argument of Runner, which defines how to
process a data batch and obtain proper outputs. The first 3 arguments of
batch_processor are fixed.
Args:
model (nn.Module): A PyTorch model.
data (dict): The data batch in a dict.
train_mode (bool): Training mode or not. It may be useless for some
models.
Returns:
dict: A dict containing losses and log vars.
"""
losses = model(**data) # 计算loss
loss, log_vars = parse_losses(losses)
outputs = dict(
loss=loss, log_vars=log_vars, num_samples=len(data['img'].data))
return outputs
def parse_losses(losses):
log_vars = OrderedDict()
for loss_name, loss_value in losses.items():
if isinstance(loss_value, torch.Tensor):
log_vars[loss_name] = loss_value.mean()
elif isinstance(loss_value, list):
log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
else:
raise TypeError(
'{} is not a tensor or list of tensors'.format(loss_name))
loss = sum(_value for _key, _value in log_vars.items() if 'loss' in _key)
log_vars['loss'] = loss
for loss_name, loss_value in log_vars.items():
# reduce loss when distributed training
if dist.is_available() and dist.is_initialized():
loss_value = loss_value.data.clone()
dist.all_reduce(loss_value.div_(dist.get_world_size()))
log_vars[loss_name] = loss_value.item()
return loss, log_vars
model( ∗ ∗ ** ∗∗data)是最重要的一个环节,将当前batch作为输入传递到SOLO这个SingleStageInsDetector中,执行forward参数,由于在做training,因此调用forward_train方法。(parse函数作用不大,相当于输出更全面的loss信息,这里可以先忽略)。
forward_train函数在mmdet/models/detectors/single_stage_ins.py
中:
def forward_train(self,
img,
img_metas,
gt_bboxes,
gt_labels,
gt_bboxes_ignore=None,
gt_masks=None):
x = self.extract_feat(img)
outs = self.bbox_head(x)
if self.with_mask_feat_head:
mask_feat_pred = self.mask_feat_head(
x[self.mask_feat_head.
start_level:self.mask_feat_head.end_level + 1])
loss_inputs = outs + (mask_feat_pred, gt_bboxes, gt_labels, gt_masks, img_metas, self.train_cfg)
else:
loss_inputs = outs + (gt_bboxes, gt_labels, gt_masks, img_metas, self.train_cfg)
losses = self.bbox_head.loss(
*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
return losses
可以发现首先将这个batch的图像作为输入正向传播,得到预测结果outs和mask_head_pred。然后再batch中的GT信息一起,作为输入,传递到self.bbox_head.loss中计算损失函数。现在我们有了一个batch的图像经过SOLO后的类别和mask预测结果,以及对应的ground truth,接下来到了最后一个环节,就是计算loss的值。有了loss值以后就可以做反向传播了。
所以最后来到了损失函数部分(重要),见文件mmdet/models/anchor_heads/solo_head.py
(炒鸡详细注释版):
def loss(self,
ins_preds,
cate_preds,
gt_bbox_list,
gt_label_list,
gt_mask_list,
img_metas,
cfg,
gt_bboxes_ignore=None):
featmap_sizes = [featmap.size()[-2:] for featmap in
ins_preds]
# print(featmap_sizes) [torch.Size([200, 304]), torch.Size([200, 304]), torch.Size([100, 152]), torch.Size([50, 76]), torch.Size([50, 76])]
# print(gt_label_list) #(n, 1) n表示GT有n个instance,每一个对应类别的序号
# print(gt_bbox_list) #(n, 4)
# print(gt_mask_list[0].shape) # (n, 1216, 800)
# for i in range(len(featmap_sizes)):
# print(ins_preds[i].shape)
# print(cate_preds[i].shape)
#"""最后两个维度有很多种,取决于img的大小,这一部分的预处理还需要再看一下,这里以(200,304)这组为例"""
# torch.Size([1, 1600, 200, 304])
# torch.Size([1, 80, 40, 40])
# torch.Size([1, 1296, 200, 304])
# torch.Size([1, 80, 36, 36])
# torch.Size([1, 576, 100, 152])
# torch.Size([1, 80, 24, 24])
# torch.Size([1, 256, 50, 76])
# torch.Size([1, 80, 16, 16])
# torch.Size([1, 144, 50, 76])
# torch.Size([1, 80, 12, 12])
ins_label_list, cate_label_list, ins_ind_label_list = multi_apply(
self.solo_target_single,
gt_bbox_list,
gt_label_list,
gt_mask_list,
featmap_sizes=featmap_sizes)
# for i in range(5):
# print(ins_label_list[0][i].shape)
# print(ins_ind_label_list[0][i].shape)
# torch.Size([1600, 200, 304])
# torch.Size([1600])
# torch.Size([1296, 200, 304])
# torch.Size([1296])
# torch.Size([576, 100, 152])
# torch.Size([576])
# torch.Size([256, 50, 76])
# torch.Size([256])
# torch.Size([144, 50, 76])
# torch.Size([144])
# ins 根据ins_ind_label_list,对每一张特征图,挑选出对应位置为True的mask,并保存在ins_labels中
ins_labels = [torch.cat([ins_labels_level_img[ins_ind_labels_level_img, ...]
for ins_labels_level_img, ins_ind_labels_level_img in
zip(ins_labels_level, ins_ind_labels_level)], 0)
for ins_labels_level, ins_ind_labels_level in zip(zip(*ins_label_list), zip(*ins_ind_label_list))]
# for i in range(5):
# print(ins_labels[i].shape) #五次循环分别为m1 m2 m3 m4 m5,对应ins_ind_label_list[0][i]为True的个数
# print(ins_ind_label_list[0][i].sum()) #依次是 m1 m2 m3 m4 m5
###########For Example##########
# torch.Size([5, 272, 200])
# tensor(5, device='cuda:0')
# torch.Size([10, 272, 200])
# tensor(10, device='cuda:0')
# torch.Size([2, 136, 100])
# tensor(2, device='cuda:0')
# torch.Size([0, 68, 50])
# tensor(0, device='cuda:0')
# torch.Size([0, 68, 50])
# tensor(0, device='cuda:0')
# 和上面一样,不同之处就是GT的ins_label_list换成了预测值ins_preds
ins_preds = [torch.cat([ins_preds_level_img[ins_ind_labels_level_img, ...]
for ins_preds_level_img, ins_ind_labels_level_img in
zip(ins_preds_level, ins_ind_labels_level)], 0)
for ins_preds_level, ins_ind_labels_level in zip(ins_preds, zip(*ins_ind_label_list))]
### 至此,GTmask的变量ins_labels 和 预测mask的变量 ins_preds 在形式上实现了统一
ins_ind_labels = [
torch.cat([ins_ind_labels_level_img.flatten()
for ins_ind_labels_level_img in ins_ind_labels_level])
for ins_ind_labels_level in zip(*ins_ind_label_list)
]
flatten_ins_ind_labels = torch.cat(ins_ind_labels)
num_ins = flatten_ins_ind_labels.sum() # 值为 m1+m2+m3+m4+m5
# dice loss
loss_ins = []
for input, target in zip(ins_preds, ins_labels):
if input.size()[0] == 0:
continue
input = torch.sigmoid(input)
loss_ins.append(dice_loss(input, target))
loss_ins = torch.cat(loss_ins).mean()
loss_ins = loss_ins * self.ins_loss_weight
# cate
cate_labels = [
torch.cat([cate_labels_level_img.flatten()
for cate_labels_level_img in cate_labels_level])
for cate_labels_level in zip(*cate_label_list)
]
# for i in range(5):
# print(cate_labels[i].shape)
# torch.Size([1600])
# torch.Size([1296])
# torch.Size([576])
# torch.Size([256])
# torch.Size([144])
flatten_cate_labels = torch.cat(cate_labels)
# print(flatten_cate_labels.shape) # 3872 = 1600 + 1296 + 576 + 256 + 144
# for i in range(5):
# print(cate_preds[i].shape) # (1, 80, 40, 40)
#
# torch.Size([1, 80, 40, 40])
# torch.Size([1, 80, 36, 36])
# torch.Size([1, 80, 24, 24])
# torch.Size([1, 80, 16, 16])
# torch.Size([1, 80, 12, 12])
cate_preds = [
cate_pred.permute(0, 2, 3, 1).reshape(-1, self.cate_out_channels) # (x, 80) x为40*40或36*36或 ...
for cate_pred in cate_preds
]
flatten_cate_preds = torch.cat(cate_preds)
# print(flatten_cate_preds.shape) # (3872, 80)
loss_cate = self.loss_cate(flatten_cate_preds, flatten_cate_labels, avg_factor=num_ins + 1)
return dict(
loss_ins=loss_ins,
loss_cate=loss_cate)
###对GT进行处理,依据每张特征图的尺寸,把gt_bbox,gt_label和gt_mask分配到对应尺寸的featmap上,len(list)=5###
def solo_target_single(self,
gt_bboxes_raw,
gt_labels_raw,
gt_masks_raw,
featmap_sizes=None):
device = gt_labels_raw[0].device
# ins
# print(gt_bboxes_raw) #(n, 4)
# print(gt_labels_raw) #(n, 1)
# print(gt_masks_raw.shape) #(n, 1216, 800)
gt_areas = torch.sqrt((gt_bboxes_raw[:, 2] - gt_bboxes_raw[:, 0]) * (
gt_bboxes_raw[:, 3] - gt_bboxes_raw[:, 1]))
# print(gt_areas) #(n) 表示instance所对应的bbox的面积
ins_label_list = []
cate_label_list = []
ins_ind_label_list = []
#对五个level的特征图循环求解
for (lower_bound, upper_bound), stride, featmap_size, num_grid \
in zip(self.scale_ranges, self.strides, featmap_sizes, self.seg_num_grids):
# print(lower_bound) [1, 48, 96, 192, 384]
# print(upper_bound) [96, 192, 384, 768, 2048]
# print(stride) [8, 8, 16,32, 32]
# print(featmap_size)
# print(num_grid) 。 [40, 36, 24, 16, 12]
# print('=========================================')
ins_label = torch.zeros([num_grid ** 2, featmap_size[0], featmap_size[1]], dtype=torch.uint8, device=device)
# print(ins_label.shape) [(1600,200,304), (1296,200,304), (576,200,304), (256,200,304), (144,200,304)]
cate_label = torch.zeros([num_grid, num_grid], dtype=torch.int64, device=device)
# print(cate_label.shape) [(40,40), (36,36), (24,24), (16,16), (12,12)]
ins_ind_label = torch.zeros([num_grid ** 2], dtype=torch.bool, device=device)
# print(ins_ind_label.shape) [1600, 1296, 576, 256, 144]
hit_indices = ((gt_areas >= lower_bound) & (gt_areas <= upper_bound)).nonzero().flatten()# 存下bbox面积在range内的box下标,m个(m<=n)
if len(hit_indices) == 0: #如果没有合适面积的bbox,则append全零矩阵,接着考量下一张特征图的尺度
ins_label_list.append(ins_label)
cate_label_list.append(cate_label)
ins_ind_label_list.append(ins_ind_label)
continue
gt_bboxes = gt_bboxes_raw[hit_indices]
gt_labels = gt_labels_raw[hit_indices]
gt_masks = gt_masks_raw[hit_indices.cpu().numpy(), ...]
# Center Sampling:一个宽度和高度,容许gt的center在这个正方形区域内,落到更多的grid上面,avg=3
half_ws = 0.5 * (gt_bboxes[:, 2] - gt_bboxes[:, 0]) * self.sigma
half_hs = 0.5 * (gt_bboxes[:, 3] - gt_bboxes[:, 1]) * self.sigma
# print(half_ws) # [m]
# print(half_hs) # [m]
# mass center
gt_masks_pt = torch.from_numpy(gt_masks).to(device=device)
center_ws, center_hs = center_of_mass(gt_masks_pt)
# print(center_ws, center_hs) # gt instance的质心 [m]
valid_mask_flags = gt_masks_pt.sum(dim=-1).sum(dim=-1) > 0
# print(valid_mask_flags) # m个True构成的向量
output_stride = stride / 2 #[4, 4, 8, 16, 16]
for seg_mask, gt_label, half_h, half_w, center_h, center_w, valid_mask_flag in zip(gt_masks, gt_labels, half_hs, half_ws, center_hs, center_ws, valid_mask_flags):
if not valid_mask_flag:
continue
upsampled_size = (featmap_sizes[0][0] * 4, featmap_sizes[0][1] * 4)
# print(upsampled_size) # (800, 1216) 最上层level特征图的4倍
# coord表示计算gt的中心点落到第几个grid中
coord_w = int((center_w / upsampled_size[1]) // (1. / num_grid))
coord_h = int((center_h / upsampled_size[0]) // (1. / num_grid))
# left, top, right, down 让一个ins落到附近多个grid上面
top_box = max(0, int(((center_h - half_h) / upsampled_size[0]) // (1. / num_grid)))
down_box = min(num_grid - 1, int(((center_h + half_h) / upsampled_size[0]) // (1. / num_grid)))
left_box = max(0, int(((center_w - half_w) / upsampled_size[1]) // (1. / num_grid)))
right_box = min(num_grid - 1, int(((center_w + half_w) / upsampled_size[1]) // (1. / num_grid)))
# 限定到coord附近最多一格
top = max(top_box, coord_h-1)
down = min(down_box, coord_h+1)
left = max(coord_w-1, left_box)
right = min(right_box, coord_w+1)
cate_label[top:(down+1), left:(right+1)] = gt_label
# ins
seg_mask = mmcv.imrescale(seg_mask, scale=1. / output_stride) #mask尺度缩小到与featmap一致
seg_mask = torch.from_numpy(seg_mask).to(device=device)
for i in range(top, down+1):
for j in range(left, right+1):
label = int(i * num_grid + j)
ins_label[label, :seg_mask.shape[0], :seg_mask.shape[1]] = seg_mask #对应的通道值设置为seg_mask的值
ins_ind_label[label] = True #把选中的位置设置为True
ins_label_list.append(ins_label)
cate_label_list.append(cate_label)
ins_ind_label_list.append(ins_ind_label)
# print(len(ins_label_list)) # 5
return ins_label_list, cate_label_list, ins_ind_label_list
上述代码的注释写得比较详细,就是有些乱望见谅。这里说一下大致的思路:我们现在有网络正向传播得到的cate_preds和ins_preds,以及GT数据的gt_bbox_list, gt_label_list和gt_mask_list。我们要明确的是:cate_preds和ins_preds是五个feature map经过anchor head得到的结果,相当于五个尺度concat起来得到的一个list;而GT中的数据是对于整个img整体的各种标注。所以(1)首先要做solo_target_single函数,其目的是根据scale_range,将所有GT instance分配到五个不同的不同level的特征图上(通过计算instance bbox的面积和scale_range做比较,以确定该instance落在那个level上);(2)solo_target_single函数内部逻辑:计算出每个instance的bbox大小,将其分配到对应scale_range的level上,将GT放缩到该level特征图的大小,从而得到instance mask质心所对应的grid的索引,进而可以确定哪些grid负责预测该instance(正例),最后GT的形式就可以拆分成五个level的组合;(3)回到loss函数,根据正例索引,筛选出所有正例grid所对应的mask(GT和pred都做筛选),将pred mask做sigmoid归一化以后,两者就可以计算Dice Loss了,作为分割损失;(4)将GT cate_label转化为[3872]维向量,将pred cate_label转化为[3872, 80]维矩阵,两者做Focal Loss;(5)返回Dice Loss和Focal Loss的值,传到上述的parse_losses函数中计算最终的loss值。
Training代码大致就是如此了,有一些Runner,Hook的细节没发讲太清楚,感兴趣的话可以深度挖掘一下。如果只从算法的层面上,solo_head这一部分无疑是最核心的。篇幅所限,如果有地方没说明白或存在漏洞,欢迎评论区留言交流~