springboot+mybatis-plus
String url = "https://api.openai.com/v1/chat/completions";
HashMap<String, Object> bodymap = new HashMap<>();
bodymap.put("model", "gpt-3.5-turbo");
bodymap.put("temperature", 0.7);
// bodymap.put("stream",true);
bodymap.put("messages", messagelist);
bodymap.put("stream", true);
Gson gson = new Gson();
String s = gson.toJson(bodymap);
// System.out.println(s);
URL url1 = new URL(url);
HttpURLConnection conn = (HttpURLConnection) url1.openConnection(new Proxy(Proxy.Type.HTTP, new InetSocketAddress(host, port)));
conn.setRequestMethod("POST");
conn.setRequestProperty("Authorization", "Bearer " + ApiKey);
conn.setRequestProperty("Content-Type", "application/json");
conn.setRequestProperty("stream", "true");
conn.setDoOutput(true);
// 写入请求参数
OutputStream os = conn.getOutputStream();
BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(os, Charset.forName("UTF-8")));
writer.write(s);
writer.close();
os.close();
读取返回值
InputStream inputStream = conn.getInputStream();
BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(inputStream));
String line = null;
// System.out.println("开始回答");
StringBuffer answoer = new StringBuffer();
while ((line = bufferedReader.readLine()) != null) {
line = line.replace("data:", "");
JsonElement jsonElement = JsonParser.parseString(line);
if (!jsonElement.isJsonObject()) {
continue;
}
JsonObject asJsonObject = jsonElement.getAsJsonObject();
JsonArray choices = asJsonObject.get("choices").getAsJsonArray();
if (choices.size() > 0) {
JsonObject choice = choices.get(0).getAsJsonObject();
JsonObject delta = choice.get("delta").getAsJsonObject();
if (delta != null) {
// System.out.println(delta);
if (delta.has("content")) {
// 发送消息
String content = delta.get("content").getAsString();
BaseResponse<String> success = ResultUtils.success(content);
WebSocket webSocket = new WebSocket();
webSocket.sendMessageByUserId(conversionid, gson.toJson(success));
answoer.append(content);
// webSocket.sendOneMessage(userid, success);
// webSocket.sendOneMessage(userid, success);
// 打印在控制台中
System.out.print(content);
}
}
}
}
String context = answoer.toString();
// 将chatgpt返回的结果保存到数据库中
Chat entity = new Chat();
entity.setContext(context);
entity.setRole("assistant");
entity.setConversionid(conversionid);
boolean save = chatService.save(entity);
// String s1 = stringRedisTemplate.opsForValue().get("web:" + userid);
// List json = (List) gson.fromJson(s1, new TypeToken>() {
// }.getType());
// ChatModel chatModel = new ChatModel("assistant",answoer.toString());
// json.add(chatModel);
// stringRedisTemplate.opsForValue().set("web:" + userid,gson.toJson(json),1, TimeUnit.DAYS);
}
@ServerEndpoint(value = "/websocket/{ConversionId}")
@Component
public class WebSocket {
private static ChatGptUntil chatGptUntil;
private static ChatService chatService;
private static ConversionService conversionService;
@Resource
public void setConversionService(ConversionService conversionService) {
WebSocket.conversionService = conversionService;
}
@Resource
public void setChatService(ChatService chatService) {
WebSocket.chatService = chatService;
}
@Resource
public void setChatGptUntil(ChatGptUntil chatGptUntil) {
WebSocket.chatGptUntil = chatGptUntil;
}
private final static Logger logger = LogManager.getLogger(WebSocket.class);
/**
* 静态变量,用来记录当前在线连接数。应该把它设计成线程安全的
*/
private static int onlineCount = 0;
/**
* concurrent包的线程安全Map,用来存放每个客户端对应的MyWebSocket对象
*/
private static ConcurrentHashMap<String, WebSocket> webSocketMap = new ConcurrentHashMap<>();
/**
* 与某个客户端的连接会话,需要通过它来给客户端发送数据
*/
private Session session;
private Long ConversionId;
/**
* 连接建立成功调用的方法
*/
@OnOpen
public void onOpen(Session session, @PathParam("ConversionId") Long ConversionId) {
this.session = session;
this.ConversionId = ConversionId;
//加入map
webSocketMap.put(ConversionId.toString(), this);
addOnlineCount(); //在线数加1
logger.info("对话{}连接成功,当前在线人数为{}", ConversionId, getOnlineCount());
try {
sendMessage(String.valueOf(this.session.getQueryString()));
} catch (IOException e) {
logger.error("IO异常");
}
}
/**
* 连接关闭调用的方法
*/
@OnClose
public void onClose() {
//从map中删除
webSocketMap.remove(ConversionId.toString());
subOnlineCount(); //在线数减1
logger.info("对话{}关闭连接!当前在线人数为{}", ConversionId, getOnlineCount());
}
/**
* 收到客户端消息后调用的方法
*
* @param message 客户端发送过来的消息
*/
@OnMessage
public void onMessage(String message, Session session) throws IOException {
logger.info("来自客户端对话:{} 消息:{}", ConversionId, message);
Gson gson = new Gson();
// ChatMessage chatMessage = gson.fromJson(message, ChatMessage.class);
System.out.println(message);
// Long conversionid = chatMessage.getConversionid();
// if (conversionid == null) {
// BaseResponse baseResponse = ResultUtils.error(4000, "请指明是哪个对话");
// String s = gson.toJson(baseResponse);
// session.getBasicRemote().sendText(s);
// }
if (message == null) {
BaseResponse baseResponse = ResultUtils.error(4000, "请指明是该对话的用途");
String s = gson.toJson(baseResponse);
session.getBasicRemote().sendText(s);
}
// 将对话保存到数据库中
Chat entity = new Chat();
entity.setContext(message);
entity.setConversionid(this.ConversionId);
entity.setRole("user");
boolean save = chatService.save(entity);
if (!save) {
BaseResponse baseResponse = ResultUtils.error(500, "数据库出现错误");
String s = gson.toJson(baseResponse);
session.getBasicRemote().sendText(s);
}
// 查询出身份
Conversion byId = conversionService.getById(this.ConversionId);
String instructions = byId.getInstructions();// 指令
// 给予chatgot身份
ArrayList<ChatModel> chatModels = new ArrayList<>();
// ChatModel scene = new ChatModel("user", instructions);
// chatModels.add(scene);
LambdaQueryWrapper<Chat> queryWrapper = new LambdaQueryWrapper<>();
// 按照修改时间进行升序排序
queryWrapper.eq(Chat::getConversionid, byId.getId()).orderByDesc(Chat::getUpdatedtime);
List<Chat> list = chatService.list(queryWrapper);
// 查询之前的对话记录
List<ChatModel> collect = list.stream().map(chat -> {
ChatModel chatModel = new ChatModel();
chatModel.setRole(chat.getRole());
chatModel.setContent(chat.getContext());
// BeanUtils.copyProperties(chat, chatModel);
return chatModel;
}).collect(Collectors.toList());
chatModels.addAll(collect);
chatGptUntil.getRespost(this.ConversionId, chatModels);
// if (chatGptUntil==null){
// System.out.println("chatuntil是空");
// }
//
// if (stringRedisTemplate==null){
// System.out.println("缓存是空");
// }
//群发消息
/*for (String item : webSocketMap.keySet()) {
try {
webSocketMap.get(item).sendMessage(message);
} catch (IOException e) {
e.printStackTrace();
}
}*/
}
/**
* 发生错误时调用
*
* @OnError
*/
@OnError
public void onError(Session session, Throwable error) {
logger.error("对话错误:" + this.ConversionId + ",原因:" + error.getMessage());
error.printStackTrace();
}
/**
* 向客户端发送消息
*/
public void sendMessage(String message) throws IOException {
this.session.getBasicRemote().sendText(message);
//this.session.getAsyncRemote().sendText(message);
}
/**
* 通过userId向客户端发送消息
*/
public void sendMessageByUserId(Long ConversionId, String message) throws IOException {
logger.info("服务端发送消息到{},消息:{}", ConversionId, message);
if (StrUtil.isNotBlank(ConversionId.toString()) && webSocketMap.containsKey(ConversionId.toString())) {
webSocketMap.get(ConversionId.toString()).sendMessage(message);
} else {
logger.error("{}不在线", ConversionId);
}
}
/**
* 群发自定义消息
*/
public static void sendInfo(String message) {
for (String item : webSocketMap.keySet()) {
try {
webSocketMap.get(item).sendMessage(message);
} catch (IOException e) {
continue;
}
}
}
public static synchronized int getOnlineCount() {
return onlineCount;
}
public static synchronized void addOnlineCount() {
WebSocket.onlineCount++;
}
public static synchronized void subOnlineCount() {
WebSocket.onlineCount--;
}
}
https://gitee.com/li-manxiang/chatgptservice.git