Flutter+SpringBoot实现ChatGPT流实输出

Flutter+SpringBoot实现ChatGPT流式输出、上下文了连续对话

最终实现Flutter的流式输出+上下文连续对话。
Flutter+SpringBoot实现ChatGPT流实输出_第1张图片

这里就是提供一个简单版的工具类和使用案例,此处页面仅参考。

服务端

这里直接封装提供工具类,修改自己的apiKey即可使用,支持连续对话

工具类及使用

http依赖这里使用okHttp

    <dependency>
      <groupId>com.squareup.okhttp3groupId>
      <artifactId>okhttpartifactId>
      <version>4.9.3version>
    dependency>
import com.alibaba.fastjson2.JSON;
import com.squareup.okhttp.Call;
import com.squareup.okhttp.MediaType;
import com.squareup.okhttp.OkHttpClient;
import com.squareup.okhttp.Request;
import com.squareup.okhttp.RequestBody;
import com.squareup.okhttp.Response;
import com.squareup.okhttp.ResponseBody;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import vip.ailtw.common.utils.StringUtil;


import javax.annotation.PostConstruct;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.Serializable;
import java.util.List;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

@Slf4j
@Component
public class ChatGptStreamUtil {

    /**
     * 修改为自己的密钥
     */
    private final String apiKey = "xxxxxxxxxxxxxx";

    public final String gptCompletionsUrl = "https://api.openai.com/v1/chat/completions";


    private static final OkHttpClient client = new OkHttpClient();
    private static MediaType mediaType;
    private static Request.Builder requestBuilder;


    public final static Pattern contentPattern = Pattern.compile("\"content\":\"(.*?)\"}");
    /**
     * 对话符号
     */
    public final static String EVENT_DATA = "d";

    /**
     * 错误结束符号
     */
    public final static String EVENT_ERROR = "e";

    /**
     * 响应结束符号
     */
    public final static String END = "<>";


    @PostConstruct
    public void init() {
        client.setConnectTimeout(60, TimeUnit.SECONDS);
        client.setReadTimeout(60, TimeUnit.SECONDS);
        mediaType = MediaType.parse("application/json; charset=utf-8");
        requestBuilder = new Request.Builder()
                .url(gptCompletionsUrl)
                .header("Content-Type", "application/json")
                .header("Authorization", "Bearer " + apiKey);
    }


