sentence_transformers 教程

文档:Search — Sentence-Transformers documentation

用途:

该模主要用来做句子嵌入,下游常用来做语意匹配

losses.CosineSimilarityLoss

计算出样本的余弦相似度,和label做MSE损失

from sentence_transformers import SentenceTransformer, InputExample, losses
from torch.utils.data import DataLoader

#Define the model. Either from scratch of by loading a pre-trained model
model = SentenceTransformer('distilbert-base-nli-mean-tokens')

#Define your train examples. You need more than just two examples...
train_examples = [InputExample(texts=['My first sentence', 'My second sentence'], label=0.8),
    InputExample(texts=['Another pair', 'Unrelated sentence'], label=0.3)]

#Define your train dataset, the dataloader and the train loss
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16)
train_loss = losses.CosineSimilarityLoss(model)

#Tune the model
model.fit(train_objectives=[(train_dataloader, train_loss)], epochs=1, warmup_steps=100)

MultipleNegativesRankingLoss

对比损失,同一批次的,其它样本视为负样本,分别两两求余弦相似度,最后做交叉熵损失,正样本的得分应该最高

train_examples = [InputExample(texts=['Anchor 1', 'Positive 1']),
                InputExample(texts=['Anchor 2', 'Positive 2'])]
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=32)
train_loss = losses.MultipleNegativesRankingLoss(model=model)

你可能感兴趣的:(深度学习,pytorch,人工智能)