Triplet Loss原理和代码实现

Triplet Loss原理和代码实现

  • Triplet Loss 原理
  • Triplet Loss 代码实现

Triplet Loss 原理

Triplet loss 最先在FaceNet: A Unifed Embedding for Face Recognition and Clustering中提出,用与训练人脸识别模型,取得了不错的效果。
下面介绍Triplet loss原理:
Triplet Loss原理和代码实现_第1张图片
如上图所示,首先在训练集中,我们有一张行人图像xia,记为anchor;然后选择一个相同行人的另外一张图像xip,记为positive;再选择一个不同行人的一张图像xin,记为negative。
我们想要anchor和positive间的距离要小于anchor和negative之间的距离,即:
在这里插入图片描述
α是margin,T是训练集中所有可能的三元组的集合,基数为N。
于是Triplet loss 就可表示为最小化 L
在这里插入图片描述

Triplet 选择
为了使训练模型能快速收敛,选择triplet是至关重要的,通常为了快速收敛,这样选择triplet:
给定一个anchor: xia
选择的xip 使 在这里插入图片描述 ,这时称其为hard-positive;
同样选择的xin 使 在这里插入图片描述 ,这时称其为hard-negative。

Triplet Loss 代码实现

import torch
from torch import nn
from torch.autograd import Variable

class TripletLoss(nn.Module):
    def __init__(self, margin=0):
        super(TripletLoss, self).__init__()
        self.margin = margin     #  一般取0.3
        self.ranking_loss = nn.MarginRankingLoss(margin=margin)  # 获得一个简单的距离triplet函数

    def forward(self, inputs, labels):
        n = inputs.size(0)    # 获取batch_size
        # Compute pairwise distance, replace by the official when merged
        dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n) #每个数平方后, 进行加和(通过keepdim保持2维),再扩展成nxn维
        dist = dist + dist.t()       # 这样每个dist[i][j]代表的是第i个特征与第j个特征的平方的和
        dist.addmm_(1, -2, inputs, inputs.t())    # 然后减去2倍的 第i个特征*第j个特征 从而通过完全平方式得到 (a-b)^2
        dist = dist.clamp(min=1e-12).sqrt()  # for numerical stability  #开方

        # For each anchor, find the hardest positive and negative
        mask = labels.expand(n, n).eq(labels.expand(n, n).t())    # 这里dist[i][j] = 1代表i和j的label相同, =0代表i和j的label不相同
        dist_ap, dist_an = [], []
        for i in range(n):
            dist_ap.append(dist[i][mask[i]].max().view(1))       # 在i与所有有相同label的j的距离中找一个最大的
            dist_an.append(dist[i][mask[i] == 0].min().view(1))  # 在i与所有不同label的j的距离找一个最小的
        dist_ap = torch.cat(dist_ap)      # 将list里的tensor拼接成新的tensor
        dist_an = torch.cat(dist_an)

        # Compute ranking hinge loss
        y = dist_an.data.new()
        y.resize_as_(dist_an.data)
        y.fill_(1)       # 声明一个与dist_an相同shape的全1 tensor
        y = Variable(y)
        loss = self.ranking_loss(dist_an, dist_ap, y)
        #prec = (dist_an.data > dist_ap.data).sum() * 1. / y.size(0)
        prec = (dist_an.data > dist_ap).data.float().mean()  #预测
        return loss, prec

你可能感兴趣的:(deep,leraning,Triplet,loss,deep,learning)