一、简介
https://github.com/michuanhaohao/reid-strong-baseline
是用于行人重识别的代码,它是论文Bag of Tricks and a Strong Baseline for Deep Person Re-Identification的对应代码,这篇论文改进了standard baseline方法,达到了更好的行人重识别效果。以下是梳理的训练部分代码的内容
二、准备数据集
数据集加载的具体实现针对不同数据集有所不同,在此不进行详细叙述。只要知道最后加载出的ImageDataset的一个item格式为:
return img, pid, camid, img_path
三、网络构建
默认配置下,网络结构如下所述。
首先看看网络的forward函数:
def forward(self, x):
global_feat = self.gap(self.base(x)) # (b, 2048, 1, 1)
global_feat = global_feat.view(global_feat.shape[0], -1) # flatten to (bs, 2048)
if self.neck == 'no':
feat = global_feat
elif self.neck == 'bnneck':
feat = self.bottleneck(global_feat) # normalize for angular softmax
if self.training:
cls_score = self.classifier(feat)
return cls_score, global_feat # global feature for triplet loss
else:
if self.neck_feat == 'after':
# print("Test with feature after BN")
return feat
else:
# print("Test with feature before BN")
return global_feat
贴一张论文中的网络结构图结合来看:
网络大致分为以下几块:
3.1 backbone
默认下是resnet50网络,只是resnet50的最后一个卷积层的stride = 1
3.2 gap
gap网络代码如下:
self.gap = nn.AdaptiveAvgPool2d(1)
其实就是一层AdaptiveAvgPool2d层
在经历过gap层后再拍平成(bs, 2048)的尺寸,得到global_feat
3.3 neck
默认配置为bnneck,bnneck的结构如下:
self.bottleneck = nn.BatchNorm1d(self.in_planes)
self.bottleneck.bias.requires_grad_(False) # no shift
self.bottleneck.apply(weights_init_kaiming)
其实就是一个BatchNorm1d层,使用Kaiming initialization初始化,kaiming initialization在pytorch中有实现
这层输入global_feat得到feat
3.4 classifier
def weights_init_classifier(m):
classname = m.__class__.__name__
if classname.find('Linear') != -1:
nn.init.normal_(m.weight, std=0.001)
if m.bias:
nn.init.constant_(m.bias, 0.0)
self.classifier = nn.Linear(self.in_planes, self.num_classes, bias=False)
self.classifier.apply(weights_init_classifier)
其实就是一个全连接层,输出神经元个数显然就是类别数了,在行人重识别里,所谓的类别数就是行人个数
3.5 网络输出
最后网络返回的是classifier分类出的结果cls_score和没经过bottleneck的global_feat
四、损失函数
代码的默认配置里,只有triplet loss没有center loss(为啥咧?咱也不知道?),不过既然论文提出了center loss,还是看既有triplet loss 也有center loss的情况下的损失函数代码。另外,默认配置下是有Label Smooth的
def make_loss_with_center(cfg, num_classes): # modified by gu
...
elif cfg.MODEL.METRIC_LOSS_TYPE == 'triplet_center':
triplet = TripletLoss(cfg.SOLVER.MARGIN) # triplet loss
center_criterion = CenterLoss(num_classes=num_classes, feat_dim=feat_dim, use_gpu=True) # center loss
if cfg.MODEL.IF_LABELSMOOTH == 'on':
xent = CrossEntropyLabelSmooth(num_classes=num_classes) # new add by luo
print("label smooth on, numclasses:", num_classes)
def loss_func(score, feat, target):
...
elif cfg.MODEL.METRIC_LOSS_TYPE == 'triplet_center':
if cfg.MODEL.IF_LABELSMOOTH == 'on':
return xent(score, target) + \
triplet(feat, target)[0] + \
cfg.SOLVER.CENTER_LOSS_WEIGHT * center_criterion(feat, target)
return loss_func, center_criterion
简单来说,就是triplet loss、center loss、CrossEntropyLabelSmooth三种Loss的和,cfg.SOLVER.CENTER_LOSS_WEIGHT用于调节center loss占的比例。
4.1 TripletLoss
论文中对triple loss的描述如下:
TripletLoss代码如下,调用是是运行call函数:
class TripletLoss(object):
def __init__(self, margin=None):
self.margin = margin
if margin is not None:
self.ranking_loss = nn.MarginRankingLoss(margin=margin)
else:
self.ranking_loss = nn.SoftMarginLoss()
def __call__(self, global_feat, labels, normalize_feature=False):
if normalize_feature:
global_feat = normalize(global_feat, axis=-1)
dist_mat = euclidean_dist(global_feat, global_feat)
dist_ap, dist_an = hard_example_mining(
dist_mat, labels)
y = dist_an.new().resize_as_(dist_an).fill_(1)
if self.margin is not None:
loss = self.ranking_loss(dist_an, dist_ap, y)
else:
loss = self.ranking_loss(dist_an - dist_ap, y)
return loss, dist_ap, dist_an
参数:
global_feat:参见本文3.2节
labels:真实标签,如为第i类则为i,尺寸为[batch_size]
默认情况下不进行normalize,那么流程如下:
1、计算global_feat和global_feat的欧氏距离。
euclidean_dist函数代码如下:
def euclidean_dist(x, y):
"""
Args:
x: pytorch Variable, with shape [m, d]
y: pytorch Variable, with shape [n, d]
Returns:
dist: pytorch Variable, with shape [m, n]
"""
m, n = x.size(0), y.size(0)
xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n)
yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t()
dist = xx + yy
dist.addmm_(1, -2, x, y.t())
dist = dist.clamp(min=1e-12).sqrt() # for numerical stability
return dist
对x每行求平方和,再将[m,1]的平方和扩展到[m,n],得到xx,对y进行同样操作得到yy。
dist = sqrt(max(1 * (xx + yy) + (-2) * (x * yT),0))
有点像(x - y )^2 = x^2 + y^2 - 2xy
2、得到dist_mat即距离矩阵后,对dist_mat和label计算hard_example_mining。
hard_example_mining代码如下:
def hard_example_mining(dist_mat, labels, return_inds=False):
"""For each anchor, find the hardest positive and negative sample.
Args:
dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N]
labels: pytorch LongTensor, with shape [N]
return_inds: whether to return the indices. Save time if `False`(?)
Returns:
dist_ap: pytorch Variable, distance(anchor, positive); shape [N]
dist_an: pytorch Variable, distance(anchor, negative); shape [N]
p_inds: pytorch LongTensor, with shape [N];
indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1
n_inds: pytorch LongTensor, with shape [N];
indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1
NOTE: Only consider the case in which all labels have same num of samples,
thus we can cope with all anchors in parallel.
"""
assert len(dist_mat.size()) == 2
assert dist_mat.size(0) == dist_mat.size(1)
N = dist_mat.size(0)
# shape [N, N]
is_pos = labels.expand(N, N).eq(labels.expand(N, N).t())
is_neg = labels.expand(N, N).ne(labels.expand(N, N).t())
# `dist_ap` means distance(anchor, positive)
# both `dist_ap` and `relative_p_inds` with shape [N, 1]
dist_ap, relative_p_inds = torch.max(
dist_mat[is_pos].contiguous().view(N, -1), 1, keepdim=True)
# `dist_an` means distance(anchor, negative)
# both `dist_an` and `relative_n_inds` with shape [N, 1]
dist_an, relative_n_inds = torch.min(
dist_mat[is_neg].contiguous().view(N, -1), 1, keepdim=True)
# shape [N]
dist_ap = dist_ap.squeeze(1)
dist_an = dist_an.squeeze(1)
if return_inds:
# shape [N, N]
ind = (labels.new().resize_as_(labels)
.copy_(torch.arange(0, N).long())
.unsqueeze(0).expand(N, N))
# shape [N, 1]
p_inds = torch.gather(
ind[is_pos].contiguous().view(N, -1), 1, relative_p_inds.data)
n_inds = torch.gather(
ind[is_neg].contiguous().view(N, -1), 1, relative_n_inds.data)
# shape [N]
p_inds = p_inds.squeeze(1)
n_inds = n_inds.squeeze(1)
return dist_ap, dist_an, p_inds, n_inds
return dist_ap, dist_an
根据真实标签分出正样本和负样本(标签相同的两个样本为正样本对,不同的为负样本对),把正样本对的欧氏距离和负样本对的欧氏距离分开返回。
3、计算ranking loss
y = dist_an(负样本欧氏距离)同样的尺寸的全是1的张量
默认下,ranking loss有margin,使用pytorch的nn.MarginRankingLoss来计算:
self.ranking_loss = nn.MarginRankingLoss(margin=margin)
loss = self.ranking_loss(dist_an, dist_ap, y)
https://www.jianshu.com/p/579a0f4cbf24
中总结了pytorch中实现的一些loss,MarginRankingLoss也在其中。
4.2 center loss
论文中,center loss公式为:
代码中,center loss为一个层,其实现如下:
class CenterLoss(nn.Module):
"""
Args:
num_classes (int): number of classes.
feat_dim (int): feature dimension.
"""
def __init__(self, num_classes=751, feat_dim=2048, use_gpu=True):
...
if self.use_gpu:
self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda())
else:
self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim))
def forward(self, x, labels):
"""
Args:
x: feature matrix with shape (batch_size, feat_dim).
labels: ground truth labels with shape (num_classes).
"""
...
batch_size = x.size(0)
distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \
torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
distmat.addmm_(1, -2, x, self.centers.t())
classes = torch.arange(self.num_classes).long()
if self.use_gpu: classes = classes.cuda()
labels = labels.unsqueeze(1).expand(batch_size, self.num_classes)
mask = labels.eq(classes.expand(batch_size, self.num_classes))
dist = distmat * mask.float()
loss = dist.clamp(min=1e-12, max=1e+12).sum() / batch_size
return loss
centers,即为num_class(类别数) * feat_dim(特征向量维度)的正态分布的随机数张量。dismat为x(训练中即为global_feat,参见本文3.5节)和centers的欧氏距离。
labels为真实标签,这里真实标签是1、3、4、5这种,不是一个向量,mask为每个标签的类别向量张量,即一行为一个[0,0,...,1,....]这样的类别向量。
dist = dismat * mask.float,也就是保留下对应类的dismat中每行对应类别列,其他为0
最后loss = dist中大于0的元素的和 / batch_size
4.3 CrossEntropyLabelSmooth
这个loss就是重识别的loss,即论文中的Lid,行人是否识别对的loss。论文中对该loss的描述如下:
代码如下:
class CrossEntropyLabelSmooth(nn.Module):
"""Cross entropy loss with label smoothing regularizer.
Reference:
Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016.
Equation: y = (1 - epsilon) * y + epsilon / K.
Args:
num_classes (int): number of classes.
epsilon (float): weight.
"""
def __init__(self, num_classes, epsilon=0.1, use_gpu=True):
super(CrossEntropyLabelSmooth, self).__init__()
self.num_classes = num_classes
self.epsilon = epsilon
self.use_gpu = use_gpu
self.logsoftmax = nn.LogSoftmax(dim=1)
def forward(self, inputs, targets):
"""
Args:
inputs: prediction matrix (before softmax) with shape (batch_size, num_classes)
targets: ground truth labels with shape (num_classes)
"""
log_probs = self.logsoftmax(inputs)
targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1)
if self.use_gpu: targets = targets.cuda()
targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
loss = (- targets * log_probs).mean(0).sum()
return loss
其实就是对特征做logsoftmax,target为对应论文中的q,两个相乘得到的张量中的元素的和就是重识别的loss Lid。
五、optimizer
优化器默认是SGD,这个参数的设置、学习率的调整等等就不详述了,可以参考论文。有center loss的优化器初始化代码如下:
def make_optimizer_with_center(cfg, model, center_criterion):
params = []
for key, value in model.named_parameters():
if not value.requires_grad:
continue
lr = cfg.SOLVER.BASE_LR
weight_decay = cfg.SOLVER.WEIGHT_DECAY
if "bias" in key:
lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR
weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS
params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]
if cfg.SOLVER.OPTIMIZER_NAME == 'SGD':
optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params, momentum=cfg.SOLVER.MOMENTUM)
else:
optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params)
optimizer_center = torch.optim.SGD(center_criterion.parameters(), lr=cfg.SOLVER.CENTER_LR)
return optimizer, optimizer_center
六、训练
网络结构、损失函数、优化器这几个要素知道之后,基本就能明白网络的训练逻辑了,具体的训练无非就是先前向计算再后向计算,贴一段相关代码:
def _update(engine, batch):
model.train()
optimizer.zero_grad()
optimizer_center.zero_grad()
img, target = batch
img = img.to(device) if torch.cuda.device_count() >= 1 else img
target = target.to(device) if torch.cuda.device_count() >= 1 else target
score, feat = model(img)
loss = loss_fn(score, feat, target)
# print("Total loss is {}, center loss is {}".format(loss, center_criterion(feat, target)))
loss.backward()
optimizer.step()
for param in center_criterion.parameters():
param.grad.data *= (1. / cetner_loss_weight)
optimizer_center.step()
# compute acc
acc = (score.max(1)[1] == target).float().mean()
return loss.item(), acc.item()
文中训练用到了ignite的Engine,具体还是得看代码,大概是有涉及到多线程,事件触发。比如:
trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'model': model,
'optimizer': optimizer,
'center_param': center_criterion,
'optimizer_center': optimizer_center})
这个event_handler笔者才想是类似安卓中的handler,是一种多线程通信机制。