ChatGLM系列六:基于知识库的问答

ChatGLM系列六:基于知识库的问答_第1张图片

1、安装milvus

下载milvus-standalone-docker-compose.yml并保存为docker-compose.yml

wget https://github.com/milvus-io/milvus/releases/download/v2.3.2/milvus-standalone-docker-compose.yml -O docker-compose.yml

运行milvus

sudo docker-compose up -d

2、文档预处理

import os
import re
import jieba
import torch
import pandas as pd
from pymilvus import utility
from pymilvus import connections, CollectionSchema, FieldSchema, Collection, DataType
from transformers import AutoTokenizer, AutoModel

connections.connect(
    alias="default",
    host='localhost',
    port='19530'
)

# 定义集合名称和维度
collection_name = "document"
dimension = 768
docs_folder = "./knowledge/"

tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
model = AutoModel.from_pretrained("bert-base-chinese")


# 获取文本的向量
def get_vector(text):
    input_ids = tokenizer(text, padding=True, truncation=True, return_tensors="pt")["input_ids"]
    with torch.no_grad():
        output = model(input_ids)[0][:, 0, :].numpy()
    return output.tolist()[0]


def create_collection():
    # 定义集合字段
    fields = [
        FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True, description="primary id"),
        FieldSchema(name="title", dtype=DataType.VARCHAR, max_length=50),
        FieldSchema(name="content", dtype=DataType.VARCHAR, max_length=10000),
        FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=dimension),
    ]

    # 定义集合模式
    schema = CollectionSchema(fields=fields, description="collection schema")

    # 创建集合

    if utility.has_collection(collection_name):
    	# 如果你想继续添加新的文档可以直接 return。但你想要重新创建collection,就可以执行下面的代码
        # return
        utility.drop_collection(collection_name)
        collection = Collection(name=collection_name, schema=schema, using='default', shards_num=2)
        # 创建索引
        default_index = {"index_type": "IVF_FLAT", "params": {"nlist": 2048}, "metric_type": "IP"}
        collection.create_index(field_name="vector", index_params=default_index)
        print(f"Collection {collection_name} created successfully")
    else:
        collection = Collection(name=collection_name, schema=schema, using='default', shards_num=2)
        # 创建索引
        default_index = {"index_type": "IVF_FLAT", "params": {"nlist": 2048}, "metric_type": "IP"}
        collection.create_index(field_name="vector", index_params=default_index)
        print(f"Collection {collection_name} created successfully")


def init_knowledge():
    collection = Collection(collection_name)
    # 遍历指定目录下的所有文件,并导入到 Milvus 集合中
    docs = []
    for root, dirs, files in os.walk(docs_folder):
        for file in files:
            # 只处理以 .txt 结尾的文本文件
            if file.endswith(".txt"):
                file_path = os.path.join(root, file)
                with open(file_path, "r", encoding="utf-8") as f:
                    content = f.read()
                # 对文本进行清洗处理
                content = re.sub(r"\s+", " ", content)
                title = os.path.splitext(file)[0]
                # 分词
                words = jieba.lcut(content)
                # 将分词后的文本重新拼接成字符串
                content = " ".join(words)
                # 获取文本向量
                vector = get_vector(title + content)
                docs.append({"title": title, "content": content, "vector": vector})

    # 将文本内容和向量通过 DataFrame 一起导入集合中
    df = pd.DataFrame(docs)
    collection.insert(df)
    print("Documents inserted successfully")


if __name__ == "__main__":
    create_collection()
    init_knowledge()

3、知识库匹配

通过向量索引库计算出与问题最为相似的文档

import torch
from document_preprocess import get_vector
from pymilvus import Collection

collection = Collection("document")  # Get an existing collection.
collection.load()
DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"


# 定义查询函数
def search_similar_text(input_text):
    # 将输入文本转换为向量
    input_vector = get_vector(input_text)
	# 查询前三个最匹配的向量ID
    similarity = collection.search(
        data=[input_vector],
        anns_field="vector",
        param={"metric_type": "IP", "params": {"nprobe": 10}, "offset": 0},
        limit=3,
        expr=None,
        consistency_level="Strong"
    )
    ids = similarity[0].ids
    # 通过ID查询出对应的知识库文档
    res = collection.query(
        expr=f"id in {ids}",
        offset=0,
        limit=3,
        output_fields=["id", "content", "title"],
        consistency_level="Strong"
    )
    print(res)
    return res


