llm本身是支持在终端流式输出的,以ollama为例
llm = Ollama(base_url="http://localhost:11434",
model="qwen",
callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]),
)
查看源码StreamingStdOutCallbackHandler类中的函数
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run on new LLM token. Only available when streaming is enabled."""
sys.stdout.write(token)
sys.stdout.flush()
sys.stdout.write(token)表示在终端以此输出每一个token,所以思路是可以将每一个token保存在一个对象中,然后在页面中以此输出。
class ChainStreamHandler(StreamingStdOutCallbackHandler):
def __init__(self):
self.tokens = []
self.str = ''
# 记得结束后这里置true
self.finish = False
def on_llm_new_token(self, token: str, **kwargs):
print(token)
self.str +=token
self.tokens.append(token)
def on_llm_end(self, response: LLMResult, **kwargs: any) -> None:
self.finish = 1
def on_llm_error(self, error: Exception, **kwargs: any) -> None:
print(str(error))
self.tokens.append(str(error))
def generate_tokens(self):
while not self.finish or self.tokens:
if self.tokens:
data = self.tokens.pop(0)
yield data
else:
pass
llm = Ollama(base_url="http://localhost:11434",
model=st.session_state.llm,
callback_manager=CallbackManager([chainStreamHandler]))
重载StreamingStdOutCallbackHandler类,保存token定义generate_tokens函数,返回值对象为生成器对象。可以用streamlit库的write_stream函数流式输出到页面。
def async_thread(func, *args, **kwargs):
thread = threading.Thread(target=func, args=args, kwargs=kwargs)
thread.start()
async_thread(chain, {"input_documents": docs, "question": user_input, "chat_history": st.session_state.chat_history}, return_only_outputs=True)
st.write_stream(chainStreamHandler.generate_tokens())
值得一提的是需要用async_thread异步函数执行,否则会先输出到终端再输出到页面。我想这也许会增加占用资源,但目前也没有得到更好的处理方法。
使用ollama自带函数,llm.stream(),返回一个生成器对象,直接用streamlit库的write_stream()流式输出到界面,但是没法运用知识库,所以我还是选择手动流式输出