python 代码
async def knowledge_base_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
score_threshold: float = Body(
SCORE_THRESHOLD,
description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右",
ge=0,
le=2
),
history: List[History] = Body(
[],
description="历史对话",
examples=[[
{"role": "user",
"content": "我们来玩成语接龙,我先来,生龙活虎"},
{"role": "assistant",
"content": "虎头虎脑"}]]
),
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(
None,
description="限制LLM生成Token数量,默认None代表模型最大值"
),
prompt_name: str = Body(
"default",
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"
),
request: Request = None,
):
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
history = [History.from_data(h) for h in history]
async def knowledge_base_chat_iterator(
query: str,
top_k: int,
history: Optional[List[History]],
model_name: str = LLM_MODELS[0],
prompt_name: str = prompt_name,
) -> AsyncIterable[str]:
nonlocal max_tokens
callback = AsyncIteratorCallbackHandler()
if isinstance(max_tokens, int) and max_tokens <= 0:
max_tokens = None
model = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
callbacks=[callback],
)
docs = search_docs(query, knowledge_base_name, top_k, score_threshold)
context = "\n".join([doc.page_content for doc in docs])
if len(docs) == 0: # 如果没有找到相关文档,使用empty模板
prompt_template = get_prompt_template("knowledge_base_chat", "empty")
else:
prompt_template = get_prompt_template("knowledge_base_chat", prompt_name)
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages(
[i.to_msg_template() for i in history] + [input_msg])
chain = LLMChain(prompt=chat_prompt, llm=model)
# Begin a task that runs in the background.
task = asyncio.create_task(wrap_done(
chain.acall({"context": context, "question": query}),
callback.done),
)
source_documents = []
for doc in docs:
text = doc.page_content.rstrip(' [链接]:\n') + '\n'
if text not in source_documents:
source_documents.append(text)
if len(source_documents) == 0: # 没有找到相关文档
source_documents = []
if stream:
async for token in callback.aiter():
# Use server-sent-events to stream the response
print(f"answer:{token}")
yield json.dumps({"answer": token}, ensure_ascii=False)
yield json.dumps({"docs": source_documents}, ensure_ascii=False)
else:
answer = ""
async for token in callback.aiter():
answer += token
yield json.dumps({"answer": answer,
"docs": source_documents},
ensure_ascii=False)
await task
return StreamingResponse(knowledge_base_chat_iterator(query=query,
top_k=top_k,
history=history,
model_name=model_name,
prompt_name=prompt_name),
media_type="text/event-stream")
vue.js 代码
<template>
<div>
<h2>Streamed Responses:</h2>
<div v-for="(message, index) in messages" :key="index">{{ message }}</div>
</div>
</template>
<script>
import {TRUE} from "sass";
export default {
data() {
return {
messages: [],
reader: null, // 用于存储流的阅读器
};
},
mounted() {
this.postDataAndStreamResponse();
},
beforeUnmount() {
if (this.reader) {
this.reader.cancel(); // 组件销毁时取消流阅读
}
},
methods: {
async postDataAndStreamResponse() {
try {
const response = await fetch('http://XXXXX/chat/knowledge_base_chat', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'Accept': 'text/event-stream',
},
body: JSON.stringify({
"query": "怎么打官司",
"knowledge_base_name": "samples",
"top_k": 5,
"score_threshold": 0.5,
"history": [{
"role": "user",
"content": "我们来玩成语接龙,我先来,生龙活虎"
}, {
"role": "assistant",
"content": "虎头虎脑"
}],
"stream": true,
"model_name": "qwen-api",
"temperature": 0.5,
"max_tokens": 1000,
"prompt_name": "default"
})
});
this.reader = response.body.getReader();
this.readStream();
} catch (error) {
console.error('Stream fetch error:', error);
// 这里可以添加用户友好的错误处理
}
},
async readStream() {
try {
const decoder = new TextDecoder();
while (TRUE) { // 使用循环而非递归
const { value, done } = await this.reader.read();
if (done) break; // 如果没有更多数据,则退出循环
const text = decoder.decode(value, { stream: true });
this.messages.push(text);
}
} catch (error) {
console.error('Stream read error:', error);
// 这里可以添加用户友好的错误处理
}
},
},
};
</script>