AlignPS中的TOIM损失

本文介绍了CVPR2021行人重识别领域中一篇名为AlignPS论文中的TOIM损失函数

论文链接:https://arxiv.org/abs/2109.00211

代码链接:GitHub - daodaofr/AlignPS: Code for CVPR 2021 paper: Anchor-Free Person Search

TOIM

TOIM Loss = OIM Loss + Triplet Loss

AlignPS中的TOIM损失_第1张图片

OIM Loss

步骤一、初始化两个查找表(Looking-Up Tabel,LUT),第一个用于存放有标注的行人特征,第二个用于存放无标注的行人特征,

self.labeled_matching_layer = LabeledMatchingLayerQueue(num_persons=num_person, feat_len=self.in_channels)
self.unlabeled_matching_layer = UnlabeledMatchingLayer(queue_size=queue_size, feat_len=self.in_channels)


# 用于存放有label匹配的embeddings
class LabeledMatchingLayerQueue(nn.Module):
    """
    Labeled matching of OIM loss function.
    """

    def __init__(self, num_persons=5532, feat_len=256):
        """
        Args:
            num_persons (int): Number of labeled persons.
            feat_len (int): Length of the feature extracted by the network.
        """
        super(LabeledMatchingLayerQueue, self).__init__()
        self.register_buffer("lookup_table", torch.zeros(num_persons, feat_len))

    def forward(self, features, pid_labels):
        """
        Args:
            features (Tensor[N, feat_len]): Features of the proposals.
            pid_labels (Tensor[N]): Ground-truth person IDs of the proposals.

        Returns:
            scores (Tensor[N, num_persons]): Labeled matching scores, namely the similarities
                                             between proposals and labeled persons.
        """
        scores, pos_feats, pos_pids = LabeledMatching.apply(features, pid_labels, self.lookup_table)
        return scores, pos_feats, pos_pids


# 用于存放无label匹配的embeddings
class UnlabeledMatchingLayer(nn.Module):
    """
    Unlabeled matching of OIM loss function.
    """

    def __init__(self, queue_size=5000, feat_len=256):
        """
        Args:
            queue_size (int): Size of the queue saving the features of unlabeled persons.
            feat_len (int): Length of the feature extracted by the network.
        """
        super(UnlabeledMatchingLayer, self).__init__()
        self.register_buffer("queue", torch.zeros(queue_size, feat_len))
        self.register_buffer("tail", torch.tensor(0))

    def forward(self, features, pid_labels):
        """
        Args:
            features (Tensor[N, feat_len]): Features of the proposals.
            pid_labels (Tensor[N]): Ground-truth person IDs of the proposals.

        Returns:
            scores (Tensor[N, queue_size]): Unlabeled matching scores, namely the similarities
                                            between proposals and unlabeled persons.
        """
        scores = UnlabeledMatching.apply(features, pid_labels, self.queue, self.tail)
        return scores

步骤二、将embeddings分别与两个LUT的转置进行矩阵乘法操作,得到(labeled_matching_scores, labeled_matching_reid, labeled_matching_ids)以及(unlabeled_matching_scores)

labeled_matching_scores, labeled_matching_reid, labeled_matching_ids = self.labeled_matching_layer(pos_reid, pos_reid_ids)


class LabeledMatching(Function):
    @staticmethod
    def forward(ctx, features, pid_labels, lookup_table, momentum=0.5):
        ctx.save_for_backward(features, pid_labels)
        ctx.lookup_table = lookup_table
        ctx.momentum = momentum

        scores = features.mm(lookup_table.t())
        pos_feats = lookup_table.clone().detach()
        pos_idx = pid_labels > 0
        pos_pids = pid_labels[pos_idx]
        pos_feats = pos_feats[pos_pids]
        
        return scores, pos_feats, pos_pids

    @staticmethod
    def backward(ctx, grad_output, grad_feat, grad_pids):
        features, pid_labels = ctx.saved_tensors
        lookup_table = ctx.lookup_table
        momentum = ctx.momentum

        grad_feats = None
        if ctx.needs_input_grad[0]:
            grad_feats = grad_output.mm(lookup_table)

        # Update lookup table, but not by standard backpropagation with gradients
        for indx, label in enumerate(pid_labels):
            if label >= 0:
                lookup_table[label] = (
                    momentum * lookup_table[label] + (1 - momentum) * features[indx]
                )

        return grad_feats, None, None, None
unlabeled_matching_scores = self.unlabeled_matching_layer(pos_reid, pos_reid_ids)


