一个更加强力的ReID Baseline
Bag of Tricks and A Strong Baseline for Deep Person Re-identification
arXiv:https://arxiv.org/abs/1903.07071
https://github.com/michuanhaohao/reid-strong-baseline
工程下的文件很多看着可能有点复杂,项目采用的是config yml,的方式进行多种情况的设置,训练,有利于实验,yml方式不理解的如下链接
https://blog.csdn.net/m0_37192554/article/details/103024960
resnet50结构到layer4卷积层,经过全局平均池化,再reshape,代替全连接层,得到全局特征,用来tripletloss的输入
然后,全局特征接BNNeck 结构得到归一化后的特征(不用l2直接进行归一化)接入全连接的分类输出层
训练模型返回的是一个全局特征,和 分类输出层(全连接层)
测试返回的是特征层
self.gap = nn.AdaptiveAvgPool2d(1)
# self.gap = nn.AdaptiveMaxPool2d(1)
self.num_classes = num_classes
self.neck = neck
self.neck_feat = neck_feat
if self.neck == 'no':
self.classifier = nn.Linear(self.in_planes, self.num_classes)
# self.classifier = nn.Linear(self.in_planes, self.num_classes, bias=False) # new add by luo
# self.classifier.apply(weights_init_classifier) # new add by luo
elif self.neck == 'bnneck':
self.bottleneck = nn.BatchNorm1d(self.in_planes)
self.bottleneck.bias.requires_grad_(False) # no shift
self.classifier = nn.Linear(self.in_planes, self.num_classes, bias=False)
self.bottleneck.apply(weights_init_kaiming)
self.classifier.apply(weights_init_classifier)
def forward(self, x):
global_feat = self.gap(self.base(x)) # (b, 2048, 1, 1) #resnet50 的卷积层模型输入,
global_feat = global_feat.view(global_feat.shape[0], -1) # flatten to (bs, 2048)
if self.neck == 'no':
feat = global_feat
elif self.neck == 'bnneck':
feat = self.bottleneck(global_feat) # normalize for angular softmax
if self.training:
cls_score = self.classifier(feat)
return cls_score, global_feat # global feature for triplet loss
else:
if self.neck_feat == 'after':
# print("Test with feature after BN")
return feat
else:
# print("Test with feature before BN")
return global_feat
损失函数,用到了函数嵌套,的一个使用,嵌套实例调用讲解
https://www.cnblogs.com/xiaxiaoxu/p/9785687.html
...
elif cfg.MODEL.IF_WITH_CENTER == 'yes':
print('Train with center loss, the loss type is', cfg.MODEL.METRIC_LOSS_TYPE)
loss_func, center_criterion = make_loss_with_center(cfg, num_classes) # modified by gu,损失函数,是一个嵌套函数,会调用子函数,
optimizer, optimizer_center = make_optimizer_with_center(cfg, model, center_criterion) #优化器
# scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR,
# cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD)
######## make_loss_with_center(cfg, num_classes),函数代码,这里是函数嵌套
### 最后return loss_func, center_criterion 这里外部接收后跟踪,调用子函数,loss = loss_fn(score, feat, target),是直接加子函数的括号参数就可以
if cfg.MODEL.IF_LABELSMOOTH == 'on':
xent = CrossEntropyLabelSmooth(num_classes=num_classes) # new add by luo
print("label smooth on, numclasses:", num_classes)
def make_loss_with_center(cfg, num_classes): # modified by gu
if cfg.MODEL.NAME == 'resnet18' or cfg.MODEL.NAME == 'resnet34':
feat_dim = 512
else:
feat_dim = 2048
if cfg.MODEL.METRIC_LOSS_TYPE == 'center':
center_criterion = CenterLoss(num_classes=num_classes, feat_dim=feat_dim, use_gpu=True) # center loss
elif cfg.MODEL.METRIC_LOSS_TYPE == 'triplet_center':
triplet = TripletLoss(cfg.SOLVER.MARGIN) # triplet loss
center_criterion = CenterLoss(num_classes=num_classes, feat_dim=feat_dim, use_gpu=True) # center loss
else:
print('expected METRIC_LOSS_TYPE with center should be center, triplet_center'
'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE))
if cfg.MODEL.IF_LABELSMOOTH == 'on':
xent = CrossEntropyLabelSmooth(num_classes=num_classes) # new add by luo
print("label smooth on, numclasses:", num_classes)
def loss_func(score, feat, target):
if cfg.MODEL.METRIC_LOSS_TYPE == 'center':
if cfg.MODEL.IF_LABELSMOOTH == 'on':
return xent(score, target) + \
cfg.SOLVER.CENTER_LOSS_WEIGHT * center_criterion(feat, target)
else:
return F.cross_entropy(score, target) + \
cfg.SOLVER.CENTER_LOSS_WEIGHT * center_criterion(feat, target)
elif cfg.MODEL.METRIC_LOSS_TYPE == 'triplet_center':
if cfg.MODEL.IF_LABELSMOOTH == 'on':
return xent(score, target) + \
triplet(feat, target)[0] + \
cfg.SOLVER.CENTER_LOSS_WEIGHT * center_criterion(feat, target)
else:
return F.cross_entropy(score, target) + \
triplet(feat, target)[0] + \
cfg.SOLVER.CENTER_LOSS_WEIGHT * center_criterion(feat, target)
else:
print('expected METRIC_LOSS_TYPE with center should be center, triplet_center'
'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE))
return loss_func, center_criterion