JAVA AI SDK 项目重构有感

之前折腾了一个 java 版本的 AI SDK

仓库地址
原介绍文章-JAVA SDK 整合 AI 大语言模型

一开始只整合了 GeminiOpenAI 两个版本,上周又给增加了 Ollama 版本,写着写着就发现项目中有好多重复的代码,而且结构很不清晰,实体创建也很混乱,于是最近几天就开始着手想重构一下整个项目,为日后新增新的模型的时候,就简单多了。

重复逻辑代码

REST API 调用一般过程:

将外部的用户对话内容,先整理为厂商模型 REST API 的请求数据格式,然后通过 HTTP 的方式请求其 API 接口,最后对返回的数据进行解析组装。

例如下面这段 OpenAI 调用代码:

public OpenAiTextResponse chat(String message, MaterialData materialData, OpenAiGenerationConfig generationConfig) throws IOException {

        if (this.openaiAccount == null || this.openaiAccount.getApiKey() == null || this.openaiAccount.getApiKey().isEmpty()) {
            throw new RuntimeException("gemini api key is empty");
        }

        if (this.openaiAccount.getBaseUrl() != null && !this.openaiAccount.getBaseUrl().isEmpty()) {
            this.BASE_URL = this.openaiAccount.getBaseUrl();
        }

        OpenAiTextRequest questParams = this.buildOpenAiTextRequest(message, materialData, history);

        if (generationConfig != null) {
            questParams.setTemperature(generationConfig.getTemperature());
            questParams.setMaxTokens(generationConfig.getMaxTokens());
            questParams.setTopP(generationConfig.getTopP());
            questParams.setN(generationConfig.getN());
            questParams.setStop(generationConfig.getStop());
        }

        MediaType json = MediaType.parse("application/json; charset=utf-8");
        RequestBody requestBody = RequestBody.create(json, JSON.toJSONString(questParams));

        String url = "{base_url}/v1/chat/completions";
        url = url.replace("{base_url}", this.BASE_URL);

        Request request = new Request.Builder()
                .url(url)
                .addHeader("Authorization", "Bearer " + this.openaiAccount.getApiKey())
                .post(requestBody)
                .build();

        Response response = this.okHttpClient.newCall(request).execute();
        if (response.isSuccessful()) {
            String responseBody = response.body().string();

            OpenAiTextResponse textResponse = JSON.parseObject(responseBody, OpenAiTextResponse.class);
            textResponse.setHistory(this.buildChatHistory(message, materialData, textResponse.getChoices(), history));
            return textResponse;
        }

        return null;
    }

如果现在需要再接入 Gemini 或者 Ollama,那么上面这段代码的大部分内容估计还要写一次。

所以,这个 chat 方法其实可以拆分出几个部分:

  • 整理用户对话内容为该厂商模型的数据结构
  • 提供厂商模型 REST API 接口地址和相关授权配置
  • 对请求后的返回内容进行解析处理
  • 对历史对话记录的处理

从伪代码来看,上面的 chat 方法可以调整为以下这样:

public AiChatResponse chat(String message, MediaData mediaData, GenerationConfig generationConfig, List<ChatHistory> history) throws Exception {
	
    //整理用户对话内容为该厂商模型的数据结构
    TextRequestParams questParams = buildTextRequest(message, mediaData, generationConfig, history);
    //提供厂商模型 REST API 接口地址和相关授权配置
    Request request = buildHttpRquest(apiUrl, account, model);
    
    Response response = this.okHttpClient.newCall(request).execute();
    if (!response.isSuccessful()) {
		throw new RuntimeException("接口请求异常");
    }
    
    String responseBody = response.body().string();
    
    //对请求后的返回内容进行解析处理
    AiChatResponse chatResponse = buildChatResponse(message, mediaData, responseBody);
    //对历史对话记录的处理
    buildChatHistory(message, mediaData, responseBody);
    
    return chatResponse;
}

而这些部分,不同的厂商有着不同的实现方式,所以这种场景比较适合使用 抽象类 来处理,其他厂商去继承这个 抽象类 即可,而拆出来的部分则为 抽象方法,这样继承的子类就必须去实现它。

