关于springboot对接chatglm3-6b大模型的尝试

之前我们通过阿里提供的cloud ai对接了通义千问。cloud ai对接通义千问
那么接下来我们尝试一些别的模型看一下,其实这个文章主要是表达一种对接方式,其他的都大同小异。都可以依此方法进行处理。

一、明确模型参数

本次我们对接的理论支持来自于阿里云提供的文档。阿里云大3-6b模型文档
我们看到他其实支持多种调用方式,包括sdk和http,我本人是不喜欢sdk的,因为会有冲突或者版本之类的问题,不如直接调用三方,把问题都扔到三方侧。所以我们这里来展示一下使用http的调用方式。
而且大模型的chat一般都是流式的,非流式的没啥技术含量而且效果很low。所以我们直接参考这部分内容即可,
关于springboot对接chatglm3-6b大模型的尝试_第1张图片
我们看到他们的服务端其实是支持SSE的推流方式的,具体SSE是啥可以自行百度。
而流式和非流式的区别就在于请求参数的设置。如果你配置了,那大模型端就会给你按照流式响应。
关于springboot对接chatglm3-6b大模型的尝试_第2张图片
在有了以上理论支持之后,我们就来测试一下。

二、代码接入

我们看到他的示例请求参数为:

curl --location 'https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation' \
--header 'Authorization: Bearer ' \  # 这里写你的appkey
--header 'Content-Type: application/json' \
--header 'X-DashScope-SSE: enable' \   # 开启流式
--data '{
    "model": "chatglm3-6b", # 模型名字
    "input":{
        "messages":[      
       
            {
                "role": "user",
                "content": "你好,请介绍一下故宫"
            }
        ]
    },
    "parameters": {
        "result_format": "message"
    }
}'

所以我们可以找到关键点就在以上三处,至于如何申请appkey,可以参考官方。
那么我们接下来就使用okhttp这种支持事件响应的来对接流式的输出。

1、编写返回内容反序列化类

首先我们先来处理返回格式,我决定用一个java类来接受,具体你觉得不灵活可以直接用Json,怎么弄都行。
我们来看一下官网的响应示例格式。

{"output":{"choices":[{"message":{"content":"\n 故宫是中国北京市中心的一座明清两代的皇宫,现已成为博物馆。故宫是中国最具代表性的古建筑之一,也是世界文化遗产之一,以其丰富的文化遗产和精美的建筑艺术而闻名于世界。故宫占地面积达72万平方米,拥有9000多间房屋和70多座建筑,由大小湖泊、宫殿、花园和殿堂组成,是中国古代宫殿建筑之精华。","role":"assistant"},"finish_reason":"stop"}]},"usage":{"total_tokens":105,"input_tokens":24,"output_tokens":81},"request_id":"9d970376-4ba3-98b8-8387-f95702280341"}

我们看到他是个字符串,然后在流式的最后一句他的finish_reason的值是stop,这时候我们就可以结束推流。
OK,我们就来接收一下。

import lombok.Data;

@Data
public class Chatglm36bResponse {
    private Output output;
    private Usage usage;
    private String requestId;

    @Data
    public static class Output {
        private Choice[] choices;

        @Data
        public static class Choice {
            private Message message;
            private String finishReason;


            @Data
            public static class Message {
                private String content;
                private String role;
            }
        }
    }

    @Data
    public static class Usage {
        private int totalTokens;
        private int inputTokens;
        private int outputTokens;
    }
}

2、编写event事件监听器


import com.alibaba.fastjson.JSONObject;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import okhttp3.Response;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener;

import java.io.IOException;

@EqualsAndHashCode(callSuper = true)
@Data
@Slf4j
@NoArgsConstructor
@AllArgsConstructor
public class ChatEventSourceListener extends EventSourceListener {

    private String clientId;

    @Override
    public void onOpen(EventSource eventSource, Response response) {
        log.info("ChatEventSourceListener onOpen invoke");
        super.onOpen(eventSource, response);
    }

