流式传输让大模型的API接入体验更加流畅

一、什么是流式传输?

        流式传输(Streaming)是一种在计算机科学中用于数据传输的概念。它指的是将数据以连续的、流式的方式传输,而不是一次性传输完整的数据块。在流式传输中,数据被分割成小的数据块(或者字节流),并且这些小数据块被持续地传送到接收端,接收端可以立即处理这些数据,而不需要等待整个数据传输完成。

        流式传输的优势在于它能够提供更快的响应时间和更低的延迟。相比较于一次性传输完整数据,流式传输允许接收端立即处理已经传输的部分数据,从而提供更好的用户体验。这种方式尤其在实时应用、多媒体传输(如音频、视频)、网络直播等场景中非常有用。

        在网络编程中,流式传输也常常用于大文件的传输,可以边读取文件,边传输文件的内容,而不需要等待整个文件被读取到内存中。

        用Python语言的WebSocket接入大模型的API即可实现流式传输,不需要等待大模型将全部内容推理完才一次性显示,让使用者能够看到AI逐行(逐个字)生成回答,避免推理时间较长时,给用户一种卡顿、假死的感觉。

二、讯飞星火大模型的示例

import _thread as thread
import base64
import datetime
import hashlib
import hmac
import json
from urllib.parse import urlparse
import ssl
from datetime import datetime
from time import mktime
from urllib.parse import urlencode
from wsgiref.handlers import format_date_time

import websocket  # 使用websocket_client
answer = ""


class Ws_Param(object):
    # 初始化
    def __init__(self, APPID, APIKey, APISecret, Spark_url):
        self.APPID = APPID
        self.APIKey = APIKey
        self.APISecret = APISecret
        self.host = urlparse(Spark_url).netloc
        self.path = urlparse(Spark_url).path
        self.Spark_url = Spark_url
        self.Streaming = True

    # 生成url
    def create_url(self):
        # 生成RFC1123格式的时间戳
        now = datetime.now()
        date = format_date_time(mktime(now.timetuple()))

        # 拼接字符串
        signature_origin = "host: " + self.host + "\n"
        signature_origin += "date: " + date + "\n"
        signature_origin += "GET " + self.path + " HTTP/1.1"

        # 进行hmac-sha256进行加密
        signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
                                 digestmod=hashlib.sha256).digest()

        signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')

        authorization_origin = f'api_key="{self.APIKey}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'

        authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')

        # 将请求的鉴权参数组合为字典
        v = {
            "authorization": authorization,
            "date": date,
            "host": self.host
        }
        # 拼接鉴权参数,生成url
        url = self.Spark_url + '?' + urlencode(v)
        # 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
        return url


# 收到websocket错误的处理
def on_error(ws, error):
    print("### error:", error)


# 收到websocket关闭的处理
def on_close(ws,one,two):
    print(" ")


# 收到websocket连接建立的处理
def on_open(ws):
    thread.start_new_thread(run, (ws,))


def run(ws, *args):
    data = json.dumps(gen_params(appid=ws.appid, domain=ws.domain,question=ws.question))
    ws.send(data)


# 收到websocket消息的处理
def on_message(ws, message):
    # print(message)
    data = json.loads(message)
    code = data['header']['code']
    if code != 0:
        print(f'请求错误: {code}, {data}')
        ws.close()
    else:
        choices = data["payload"]["choices"]
        status = choices["status"]
        content = choices["text"][0]["content"]
        print(content, end="")
        global answer
        answer += content
        # print(1)
        if status == 2:
            ws.close()

# 当WebSocket连接接收到消息时,on_message 函数被调用,并将接收到的消息作为参数传递给 message 变量。
#
# json.loads(message) 用于将收到的消息解析为Python字典,使得你可以方便地处理消息的内容。
#
# choices = data["payload"]["choices"] 从解析后的消息中获取到了 "choices" 部分的数据。
#
# status = choices["status"] 获取到了 "status" 字段的值。
#
# content = choices["text"][0]["content"] 获取到了 "text" 字段中第一个元素的 "content" 字段的值。
#
# print(content, end="") 将获取到的内容打印出来。end="" 参数的作用是在打印结束时不换行,这样可以保持在同一行上逐个字地打印。
#
# global answer 语句表明 answer 是一个全局变量,answer += content 将获取到的内容逐个字地添加到 answer 变量中。
#
# if status == 2: 检查 "status" 的值是否为2,如果是,表示消息接收完毕,此时调用 ws.close() 关闭WebSocket连接。
#
# 通过这样的处理,代码在控制台上逐个字地打印服务器返回的信息,并且保持在同一行上,实现了逐个字显示的效果。


def gen_params(appid, domain,question):
    """
    通过appid和用户的提问来生成请参数
    """
    data = {
        "header": {
            "app_id": appid,
            "uid": "1234"
        },
        "parameter": {
            "chat": {
                "domain": domain,
                "random_threshold": 0.5,
                "max_tokens": 2048,
                "auditing": "default"
            }
        },
        "payload": {
            "message": {
                "text": question
            }
        }
    }
    return data


def main(appid, api_key, api_secret, Spark_url,domain, question):
    # print("星火:")
    wsParam = Ws_Param(appid, api_key, api_secret, Spark_url)
    websocket.enableTrace(False)
    # websocket.enableTrace(False)的作用是禁用WebSocket库的调试日志输出,使得在正式的生产环境中运行
    # 时,日志信息更加简洁,不受WebSocket库的详细调试信息干扰。
    wsUrl = wsParam.create_url()
    ws = websocket.WebSocketApp(wsUrl, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open)
    # 通过将这些回调函数传递给WebSocketApp对象,你可以在不同的连接事件发生时执行相应的操作,
    # 使得你的WebSocket客户端能够处理消息、错误、连接关闭等情况,
    # 从而实现更复杂的WebSocket应用逻辑。
    ws.appid = appid
    ws.question = question
    ws.domain = domain
    ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
    # ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE}) 的作用是启动 WebSocketApp 对象的事件循环,
    # 使得 WebSocket 客户端能够持续运行并处理来自服务器的消息,直到连接关闭或发生错误为止。
    #
    # 具体来说,run_forever() 函数会一直运行,直到 WebSocket 连接被关闭或者发生错误。在这个过程中,
    # 它会监听来自服务器的消息,并根据注册的回调函数
    # (例如 on_message、on_error、on_close 等)来处理不同的事件。
    #
    # 另外,sslopt={"cert_reqs": ssl.CERT_NONE} 是一个可选参数,用于指定 SSL/TLS 连接的选项。
    # 在这个例子中,{"cert_reqs": ssl.CERT_NONE} 表示在建立 SSL/TLS 连接时不验证服务器的 SSL 证书。
    # 这通常用于测试或者在开发过程中,但在生产环境中,为了安全起见,
    # 建议使用正确的 SSL/TLS 证书并进行验证,以确保通信的安全性。


你可能感兴趣的:(python,开发语言)