java模拟GPT流式问答

流式请求gpt并且流式推送相关前端页面

1)java流式获取gpt答案

1、读取文件流的方式

使用post请求数据,由于gpt是eventsource的方式返回数据,所以格式是data:,需要手动替换一下值

/**
org.apache.http.client.methods
**/
@SneakyThrows
    private void chatStream(List messagesBOList) {
        CloseableHttpClient httpclient = HttpClients.createDefault();
        HttpPost httpPost = new HttpPost("https://api.openai.com/v1/chat/completions");
        httpPost.setHeader("Authorization","xxxxxxxxxxxx");
        httpPost.setHeader("Content-Type","application/json; charset=UTF-8");

        ChatParamBO build = ChatParamBO.builder()
                .temperature(0.7)
                .model("gpt-3.5-turbo")
                .messages(messagesBOList)
                .stream(true)
                .build();
        System.out.println(JsonUtils.toJson(build));
        httpPost.setEntity(new StringEntity(JsonUtils.toJson(build),"utf-8"));
        CloseableHttpResponse response = httpclient.execute(httpPost);
        try {
            HttpEntity entity = response.getEntity();
            if (entity != null) {
                InputStream inputStream = entity.getContent();
                BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream));

                String line;
                while ((line = reader.readLine()) != null) {
                    // 处理 event stream 数据

                    try {
//                        System.out.println(line);
                        ChatResultBO chatResultBO = JsonUtils.toObject(line.replace("data:", ""), ChatResultBO.class);
                        String content = chatResultBO.getChoices().get(0).getDelta().getContent();
                        log.info(content);
                        
//                        System.out.println(chatResultBO.getChoices().get(0).getMessage().getContent());
                    } catch (Exception e) {
//                        e.printStackTrace();
                    }
                }
            }
        } finally {
            response.close();
        }
    }

2、sse链接的方式获取数据

用到了okhttp

需要先引用相关maven:

        
            com.squareup.okhttp3
            okhttp
        
        
            com.squareup.okhttp3
            okhttp-sse
        
       
        // 定义see接口
        Request request = new Request.Builder().url("https://api.openai.com/v1/chat/completions")
                .header("Authorization","xxx")
                .post(okhttp3.RequestBody.create(okhttp3.MediaType.parse("application/json; charset=utf-8"),param.toJSONString()))
                .build();
        OkHttpClient okHttpClient = new OkHttpClient.Builder()
                .connectTimeout(10, TimeUnit.MINUTES)
                .readTimeout(10, TimeUnit.MINUTES)//这边需要将超时显示设置长一点,不然刚连上就断开,之前以为调用方式错误被坑了半天
                .build();

        // 实例化EventSource,注册EventSource监听器
        RealEventSource realEventSource = new RealEventSource(request, new EventSourceListener() {

            @Override
            public void onOpen(EventSource eventSource, Response response) {
                log.info("onOpen");
            }

            @SneakyThrows
            @Override
            public void onEvent(EventSource eventSource, String id, String type, String data) {
//                log.info("onEvent");
                log.info(data);//请求到的数据
            
            }

            @Override
            public void onClosed(EventSource eventSource) {
                log.info("onClosed");
//                emitter.complete();
            }

            @Override
            public void onFailure(EventSource eventSource, Throwable t, Response response) {
                log.info("onFailure,t={},response={}",t,response);//这边可以监听并重新打开
//                emitter.complete();
            }
        });
        realEventSource.connect(okHttpClient);//真正开始请求的一步

2)流式推送答案

方法一:通过订阅式SSE/WebSocket

原理是先建立链接,然后不断发消息就可以

1、websocket

创建相关配置:


import javax.websocket.Session;

import lombok.Data;

/**
 * @description WebSocket客户端连接
 */
@Data
public class WebSocketClient {

    // 与某个客户端的连接会话,需要通过它来给客户端发送数据
    private Session session;

    //连接的uri
    private String uri;

}


import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.socket.server.standard.ServerEndpointExporter;

@Configuration
public class WebSocketConfig {
    @Bean
    public ServerEndpointExporter serverEndpointExporter() {
        return new ServerEndpointExporter();
    }
}
配置相关service


@Slf4j
@Component
@ServerEndpoint("/websocket/chat/{chatId}")
public class ChatWebsocketService {

    static final ConcurrentHashMap> webSocketClientMap= new ConcurrentHashMap<>();

    private String chatId;

    /**
     * 连接建立成功时触发,绑定参数
     * @param session 与某个客户端的连接会话,需要通过它来给客户端发送数据
     * @param chatId 商户ID
     */
    @OnOpen
    public void onOpen(Session session, @PathParam("chatId") String chatId){

        WebSocketClient client = new WebSocketClient();
        client.setSession(session);
        client.setUri(session.getRequestURI().toString());

        List webSocketClientList = webSocketClientMap.get(chatId);
        if(webSocketClientList == null){
            webSocketClientList = new ArrayList<>();
        }
        webSocketClientList.add(client);
        webSocketClientMap.put(chatId, webSocketClientList);
        this.chatId = chatId;
    }