    @Override
    public void onEvent(EventSource eventSource, String id, String type, String data) {
        log.info("ChatEventSourceListener onEvent invoke");
        Chatglm36bResponse chatglm36bResponse = JSONObject.parseObject(data, Chatglm36bResponse.class);
        Chatglm36bResponse.Output output = chatglm36bResponse.getOutput();
        Chatglm36bResponse.Output.Choice[] choices = output.getChoices();
        for (Chatglm36bResponse.Output.Choice choice : choices) {
            String finishReason = choice.getFinishReason();
            String content = choice.getMessage().getContent();
            log.info("ChatEventSourceListener onEvent finishReason is:{},content is:{}", finishReason, content);

            try {
            	// 给前端推流,前端有组件可以接收这种流。
                SseEmitterUtils.sendMsg(clientId, content);
            } catch (IOException e) {
                throw new RuntimeException(e);
            }
            // 结束了,取消事件,并且结束SSE推流
            if ("stop".equals(finishReason)) {
                eventSource.cancel();
                SseEmitterUtils.completeDelay(clientId);
            }
        }
        super.onEvent(eventSource, id, type, data);
    }

    @Override
    public void onClosed(EventSource eventSource) {
        log.info("ChatEventSourceListener onClosed invoke ******");
        super.onClosed(eventSource);
    }

    @Override
    public void onFailure(EventSource eventSource, Throwable t, Response response) {
        super.onFailure(eventSource, t, response);
        String message = response.message();
        response.close();
        log.info("ChatEventSourceListener onFailure invoke ****** Throwable is:{},res is {}", t.getMessage(),message);
    }
}

我们在每一类事件里面都做了相应的处理。
与之配套的是一个SSE的工具类。

package com.yxy.springbootdemo.utils.sse;


import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;

import java.io.IOException;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.*;

public class SseEmitterUtils {

    private static final Logger logger = LoggerFactory.getLogger(SseEmitterUtils.class);

    private static final ThreadPoolExecutor ssePool =  new ThreadPoolExecutor(
                20,
                200,
                30,
                TimeUnit.SECONDS,
                new LinkedBlockingQueue<>(1000),
                runnable -> new Thread(runnable, "sse-sendMsg-pool"),
                new ThreadPoolExecutor.AbortPolicy()
    );

    // SSE连接关闭延迟时间
    private static final Integer EMITTER_COMPLETE_DELAY_MILLISECONDS = 5000;

    // SSE连接初始化超时时间
    private static final Long EMITTER_TIME_OUT_MILLISECONDS = 600_000L;

    // 缓存 SSE连接
    private static final Map<String, SseEmitter> SSE_CACHE = new ConcurrentHashMap<>();

    /**
     * 获取 SSE连接 默认超时时间EMITTER_TIME_OUT_MILLISECONDS 毫秒
     *
     * @param clientId 客户端 ID
     * @return 连接对象
     */
    public static SseEmitter getConnection(String clientId) {
       return getConnection(clientId,EMITTER_TIME_OUT_MILLISECONDS);
    }

    /**
     * 获取 SSE连接
     *
     * @param clientId 客户端 ID
     * @param timeout  连接超时时间,单位毫秒
     * @return 连接对象
     */
    public static SseEmitter getConnection(String clientId,Long timeout) {
        final SseEmitter sseEmitter = SSE_CACHE.get(clientId);
        if (Objects.nonNull(sseEmitter)) {
            return sseEmitter;
        } else {
            final SseEmitter emitter = new SseEmitter(timeout);

            // 初始化emitter回调
            initSseEmitter(emitter, clientId);

            // 连接建立后,将连接放入缓存
            SSE_CACHE.put(clientId, emitter);
            logger.info("[SseEmitter] 连接已建立,clientId = {}", clientId);
            return emitter;
        }
    }

    /**
     * 关闭指定的流连接
     *
     * @param clientId 客户端 ID
     */
    public static void closeConnection(String clientId) {
        final SseEmitter sseEmitter = SSE_CACHE.get(clientId);
        logger.info("[流式响应-停止生成] 收到客户端关闭连接指令,Emitter is {},clientId = {}", null == sseEmitter ? "NOT-Exist" : "Exist", clientId);
        if (Objects.nonNull(sseEmitter)) {
            SSE_CACHE.remove(clientId);
            sseEmitter.complete();
        }
        try {
            TimeUnit.MILLISECONDS.sleep(EMITTER_COMPLETE_DELAY_MILLISECONDS);
        } catch (InterruptedException ex) {
            logger.error("流式响应异常", ex);
            Thread.currentThread().interrupt();
        }
    }