    /**
     * 流式对话
     *
     * @param talkList 上下文对话,最早的对话放在首位
     * @param callable 消费者,流式对话每次响应的内容
     */
    public GptChatResultDTO chatStream(List<ChatGptDTO> talkList, Consumer<String> callable) throws Exception {
        long start = System.currentTimeMillis();
        StringBuilder resp = new StringBuilder();
        Response response = chatStream(talkList);
        //解析对话内容
        try (ResponseBody responseBody = response.body();
             InputStream inputStream = responseBody.byteStream();
             BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(inputStream))) {
            String line;
            while ((line = bufferedReader.readLine()) != null) {
                if (!StringUtils.hasLength(line)) {
                    continue;
                }
                Matcher matcher = contentPattern.matcher(line);
                if (matcher.find()) {
                    String content = matcher.group(1);
                    resp.append(content);
                    callable.accept(content);
                }

            }
        }
        int wordSize = 0;
        for (ChatGptDTO dto : talkList) {
            String content = dto.getContent();
            wordSize += content.toCharArray().length;
        }
        wordSize += resp.toString().toCharArray().length;
        long end = System.currentTimeMillis();
        return GptChatResultDTO.builder().resContent(resp.toString()).time(end - start).wordSize(wordSize).build();
    }

    /**
     * 流式对话
     *
     * @param talkList 上下文对话
     * @return 接口请求响应
     */
    private Response chatStream(List<ChatGptDTO> talkList) throws Exception {
        ChatStreamDTO chatStreamDTO = new ChatStreamDTO(talkList);
        RequestBody bodyOk = RequestBody.create(mediaType, chatStreamDTO.toString());
        Request requestOk = requestBuilder.post(bodyOk).build();
        Call call = client.newCall(requestOk);
        Response response;
        try {
            response = call.execute();
        } catch (IOException e) {
            throw new IOException("请求时IO异常: " + e.getMessage());
        }
        if (response.isSuccessful()) {
            return response;
        }
        try (ResponseBody body = response.body()) {
            if (429 == response.code()) {
                String msg = "Open Api key 已过期,msg: " + body.string();
                log.error(msg);
            }
            throw new RuntimeException("chat api 请求异常, code: " + response.code() + "body: " + body.string());
        }
    }


    private boolean sendToClient(String event, String data, SseEmitter emitter) {
        try {
            emitter.send(SseEmitter.event().name(event).data("{" + data + "}"));
            return true;
        } catch (IOException e) {
            log.error("向客户端发送消息时出现异常", e);
        }
        return false;
    }

    /**
     * 发送事件给客户端
     */
    public boolean sendData(String data, SseEmitter emitter) {
        if (StringUtil.isBlank(data)) {
            return true;
        }
        return sendToClient(EVENT_DATA, data, emitter);
    }

    /**
     * 发送结束事件,会关闭emitter
     */
    public void sendEnd(SseEmitter emitter) {
        try {
            sendToClient(EVENT_DATA, END, emitter);
        } finally {
            emitter.complete();
        }
    }


    /**
     * 发送异常事件,会关闭emitter
     */
    public void sendError(SseEmitter emitter) {
        try {
            sendToClient(EVENT_ERROR, "我累垮了", emitter);
        } finally {
            emitter.complete();
        }
    }


    /**
     * gpt请求结果
     */
    @Data
    @NoArgsConstructor
    @AllArgsConstructor
    @Builder
    public static class GptChatResultDTO implements Serializable {
        /**
         * gpt请求返回的全部内容
         */
        private String resContent;

        /**
         * 上下文消耗的字数
         */
        private int wordSize;

        /**
         * 耗时
         */
        private long time;
    }


    /**
     * 连续对话DTO
     */
    @Data
    @Builder
    @NoArgsConstructor
    @AllArgsConstructor
    public static class ChatGptDTO implements Serializable {
        /**
         * 对话内容
         */
        private String content;

        /**
         * 角色 {@link GptRoleEnum}
         */
        private String role;
    }


    /**
     * gpt连续对话角色
     */
    @Getter
    public static enum GptRoleEnum {
        USER_ROLE("user", "用户"),
        GPT_ROLE("assistant", "ChatGPT本身"),

        /**
         * message里role为system,是为了让ChatGPT在对话过程中设定自己的行为
         * 可以理解为对话的设定,如你是谁,要什么语气、等级
         */
        SYSTEM_ROLE("system", "对话设定"),

        ;

        private final String value;
        private final String desc;

        GptRoleEnum(String value, String desc) {
            this.value = value;
            this.desc = desc;
        }
    }


    /**
     * gpt请求body
     */
    @Data
    public static class ChatStreamDTO {
        private static final String model = "gpt-3.5-turbo";
        private static final boolean stream = true;
        private List<ChatGptDTO> messages;


        public ChatStreamDTO(List<ChatGptDTO> messages) {
            this.messages = messages;
        }

        @Override
        public String toString() {
            return "{\"model\":\"" + model + "\"," +
                    "\"messages\":" + JSON.toJSONString(messages) + "," +
                    "\"stream\":" + stream + "}";
        }
    }


}

使用案例:

    public static void main(String[] args) throws Exception {
        ChatGptStreamUtil chatGptStreamUtil = new ChatGptStreamUtil();
        chatGptStreamUtil.init();

        //构建一个上下文对话情景
        List<ChatGptDTO> talkList = new ArrayList<>();
        //设定gpt
        talkList.add(ChatGptDTO.builder().content("你是chatgpt助手,能过帮助我查阅资料,编写教学报告。").role(GptRoleEnum.GPT_ROLE.getValue()).build());
        //开始提问
        talkList.add(ChatGptDTO.builder().content("请帮我写一篇小学数学加法运算教案").role(GptRoleEnum.USER_ROLE.getValue()).build());
        chatGptStreamUtil.chatStream(talkList, (respContent) -> {
            //这里是gpt每次流式返回的内容
            System.out.println("gpt返回:" + respContent);
        });
    }

SpringBoot接口

基于SpringBoot工程,提供接口,供Flutter端使用。