if __name__ == "__main__":
	question = input('Please enter your question: ')
    search_similar_text(question)

4、完成回答

from transformers import AutoModel, AutoTokenizer
from knowledge_query import search_similar_text


tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True).half().cuda()
model = model.eval()


def predict(input, max_length=2048, top_p=0.7, temperature=0.95, history=[]):
	res = search_similar_text(input)
	prompt_template = f"""基于以下已知信息,简洁和专业的来回答用户的问题。
如果无法从中得到答案,请说 "当前会话仅支持解决一个类型的问题,请清空历史信息重试",不允许在答案中添加编造成分,答案请使用中文。

已知内容:
{res}

问题:
{input}
"""
	query = prompt_template
	for response, history in model.stream_chat(tokenizer, query, history, max_length=max_length, top_p=top_p,
	                                           temperature=temperature):
	    chatbot[-1] = (parse_text(input), parse_text(response))
	
	    yield chatbot, history

from transformers import AutoModel, AutoTokenizer
import gradio as gr
import mdtex2html

from knowledge_query import search_similar_text

tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True).half().cuda()
model = model.eval()
is_knowledge = True

"""Override Chatbot.postprocess"""


def postprocess(self, y):
    if y is None:
        return []
    for i, (message, response) in enumerate(y):
        y[i] = (
            None if message is None else mdtex2html.convert((message)),
            None if response is None else mdtex2html.convert(response),
        )
    return y


gr.Chatbot.postprocess = postprocess


def parse_text(text):
    """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
    lines = text.split("\n")
    lines = [line for line in lines if line != ""]
    count = 0
    for i, line in enumerate(lines):
        if "```" in line:
            count += 1
            items = line.split('`')
            if count % 2 == 1:
                lines[i] = f'
{items[-1]}">'
            else:
                lines[i] = f'
'
else: if i > 0: if count % 2 == 1: line = line.replace("`", "\`") line = line.replace("<", "<") line = line.replace(">", ">") line = line.replace(" ", " ") line = line.replace("*", "*") line = line.replace("_", "_") line = line.replace("-", "-") line = line.replace(".", ".") line = line.replace("!", "!") line = line.replace("(", "(") line = line.replace(")", ")") line = line.replace("$", "$") lines[i] = "
"
+line text = "".join(lines) return text def predict(input, chatbot, max_length, top_p, temperature, history): global is_knowledge chatbot.append((parse_text(input), "")) query = input if is_knowledge: res = search_similar_text(input) prompt_template = f"""基于以下已知信息,简洁和专业的来回答用户的问题。 如果无法从中得到答案,请说 "当前会话仅支持解决一个类型的问题,请清空历史信息重试",不允许在答案中添加编造成分,答案请使用中文。 已知内容: {res} 问题: {input} """ query = prompt_template is_knowledge = False for response, history in model.stream_chat(tokenizer, query, history, max_length=max_length, top_p=top_p, temperature=temperature): chatbot[-1] = (parse_text(input), parse_text(response)) yield chatbot, history def reset_user_input(): return gr.update(value='') def reset_state(): global is_knowledge is_knowledge = False return [], [] with gr.Blocks() as demo: gr.HTML("""

ChatGLM

"""
) chatbot = gr.Chatbot() with gr.Row(): with gr.Column(scale=4): with gr.Column(scale=12): user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style( container=False) with gr.Column(min_width=32, scale=1): submitBtn = gr.Button("Submit", variant="primary") with gr.Column(scale=1): emptyBtn = gr.Button("Clear History") max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True) top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True) temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True) history = gr.State([]) submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history], [chatbot, history], show_progress=True) submitBtn.click(reset_user_input, [], [user_input]) emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True) demo.queue().launch(share=False, inbrowser=True)

你可能感兴趣的:(自然语言处理)