AIGC: 关于ChatGPT中基于API实现一个StreamClient流式客户端

Java版GPT的StreamClient

  • 可作为其他编程语言的参考
  • 注意: 下面包名中的 xxx 可以换成自己的
  • 代码基于java,来源于网络,可修改成其他编程语言实现
  • 参考前文: https://blog.csdn.net/Tyro_java/article/details/134748994

1 )核心代码结构设计

  • src
    • main
      • java
        • com.xxx.gpt.client
          • listener
            • AbstractStreamListener.java
            • ConsoleStreamListener.java
          • ChatGPTStreamClient.java
    • test
      • java
        • com.xxx.gpt.client.test
          • StreamClientTest.java

2 )相关程序如下

  • 前文,通过我们开发的Client能够正常的和 Open AI 进行交互,能够去调用GPT的API
  • 通过API将我们的 message 请求发送给GPT并且正常的接收到了GPT对我们的返回
  • 在前面我们去浏览 GPT 它的API的时候,我们发现它是支持流式访问的
  • 我们可以开发一个Stream的Client,能够支持流式的接收GPT的响应
  • 流式的Client在很多场景下也是非常有必要的
  • 首先需要去先创建一个listener,去流式的接收GPT的返回, 我们实现一个 AbstractStreamListener 类 和 ConsoleStreamListener 类
  • 需要继承于 EventSourceListener, 内部添加几个方法
    • onOpen
    • onEvent, 可以获取到流失的输入,这个是重点
    • onClosed
    • onFailure

AbstractStreamListener.java

package com.xxx.gpt.client.listener;

import cn.hutool.core.util.StrUtil;
import com.alibaba.fastjson.JSON;
import com.xxx.gpt.client.entity.ChatChoice;
import com.xxx.gpt.client.entity.ChatCompletionResponse;
import com.xxx.gpt.client.entity.Message;
import lombok.Getter;
import lombok.Setter;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import okhttp3.Response;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener;

import java.util.List;
import java.util.Objects;
import java.util.function.Consumer;

@Slf4j
public abstract class AbstractStreamListener extends EventSourceListener {
    protected String lastMessage = "";

    /**
     * Called when all new message are received.
     *
     * @param message the new message
     */
    @Setter
    @Getter
    protected Consumer<String> onComplate = s -> {};

    /**
     * Called when a new message is received.
     * 收到消息 单个字
     *
     * @param message the new message
     */
    public abstract void onMsg(String message);
    /**
     * Called when an error occurs.
     * 出错时调用
     *
     * @param throwable the throwable that caused the error
     * @param response  the response associated with the error, if any
     */
    public abstract void onError(Throwable throwable, String response);

    @Override
    public void onOpen(EventSource eventSource, Response response) {
        // do nothing
    }

    @Override
    public void onClosed(EventSource eventSource) {
        // do nothing
    }

    @Override
    public void onEvent(EventSource eventSource, String id, String type, String data) {
        if (data.equals("[DONE]")) {
            onComplate.accept(lastMessage);
            return;
        }
        // 将数据反序列化为 GPT的 response
        ChatCompletionResponse response = JSON.parseObject(data, ChatCompletionResponse.class);
        // 获取GPT的返回,读取Json
        List<ChatChoice> choices = response.getChoices();
        // 为空则 return
        if (choices == null || choices.isEmpty()) {
            return;
        }
        // 获取流式信息
        Message delta = choices.get(0).getDelta();
        String text = delta.getContent();
        if (text != null) {
            lastMessage += text;
            onMsg(text);
        }
    }

    @SneakyThrows
    @Override
    public void onFailure(EventSource eventSource, Throwable throwable, Response response) {
        try {
            log.error("Stream connection error: {}", throwable);
            String responseText = "";
            if (Objects.nonNull(response)) {
                responseText = response.body().string();
            }
            log.error("response:{}", responseText);
            String forbiddenText = "Your access was terminated due to violation of our policies";
            if (StrUtil.contains(responseText, forbiddenText)) {
                log.error("Chat session has been terminated due to policy violation");
                log.error("检测到号被封了");
            }
            String overloadedText = "That model is currently overloaded with other requests.";
            if (StrUtil.contains(responseText, overloadedText)) {
                log.error("检测到官方超载了,赶紧优化你的代码,做重试吧");
            }
            this.onError(throwable, responseText);
        } catch (Exception e) {
            log.warn("onFailure error:{}", e);
            // do nothing
        } finally {
            eventSource.cancel();
        }
    }
}

ConsoleStreamListener.java

package com.xxx.gpt.client.listener;

import lombok.extern.slf4j.Slf4j;