通过上面的工具类的使用,可以知道gpt返回给我们的内容是一段一段的,因此如果我们服务端也要提供类似的效果,提供两个思路和实现:

  • WebSocket,服务端接收gpt返回的内容时推送内容给flutter
  • 使用Http长链接,也就是 SseEmitter,这里也是采用这种方式。

代码:

@RestController
@RequestMapping("/chat")
@Slf4j
public class ChatController {
    @Autowired
    private ChatGptStreamUtil chatGptStreamUtil;
  
    @PostMapping(value = "/chatStream")
    @ApiOperation("流式对话")
    public SseEmitter chatStream() {
        SseEmitter emitter = new SseEmitter(80000L);
      
        //构建一个上下文对话情景
        List<ChatGptDTO> talkList = new ArrayList<>();
        //设定gpt
        talkList.add(ChatGptDTO.builder().content("你是chatgpt助手,能过帮助我查阅资料,编写教学报告。").role(GptRoleEnum.GPT_ROLE.getValue()).build());
        //开始提问
        talkList.add(ChatGptDTO.builder().content("请帮我写一篇小学数学加法运算教案").role(GptRoleEnum.USER_ROLE.getValue()).build());
        GptChatResultDTO gptChatResultDTO = chatGptStreamUtil.chatStream(talkList, (content) -> {
          //这里服务端接收到消息就发送给Flutter
               chatGptStreamUtil.sendData(content, emitter);
            });
        return emitter;
    }

}

Flutter端

这里使用dio作为网络请求的工具

依赖

	dio: ^5.2.1+1

工具类

import 'dart:async';
import 'dart:convert';

import 'package:dio/dio.dart';
import 'package:flutter/cupertino.dart';
import 'package:flutter/foundation.dart';
import 'package:get/get.dart' hide Response;

///http工具类
class HttpUtil {
  Dio? client;

  static HttpUtil of() {
    return HttpUtil.init();
  }

  //初始化http工具
  HttpUtil.init() {
    if (client == null) {
      var options = BaseOptions(
          baseUrl: Config.baseUrl,
          connectTimeout: const Duration(seconds: 100),
          receiveTimeout: const Duration(seconds: 100));
      client = Dio(options);
      // 请求与响应拦截器/异常拦截器
      client?.interceptors.add(OnReqResInterceptors());
    }
  }

  Future<Stream<String>?> postStream(String path,
      [Map<String, dynamic>? params]) async {
    Response<ResponseBody> rs =
    await Dio().post<ResponseBody>(Config.baseUrl + path,
        options: Options(headers: {
          "Accept": "text/event-stream",
          "Cache-Control": "no-cache"
        }, responseType: ResponseType.stream),
        data: params 
    );
    StreamTransformer<Uint8List, List<int>> unit8Transformer =
    StreamTransformer.fromHandlers(
      handleData: (data, sink) {
        sink.add(List<int>.from(data));
      },
    );
    var resp = rs.data?.stream
        .transform(unit8Transformer)
        .transform(const Utf8Decoder())
        .transform(const LineSplitter());
    return resp;
  }



/// Dio 请求与响应拦截器
class OnReqResInterceptors extends InterceptorsWrapper {
  
  Future<void> onRequest(
      RequestOptions options, RequestInterceptorHandler handler) async {
    //统一添加token
    var headers = options.headers;
    headers['Authorization'] = '请求头token';
    return super.onRequest(options, handler);
  }

  
  void onError(DioError err, ErrorInterceptorHandler handler) {
    if (err.type == DioErrorType.unknown) {
      // 网络不可用,请稍后再试
    }
    return super.onError(err, handler);
  }

  
  void onResponse(
      Response<dynamic> response, ResponseInterceptorHandler handler) {
    Response res = response;
    return super.onResponse(res, handler);
  }
}



使用

  //构建文章、流式对话
  chatStream() async {
    final stream = await HttpUtil.of().postStream("/api/chat/chatStream");
    String respContent = "";
    stream?.listen((content) {
      debugPrint(content);
      if (content != '' && content.contains("data:")) {
        //解析数据
        var start = content.indexOf("{") + 1;
        var end = content.indexOf("}");
        var substring = content.substring(start, end);
        content = substring;
        respContent += content;
        print("返回的内容:$content");
      }
    });
  }

你可能感兴趣的:(Java,flutter,spring,boot,chatgpt)