    /**
     * 收到客户端消息后调用的方法
     *
     * @param message 客户端发送过来的消息
     */
    @OnMessage
    public void onMessage(String message) {
        log.info("chatId = {},message = {}",chatId,message);
        // 回复消息
        this.chatStream(BaseUtil.newList(ChatParamMessagesBO.builder().content(message).role("user").build()));
//        this.sendMessage(chatId,message+"233");
    }

    /**
     * 连接关闭时触发,注意不能向客户端发送消息了
     * @param chatId
     */
    @OnClose
    public void onClose(@PathParam("chatId") String chatId){
        webSocketClientMap.remove(chatId);
    }

    /**
     * 通信发生错误时触发
     * @param session
     * @param error
     */
    @OnError
    public void onError(Session session, Throwable error) {
        System.out.println("发生错误");
        error.printStackTrace();
    }

    /**
     * 向客户端发送消息
     * @param chatId
     * @param message
     */
    public void sendMessage(String chatId,String message){
        try {
            List webSocketClientList = webSocketClientMap.get(chatId);
            if(webSocketClientList!=null){
                for(WebSocketClient webSocketServer:webSocketClientList){
                    webSocketServer.getSession().getBasicRemote().sendText(message);
                }
            }
        } catch (IOException e) {
            e.printStackTrace();
            throw new RuntimeException(e.getMessage());
        }
    }

    /**
     * 流式调用查询gpt
     * @param messagesBOList
     * @throws IOException
     */
    @SneakyThrows
    private void chatStream(List messagesBOList) {
       // TODO 和GPT的访问请求
    }
}
测试,postman建立链接

java模拟GPT流式问答_第1张图片

2、SSE

本质也是基于订阅推送方式

前端:




    
    SseEmitter




后端:
controller


import org.springframework.cloud.context.config.annotation.RefreshScope;
import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;

import java.util.Set;
import java.util.function.Consumer;

import javax.annotation.Resource;

import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;

@Validated
@RestController
@RequestMapping("/api/sse")
@Slf4j
@RefreshScope  // 会监听变化实时变化值
public class SseController {

    @Resource
    private SseBizService sseBizService;


    /**
     * 创建用户连接并返回 SseEmitter
     *
     * @param conversationId 用户ID
     * @return SseEmitter
     */
    @SneakyThrows
    @GetMapping(value = "/connect", produces = "text/event-stream; charset=utf-8")
    public SseEmitter connect(String conversationId) {
        // 设置超时时间,0表示不过期。默认30秒,超过时间未完成会抛出异常:AsyncRequestTimeoutException
        SseEmitter sseEmitter = new SseEmitter(0L);
        // 注册回调
        sseEmitter.onCompletion(completionCallBack(conversationId));
        sseEmitter.onError(errorCallBack(conversationId));
        sseEmitter.onTimeout(timeoutCallBack(conversationId));
        log.info("创建新的sse连接,当前用户:{}", conversationId);
        sseBizService.addConnect(conversationId,sseEmitter);
        sseBizService.sendMsg(conversationId,"链接成功");
//        sseCache.get(conversationId).send(SseEmitter.event().reconnectTime(10000).data("链接成功"),MediaType.TEXT_EVENT_STREAM);
        return sseEmitter;
    }

    /**
     * 给指定用户发送信息  -- 单播
     */
    @GetMapping(value = "/send", produces = "text/event-stream; charset=utf-8")
    public void sendMessage(String conversationId, String msg) {
        sseBizService.sendMsg(conversationId,msg);
    }

    /**
     * 移除用户连接
     */
    @GetMapping(value = "/disconnection", produces = "text/event-stream; charset=utf-8")
    public void removeUser(String conversationId) {
        log.info("移除用户:{}", conversationId);
        sseBizService.deleteConnect(conversationId);
    }

    /**
     * 向多人发布消息   -- 组播
     * @param groupId 开头标识
     * @param message 消息内容
     */
    public void groupSendMessage(String groupId, String message) {
       /* if (!BaseUtil.isNullOrEmpty(sseCache)) {
            *//*Set ids = sseEmitterMap.keySet().stream().filter(m -> m.startsWith(groupId)).collect(Collectors.toSet());
            batchSendMessage(message, ids);*//*
            sseCache.forEach((k, v) -> {
                try {
                    if (k.startsWith(groupId)) {
                        v.send(message, MediaType.APPLICATION_JSON);
                    }
                } catch (IOException e) {
                    log.error("用户[{}]推送异常:{}", k, e.getMessage());
                    removeUser(k);
                }
            });
        }*/
    }

    /**
     * 群发所有人   -- 广播
     */
    public void batchSendMessage(String message) {
        /*sseCache.forEach((k, v) -> {
            try {
                v.send(message, MediaType.APPLICATION_JSON);
            } catch (IOException e) {
                log.error("用户[{}]推送异常:{}", k, e.getMessage());
                removeUser(k);
            }
        });*/
    }

    /**
     * 群发消息
     */
    public void batchSendMessage(String message, Set ids) {
        ids.forEach(userId -> sendMessage(userId, message));
    }


    /**
     * 获取当前连接信息
     */
//    public List getIds() {
//        return new ArrayList<>(sseCache.keySet());
//    }

