springboot集成sse实现后端流式输出消息

springboot集成sse实现后端流式输出消息

一、前言

在实现人工智能聊天的过程中,往往不难发现,主流的输出方式都是采用流式输出,而在后端实现流式输出可以有两种实现思路,第一种是采用sse,第二种是采用websocket,本编将为大家介绍 sse实现流式输出消息

二、源码资料

还没有springboot专栏集的相关源码资料的同学,可以前往 springboot集成sse实现流式输出下载源码

三、实现思路

1、创建sse对象,前端传入唯一的id,作为缓存该session对象的key

2、创建一个供前端提问的接口,传入对应的唯一id,后端拿到id后去缓存获取session对象

3、如果需要停止和AI会话,还可以创建一个关闭session会话的接口

四、重难点讲解

1、创建sse对象接口

在创建sse接口的过程中,注意需要做到每一页面会话必须唯一,确保前后端交互一致
package com.jiuzhou.controller;

import com.jiuzhou.common.dto.ChatDto;
import com.jiuzhou.common.vo.ChatVo;
import com.jiuzhou.service.SseService;
import com.unfbx.chatgpt.exception.BaseException;
import com.unfbx.chatgpt.exception.CommonError;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Controller;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;

import javax.servlet.http.HttpServletResponse;
import java.util.Map;

/**
 * github地址 http://www.github.com/wanyushu
 * gitee地址 http://www.gitee.com/wanyushu
 *
 * @author yushu
 * @email [email protected]
 * @date 2024/1/3 9:31
 */
@Controller
@Slf4j
public class ChatController {

    private final SseService sseService;

    public ChatController(SseService sseService) {
        this.sseService = sseService;
    }

    /**
     * 创建sse连接
     *
     * @param headers
     * @return
     */
    @CrossOrigin
    @GetMapping("/createSse")
    public SseEmitter createConnect(@RequestHeader Map<String, String> headers) {
        String uid = getUid(headers);
        return sseService.createSse(uid);
    }

    /**
     * 聊天接口
     *
     * @param chatRequest
     * @param headers
     */
    @CrossOrigin
    @PostMapping("/chat")
    @ResponseBody
    public ChatVo sseChat(@RequestBody ChatDto chatRequest, @RequestHeader Map<String, String> headers, HttpServletResponse response) {
        String uid = getUid(headers);
        return sseService.sseChat(uid, chatRequest);
    }
    /**
     * 关闭连接
     *
     * @param headers
     */
    @CrossOrigin
    @GetMapping("/closeSse")
    public void closeConnect(@RequestHeader Map<String, String> headers) {

        try{
            String uid = getUid(headers);
            sseService.closeSse(uid);
        }catch (Exception e){
            e.printStackTrace();
        }

    }

    /**
     * 获取uid
     *
     * @param headers
     * @return
     */
    private String getUid(Map<String, String> headers) {
        String uid = headers.get("uid");
        if (null==uid) {
            throw new BaseException(CommonError.SYS_ERROR);
        }
        return uid;
    }


}

2、创建接口

package com.jiuzhou.service;

import com.jiuzhou.common.dto.ChatDto;
import com.jiuzhou.common.vo.ChatVo;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;

public interface SseService {
    /**
     * 创建SSE
     * @param uid
     * @return
     */
    SseEmitter createSse(String uid);


    /**
     * 客户端发送消息到服务端
     * @param uid
     * @param chatDto
     */
    ChatVo sseChat(String uid, ChatDto chatDto);

    /**
     * 关闭SSE
     * @param uid
     */
    void closeSse(String uid);

}

3、创建实现类

package com.jiuzhou.service.impl;

import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSONUtil;
import com.jiuzhou.common.dto.ChatDto;
import com.jiuzhou.common.listener.OpenAISSEEventSourceListener;
import com.jiuzhou.common.vo.ChatVo;
import com.jiuzhou.service.SseService;
import com.jiuzhou.utils.LocalCache;
import com.unfbx.chatgpt.OpenAiStreamClient;
import com.unfbx.chatgpt.entity.chat.ChatCompletion;
import com.unfbx.chatgpt.entity.chat.Message;
import com.unfbx.chatgpt.exception.BaseException;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

/**
 * 描述:
 */
@Service
@Slf4j
public class SseServiceImpl implements SseService {