@Slf4j
public class ConsoleStreamListener extends AbstractStreamListener {
    @Override
    public void onMsg(String message) {
        System.out.print(message);
    }
    @Override
    public void onError(Throwable throwable, String response) {}
}
  • 再创建一个 ChatGPTStreamClient 类
    • 添加如下相关属性
    • 添加 init 方法
    • 完成 streamChatCompletion 方法

ChatGPTStreamClient.java

package com.xxx.gpt.client;

import cn.hutool.core.util.RandomUtil;
import cn.hutool.http.ContentType;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.xxx.gpt.client.entity.ChatCompletion;
import com.xxx.gpt.client.entity.Message;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import okhttp3.MediaType;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener;
import okhttp3.sse.EventSources;

import java.net.Proxy;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.TimeUnit;

@Slf4j
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
public class ChatGPTStreamClient {
    private String apiKey;
    private List<String> apiKeyList;
    private OkHttpClient okHttpClient;
    /**
     * 连接超时
     */
    @Builder.Default
    private long timeout = 90;

    /**
     * 网络代理
     */
    @Builder.Default
    private Proxy proxy = Proxy.NO_PROXY;
    /**
     * 反向代理
     */
    @Builder.Default
    private String apiHost = ChatApi.CHAT_GPT_API_HOST;

    /**
     * 初始化
     */
    public ChatGPTStreamClient init() {
        OkHttpClient.Builder client = new OkHttpClient.Builder();
        client.connectTimeout(timeout, TimeUnit.SECONDS);
        client.writeTimeout(timeout, TimeUnit.SECONDS);
        client.readTimeout(timeout, TimeUnit.SECONDS);
        if (Objects.nonNull(proxy)) {
            client.proxy(proxy);
        }
        okHttpClient = client.build();
        return this;
    }

    /**
     * 流式输出
     */
    public void streamChatCompletion(ChatCompletion chatCompletion,
                                     EventSourceListener eventSourceListener) {
        chatCompletion.setStream(true);
        try {
            EventSource.Factory factory = EventSources.createFactory(okHttpClient);
            ObjectMapper mapper = new ObjectMapper();
            String requestBody = mapper.writeValueAsString(chatCompletion);
            String key = apiKey;
            if (apiKeyList != null && !apiKeyList.isEmpty()) {
                key = RandomUtil.randomEle(apiKeyList);
            }
            Request request = new Request.Builder()
                    .url(apiHost + "v1/chat/completions")
                    .post(RequestBody.create(MediaType.parse(ContentType.JSON.getValue()),
                            requestBody))
                    .header("Authorization", "Bearer " + key)
                    .build();
            factory.newEventSource(request, eventSourceListener);
        } catch (Exception e) {
            log.error("请求出错:{}", e);
        }
    }

    /**
     * 流式输出
     */
    public void streamChatCompletion(List<Message> messages,
                                     EventSourceListener eventSourceListener) {
        ChatCompletion chatCompletion = ChatCompletion.builder()
                .messages(messages)
                .stream(true)
                .build();
        streamChatCompletion(chatCompletion, eventSourceListener);
    }
}

再添加一个测试方法 StreamClientTest.java

package com.xxx.gpt.client.test;

import com.xxx.gpt.client.ChatGPTStreamClient;
import com.xxx.gpt.client.entity.ChatCompletion;
import com.xxx.gpt.client.entity.Message;
import com.xxx.gpt.client.listener.ConsoleStreamListener;
import com.xxx.gpt.client.util.Proxys;
import org.junit.Before;
import org.junit.Test;

import java.net.Proxy;
import java.util.Arrays;
import java.util.concurrent.CountDownLatch;

public class StreamClientTest {
    private ChatGPTStreamClient chatGPTStreamClient;

    @Before
    public void before() {
        Proxy proxy = Proxys.http("127.0.0.1", 7890);
        chatGPTStreamClient = ChatGPTStreamClient.builder()
                .apiKey("sk-6kchadsfsfkc3aIs66ct") // 填入自己的 key
                .proxy(proxy)
                .timeout(600)
                .apiHost("https://api.openai.com/")
                .build()
                .init();
    }
    @Test
    public void chatCompletions() {
        ConsoleStreamListener listener = new ConsoleStreamListener();
        Message message = Message.of("写一段七言绝句诗");
        ChatCompletion chatCompletion = ChatCompletion.builder()
                .messages(Arrays.asList(message))
                .build();
        chatGPTStreamClient.streamChatCompletion(chatCompletion, listener);
        try {
            Thread.sleep(10000);
        } catch (InterruptedException e) {
            throw new RuntimeException(e);
        }
    }
}
  • 这样,程序基本已经完成了
  • 这里构建了一个流式访问的参数,然后去调用GPT的API实现了流式的输出
  • 可参考以上 Java 版实现, 去实现其他语言版本的 StreamClient

你可能感兴趣的:(AIGC,Java,AIGC)