【代码】CenterNet使用(续)(对五六七部分详解)(五)

接上面部分,对五六七部分进行详解,这篇介绍第五部分,也就是model从建立到测试,数据从images到output、dets的详细过程。

一、回顾

第五部分放入网络中测试,产生输出:

      output, dets, forward_time = self.process(images, return_time=True)

process部分在ctdet.py中:

  def process(self, images, return_time=False):
    with torch.no_grad():
      output = self.model(images)[-1]
      hm = output['hm'].sigmoid_()
      wh = output['wh']
      reg = output['reg'] if self.opt.reg_offset else None
      if self.opt.flip_test:
        hm = (hm[0:1] + flip_tensor(hm[1:2])) / 2
        wh = (wh[0:1] + flip_tensor(wh[1:2])) / 2
        reg = reg[0:1] if reg is not None else None
      torch.cuda.synchronize()
      forward_time = time.time()
      dets = ctdet_decode(hm, wh, reg=reg, K=self.opt.K)
      
    if return_time:
      return output, dets, forward_time
    else:
      return output, dets

首先将images放入model中,就得到output了。output具有三个部分

{'hm': 1*80*128*128,

'reg': 1*2*128*128,

'wh': 1*2*128*128},可以看出来,只有hm(heatmap)是与类别(80个)相关的,reg(offset:x_off & y_off)和wh(width & height)是与类别无关的。

之后使用ctdet_decode进行解码,得到dets,dets是1*100*6的张量。

最终,返回outputs,dets,forward_time。

二、详解

分为两个部分,第一个部分是images放入model中,得到output

第二个部分是ctdet_decode解码。

1.self.model(images)[-1]

1.seld.model的建立

在BaseDetector中:

    self.model = create_model(opt.arch, opt.heads, opt.head_conv)
    self.model = load_model(self.model, opt.load_model) 

 涉及到的两个函数来源于models.model。

1. create_model

两行主要的代码如下:

  get_model = _model_factory[arch]
  model = get_model(num_layers=num_layers, heads=heads, head_conv=head_conv)

产生的中间变量的结果:【代码】CenterNet使用(续)(对五六七部分详解)(五)_第1张图片,arch用来获得get_model,在demo中,获得的是networks中的pose_dla_dcn的get_pose_net函数,其定义为:

def get_pose_net(num_layers, heads, head_conv=256, down_ratio=4):
  model = DLASeg('dla{}'.format(num_layers), heads,
                 pretrained=True,
                 down_ratio=down_ratio,
                 final_kernel=1,
                 last_level=5,
                 head_conv=head_conv)
  return model

 是DLANet,可能是来自于这种网络结构:https://blog.csdn.net/wuyubinbin/article/details/80622762

2. load_model,用于加载预训练模型(待看):

def load_model(model, model_path, optimizer=None, resume=False, 
               lr=None, lr_step=None):
  start_epoch = 0
  checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage)
  print('loaded {}, epoch {}'.format(model_path, checkpoint['epoch']))
  state_dict_ = checkpoint['state_dict']
  state_dict = {}
  
  # convert data_parallal to model
  for k in state_dict_:
    if k.startswith('module') and not k.startswith('module_list'):
      state_dict[k[7:]] = state_dict_[k]
    else:
      state_dict[k] = state_dict_[k]
  model_state_dict = model.state_dict()

  # check loaded parameters and created model parameters
  for k in state_dict:
    if k in model_state_dict:
      if state_dict[k].shape != model_state_dict[k].shape:
        print('Skip loading parameter {}, required shape{}, '\
              'loaded shape{}.'.format(
          k, model_state_dict[k].shape, state_dict[k].shape))
        state_dict[k] = model_state_dict[k]
    else:
      print('Drop parameter {}.'.format(k))
  for k in model_state_dict:
    if not (k in state_dict):
      print('No param {}.'.format(k))
      state_dict[k] = model_state_dict[k]
  model.load_state_dict(state_dict, strict=False)

  # resume optimizer parameters
  if optimizer is not None and resume:
    if 'optimizer' in checkpoint:
      optimizer.load_state_dict(checkpoint['optimizer'])
      start_epoch = checkpoint['epoch']
      start_lr = lr
      for step in lr_step:
        if start_epoch >= step:
          start_lr *= 0.1
      for param_group in optimizer.param_groups:
        param_group['lr'] = start_lr
      print('Resumed optimizer with start lr', start_lr)
    else:
      print('No optimizer parameters in checkpoint.')
  if optimizer is not None:
    return model, optimizer, start_epoch
  else:
    return model

2. model的forward()

    def forward(self, x):
# x = 1*3*512*512
        x = self.base(x)
# x 是六个元素的list = 1* [16*512*512, 32*256*256, 64*128*128, 128*64*64, 256*32*32, 512*16*16]

        x = self.dla_up(x)
# y = 1* [64*128*128, 128*64*64, 256*32*32]
        y = []
        for i in range(self.last_level - self.first_level):
            y.append(x[i].clone())
        self.ida_up(y, 0, len(y))
# y = 1* [64*128*128, 64*128*128, 64*128*128]
        z = {}
        for head in self.heads:
            z[head] = self.__getattr__(head)(y[-1])
# z = {'hm' : 1*80*128*128,
       'reg' : 1*2*128*128,
       'wh' : 1*2*128*128}
        return [z]

 2. ctdet_decode

ctdet_decode在models.decode中,最终产生的detections是bboxes、scores、clses的合并:
    detections = torch.cat([bboxes, scores, clses], dim=2)

 其中bboxes是左上角,右下角的形式,是1*100*4的FloatTensor。scores是1*100*1的FloatTensor的[0, 1]内的Tensor,其按照降序排列。clses也是1*100*1的Tensor,均是整数,代表类别。具体的解码过程可以参照之前的:

https://mp.csdn.net/postedit/91955759

 

你可能感兴趣的:(代码)