sentence-transformers(SBert)中文文本相似度预测(附代码)

sentence-transformers(SBert)中文文本相似度预测(附代码)_第1张图片

前言

  • 训练文本相似度数据集并进行评估:sentence-transformers(SBert)
  • 预训练模型:chinese-roberta-wwm-ext
  • 数据集:蚂蚁金融文本相似度数据集
  • 前端:Vue2+elementui+axios
  • 后端:flask

训练模型

  1. 创建网络:使用Sbert官方给出的预训练模型sentence_hfl_chinese-roberta-wwm-ext,先载入embedding层进行分词,再载入池化层并传入嵌入后的维度,对模型进行降维压缩,最后载入密集层,选择Than激活函数,输出维度大小为256维。
  2. 获取训练数据:构建出新模型后使用InputExample类存储训练数据,它接受文本对字符串列表和用于指示语义相似性的标签,用标准的Pytorch Dataloader包装train_examples,作用是打乱数据并生成特定大小的批次。
  3. 计算损失函数:对于每个句子对,通过网络传递句子A和句子B,从而产生嵌入u和v,使用余弦相似度计算相似性,并将结果与标准相似度得分进行比较。这样网络就能够进行微调,更好地识别句子的相似性。
  4. 模型调优:通过调用model.fit()来调优模型。向model.fit()中传递train_objective列表(由元组(dataloader, loss_function))组成。也可以传递多个元组,以便在具有不同损失函数的多个数据集上执行多任务学习。在训练过程需要使用sentence_transformers.evaluation评估表现是否有所改善,它包含各种可以传递给fit方法的evaluators。Evaluators会在训练期间定期运行,并且会返回分数,只有得分最高的模型才会存储在磁盘上。

首先运行preprocess.py获取数据,并划分训练集和测试集,之后运行train_sentence_bert.py,使用预训练模型, sbert将数据集用sbert训练相似度任务,得到训练好的模型,最后运行evaluate.py评估训练好的模型,将结果保存在predict.txt中,并输出预测结果。

这部分在详细代码里注释得很全。

后端部分

使用flask编写post接口,接收的数据格式为application/json,将前端传来的两个句子使用训练好的模型对其进行相似度预测,将得到的相似度类型从无法序列化存入json的tensor转成list,并将状态码,信息,数据返回给前端。

from sentence_transformers import SentenceTransformer, util
# 后端接口
from flask import Flask, jsonify, request
import re
# 用当前脚本名称实例化Flask对象,方便flask从该脚本文件中获取需要的内容
app = Flask(__name__)
# 使通过jsonify返回的中文显示正常,否则显示为ASCII码
app.config["JSON_AS_ASCII"] = False
model_path = 'D:/xxx模型路径/'
model = SentenceTransformer(model_path)
@app.route("/evaluate",methods=['POST'])
def evalute_sentence():
    s1 = request.json.get("s1")
    s2 = request.json.get("s2")
    if s1 and s2:
        embedding1 = model.encode(s1, convert_to_tensor=True)
        embedding2 = model.encode(s2, convert_to_tensor=True)
        similarity = util.cos_sim(embedding1, embedding2).tolist()
        return jsonify({"code": 200, "msg": "预测成功", "data": similarity})
    else:
        return jsonify({"code": 400, "msg": "缺少字段"})
if __name__ == '__main__':
    app.run(debug=True)

前端部分

框架使用Vue2,UI框架使用elementui。组件校验用户输入的表单(内容为中文,字数限制32个字,两个句子不为空),只有符合规则的字段才能提交表单。将数据通过Axios调用接口传递给后端,再根据后端接口响应状态进行相应的处理,如果返回状态码200,说明接口调用成功,展示返回的预测值,否则调用失败,页面弹出失败消息提示。

<template>
  <div class="recommend">
    <el-card class="box">
      <h2 class="title">中文文本相似度预测</h2>
      <el-form :model="evaluateForm" :rules="evaluateRules" ref="evaluateForm" class="form">
        <el-form-item prop="s1">
          <el-input
            placeholder="请输入句子一"
            maxlength="32"
            show-word-limit
            v-model="evaluateForm.s1"
            autocomplete="false"
            prefix-icon="el-icon-edit-outline"
          ></el-input>
        </el-form-item>
        <el-form-item prop="s2">
          <el-input
            maxlength="32"
            placeholder="请输入句子二"
            v-model="evaluateForm.s2"
            show-word-limit
            autocomplete="false"
            prefix-icon="el-icon-edit-outline"
          ></el-input>
        </el-form-item>
        <el-form-item class="btn-container">
          <el-button
            type="primary"
            @click="submitForm('evaluateForm')"
            class="btn"
            id="queryButton"
          >开始预测</el-button>
        </el-form-item>
      </el-form>
      <div v-show="result" style="margin-top: 20px">
        <el-progress
          :text-inside="true"
          :stroke-width="26"
          :percentage="result*100 ? result*100 : 0"
          class="el-bg-inner-running"
        ></el-progress>
        <p>预测结果:{{result}}</p>
      </div>
    </el-card>
  </div>
