在实现人工智能聊天的过程中,往往不难发现,主流的输出方式都是采用流式输出,而在后端实现流式输出可以有两种实现思路,第一种是采用sse,第二种是采用websocket,本编将为大家介绍 sse实现流式输出消息
还没有springboot专栏集的相关源码资料的同学,可以前往 springboot集成sse实现流式输出下载源码
1、创建sse对象,前端传入唯一的id,作为缓存该session对象的key
2、创建一个供前端提问的接口,传入对应的唯一id,后端拿到id后去缓存获取session对象
3、如果需要停止和AI会话,还可以创建一个关闭session会话的接口
在创建sse接口的过程中,注意需要做到每一页面会话必须唯一,确保前后端交互一致
package com.jiuzhou.controller;
import com.jiuzhou.common.dto.ChatDto;
import com.jiuzhou.common.vo.ChatVo;
import com.jiuzhou.service.SseService;
import com.unfbx.chatgpt.exception.BaseException;
import com.unfbx.chatgpt.exception.CommonError;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Controller;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import javax.servlet.http.HttpServletResponse;
import java.util.Map;
/**
* github地址 http://www.github.com/wanyushu
* gitee地址 http://www.gitee.com/wanyushu
*
* @author yushu
* @email [email protected]
* @date 2024/1/3 9:31
*/
@Controller
@Slf4j
public class ChatController {
private final SseService sseService;
public ChatController(SseService sseService) {
this.sseService = sseService;
}
/**
* 创建sse连接
*
* @param headers
* @return
*/
@CrossOrigin
@GetMapping("/createSse")
public SseEmitter createConnect(@RequestHeader Map<String, String> headers) {
String uid = getUid(headers);
return sseService.createSse(uid);
}
/**
* 聊天接口
*
* @param chatRequest
* @param headers
*/
@CrossOrigin
@PostMapping("/chat")
@ResponseBody
public ChatVo sseChat(@RequestBody ChatDto chatRequest, @RequestHeader Map<String, String> headers, HttpServletResponse response) {
String uid = getUid(headers);
return sseService.sseChat(uid, chatRequest);
}
/**
* 关闭连接
*
* @param headers
*/
@CrossOrigin
@GetMapping("/closeSse")
public void closeConnect(@RequestHeader Map<String, String> headers) {
try{
String uid = getUid(headers);
sseService.closeSse(uid);
}catch (Exception e){
e.printStackTrace();
}
}
/**
* 获取uid
*
* @param headers
* @return
*/
private String getUid(Map<String, String> headers) {
String uid = headers.get("uid");
if (null==uid) {
throw new BaseException(CommonError.SYS_ERROR);
}
return uid;
}
}
package com.jiuzhou.service;
import com.jiuzhou.common.dto.ChatDto;
import com.jiuzhou.common.vo.ChatVo;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
public interface SseService {
/**
* 创建SSE
* @param uid
* @return
*/
SseEmitter createSse(String uid);
/**
* 客户端发送消息到服务端
* @param uid
* @param chatDto
*/
ChatVo sseChat(String uid, ChatDto chatDto);
/**
* 关闭SSE
* @param uid
*/
void closeSse(String uid);
}
package com.jiuzhou.service.impl;
import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSONUtil;
import com.jiuzhou.common.dto.ChatDto;
import com.jiuzhou.common.listener.OpenAISSEEventSourceListener;
import com.jiuzhou.common.vo.ChatVo;
import com.jiuzhou.service.SseService;
import com.jiuzhou.utils.LocalCache;
import com.unfbx.chatgpt.OpenAiStreamClient;
import com.unfbx.chatgpt.entity.chat.ChatCompletion;
import com.unfbx.chatgpt.entity.chat.Message;
import com.unfbx.chatgpt.exception.BaseException;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
/**
* 描述:
*/
@Service
@Slf4j
public class SseServiceImpl implements SseService {
private final OpenAiStreamClient openAiStreamClient;
public SseServiceImpl(OpenAiStreamClient openAiStreamClient) {
this.openAiStreamClient = openAiStreamClient;
}
@Override
public SseEmitter createSse(String uid) {
//默认30秒超时,设置为0L则永不超时
SseEmitter sseEmitter = new SseEmitter(0l);
//完成后回调
sseEmitter.onCompletion(() -> {
log.info("[{}]结束连接...................", uid);
LocalCache.CACHE.remove(uid);
});
//超时回调
sseEmitter.onTimeout(() -> {
log.info("[{}]连接超时...................", uid);
});
//异常回调
sseEmitter.onError(
throwable -> {
try {
log.info("[{}]连接异常,{}", uid, throwable.toString());
sseEmitter.send(SseEmitter.event()
.id(uid)
.name("发生异常!")
.data(Message.builder().content("发生异常请重试!").build())
.reconnectTime(3000));
LocalCache.CACHE.put(uid, sseEmitter);
} catch (IOException e) {
e.printStackTrace();
}
}
);
LocalCache.CACHE.put(uid, sseEmitter);
log.info("[{}]创建sse连接成功!", uid);
return sseEmitter;
}
@Override
public void closeSse(String uid) {
SseEmitter sse = (SseEmitter) LocalCache.CACHE.get(uid);
if (sse != null) {
sse.complete();
//移除
LocalCache.CACHE.remove(uid);
}
}
@Override
public ChatVo sseChat(String uid, ChatDto chatRequest) {
if (StrUtil.isBlank(chatRequest.getQuestion())) {
log.info("参数异常,msg为null", uid);
throw new BaseException("参数异常,msg不能为空~");
}
String messageContext = (String) LocalCache.CACHE.get("msg" + uid);
List<Message> messages = new ArrayList<>();
if (StrUtil.isNotBlank(messageContext)) {
messages = JSONUtil.toList(messageContext, Message.class);
if (messages.size() >= 5) {
messages = messages.subList(1, 5);
}
Message currentMessage = Message.builder().content(chatRequest.getQuestion()).role(Message.Role.USER).build();
messages.add(currentMessage);
} else {
Message currentMessage = Message.builder().content(chatRequest.getQuestion()).role(Message.Role.USER).build();
messages.add(currentMessage);
}
SseEmitter sseEmitter = (SseEmitter) LocalCache.CACHE.get(uid);
if (sseEmitter == null) {
log.info("聊天消息推送失败uid:[{}],没有创建连接,请重试。", uid);
throw new BaseException("聊天消息推送失败uid:[{}],没有创建连接,请重试。~");
}
OpenAISSEEventSourceListener openAIEventSourceListener = new OpenAISSEEventSourceListener(sseEmitter);
ChatCompletion completion = ChatCompletion
.builder()
.messages(messages)
.model(ChatCompletion.Model.GPT_3_5_TURBO_0613.getName())
.build();
openAiStreamClient.streamChatCompletion(completion, openAIEventSourceListener);
LocalCache.CACHE.put("msg" + uid, JSONUtil.toJsonStr(messages), LocalCache.TIMEOUT);
ChatVo response = new ChatVo();
response.setQuestionTokens(completion.tokens());
return response;
}
}
注意:OpenAISSEEventSourceListener 可以是其它第三方平台的数据流
1、在创建sse的过程中需用到缓存自动失效工具类,请注意对应的客户端需做好定时心跳重连
2、sse客户的session 缓存可以在实战项目中用token解密后的用户id作为唯一key
3、消息的会话线可以设置合适的会话长度
4、前后端分离的接口需做好允许跨域操作
如果大佬们有什么优化建议,请在评论区帮忙留言,顺便点赞关注哦