Keras中的三输入模型的损失函数Triplet Loss

Triploss函数来自罗浩博士的知乎分享与github分享:https://www.zhihu.com/question/46943328/answer/175040246

def triplet_loss(y_true, y_pred):
    y_pred = K.l2_normalize(y_pred,axis=1)
    batch = batch_size
    #print(batch)
    ref1 = y_pred[0:batch,:]
    pos1 = y_pred[batch:batch+batch,:]
    neg1 = y_pred[batch+batch:3*batch,:]
    dis_pos = K.sum(K.square(ref1 - pos1), axis=1, keepdims=True)
    dis_neg = K.sum(K.square(ref1 - neg1), axis=1, keepdims=True)
    dis_pos = K.sqrt(dis_pos)
    dis_neg = K.sqrt(dis_neg)
    a1 = 17
    d1 = dis_pos + K.maximum(0.0, dis_pos - dis_neg + a1)
    return K.mean(d1)

这里比较巧妙地将一个batch_size扩充成3倍,设batch_size = N,那么每批(1-3*N)中,1-N依次放 anchor1,anchor2,…anchorN;(N+1)-2*n里依次放pos1,pos2,…posN,(2*N+1)-3*N里依次放neg1,neg2,…,negN。

所以虽然模型是三输入的,但是实际上却还是单输入的,这样有个好处是可以利用在Imagenet上预训练好的参数进行初始化,加速训练。亲测在某一种图像的分类上,由于样本库比较小,不使用预训练好的参数初始化的话,最高只能达到60%左右,但是使用了预训练好的参数初始化后,一轮训练后正确率即可达到95+%。可见Imagenet比赛上获奖的这些模型的泛化能力极强。
此Triplet Loss的使用见上一篇博客http://blog.csdn.net/yjy728/article/details/79569807

你可能感兴趣的:(机器学习)