PyTorch学习笔记:nn.TripletMarginLoss——三元组损失

PyTorch学习笔记:nn.TripletMarginLoss——三元组损失

torch.nn.TripletMarginLoss(margin=1.0, p=2.0, eps=1e-06, swap=False, size_average=None, reduce=None, reduction='mean')

功能:创建一个三元组损失函数(triplet loss),用于衡量输入数据 x 1 , x 2 , x 3 x_1,x_2,x_3 x1,x2,x3之间的相对相似性,其中输入样本又分别称为中立样本、正样本以及负样本,具体介绍可见论文《Learning shallow convolutional feature descriptors with triplet losses》

损失函数
L ( x 1 , x 2 , x 3 ) = max ⁡ { d ( x 1 , x 2 ) − d ( x 1 , x 3 ) + margin , 0 } L(x_1,x_2,x_3)=\max\{d(x_1,x_2)-d(x_1,x_3)+\text{margin},0\} L(x1,x2,x3)=max{d(x1,x2)d(x1,x3)+margin,0}
其中:
d ( x i , y i ) = ∣ ∣ x i − y i ∣ ∣ p d(x_i,y_i)=||x_i-y_i||_p d(xi,yi)=∣∣xiyip
该函数的作用就是拉进 x 1 x_1 x1 x 2 x_2 x2的距离,使它们更加相似,同时推离 x 1 x_1 x1 x 3 x_3 x3的距离,即使它们更加不同。

输入:

  • margin:边界距离,具体含义如公式所示,如果该值越大,则表明 x 1 x_1 x1 x 2 x_2 x2期望距离越近, x 1 x_1 x1 x 3 x_3 x3期望距离越远。输入数据类型为浮点数(float),默认1.0;
  • p:用于计算两个向量距离的范数,具体含义如公式所示。输入数据类型为整数(int),默认2,即欧氏距离;
  • swap:是否使用距离交换,具体功能可见论文《Learning shallow convolutional feature descriptors with triplet losses》;
  • size_averagereduce已被弃用,具体功能由reduction替代;
  • reduction:指定损失输出的形式,有三种选择:none|mean|sumnone:损失不做任何处理,直接输出一个数组;mean:将得到的损失求平均值再输出,会输出一个数;sum:将得到的损失求和再输出,会输出一个数。

注意:

  • 输入的三个样本数据维数必须为二维 ( N , D ) (N,D) (N,D),其中第二个维度 D D D表示向量长度;
  • 如果reduction设置为none,则输出的数组维数为1,尺寸为 ( N ) (N) (N)

代码案例

一般用法

import torch
import torch.nn as nn

# reduction设为none便于查看损失计算的结果
triplet_loss = nn.TripletMarginLoss(reduction='none')
x1 = torch.randn(20).reshape(2,10)
x2 = torch.randn(20).reshape(2,10)
x3 = torch.randn(20).reshape(2,10)
loss = triplet_loss(x1, x2, x3)
print(x1)
print(x2)
print(x3)
print(loss)

输出

tensor([[-0.1419,  0.0550, -0.2996, -1.7194,  0.5485, -0.9163, -0.6983,  0.0239,
          1.2940, -0.4858],
        [ 1.8544, -0.2349, -0.2523, -1.6167,  0.7861, -1.7627,  0.3139, -1.5112,
         -0.3378,  0.0059]])
tensor([[-1.5967,  0.4007,  0.1468, -1.0085, -1.4989,  1.7531,  0.0865, -0.9080,
         -0.4046,  0.5229],
        [-1.8673, -0.4958,  1.0122, -1.8696,  0.1974, -0.8017, -1.0562, -2.1461,
          1.7112, -0.6001]])
tensor([[-1.0008,  1.5316,  0.0078,  1.1405, -0.0629,  0.4934, -1.8050, -1.0302,
          0.8676, -0.1988],
        [ 1.3015, -0.2786,  0.4215, -0.6413, -0.0760, -0.8138,  0.2173,  1.5132,
         -0.6389,  0.7173]])
tensor([1.4133, 2.2473])

官方文档

nn.TripletMarginLoss:https://pytorch.org/docs/stable/generated/torch.nn.TripletMarginLoss.html?highlight=tripletmarginloss#torch.nn.TripletMarginLoss

初步完稿于:2022年3月30日

你可能感兴趣的:(PyTorch学习笔记,pytorch,学习,深度学习)