class UnlabeledMatching(Function):
    @staticmethod
    def forward(ctx, features, pid_labels, queue, tail):
        ctx.save_for_backward(features, pid_labels)
        ctx.queue = queue
        ctx.tail = tail

        scores = features.mm(queue.t())
        return scores

    @staticmethod
    def backward(ctx, grad_output):
        features, pid_labels = ctx.saved_tensors
        queue = ctx.queue
        tail = ctx.tail

        grad_feats = None
        if ctx.needs_input_grad[0]:
            grad_feats = grad_output.mm(queue.data)

        """
        只将无label行人的前64维特征进行存储, 如果存储的无label行人数量大于queue_size 
        则对queue进行类似push和pop操作, 使queue的大小维持在queue_size
        """
        for indx, label in enumerate(pid_labels):
            if label == -1:
                queue[tail, :64] = features[indx, :64]
                tail += 1
                if tail >= queue.size(0):
                    tail -= queue.size(0)

        return grad_feats, None, None, None

步骤三、将步骤二得到的labeled_matching_scores和unlabeled_matching_scores分别乘以10后,沿着dim=1进行concat,得到matching_scores。对matching_scores进行softmax处理,得到p_i,对应论文中的公式如下,

labeled_matching_scores *= 10
unlabeled_matching_scores *= 10
matching_scores = torch.cat((labeled_matching_scores, unlabeled_matching_scores), dim=1)
p_i = F.softmax(matching_scores, dim=1)

 根据p_i的大小,对p_i进行加权处理(类似focal loss),把较大的权重因子给到较小的p_i,得到focal_p_i,

focal_p_i = (1 - p_i)**2 * p_i.log()

步骤四、对focal_p_i以及对应的label求负对数似然,便可得到OIM Loss

loss_oim = F.nll_loss(focal_p_i, pid_labels, reduction='none', ignore_index=-1)

步骤五、反向传播时,会对存放有label行人特征的LUT进行更新,更新的方式如下,

lookup_table[label] = (momentum * lookup_table[label] + (1 - momentum) * features[indx])

Triplet Loss

步骤一、将求OIM Loss过程中得到的labeled_matching_reid和labeled_matching_ids分别与pos_reid和pid_labels进行concat(相当于扩大了batch size,让triplet loss在更大的样本空间中寻找困难样本对),

pos_reid = torch.cat((pos_reid, labeled_matching_reid), dim=0)
pid_labels = torch.cat((pid_labels, labeled_matching_ids), dim=0)

步骤二、根据pos_reid和pid_labels求得Triplet Loss,

 

loss_tri = self.loss_tri(pos_reid, pid_labels)


class TripletLossFilter(nn.Module):
    """Triplet loss with hard positive/negative mining.

    Reference:
    Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737.

    Code imported from https://github.com/Cysu/open-reid/blob/master/reid/loss/triplet.py.

    Args:
        margin (float): margin for triplet.
    """
    def __init__(self, margin=0.3):
        super(TripletLossFilter, self).__init__()
        self.margin = margin
        self.ranking_loss = nn.MarginRankingLoss(margin=margin)

    def forward(self, inputs, targets):
        """
        Does not calculate noise inputs with label -1
        Args:
            inputs: feature matrix with shape (batch_size, feat_dim)
            targets: ground truth labels with shape (num_classes)
        """
        inputs_new = []
        targets_new = []
        targets_value = []
        for i in range(len(targets)):
            if targets[i] == -1:
                continue
            else:
                inputs_new.append(inputs[i])
                targets_new.append(targets[i])
                targets_value.append(targets[i].cpu().numpy().item())
        if len(set(targets_value)) < 2:
            tmp_loss = torch.zeros(1)
            tmp_loss = tmp_loss[0]
            tmp_loss = tmp_loss.to(targets.device)
            return tmp_loss
        
        inputs_new = torch.stack(inputs_new)
        targets_new = torch.stack(targets_new)
        n = inputs_new.size(0)

        # Compute pairwise distance, replace by the official when merged
        dist = torch.pow(inputs_new, 2).sum(dim=1, keepdim=True).expand(n, n)
        dist = dist + dist.t()
        dist.addmm_(1, -2, inputs_new, inputs_new.t())
        dist = dist.clamp(min=1e-12).sqrt()  # for numerical stability

        # For each anchor, find the hardest positive and negative
        mask = targets_new.expand(n, n).eq(targets_new.expand(n, n).t())
        dist_ap, dist_an = [], []
        for i in range(n):
            dist_ap.append(dist[i][mask[i]].max())
            dist_an.append(dist[i][mask[i] == 0].min())

        dist_ap = torch.stack(dist_ap)
        dist_an = torch.stack(dist_an)
        # Compute ranking hinge loss
        y = torch.ones_like(dist_an)
        loss = self.ranking_loss(dist_an, dist_ap, y)
        return loss

补充一下,torch.nn.MarginRankingLoss(margin=margin)的公式如下,

对应到以上代码中,

Loss(d_{an},d_{ap},y)=max(0,d_{ap}-d_{an}+margin) 

你可能感兴趣的:(人脸识别,损失函数,pytorch,机器学习,深度学习,人工智能)