开源模型应用落地-业务整合篇(三)

一、前言

    在之前的两篇文章中,我们学习了如何构建基本的即时消息(IM)功能。今天,我们将进一步将IM模块与AI服务进行连接,实现用户提问并由模型进行回答,最后将结果展示在用户界面上。


二、术语

2.1. Spring Boot

    是一个用于快速构建基于Spring框架的Java应用程序的开源框架。它简化了Spring应用程序的初始化和配置过程,使开发人员能够更专注于业务逻辑的实现。

2.2. 读超时时间(Read Timeout)

    是指在进行网络通信时,接收数据的操作所允许的最长等待时间。当一个请求被发送到服务器,并且在规定的时间内没有收到服务器的响应数据,就会触发读超时错误。读超时时间用于控制客户端等待服务器响应的时间,以防止长时间的阻塞。

2.3. 写超时时间(Write Timeout)

    是指在进行网络通信时,发送数据的操作所允许的最长等待时间。当一个请求被发送到服务器,但在规定的时间内无法将数据完全发送完成,就会触发写超时错误。写超时时间用于控制客户端发送数据的时间,以防止长时间的阻塞。

2.4. 连接超时时间(Connection Timeout)

    是指在建立网络连接时,客户端尝试连接到服务器所允许的最长等待时间。当一个客户端尝试连接到服务器时,如果在规定的时间内无法建立连接,就会触发连接超时错误。连接超时时间用于控制客户端与服务器建立连接的时间,以防止长时间的等待。


三、前置条件

3.1. 调通最基本的WebSocket流程(参见开源模型应用落地-业务整合篇(二))

3.2. 已经部署至少单节点的AI服务


四、技术实现 

# 打通IM和AI服务之间的通道

4.1. 新增AI服务调用的公共类

import com.alibaba.fastjson.JSON;
import lombok.extern.slf4j.Slf4j;
import okhttp3.MediaType;
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.Response;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

import java.io.ByteArrayOutputStream;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.util.Objects;

@Slf4j
@Component
public class AIChatUtils {
    @Autowired
    private AIConfig aiConfig;

    private Request buildRequest(Long userId, String prompt) throws Exception {
        //创建一个请求体对象(body)
        MediaType mediaType = MediaType.parse("application/json");
        RequestBody requestBody = RequestBody.create(mediaType, prompt);

        return buildHeader(userId, new Request.Builder().post(requestBody))
                .url(aiConfig.getUrl()).build();
    }

    private Request.Builder buildHeader(Long userId, Request.Builder builder) throws Exception {
        return builder
                .addHeader("Content-Type", "application/json")
                .addHeader("userId", String.valueOf(userId))
                .addHeader("secret",generateSecret(userId))

    }



    /**
     * 生成请求密钥
     *
     * @param userId 用户ID
     * @return
     */
    private String generateSecret(Long userId) throws Exception {
        String key = aiConfig.getServerKey();
        String content = key + userId + key;

        MessageDigest digest = MessageDigest.getInstance("SHA-256");
        byte[] hash = digest.digest(content.getBytes(StandardCharsets.UTF_8));

        StringBuilder hexString = new StringBuilder();
        for (byte b : hash) {
            String hex = Integer.toHexString(0xff & b);
            if (hex.length() == 1) {
                hexString.append('0');
            }
            hexString.append(hex);
        }
        return hexString.toString();
    }

    public String chatStream(ApiReqMessage apiReqMessage) throws Exception {

        //定义请求的参数
        String prompt = JSON.toJSONString(AIChatReqVO.init(apiReqMessage.getContents(), apiReqMessage.getHistory()));
        log.info("【AIChatUtils】调用AI聊天,用户({}),prompt:{}",  apiReqMessage.getUserId(), prompt);

        //创建一个请求对象
        Request request = buildRequest(apiReqMessage.getUserId(), prompt);

        InputStream is = null;
        try {

            // 从线程池获取http请求并执行
            Response response =OkHttpUtils.getInstance(aiConfig).getOkHttpClient().newCall(request).execute();

            // 响应结果
            StringBuffer resultBuff = new StringBuffer();
            //正常返回
            if (response.code() == 200) {
                //打印返回的字符数据
                is = response.body().byteStream();
                byte[] bytes = new byte[1024];

                int len = is.read(bytes);
                while (len != -1) {
                    ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
                    outputStream.write(bytes, 0, len);
                    outputStream.flush();
                    // 本轮读取到的数据
                    String result = new String(outputStream.toByteArray(), StandardCharsets.UTF_8);
                    resultBuff.append(result);

                    len = is.read(bytes);

                    // 将数据逐个传输给用户
                    AbstractBusinessLogicHandler.pushChatMessageForUser(apiReqMessage.getUserId(), result);
                }

                // 正常响应
                return resultBuff.toString();
            }
            else {
                String result = response.body().string();
                log.warn("处理异常,异常描述:{}",result);
            }
        } catch (Throwable e) {
            log.error("【AIChatUtils】消息({})调用AI聊天 chatStream 异常,异常消息:{}", apiReqMessage.getMessageId(), e.getMessage(), e);

        } finally {
            if (!Objects.isNull(is)) {
                try {
                    is.close();
                } catch (Exception e) {
                    e.printStackTrace();
                }
            }
        }
        return null;
    }


}