    private final OpenAiStreamClient openAiStreamClient;

    public SseServiceImpl(OpenAiStreamClient openAiStreamClient) {
        this.openAiStreamClient = openAiStreamClient;
    }

    @Override
    public SseEmitter createSse(String uid) {
        //默认30秒超时,设置为0L则永不超时
        SseEmitter sseEmitter = new SseEmitter(0l);
        //完成后回调
        sseEmitter.onCompletion(() -> {
            log.info("[{}]结束连接...................", uid);
            LocalCache.CACHE.remove(uid);
        });
        //超时回调
        sseEmitter.onTimeout(() -> {
            log.info("[{}]连接超时...................", uid);
        });
        //异常回调
        sseEmitter.onError(
                throwable -> {
                    try {
                        log.info("[{}]连接异常,{}", uid, throwable.toString());
                        sseEmitter.send(SseEmitter.event()
                                .id(uid)
                                .name("发生异常!")
                                .data(Message.builder().content("发生异常请重试!").build())
                                .reconnectTime(3000));
                        LocalCache.CACHE.put(uid, sseEmitter);
                    } catch (IOException e) {
                        e.printStackTrace();
                    }
                }
        );
        LocalCache.CACHE.put(uid, sseEmitter);
        log.info("[{}]创建sse连接成功!", uid);
        return sseEmitter;
    }

    @Override
    public void closeSse(String uid) {
        SseEmitter sse = (SseEmitter) LocalCache.CACHE.get(uid);
        if (sse != null) {
            sse.complete();
            //移除
            LocalCache.CACHE.remove(uid);
        }
    }

    @Override
    public ChatVo sseChat(String uid, ChatDto chatRequest) {
        if (StrUtil.isBlank(chatRequest.getQuestion())) {
            log.info("参数异常,msg为null", uid);
            throw new BaseException("参数异常,msg不能为空~");
        }
        String messageContext = (String) LocalCache.CACHE.get("msg" + uid);
        List<Message> messages = new ArrayList<>();
        if (StrUtil.isNotBlank(messageContext)) {
            messages = JSONUtil.toList(messageContext, Message.class);
            if (messages.size() >= 5) {
                messages = messages.subList(1, 5);
            }
            Message currentMessage = Message.builder().content(chatRequest.getQuestion()).role(Message.Role.USER).build();
            messages.add(currentMessage);
        } else {
            Message currentMessage = Message.builder().content(chatRequest.getQuestion()).role(Message.Role.USER).build();
            messages.add(currentMessage);
        }

        SseEmitter sseEmitter = (SseEmitter) LocalCache.CACHE.get(uid);

        if (sseEmitter == null) {
            log.info("聊天消息推送失败uid:[{}],没有创建连接,请重试。", uid);
            throw new BaseException("聊天消息推送失败uid:[{}],没有创建连接,请重试。~");
        }
        OpenAISSEEventSourceListener openAIEventSourceListener = new OpenAISSEEventSourceListener(sseEmitter);
        ChatCompletion completion = ChatCompletion
                .builder()
                .messages(messages)
                .model(ChatCompletion.Model.GPT_3_5_TURBO_0613.getName())
                .build();
        openAiStreamClient.streamChatCompletion(completion, openAIEventSourceListener);
        LocalCache.CACHE.put("msg" + uid, JSONUtil.toJsonStr(messages), LocalCache.TIMEOUT);
        ChatVo response = new ChatVo();
        response.setQuestionTokens(completion.tokens());
        return response;
    }
}
注意:OpenAISSEEventSourceListener 可以是其它第三方平台的数据流

五、演示效果

1、下载 从零开始搭建AI聊天 前端源码,发送消息
springboot集成sse实现后端流式输出消息_第1张图片

六、总结及注意事项

1、在创建sse的过程中需用到缓存自动失效工具类,请注意对应的客户端需做好定时心跳重连

2、sse客户的session 缓存可以在实战项目中用token解密后的用户id作为唯一key

3、消息的会话线可以设置合适的会话长度

4、前后端分离的接口需做好允许跨域操作

如果大佬们有什么优化建议,请在评论区帮忙留言,顺便点赞关注哦

你可能感兴趣的:(springboot使用技能,spring,boot,后端,java,人工智能)