    /**
     * 获取当前连接数量
     */
//    public int getUserCount() {
//        return count.intValue();
//    }

    private Runnable completionCallBack(String userId) {
        return () -> {
            log.info("结束连接:{}", userId);
            removeUser(userId);
        };
    }

    private Runnable timeoutCallBack(String userId) {
        return () -> {
            log.info("连接超时:{}", userId);
            removeUser(userId);
        };
    }

    private Consumer errorCallBack(String userId) {
        return throwable -> {
            log.info("连接异常:{}", userId);
            removeUser(userId);
        };
    }
}
service


import org.springframework.cloud.context.config.annotation.RefreshScope;
import org.springframework.http.MediaType;
import org.springframework.stereotype.Component;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;

import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;

@Component
@Slf4j
@RefreshScope  // 会监听变化实时变化值
public class SseBizService {
    /**
     * 
     * 当前连接数
     */
    private AtomicInteger count = new AtomicInteger(0);

    /**
     * 使用map对象,便于根据userId来获取对应的SseEmitter,或者放redis里面
     */
    private Map sseCache = new ConcurrentHashMap<>();


    /**
     * 添加用户
     * @author pengbin 
     * @date 2023/9/11 11:37
     * @param
     * @return
     */
    public void addConnect(String id,SseEmitter sseEmitter){
        sseCache.put(id, sseEmitter);
        // 数量+1
        count.getAndIncrement();
    }
    /**
     * 删除用户
     * @author pengbin 
     * @date 2023/9/11 11:37
     * @param
     * @return
     */
    public void deleteConnect(String id){
        sseCache.remove(id);
        // 数量+1
        count.getAndDecrement();
    }

    /**
     * 发送消息
     * @author pengbin 
     * @date 2023/9/11 11:38
     * @param
     * @return
     */
    @SneakyThrows
    public void sendMsg(String id, String msg){
        if(sseCache.containsKey(id)){
            sseCache.get(id).send(msg, MediaType.TEXT_EVENT_STREAM);
        }
    }

}

方法二:SSE建立eventSource,使用完成后即刻销毁

前端:在接收到结束标识后立即销毁

/**
         * 客户端收到服务器发来的数据
         * 另一种写法:source.onmessage = function (event) {}
         */
        source.addEventListener('message', function(e) {
			//console.log(e);
            setMessageInnerHTML(e.data);
			if(e.data == '[DONE]'){
				source.close();
			}
        });

后端:
 

@SneakyThrows
    @GetMapping(value = "/stream/sse", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
    public SseEmitter completionsStream(@RequestParam String conversationId){
        //
        List messagesBOList =new ArrayList();

        // 获取内容信息
        ChatParamBO build = ChatParamBO.builder()
                .temperature(0.7)
                .stream(true)
                .model("xxxx")
                .messages(messagesBOList)
                .build();

        SseEmitter emitter = new SseEmitter();
            
        // 定义see接口
        Request request = new Request.Builder().url("xxx")
                .header("Authorization","xxxx")
                .post(okhttp3.RequestBody.create(okhttp3.MediaType.parse("application/json; charset=utf-8"),JsonUtils.toJson(build)))
                .build();
        OkHttpClient okHttpClient = new OkHttpClient.Builder()
                .connectTimeout(10, TimeUnit.MINUTES)
                .readTimeout(10, TimeUnit.MINUTES)//这边需要将超时显示设置长一点,不然刚连上就断开,之前以为调用方式错误被坑了半天
                .build();

        StringBuffer sb = new StringBuffer("");

        // 实例化EventSource,注册EventSource监听器
        RealEventSource realEventSource = null;
        realEventSource = new RealEventSource(request, new EventSourceListener() {

            @Override
            public void onOpen(EventSource eventSource, Response response) {
                log.info("onOpen");
            }

            @SneakyThrows
            @Override
            public void onEvent(EventSource eventSource, String id, String type, String data) {

                log.info(data);//请求到的数据
                try {

                    ChatResultBO chatResultBO = JsonUtils.toObject(data.replace("data:", ""), ChatResultBO.class);
                    String content = chatResultBO.getChoices().get(0).getDelta().getContent();
                    sb.append(content);
                    emitter.send(SseEmitter.event().data(JsonUtils.toJson(ChatContentBO.builder().content(content).build())));

                } catch (Exception e) {
//                        e.printStackTrace();
                }
                if("[DONE]".equals(data)){
                    emitter.send(SseEmitter.event().data(data));
                    emitter.complete();
                    log.info("result={}",sb);
                }
            }

            @Override
            public void onClosed(EventSource eventSource) {
                log.info("onClosed,eventSource={}",eventSource);//这边可以监听并重新打开
//                emitter.complete();
            }

            @Override
            public void onFailure(EventSource eventSource, Throwable t, Response response) {
                log.info("onFailure,t={},response={}",t,response);//这边可以监听并重新打开
//                emitter.complete();
            }
        });
        realEventSource.connect(okHttpClient);//真正开始请求的一步
        return emitter;
    }

你可能感兴趣的:(gpt,java,sse,eventSource,流式请求,stream)