4.2. 新增OkHttp调用的公共类

import lombok.Getter;
import okhttp3.ConnectionPool;
import okhttp3.OkHttpClient;
import java.util.concurrent.TimeUnit;

/**
 * http线程池工具类
 **/
public class OkHttpUtils {

    private static OkHttpUtils okHttpUtils ;
    @Getter
    private OkHttpClient okHttpClient;

    public OkHttpUtils(AIConfig aiConfig){
         this.okHttpClient = new OkHttpClient.Builder().readTimeout(aiConfig.getReadTimeout(), TimeUnit.SECONDS)
                .connectTimeout(aiConfig.getConnectionTimeout(), TimeUnit.SECONDS)
                .writeTimeout(aiConfig.getWriteTimeout(), TimeUnit.SECONDS)
                .connectionPool(new ConnectionPool(aiConfig.getKeepAliveConnections(), aiConfig.getKeepAliveDuration(), TimeUnit.SECONDS))
                .build();
    }

    public static OkHttpUtils getInstance(AIConfig aiConfig){
        if (null == okHttpUtils){
            synchronized (OkHttpUtils.class){
                if (null == okHttpUtils){
                    return new OkHttpUtils(aiConfig);
                }
            }
        }
        return okHttpUtils;
    }

}

4.3. 修改第二篇定义好的具体业务处理类

import com.alibaba.fastjson.JSON;
import io.netty.channel.ChannelHandler;
import lombok.extern.slf4j.Slf4j;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;


/**
 * @Description: 处理消息的handler
 */
@Slf4j
@ChannelHandler.Sharable
@Component
public class BusinessHandler extends AbstractBusinessLogicHandler {
    @Autowired
    private AIChatUtils aiChatUtils;

    @Override
    public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
        String channelId = ctx.channel().id().asShortText();
        log.info("add client,channelId:{}", channelId);
    }

    @Override
    public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
        String channelId = ctx.channel().id().asShortText();
        log.info("remove client,channelId:{}", channelId);
    }


    @Override
    protected void channelRead0(ChannelHandlerContext channelHandlerContext, TextWebSocketFrame textWebSocketFrame)
            throws Exception {
        // 获取客户端传输过来的消息
        String content = textWebSocketFrame.text();
        log.info("接收到客户端发送的信息: {}",content);

        Long userIdForReq;
        String msgType = "";
        String contents = "";

        try {
            ApiReqMessage apiReqMessage = JSON.parseObject(content, ApiReqMessage.class);
            msgType = apiReqMessage.getMsgType();
            contents = apiReqMessage.getContents();

            userIdForReq = apiReqMessage.getUserId();

            // 添加用户
            if(!isExists(userIdForReq)){
                addChannel(channelHandlerContext, userIdForReq);
            }


            log.info("用户标识: {}, 消息类型: {}, 消息内容: {}",userIdForReq,msgType,contents);

            if(StringUtils.equals(msgType,String.valueOf(MsgType.CHAT.getCode()))){
//                ApiRespMessage apiRespMessage = ApiRespMessage.builder().code(String.valueOf(StatusCode.SUCCESS.getCode()))
//                        .respTime(String.valueOf(System.currentTimeMillis()))
//                        .contents("测试通过,很高兴收到你的信息")
//                        .msgType(String.valueOf(MsgType.CHAT.getCode()))
//                        .build();
//                String response = JSON.toJSONString(apiRespMessage);
//                channelHandlerContext.writeAndFlush(new TextWebSocketFrame(response));

                aiChatUtils.chatStream(apiReqMessage);

            }else{
                log.info("用户标识: {}, 消息类型有误,不支持类型: {}",userIdForReq,msgType);
            }


        } catch (Exception e) {
            log.warn("【BusinessHandler】接收到请求内容:{},异常信息:{}", content, e.getMessage(), e);
            // 异常返回
            return;
        }

    }

}

