1 )核心代码结构设计
2 )相关程序如下
AbstractStreamListener.java
package com.xxx.gpt.client.listener;
import cn.hutool.core.util.StrUtil;
import com.alibaba.fastjson.JSON;
import com.xxx.gpt.client.entity.ChatChoice;
import com.xxx.gpt.client.entity.ChatCompletionResponse;
import com.xxx.gpt.client.entity.Message;
import lombok.Getter;
import lombok.Setter;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import okhttp3.Response;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener;
import java.util.List;
import java.util.Objects;
import java.util.function.Consumer;
@Slf4j
public abstract class AbstractStreamListener extends EventSourceListener {
protected String lastMessage = "";
/**
* Called when all new message are received.
*
* @param message the new message
*/
@Setter
@Getter
protected Consumer<String> onComplate = s -> {};
/**
* Called when a new message is received.
* 收到消息 单个字
*
* @param message the new message
*/
public abstract void onMsg(String message);
/**
* Called when an error occurs.
* 出错时调用
*
* @param throwable the throwable that caused the error
* @param response the response associated with the error, if any
*/
public abstract void onError(Throwable throwable, String response);
@Override
public void onOpen(EventSource eventSource, Response response) {
// do nothing
}
@Override
public void onClosed(EventSource eventSource) {
// do nothing
}
@Override
public void onEvent(EventSource eventSource, String id, String type, String data) {
if (data.equals("[DONE]")) {
onComplate.accept(lastMessage);
return;
}
// 将数据反序列化为 GPT的 response
ChatCompletionResponse response = JSON.parseObject(data, ChatCompletionResponse.class);
// 获取GPT的返回,读取Json
List<ChatChoice> choices = response.getChoices();
// 为空则 return
if (choices == null || choices.isEmpty()) {
return;
}
// 获取流式信息
Message delta = choices.get(0).getDelta();
String text = delta.getContent();
if (text != null) {
lastMessage += text;
onMsg(text);
}
}
@SneakyThrows
@Override
public void onFailure(EventSource eventSource, Throwable throwable, Response response) {
try {
log.error("Stream connection error: {}", throwable);
String responseText = "";
if (Objects.nonNull(response)) {
responseText = response.body().string();
}
log.error("response:{}", responseText);
String forbiddenText = "Your access was terminated due to violation of our policies";
if (StrUtil.contains(responseText, forbiddenText)) {
log.error("Chat session has been terminated due to policy violation");
log.error("检测到号被封了");
}
String overloadedText = "That model is currently overloaded with other requests.";
if (StrUtil.contains(responseText, overloadedText)) {
log.error("检测到官方超载了,赶紧优化你的代码,做重试吧");
}
this.onError(throwable, responseText);
} catch (Exception e) {
log.warn("onFailure error:{}", e);
// do nothing
} finally {
eventSource.cancel();
}
}
}
ConsoleStreamListener.java
package com.xxx.gpt.client.listener;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class ConsoleStreamListener extends AbstractStreamListener {
@Override
public void onMsg(String message) {
System.out.print(message);
}
@Override
public void onError(Throwable throwable, String response) {}
}
ChatGPTStreamClient.java
package com.xxx.gpt.client;
import cn.hutool.core.util.RandomUtil;
import cn.hutool.http.ContentType;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.xxx.gpt.client.entity.ChatCompletion;
import com.xxx.gpt.client.entity.Message;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import okhttp3.MediaType;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener;
import okhttp3.sse.EventSources;
import java.net.Proxy;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.TimeUnit;
@Slf4j
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
public class ChatGPTStreamClient {
private String apiKey;
private List<String> apiKeyList;
private OkHttpClient okHttpClient;
/**
* 连接超时
*/
@Builder.Default
private long timeout = 90;
/**
* 网络代理
*/
@Builder.Default
private Proxy proxy = Proxy.NO_PROXY;
/**
* 反向代理
*/
@Builder.Default
private String apiHost = ChatApi.CHAT_GPT_API_HOST;
/**
* 初始化
*/
public ChatGPTStreamClient init() {
OkHttpClient.Builder client = new OkHttpClient.Builder();
client.connectTimeout(timeout, TimeUnit.SECONDS);
client.writeTimeout(timeout, TimeUnit.SECONDS);
client.readTimeout(timeout, TimeUnit.SECONDS);
if (Objects.nonNull(proxy)) {
client.proxy(proxy);
}
okHttpClient = client.build();
return this;
}
/**
* 流式输出
*/
public void streamChatCompletion(ChatCompletion chatCompletion,
EventSourceListener eventSourceListener) {
chatCompletion.setStream(true);
try {
EventSource.Factory factory = EventSources.createFactory(okHttpClient);
ObjectMapper mapper = new ObjectMapper();
String requestBody = mapper.writeValueAsString(chatCompletion);
String key = apiKey;
if (apiKeyList != null && !apiKeyList.isEmpty()) {
key = RandomUtil.randomEle(apiKeyList);
}
Request request = new Request.Builder()
.url(apiHost + "v1/chat/completions")
.post(RequestBody.create(MediaType.parse(ContentType.JSON.getValue()),
requestBody))
.header("Authorization", "Bearer " + key)
.build();
factory.newEventSource(request, eventSourceListener);
} catch (Exception e) {
log.error("请求出错:{}", e);
}
}
/**
* 流式输出
*/
public void streamChatCompletion(List<Message> messages,
EventSourceListener eventSourceListener) {
ChatCompletion chatCompletion = ChatCompletion.builder()
.messages(messages)
.stream(true)
.build();
streamChatCompletion(chatCompletion, eventSourceListener);
}
}
再添加一个测试方法 StreamClientTest.java
package com.xxx.gpt.client.test;
import com.xxx.gpt.client.ChatGPTStreamClient;
import com.xxx.gpt.client.entity.ChatCompletion;
import com.xxx.gpt.client.entity.Message;
import com.xxx.gpt.client.listener.ConsoleStreamListener;
import com.xxx.gpt.client.util.Proxys;
import org.junit.Before;
import org.junit.Test;
import java.net.Proxy;
import java.util.Arrays;
import java.util.concurrent.CountDownLatch;
public class StreamClientTest {
private ChatGPTStreamClient chatGPTStreamClient;
@Before
public void before() {
Proxy proxy = Proxys.http("127.0.0.1", 7890);
chatGPTStreamClient = ChatGPTStreamClient.builder()
.apiKey("sk-6kchadsfsfkc3aIs66ct") // 填入自己的 key
.proxy(proxy)
.timeout(600)
.apiHost("https://api.openai.com/")
.build()
.init();
}
@Test
public void chatCompletions() {
ConsoleStreamListener listener = new ConsoleStreamListener();
Message message = Message.of("写一段七言绝句诗");
ChatCompletion chatCompletion = ChatCompletion.builder()
.messages(Arrays.asList(message))
.build();
chatGPTStreamClient.streamChatCompletion(chatCompletion, listener);
try {
Thread.sleep(10000);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
}