接上面部分,对五六七部分进行详解,这篇介绍第五部分,也就是model从建立到测试,数据从images到output、dets的详细过程。
第五部分放入网络中测试,产生输出:
output, dets, forward_time = self.process(images, return_time=True)
process部分在ctdet.py中:
def process(self, images, return_time=False):
with torch.no_grad():
output = self.model(images)[-1]
hm = output['hm'].sigmoid_()
wh = output['wh']
reg = output['reg'] if self.opt.reg_offset else None
if self.opt.flip_test:
hm = (hm[0:1] + flip_tensor(hm[1:2])) / 2
wh = (wh[0:1] + flip_tensor(wh[1:2])) / 2
reg = reg[0:1] if reg is not None else None
torch.cuda.synchronize()
forward_time = time.time()
dets = ctdet_decode(hm, wh, reg=reg, K=self.opt.K)
if return_time:
return output, dets, forward_time
else:
return output, dets
首先将images放入model中,就得到output了。output具有三个部分
{'hm': 1*80*128*128,
'reg': 1*2*128*128,
'wh': 1*2*128*128},可以看出来,只有hm(heatmap)是与类别(80个)相关的,reg(offset:x_off & y_off)和wh(width & height)是与类别无关的。
之后使用ctdet_decode进行解码,得到dets,dets是1*100*6的张量。
最终,返回outputs,dets,forward_time。
分为两个部分,第一个部分是images放入model中,得到output
第二个部分是ctdet_decode解码。
在BaseDetector中:
self.model = create_model(opt.arch, opt.heads, opt.head_conv)
self.model = load_model(self.model, opt.load_model)
涉及到的两个函数来源于models.model。
1. create_model
两行主要的代码如下:
get_model = _model_factory[arch]
model = get_model(num_layers=num_layers, heads=heads, head_conv=head_conv)
产生的中间变量的结果:,arch用来获得get_model,在demo中,获得的是networks中的pose_dla_dcn的get_pose_net函数,其定义为:
def get_pose_net(num_layers, heads, head_conv=256, down_ratio=4):
model = DLASeg('dla{}'.format(num_layers), heads,
pretrained=True,
down_ratio=down_ratio,
final_kernel=1,
last_level=5,
head_conv=head_conv)
return model
是DLANet,可能是来自于这种网络结构:https://blog.csdn.net/wuyubinbin/article/details/80622762
2. load_model,用于加载预训练模型(待看):
def load_model(model, model_path, optimizer=None, resume=False,
lr=None, lr_step=None):
start_epoch = 0
checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage)
print('loaded {}, epoch {}'.format(model_path, checkpoint['epoch']))
state_dict_ = checkpoint['state_dict']
state_dict = {}
# convert data_parallal to model
for k in state_dict_:
if k.startswith('module') and not k.startswith('module_list'):
state_dict[k[7:]] = state_dict_[k]
else:
state_dict[k] = state_dict_[k]
model_state_dict = model.state_dict()
# check loaded parameters and created model parameters
for k in state_dict:
if k in model_state_dict:
if state_dict[k].shape != model_state_dict[k].shape:
print('Skip loading parameter {}, required shape{}, '\
'loaded shape{}.'.format(
k, model_state_dict[k].shape, state_dict[k].shape))
state_dict[k] = model_state_dict[k]
else:
print('Drop parameter {}.'.format(k))
for k in model_state_dict:
if not (k in state_dict):
print('No param {}.'.format(k))
state_dict[k] = model_state_dict[k]
model.load_state_dict(state_dict, strict=False)
# resume optimizer parameters
if optimizer is not None and resume:
if 'optimizer' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer'])
start_epoch = checkpoint['epoch']
start_lr = lr
for step in lr_step:
if start_epoch >= step:
start_lr *= 0.1
for param_group in optimizer.param_groups:
param_group['lr'] = start_lr
print('Resumed optimizer with start lr', start_lr)
else:
print('No optimizer parameters in checkpoint.')
if optimizer is not None:
return model, optimizer, start_epoch
else:
return model
def forward(self, x):
# x = 1*3*512*512
x = self.base(x)
# x 是六个元素的list = 1* [16*512*512, 32*256*256, 64*128*128, 128*64*64, 256*32*32, 512*16*16]
x = self.dla_up(x)
# y = 1* [64*128*128, 128*64*64, 256*32*32]
y = []
for i in range(self.last_level - self.first_level):
y.append(x[i].clone())
self.ida_up(y, 0, len(y))
# y = 1* [64*128*128, 64*128*128, 64*128*128]
z = {}
for head in self.heads:
z[head] = self.__getattr__(head)(y[-1])
# z = {'hm' : 1*80*128*128,
'reg' : 1*2*128*128,
'wh' : 1*2*128*128}
return [z]
ctdet_decode在models.decode中,最终产生的detections是bboxes、scores、clses的合并:
detections = torch.cat([bboxes, scores, clses], dim=2)
其中bboxes是左上角,右下角的形式,是1*100*4的FloatTensor。scores是1*100*1的FloatTensor的[0, 1]内的Tensor,其按照降序排列。clses也是1*100*1的Tensor,均是整数,代表类别。具体的解码过程可以参照之前的:
https://mp.csdn.net/postedit/91955759