    /**
     * 推送消息
     *
     * @param clientId 客户端 ID
     * @param msg      消息
     * @return 连接是否存在
     * @throws IOException IO异常
     */
    public static boolean sendMsg(String clientId, String msg) throws IOException {
        final SseEmitter sseEmitter = SSE_CACHE.get(clientId);
        if (Objects.nonNull(sseEmitter)) {
            try {
                sseEmitter.send(msg);
            } catch (Exception e) {
                logger.error("[流式响应-停止生成] ");
                return true;
            }
            return false;
        } else {
            return true;
        }
    }

    /**
     * 异步推送消息 TODO 目前未实现提供回调
     *
     * @param clientId 客户端 ID
     * @param msg      消息
     * @return 连接是否存在
     * @throws IOException IO异常
     */
    public static boolean sendMsgAsync(String clientId, String msg){
        final SseEmitter sseEmitter = SSE_CACHE.get(clientId);
        if (Objects.nonNull(sseEmitter)) {
            try {
                ssePool.submit(()->{
                    try {
                        sseEmitter.send(msg);
                    } catch (IOException e) {
                        logger.error("[流式响应-停止生成] ");
                    }
                });
            } catch (Exception e) {
                logger.error("[流式响应-停止生成] ");
                return true;
            }
            return false;
        } else {
            return true;
        }
    }

    /**
     * 立即关闭SseEmitter,可能存在推流不完全的情况,谨慎使用
     *
     * @param clientId
     */
    public static void complete(String clientId) {
        completeDelay(clientId,0);
    }

    /**
     * 延迟关闭 SseEmitter,延迟一定时长时为了尽量保证最后一次推送数据被前端完整接收
     *
     * @param clientId 客户端ID
     */
    public static void completeDelay(String clientId) {
        completeDelay(clientId,EMITTER_COMPLETE_DELAY_MILLISECONDS);
    }

    /**
     * 延迟关闭 SseEmitter,延迟指定时长时为了尽量保证最后一次推送数据被前端完整接收
     *
     * @param clientId 客户端ID
     */
    public static void completeDelay(String clientId,Integer delayMilliSeconds) {
        final SseEmitter sseEmitter = SSE_CACHE.get(clientId);
        if (Objects.nonNull(sseEmitter)) {
            try {
                TimeUnit.MILLISECONDS.sleep(delayMilliSeconds);
                sseEmitter.complete();
            } catch (InterruptedException ex) {
                logger.error("流式响应异常", ex);
                Thread.currentThread().interrupt();
            }
        }
    }

    /**
     * 初始化 SSE连接 设置一些属性和回调之类的
     *
     * @param emitter 连接对象
     * @param clientId 客户端 ID
     */
    private static void initSseEmitter(SseEmitter emitter, String clientId){
        // 设置SSE的超时回调
        emitter.onTimeout(() -> {
            logger.info("[SseEmitter] 连接已超时,正准备关闭,clientId = {}", clientId);
            SSE_CACHE.remove(clientId);
        });

        // 设置SSE的结束回调
        emitter.onCompletion(() -> {
            logger.info("[SseEmitter] 连接已释放,clientId = {}", clientId);
            SSE_CACHE.remove(clientId);
        });

        // 设置SSE的异常回调
        emitter.onError(throwable -> {
            logger.error("[SseEmitter] 连接已异常,正准备关闭,clientId = {}", clientId);
            SSE_CACHE.remove(clientId);
        });
    }
}

3、编写调用接口


import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;

import okhttp3.MediaType;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener;
import okhttp3.sse.EventSources;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;

import java.util.concurrent.CompletableFuture;

@RestController
@RequestMapping("/chat")
public class StreamChatController {

    @PostMapping("/send")
    public SseEmitter sendMessage(@RequestParam String username, @RequestParam String message) {

        SseEmitter sseEmitter = SseEmitterUtils.getConnection(username);

        CompletableFuture.runAsync(()->send(username,message));

        return sseEmitter;
    }