定义抽象类

将上面拆出来的部分定义为 抽象方法,而 chat 方法内只需要实现请求的过程。

public abstract class AiBaseClient {
    
    public AiChatResponse chat(String message, MediaData mediaData, GenerationConfig generationConfig, List<ChatHistory> history) throws IOException {
        this.stream = false;
	    //构建统一请求
        Request request = this.buildHttpRequest(message, mediaData, generationConfig, false, history);
        
        Response response = this.okHttpClient.newCall(request).execute();
        if (response.isSuccessful()) {
            String responseBody = response.body().string();
		   //通过子类处理返回的消息内容和历史记录
            return this.buildChatResponse(responseBody, message, mediaData, history);
        }

        return null;
    }
    
    //构建统一请求
    private Request buildHttpRequest(String message, MediaData mediaData, GenerationConfig generationConfig, boolean stream, List<ChatHistory> history) {
        String baseUrl = this.getDefaultBaseUrl();
        if (this.getAccount() != null && this.getAccount().getBaseUrl() != null && !this.getAccount().getBaseUrl().isEmpty()) {
            baseUrl = this.getAccount().getBaseUrl();
        }
	    //通过子类获取请求参数
        JSONObject chatRequestParams = this.buildChatRequest(message, mediaData, generationConfig, stream, history);

        MediaType json = MediaType.parse("application/json; charset=utf-8");
        RequestBody requestBody = RequestBody.create(json, JSON.toJSONString(chatRequestParams));

        String url = baseUrl + this.getApi();

        return new Request.Builder()
                .url(url)
                .addHeader("Authorization", "Bearer " + this.getAccount().getApiKey())
                .post(requestBody)
                .build();
    }
    
	/**
     * 子类构建请求参数
     *
     * @param message
     * @param mediaData
     * @param generationConfig
     * @param stream
     * @param history
     * @return
     */
    protected abstract JSONObject buildChatRequest(String message, MediaData mediaData, GenerationConfig generationConfig, boolean stream, List<ChatHistory> history);

    /**
     * 子类构建返回内容
     *
     * @param responseBody
     * @param message
     * @param mediaData
     * @param history
     * @return
     */
    protected abstract AiChatResponse buildChatResponse(String responseBody, String message, MediaData mediaData, List<ChatHistory> history);
    
    /**
     * 子类提供模型名称
     *
     * @return
     */
    protected abstract String getDefaultModelName();

    /**
     * 子类提供base url
     *
     * @return
     */
    protected abstract String getDefaultBaseUrl();

    /**
     * 子类提供api地址
     *
     * 如:/v1/chat/completions
     *
     * @return
     */
    protected abstract String getApi();
}

子类实现

所以对于 OpenAI 模型,调整后只需要继承这个 抽象类,并实现里面的抽象方法即可。

public class OpenAiClient extends AiBaseClient {

    @Override
    protected JSONObject buildChatRequest(String message, MediaData mediaData, GenerationConfig generationConfig, boolean stream, List<ChatHistory> history) {
    	//封装 OpenAI 的请求参数格式
    }
    
    @Override
    protected AiChatResponse buildChatResponse(String responseBody, String message, MediaData mediaData, List<ChatHistory> history) {
    	//对模型接口返回的 responseBody 内容进行解析,并封装为统一的 AiChatResponse 对象并返回
        
        //处理当前的对话历史记录 history,并通过 AiChatResponse 带回
    }
    
    @Override
    protected String getDefaultModelName() {

        return OpenAiModelEnum.GPT_35_TURBO.getName();
    }

    @Override
    protected String getDefaultBaseUrl() {

        return "https://api.openai.com";
    }

    @Override
    protected String getApi() {

        return "/v1/chat/completions";
    }

}

这样处理以后,后面的 Gemini 模型和 Ollama 就可以快速的实现了。
目前项目已经更新到 0.2.0 版本

<dependency>
    <groupId>org.liurb.ai.sdkgroupId>
    <artifactId>java-ai-sdkartifactId>
    <version>${version}version>
dependency>

你可能感兴趣的:(AI学习,java,人工智能,重构)