PS:

1. 原继承SimpleChannelInboundHandler,现在继承自定义的AbstractBusinessLogicHandler

2. 用户连接上WebSocketServer之后,需要保存用户与channel之间的关系。此处采用userId(全局唯一)关联channel。具体参见:AbstractBusinessLogicHandler

4.4. 新增AbstractBusinessLogicHandler

import com.alibaba.fastjson.JSON;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.group.ChannelGroup;
import io.netty.channel.group.DefaultChannelGroup;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import io.netty.util.AttributeKey;
import io.netty.util.concurrent.GlobalEventExecutor;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.DisposableBean;
import java.util.concurrent.ConcurrentHashMap;


@SuppressWarnings("all")
@Slf4j
public abstract class AbstractBusinessLogicHandler extends SimpleChannelInboundHandler implements DisposableBean {


    protected static final ConcurrentHashMap USER_ID_TO_CHANNEL = new ConcurrentHashMap<>();


    /**
     * 添加socket通道
     *
     * @param channelHandlerContext socket通道上下文
     */
    protected void addChannel(ChannelHandlerContext channelHandlerContext, Long userId) {
        // 将当前通道存放起来
        USER_ID_TO_CHANNEL.put(userId, channelHandlerContext);
    }

    /**
     * 判斷用戶是否存在
     * @param userId
     * @return
     */
    protected boolean isExists(Long userId){
        return USER_ID_TO_CHANNEL.containsKey(userId);
    }

    protected static void buildResponse(ChannelHandlerContext channelHandlerContext, int code, long respTime, int msgType, String msg) {
        buildResponse(channelHandlerContext, ApiRespMessage.builder().code(String.valueOf(code))
                .respTime(String.valueOf(respTime))
                .msgType(String.valueOf(msgType))
                .contents(msg).build());
    }

    protected static void buildResponseIncludeOperateId(ChannelHandlerContext channelHandlerContext, int code, long respTime, int msgType, String msg, String operateId) {
        buildResponse(channelHandlerContext, ApiRespMessage.builder().code(String.valueOf(code))
                .respTime(String.valueOf(respTime))
                .msgType(String.valueOf(msgType))
                .operateId(operateId)
                .contents(msg).build());
    }


    protected static void buildResponse(ChannelHandlerContext channelHandlerContext, ApiRespMessage apiRespMessage) {
        String response = JSON.toJSONString(apiRespMessage);
        channelHandlerContext.writeAndFlush(new TextWebSocketFrame(response));
    }


    @Override
    public void destroy() throws Exception {
        try {
            USER_ID_TO_CHANNEL.clear();
        } catch (Throwable e) {

        }
    }


    public static void pushChatMessageForUser(Long userId,String chatRespMessage) {
        ChannelHandlerContext channelHandlerContext = USER_ID_TO_CHANNEL.get(userId);

        if (channelHandlerContext != null ) {
            buildResponse(channelHandlerContext, ApiRespMessage.builder().code("200")
                    .respTime(String.valueOf(System.currentTimeMillis()))
                    .msgType(String.valueOf(MsgType.CHAT.getCode()))
                    .contents(chatRespMessage)
                    .build());
            return;
        }
    }

}

4.5. AI配置类

import lombok.Getter;
import lombok.Setter;
import lombok.ToString;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.stereotype.Component;

@ConfigurationProperties(prefix="ai.server")
@Component("aiConfig")
@Setter
@Getter
@ToString
public class AIConfig {
    private String url;
    private Integer connectionTimeout;
    private Integer writeTimeout;
    private Integer readTimeout;
    private String serverKey;
    private Integer keepAliveConnections;
    private Integer keepAliveDuration;
}

4.6. AI配置类对应的具体配置

ai:
  server:
    url: http://127.0.0.1:9999/api/chat
    connection_timeout: 3
    write_timeout: 30
    read_timeout: 30
    server_key: 88888888
    keep_alive_connections: 30
    keep_alive_duration: 60

PS:

1. 需要根据实际情况修改url和server_key

4.7.Netty配置类

package com.zwzt.communication.config;

import lombok.Getter;
import lombok.Setter;
import lombok.ToString;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.stereotype.Component;

@ConfigurationProperties(prefix="ws.server")
@Component
@Setter
@Getter
@ToString
public class NettyConfig {
    private String path;
    private int port;
    private int backlog;
    private int bossThread;
    private int workThread;
    private int businessThread;
    private int idleRead;
    private int idleWrite;
    private int idleAll;
    private int aggregator;
}