    public void send(String username,String message){
        OkHttpClient client = new OkHttpClient();
        JSONObject inputJson = new JSONObject();
        JSONArray messagesArray = new JSONArray();
        JSONObject systemMessage = new JSONObject();
        systemMessage.put("role", "system");
        systemMessage.put("content", "You are a helpful assistant.");
        messagesArray.add(systemMessage);

        JSONObject userMessage = new JSONObject();
        userMessage.put("role", "user");
        userMessage.put("content", message);
        messagesArray.add(userMessage);

        inputJson.put("messages", messagesArray);

        JSONObject payloadJson = new JSONObject();
        payloadJson.put("model", "chatglm3-6b");
        payloadJson.put("input", inputJson);

        JSONObject parametersJson = new JSONObject();
        parametersJson.put("result_format", "message");
        payloadJson.put("parameters", parametersJson);

        String json = payloadJson.toString();

        RequestBody body = RequestBody.create(MediaType.parse("application/json"),json);

        Request request = new Request.Builder()
                .url("https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation")
                .post(body)
                .addHeader("Authorization", "Bearer 你得API-KEY")
                .addHeader("Content-Type", "application/json")
                .addHeader("X-DashScope-SSE", "enable")
                .build();

        // 创建事件监听器
        EventSourceListener eventSourceListener = new ChatEventSourceListener(username);

        EventSource.Factory factory = EventSources.createFactory(client);
        // 创建事件
        EventSource eventSource = factory.newEventSource(request, eventSourceListener);
        // 与服务器建立连接
        eventSource.request();
    }
}

4、编写前端

我这个有点粗糙,实际效果比这好的多。

DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <title>SSE Chattitle>
head>
<body>
<h1>YXY-Chath1>

<div id="chat-messages">div>
<form id="message-form">
    <input type="text" id="message-input" placeholder="输入消息">
    <button type="submit">发送button>
form>

<script>
    const chatMessages = document.getElementById('chat-messages');
    const messageForm = document.getElementById('message-form');
    const messageInput = document.getElementById('message-input');

    // 连接到聊天室
    const connectToChat = () => {
        const username = prompt('Enter your username:');
        const eventSource = new EventSource(`/chat/connect?username=${encodeURIComponent(username)}`);

        // 接收来自服务器的消息
        eventSource.onmessage = function(event) {
            const message = event.data;
            displayMessage(message);
        };

        // 处理连接错误
        eventSource.onerror = function(event) {
            console.error('EventSource error:', event);
            eventSource.close();
        };

        // 提交消息表单
        messageForm.addEventListener('submit', function(event) {
            event.preventDefault();
            const message = messageInput.value.trim();
            if (message !== '') {
                sendMessage(username, message);
                messageInput.value = '';
            }
        });
    };

    // 发送消息到服务器
    const sendMessage = (username, message) => {
        fetch(`/chat/send?username=${encodeURIComponent(username)}&message=${encodeURIComponent(message)}`, {
            method: 'POST'
        })
        .catch(error => console.error('Error sending message:', error));
    };

    // 在界面上显示消息
    const displayMessage = (message) => {
        const messageElement = document.createElement('div');
        messageElement.textContent = message;
        chatMessages.appendChild(messageElement);
    };
    // 发起连接
    connectToChat();

script>
body>
html>

5、发起调用

关于springboot对接chatglm3-6b大模型的尝试_第3张图片
我们看到其实是成功了,但是前端没有把流数据渲染上去,我不太懂前端,后面改一改试试。

三、总结

我们这只是其中一种模型的对接,其实别的也都差不多,都是基于流可以用http来操作,你可以在你的项目中建立一个AI中台,来对接各种模型,给别的服务提供调用。只是需要看明白每种模型的参数。
而且我们目前只是简单的实现,还存在很多问题,比如okhttp客户端没有做池化,每次都是new出来的。
CompletableFuture的异步调用没有指定线程池,还是共用的默认池,这样会导致可能被别的业务影响。
等等细节问题,我们这里先不做处理,后面如果真的要用,可以着手细节处的优化。

你可能感兴趣的:(#,springcloud,#,springboot,JAVA,spring,boot,后端,语言模型)