</template>

<script>
import api from "@/api/index"
export default {
  data () {
    return {
      evaluateForm: {
        s1: "",
        s2: ""
      },
      evaluateRules: { // 评估表单校验规则
        s1: [
          { required: true, message: '请输入中文句子', trigger: 'blur', pattern: /^[\u4E00-\u9FA5]+$/ },
        ],
        s2: [
          { required: true, message: '请输入中文句子', trigger: 'blur', pattern: /^[\u4E00-\u9FA5]+$/ },
        ],
      },
      result: undefined,
    }
  },
  methods: {
    postEvaluate () { // 调用接口
      api.postEvaluate(this.evaluateForm)
        .then((res) => {
          if (!res) {
            return
          }
          console.log("res", res)
          if (res.data.code !== 200) {
            this.$message({
              message: "请求失败",
              type: "error"
            })
            return
          }
          let data = res.data.data[0]
          this.result = data[0]
          console.log("this.result", this.result)
          this.$message({
            message: "预测成功!",
            type: "success"
          })

        })
        .catch((error) => {
          this.$message.error('资源获取错误!')
        })
    },
    submitForm (formName) { // 提交表单
      this.$refs[formName].validate((valid) => {
        if (valid) {
          this.postEvaluate()
        } else {
          this.$message({
            message: "请按要求填写",
            type: "warning"
          })
          console.log('error in submit form')
          return false
        }
      })
      document.getElementById("queryButton").blur()
    },
  }

}
</script>

<style lang="scss" scoped>
.recommend {
  width: 100%;
  height: 100%;
  text-align: center;
  display: flex;
  text-align: center;
  flex-direction: column;
  align-items: center;
  justify-content: center;
  overflow: hidden;
  background: #00416a 0 / cover fixed; /* fallback for old browsers */
  background: -webkit-linear-gradient(
    to right,
    #00416a,
    #e4e5e6
  ); /* Chrome 10-25, Safari 5.1-6 */
  background: linear-gradient(
    to right,
    #00416a,
    #e4e5e6
  ); /* W3C, IE 10+/ Edge, Firefox 16+, Chrome 26+, Opera 12+, Safari 7+ */
  .box {
    width: 48%;
    height: 60%;
    position: relative;
    background: hsla(0, 0%, 100%, 0.3);
    z-index: 5;
    padding: 10px 20px;
    // display: flex;
    // flex-direction: column;
    // justify-content: center;
    box-sizing: border-box;
    &::before {
      content: '';
      position: absolute;
      top: 0;
      right: 0;
      bottom: 0;
      left: 0;
      filter: blur(20px);
    }
    .title {
      color: #143b54;
    }
    .btn-container {
      margin: 10px auto;
      .btn {
        width: 100%;
        border-radius: 20px;
      }
    }
  }
}
::v-deep .el-card {
  border: 0;
  box-shadow: 0 5px 16px 0 rgb(0 0 0 / 30%);
}
::v-deep .el-progress-bar__outer {
  border: 0;
  background-color: transparent;
  // background-color: #abcbe0;
}
::v-deep .el-bg-inner-running .el-progress-bar__inner {
  background: #9cecfb; /* fallback for old browsers */
  background: -webkit-linear-gradient(
    to left,
    #0052d4,
    #65c7f7,
    #9cecfb
  ); /* Chrome 10-25, Safari 5.1-6 */
  background: linear-gradient(
    to left,
    #0052d4,
    #65c7f7,
    #9cecfb
  ); /* W3C, IE 10+/ Edge, Firefox 16+, Chrome 26+, Opera 12+, Safari 7+ */
}
</style>

预训练模型比较

paraphrase-multilingual-MiniLM-L12-v2
参数设置:epochs=1,batch_size=16
特点:作为sbert官方多语言预训练模型,已带有BERT层和池化层,可直接用数据评估,但未经纯中文文本训练,准确率较低

sentence-transformers(SBert)中文文本相似度预测(附代码)_第2张图片

chinese-electra-180g-small-discriminator
参数设置:epochs=1, batch_size=16
特点:运行时间快,准确率尚可

sentence-transformers(SBert)中文文本相似度预测(附代码)_第3张图片

chinese-electra-180g-small-discriminator
参数设置:epochs=20, batch_size=16
特点:20次迭代比1次迭代有效果,但差别不大

sentence-transformers(SBert)中文文本相似度预测(附代码)_第4张图片

chinese-electra-180g-small-discriminator
参数设置:epochs=1,batch_size=8
特点:比batch_size=16时效果更好

sentence-transformers(SBert)中文文本相似度预测(附代码)_第5张图片

chinese-roberta-wwm-ext
参数设置:epochs=1,batch_size=8
特点:迭代1次和20次准确率无差别,稳定且效果在所有模型中最好,缺点是体积大运行速度慢

sentence-transformers(SBert)中文文本相似度预测(附代码)_第6张图片

最后

代码已上传至sbert中文文本相似度预测,欢迎star!

你可能感兴趣的:(课设,python,vue,vue.js,文本相似度,sbert)