三元组损失是谷歌公司针对人脸检测提出的一种损失函数,其论文为:A Unified Embedding for Face Recognition and Clustering,论文中的整个结构如下图:
训练的人脸图像送入CNN神经网络进行特征提取,从图中输入的图像batch可以看出,每3个图像一组,分别用蓝色、绿色和红色表示,设输入图像用表示,embedding层是这样定义的,其表示为输入图像的函数,维度为,且由于经过层归一化处理,故为单位向量,故embedding将输入图像投影到了一个欧几里德空间的超球面上。(anchor,蓝色)表示一个人的图像,这个人另外的图像,即这同一个人的图像用(positive,正样本,绿色)表示,而一个其他人的图像用(negative,负样本红色)表示。三元函数的训练目标是使得同一个人的任意两幅图像的距离小于此人图像与任意其他人图像的距离,如下图:
那么对应到embedding层,就是最小化:
其中,起到强制增加正样本和负样本间距离的作用,为了加速训练,通常鼓励粗略地选择误差最大的正样本和误差最小的负样本。下面我们来看一下其对应的tensorflow代码,facenet源代码已在专题1给出,其训练代码对应的是train_tripletloss.py,其内容与之前介绍的sotfmax.py基本相同,这里我们重点说下不同的部分,即lost函数。
prelogits, _ = network.inference(image_batch, args.keep_probability,
phase_train=phase_train_placeholder, bottleneck_layer_size=args.embedding_size,
weight_decay=args.weight_decay)
embeddings = tf.nn.l2_normalize(prelogits, 1, 1e-10, name='embeddings')
# Split embeddings into anchor, positive and negative and calculate triplet loss
anchor, positive, negative = tf.unstack(tf.reshape(embeddings, [-1,3,args.embedding_size]), 3, 1)
triplet_loss = facenet.triplet_loss(anchor, positive, negative, args.alpha)
由专题3分析可以得到,prelogits是全连接层的输出,其维度为batch_size*embedding_size,tf.reshape将其维度转为(batch_size/3)*3*embedding_size,tf.unstack分解为3个(batch_size/3)*embedding_size张量,分别对应anchor、positive和negative,由此也可以看出输入样本是按照anchor、positive和negative输入的,接下来看下损失函数triplet_loss的定义:
def triplet_loss(anchor, positive, negative, alpha):
"""Calculate the triplet loss according to the FaceNet paper
Args:
anchor: the embeddings for the anchor images.
positive: the embeddings for the positive images.
negative: the embeddings for the negative images.
Returns:
the triplet loss according to the FaceNet paper as a float tensor.
"""
with tf.variable_scope('triplet_loss'):
pos_dist = tf.reduce_sum(tf.square(tf.subtract(anchor, positive)), 1)
neg_dist = tf.reduce_sum(tf.square(tf.subtract(anchor, negative)), 1)
basic_loss = tf.add(tf.subtract(pos_dist,neg_dist), alpha)
loss = tf.reduce_mean(tf.maximum(basic_loss, 0.0), 0)
return loss
其实很简单,就是论文中三元损失函数的实现,随后的训练代码和之前sotfmax方法是一样的。