流式请求gpt并且流式推送相关前端页面
使用post请求数据,由于gpt是eventsource的方式返回数据,所以格式是data:,需要手动替换一下值
/**
org.apache.http.client.methods
**/
@SneakyThrows
private void chatStream(List messagesBOList) {
CloseableHttpClient httpclient = HttpClients.createDefault();
HttpPost httpPost = new HttpPost("https://api.openai.com/v1/chat/completions");
httpPost.setHeader("Authorization","xxxxxxxxxxxx");
httpPost.setHeader("Content-Type","application/json; charset=UTF-8");
ChatParamBO build = ChatParamBO.builder()
.temperature(0.7)
.model("gpt-3.5-turbo")
.messages(messagesBOList)
.stream(true)
.build();
System.out.println(JsonUtils.toJson(build));
httpPost.setEntity(new StringEntity(JsonUtils.toJson(build),"utf-8"));
CloseableHttpResponse response = httpclient.execute(httpPost);
try {
HttpEntity entity = response.getEntity();
if (entity != null) {
InputStream inputStream = entity.getContent();
BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream));
String line;
while ((line = reader.readLine()) != null) {
// 处理 event stream 数据
try {
// System.out.println(line);
ChatResultBO chatResultBO = JsonUtils.toObject(line.replace("data:", ""), ChatResultBO.class);
String content = chatResultBO.getChoices().get(0).getDelta().getContent();
log.info(content);
// System.out.println(chatResultBO.getChoices().get(0).getMessage().getContent());
} catch (Exception e) {
// e.printStackTrace();
}
}
}
} finally {
response.close();
}
}
用到了okhttp
需要先引用相关maven:
com.squareup.okhttp3
okhttp
com.squareup.okhttp3
okhttp-sse
// 定义see接口
Request request = new Request.Builder().url("https://api.openai.com/v1/chat/completions")
.header("Authorization","xxx")
.post(okhttp3.RequestBody.create(okhttp3.MediaType.parse("application/json; charset=utf-8"),param.toJSONString()))
.build();
OkHttpClient okHttpClient = new OkHttpClient.Builder()
.connectTimeout(10, TimeUnit.MINUTES)
.readTimeout(10, TimeUnit.MINUTES)//这边需要将超时显示设置长一点,不然刚连上就断开,之前以为调用方式错误被坑了半天
.build();
// 实例化EventSource,注册EventSource监听器
RealEventSource realEventSource = new RealEventSource(request, new EventSourceListener() {
@Override
public void onOpen(EventSource eventSource, Response response) {
log.info("onOpen");
}
@SneakyThrows
@Override
public void onEvent(EventSource eventSource, String id, String type, String data) {
// log.info("onEvent");
log.info(data);//请求到的数据
}
@Override
public void onClosed(EventSource eventSource) {
log.info("onClosed");
// emitter.complete();
}
@Override
public void onFailure(EventSource eventSource, Throwable t, Response response) {
log.info("onFailure,t={},response={}",t,response);//这边可以监听并重新打开
// emitter.complete();
}
});
realEventSource.connect(okHttpClient);//真正开始请求的一步
原理是先建立链接,然后不断发消息就可以
import javax.websocket.Session;
import lombok.Data;
/**
* @description WebSocket客户端连接
*/
@Data
public class WebSocketClient {
// 与某个客户端的连接会话,需要通过它来给客户端发送数据
private Session session;
//连接的uri
private String uri;
}
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.socket.server.standard.ServerEndpointExporter;
@Configuration
public class WebSocketConfig {
@Bean
public ServerEndpointExporter serverEndpointExporter() {
return new ServerEndpointExporter();
}
}
@Slf4j
@Component
@ServerEndpoint("/websocket/chat/{chatId}")
public class ChatWebsocketService {
static final ConcurrentHashMap> webSocketClientMap= new ConcurrentHashMap<>();
private String chatId;
/**
* 连接建立成功时触发,绑定参数
* @param session 与某个客户端的连接会话,需要通过它来给客户端发送数据
* @param chatId 商户ID
*/
@OnOpen
public void onOpen(Session session, @PathParam("chatId") String chatId){
WebSocketClient client = new WebSocketClient();
client.setSession(session);
client.setUri(session.getRequestURI().toString());
List webSocketClientList = webSocketClientMap.get(chatId);
if(webSocketClientList == null){
webSocketClientList = new ArrayList<>();
}
webSocketClientList.add(client);
webSocketClientMap.put(chatId, webSocketClientList);
this.chatId = chatId;
}
/**
* 收到客户端消息后调用的方法
*
* @param message 客户端发送过来的消息
*/
@OnMessage
public void onMessage(String message) {
log.info("chatId = {},message = {}",chatId,message);
// 回复消息
this.chatStream(BaseUtil.newList(ChatParamMessagesBO.builder().content(message).role("user").build()));
// this.sendMessage(chatId,message+"233");
}
/**
* 连接关闭时触发,注意不能向客户端发送消息了
* @param chatId
*/
@OnClose
public void onClose(@PathParam("chatId") String chatId){
webSocketClientMap.remove(chatId);
}
/**
* 通信发生错误时触发
* @param session
* @param error
*/
@OnError
public void onError(Session session, Throwable error) {
System.out.println("发生错误");
error.printStackTrace();
}
/**
* 向客户端发送消息
* @param chatId
* @param message
*/
public void sendMessage(String chatId,String message){
try {
List webSocketClientList = webSocketClientMap.get(chatId);
if(webSocketClientList!=null){
for(WebSocketClient webSocketServer:webSocketClientList){
webSocketServer.getSession().getBasicRemote().sendText(message);
}
}
} catch (IOException e) {
e.printStackTrace();
throw new RuntimeException(e.getMessage());
}
}
/**
* 流式调用查询gpt
* @param messagesBOList
* @throws IOException
*/
@SneakyThrows
private void chatStream(List messagesBOList) {
// TODO 和GPT的访问请求
}
}
本质也是基于订阅推送方式
SseEmitter
import org.springframework.cloud.context.config.annotation.RefreshScope;
import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.util.Set;
import java.util.function.Consumer;
import javax.annotation.Resource;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
@Validated
@RestController
@RequestMapping("/api/sse")
@Slf4j
@RefreshScope // 会监听变化实时变化值
public class SseController {
@Resource
private SseBizService sseBizService;
/**
* 创建用户连接并返回 SseEmitter
*
* @param conversationId 用户ID
* @return SseEmitter
*/
@SneakyThrows
@GetMapping(value = "/connect", produces = "text/event-stream; charset=utf-8")
public SseEmitter connect(String conversationId) {
// 设置超时时间,0表示不过期。默认30秒,超过时间未完成会抛出异常:AsyncRequestTimeoutException
SseEmitter sseEmitter = new SseEmitter(0L);
// 注册回调
sseEmitter.onCompletion(completionCallBack(conversationId));
sseEmitter.onError(errorCallBack(conversationId));
sseEmitter.onTimeout(timeoutCallBack(conversationId));
log.info("创建新的sse连接,当前用户:{}", conversationId);
sseBizService.addConnect(conversationId,sseEmitter);
sseBizService.sendMsg(conversationId,"链接成功");
// sseCache.get(conversationId).send(SseEmitter.event().reconnectTime(10000).data("链接成功"),MediaType.TEXT_EVENT_STREAM);
return sseEmitter;
}
/**
* 给指定用户发送信息 -- 单播
*/
@GetMapping(value = "/send", produces = "text/event-stream; charset=utf-8")
public void sendMessage(String conversationId, String msg) {
sseBizService.sendMsg(conversationId,msg);
}
/**
* 移除用户连接
*/
@GetMapping(value = "/disconnection", produces = "text/event-stream; charset=utf-8")
public void removeUser(String conversationId) {
log.info("移除用户:{}", conversationId);
sseBizService.deleteConnect(conversationId);
}
/**
* 向多人发布消息 -- 组播
* @param groupId 开头标识
* @param message 消息内容
*/
public void groupSendMessage(String groupId, String message) {
/* if (!BaseUtil.isNullOrEmpty(sseCache)) {
*//*Set ids = sseEmitterMap.keySet().stream().filter(m -> m.startsWith(groupId)).collect(Collectors.toSet());
batchSendMessage(message, ids);*//*
sseCache.forEach((k, v) -> {
try {
if (k.startsWith(groupId)) {
v.send(message, MediaType.APPLICATION_JSON);
}
} catch (IOException e) {
log.error("用户[{}]推送异常:{}", k, e.getMessage());
removeUser(k);
}
});
}*/
}
/**
* 群发所有人 -- 广播
*/
public void batchSendMessage(String message) {
/*sseCache.forEach((k, v) -> {
try {
v.send(message, MediaType.APPLICATION_JSON);
} catch (IOException e) {
log.error("用户[{}]推送异常:{}", k, e.getMessage());
removeUser(k);
}
});*/
}
/**
* 群发消息
*/
public void batchSendMessage(String message, Set ids) {
ids.forEach(userId -> sendMessage(userId, message));
}
/**
* 获取当前连接信息
*/
// public List getIds() {
// return new ArrayList<>(sseCache.keySet());
// }
/**
* 获取当前连接数量
*/
// public int getUserCount() {
// return count.intValue();
// }
private Runnable completionCallBack(String userId) {
return () -> {
log.info("结束连接:{}", userId);
removeUser(userId);
};
}
private Runnable timeoutCallBack(String userId) {
return () -> {
log.info("连接超时:{}", userId);
removeUser(userId);
};
}
private Consumer errorCallBack(String userId) {
return throwable -> {
log.info("连接异常:{}", userId);
removeUser(userId);
};
}
}
import org.springframework.cloud.context.config.annotation.RefreshScope;
import org.springframework.http.MediaType;
import org.springframework.stereotype.Component;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
@Component
@Slf4j
@RefreshScope // 会监听变化实时变化值
public class SseBizService {
/**
*
* 当前连接数
*/
private AtomicInteger count = new AtomicInteger(0);
/**
* 使用map对象,便于根据userId来获取对应的SseEmitter,或者放redis里面
*/
private Map sseCache = new ConcurrentHashMap<>();
/**
* 添加用户
* @author pengbin
* @date 2023/9/11 11:37
* @param
* @return
*/
public void addConnect(String id,SseEmitter sseEmitter){
sseCache.put(id, sseEmitter);
// 数量+1
count.getAndIncrement();
}
/**
* 删除用户
* @author pengbin
* @date 2023/9/11 11:37
* @param
* @return
*/
public void deleteConnect(String id){
sseCache.remove(id);
// 数量+1
count.getAndDecrement();
}
/**
* 发送消息
* @author pengbin
* @date 2023/9/11 11:38
* @param
* @return
*/
@SneakyThrows
public void sendMsg(String id, String msg){
if(sseCache.containsKey(id)){
sseCache.get(id).send(msg, MediaType.TEXT_EVENT_STREAM);
}
}
}
/**
* 客户端收到服务器发来的数据
* 另一种写法:source.onmessage = function (event) {}
*/
source.addEventListener('message', function(e) {
//console.log(e);
setMessageInnerHTML(e.data);
if(e.data == '[DONE]'){
source.close();
}
});
后端:
@SneakyThrows
@GetMapping(value = "/stream/sse", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
public SseEmitter completionsStream(@RequestParam String conversationId){
//
List messagesBOList =new ArrayList();
// 获取内容信息
ChatParamBO build = ChatParamBO.builder()
.temperature(0.7)
.stream(true)
.model("xxxx")
.messages(messagesBOList)
.build();
SseEmitter emitter = new SseEmitter();
// 定义see接口
Request request = new Request.Builder().url("xxx")
.header("Authorization","xxxx")
.post(okhttp3.RequestBody.create(okhttp3.MediaType.parse("application/json; charset=utf-8"),JsonUtils.toJson(build)))
.build();
OkHttpClient okHttpClient = new OkHttpClient.Builder()
.connectTimeout(10, TimeUnit.MINUTES)
.readTimeout(10, TimeUnit.MINUTES)//这边需要将超时显示设置长一点,不然刚连上就断开,之前以为调用方式错误被坑了半天
.build();
StringBuffer sb = new StringBuffer("");
// 实例化EventSource,注册EventSource监听器
RealEventSource realEventSource = null;
realEventSource = new RealEventSource(request, new EventSourceListener() {
@Override
public void onOpen(EventSource eventSource, Response response) {
log.info("onOpen");
}
@SneakyThrows
@Override
public void onEvent(EventSource eventSource, String id, String type, String data) {
log.info(data);//请求到的数据
try {
ChatResultBO chatResultBO = JsonUtils.toObject(data.replace("data:", ""), ChatResultBO.class);
String content = chatResultBO.getChoices().get(0).getDelta().getContent();
sb.append(content);
emitter.send(SseEmitter.event().data(JsonUtils.toJson(ChatContentBO.builder().content(content).build())));
} catch (Exception e) {
// e.printStackTrace();
}
if("[DONE]".equals(data)){
emitter.send(SseEmitter.event().data(data));
emitter.complete();
log.info("result={}",sb);
}
}
@Override
public void onClosed(EventSource eventSource) {
log.info("onClosed,eventSource={}",eventSource);//这边可以监听并重新打开
// emitter.complete();
}
@Override
public void onFailure(EventSource eventSource, Throwable t, Response response) {
log.info("onFailure,t={},response={}",t,response);//这边可以监听并重新打开
// emitter.complete();
}
});
realEventSource.connect(okHttpClient);//真正开始请求的一步
return emitter;
}