4.8.Netty配置类对应的具体配置

ws:
  server:
    path: /ws
    port: 7778
    backlog: 1024
    boss_thread: 1
    work_thread: 8
    business_thread: 16
    idle_read: 30
    idle_write: 30
    idle_all: 60
    aggregator: 65536

4.9.VO类

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;

import java.util.List;

@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
public class AIChatReqVO {

    // 问题
    private String prompt;

    // 对话历史
    private List history;

    // AI模型参数
    private Double top_p;
    private Double temperature;
    private Double repetition_penalty;
    private Long max_new_tokens;

    public static AIChatReqVO init(String prompt, List history) {
        return AIChatReqVO.builder()
                .prompt(prompt)
                .history(history)
                .top_p(0.9)
                .temperature(0.45)
                .repetition_penalty(1.1)
                .max_new_tokens(8192L)
                .build();
    }


}

4.10.实体类

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;

@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
public class ChatContext {
    // 发送者
    private String from;
    // 发送数据
    private String value;
}

#  将Netty集成进SpringBoot项目

4.11.新增SpringBoot启动类

package com.zwzt.communication;

import com.zwzt.communication.netty.server.Server;
import com.zwzt.communication.utils.SpringContextUtils;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.DisposableBean;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.context.ApplicationListener;
import org.springframework.context.event.ContextRefreshedEvent;

import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;


@SpringBootApplication
@Slf4j
public class Application implements ApplicationListener , ApplicationContextAware , DisposableBean {

    public static void main(String[] args) {
        SpringApplication.run(Application.class, args);
    }

    @Override
    public void onApplicationEvent(ContextRefreshedEvent contextRefreshedEvent) {
        if (contextRefreshedEvent.getApplicationContext().getParent() == null) {
            try {
                //启动websocket服务
                new Thread(){
                    @Override
                    public void run() {
                        Server.getInstance().start();
                    }
                }.start();
            } catch (Exception e) {
                log.error("webSocket server startup exception!",e);
                System.exit(-1);
            }


        }
    }

    @Override
    public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
        SpringContextUtils.setApplicationContext(applicationContext);
    }

    @Override
    public void destroy() throws Exception {
        try{
            Server.getInstance().close();
        }catch(Throwable e){

        }

    
    }
}

4.12.SpringBoot对应的配置

application.yml

server:
  port: 7777
  tomcat:
    uri-encoding: UTF-8
spring:
  application:
    name: ai_business_project
  main:
    banner-mode: "off"
  profiles:
    active: ai-dev
# 日志配置
logging:
  config: classpath:logback-spring.xml

application-ai-dev.yml

ai:
  server:
    url: http://127.0.0.1:9999/api/chat
    connection_timeout: 3
    write_timeout: 30
    read_timeout: 30
    server_key: 88888888
    keep_alive_connections: 30
    keep_alive_duration: 60
ws:
  server:
    path: /ws
    port: 7778
    backlog: 1024
    boss_thread: 1
    work_thread: 8
    business_thread: 16
    idle_read: 30
    idle_write: 30
    idle_all: 60
    aggregator: 65536

4.13.Spring上下文公共类

import org.springframework.context.ApplicationContext;

public class SpringContextUtils {
    private static ApplicationContext applicationContext;

    public static void setApplicationContext(ApplicationContext applicationContext){
        SpringContextUtils.applicationContext = applicationContext;
    }

    public static ApplicationContext getApplicationContext(){
        return applicationContext;
    }
}

4.14. 启动服务,执行Application类

开源模型应用落地-业务整合篇(三)_第1张图片

启动成功后SpringBoot监听7777端口,WebSocket监听7778端口


五、测试

#沿用上一篇的代码,不需要调整

6.1. 页面测试

开源模型应用落地-业务整合篇(三)_第2张图片

6.2. 在线测试

开源模型应用落地-业务整合篇(三)_第3张图片

到此我们已经成功调通了整个IM与AI服务交互的链路


六、附带说明

6.1. 上面的代码还有很多需要改进的地方,尤其是没有考虑到一些非功能性需求。我们的主要目标是确保整个程序能够顺利运行,然后逐步进行改进和完善。

6.2.关于搭建Spring Boot项目,网上已经有很多成熟的案例可供参考。由于内容过长,这里就不再详细介绍了。

你可能感兴趣的:(应用落地,深度学习)