torch.nn.TripletMarginLoss(margin=1.0, p=2.0, eps=1e-06, swap=False, size_average=None, reduce=None, reduction='mean')
功能:创建一个三元组损失函数(triplet loss),用于衡量输入数据 x 1 , x 2 , x 3 x_1,x_2,x_3 x1,x2,x3之间的相对相似性,其中输入样本又分别称为中立样本、正样本以及负样本,具体介绍可见论文《Learning shallow convolutional feature descriptors with triplet losses》
损失函数:
L ( x 1 , x 2 , x 3 ) = max { d ( x 1 , x 2 ) − d ( x 1 , x 3 ) + margin , 0 } L(x_1,x_2,x_3)=\max\{d(x_1,x_2)-d(x_1,x_3)+\text{margin},0\} L(x1,x2,x3)=max{d(x1,x2)−d(x1,x3)+margin,0}
其中:
d ( x i , y i ) = ∣ ∣ x i − y i ∣ ∣ p d(x_i,y_i)=||x_i-y_i||_p d(xi,yi)=∣∣xi−yi∣∣p
该函数的作用就是拉进 x 1 x_1 x1与 x 2 x_2 x2的距离,使它们更加相似,同时推离 x 1 x_1 x1与 x 3 x_3 x3的距离,即使它们更加不同。
输入:
margin
:边界距离,具体含义如公式所示,如果该值越大,则表明 x 1 x_1 x1与 x 2 x_2 x2期望距离越近, x 1 x_1 x1与 x 3 x_3 x3期望距离越远。输入数据类型为浮点数(float),默认1.0;p
:用于计算两个向量距离的范数,具体含义如公式所示。输入数据类型为整数(int),默认2,即欧氏距离;swap
:是否使用距离交换,具体功能可见论文《Learning shallow convolutional feature descriptors with triplet losses》;size_average
与reduce
已被弃用,具体功能由reduction
替代;reduction
:指定损失输出的形式,有三种选择:none
|mean
|sum
。none
:损失不做任何处理,直接输出一个数组;mean
:将得到的损失求平均值再输出,会输出一个数;sum
:将得到的损失求和再输出,会输出一个数。注意:
reduction
设置为none
,则输出的数组维数为1,尺寸为 ( N ) (N) (N)一般用法
import torch
import torch.nn as nn
# reduction设为none便于查看损失计算的结果
triplet_loss = nn.TripletMarginLoss(reduction='none')
x1 = torch.randn(20).reshape(2,10)
x2 = torch.randn(20).reshape(2,10)
x3 = torch.randn(20).reshape(2,10)
loss = triplet_loss(x1, x2, x3)
print(x1)
print(x2)
print(x3)
print(loss)
输出
tensor([[-0.1419, 0.0550, -0.2996, -1.7194, 0.5485, -0.9163, -0.6983, 0.0239,
1.2940, -0.4858],
[ 1.8544, -0.2349, -0.2523, -1.6167, 0.7861, -1.7627, 0.3139, -1.5112,
-0.3378, 0.0059]])
tensor([[-1.5967, 0.4007, 0.1468, -1.0085, -1.4989, 1.7531, 0.0865, -0.9080,
-0.4046, 0.5229],
[-1.8673, -0.4958, 1.0122, -1.8696, 0.1974, -0.8017, -1.0562, -2.1461,
1.7112, -0.6001]])
tensor([[-1.0008, 1.5316, 0.0078, 1.1405, -0.0629, 0.4934, -1.8050, -1.0302,
0.8676, -0.1988],
[ 1.3015, -0.2786, 0.4215, -0.6413, -0.0760, -0.8138, 0.2173, 1.5132,
-0.6389, 0.7173]])
tensor([1.4133, 2.2473])
nn.TripletMarginLoss:https://pytorch.org/docs/stable/generated/torch.nn.TripletMarginLoss.html?highlight=tripletmarginloss#torch.nn.TripletMarginLoss
初步完稿于:2022年3月30日