开源模型应用落地-业务优化篇(三)

一、前言

    假如您跟随我的脚步,学习到上一篇的内容,到这里,相信细心的您,已经发现了,在上一篇中遗留的问题。那就是IM服务过载的时候,如何进行水平扩容?

    因为在每个IM服务中,我们用JVM缓存了用户与WS的通道的绑定关系,并且使用Redis队列进行解耦。那扩展了IM服务实例之后,如何确保Redis队列的消息能正常消费,即如何能找回对应的用户通道?别着急,接下来,我将给您做详细的解释。


二、术语

2.1.水平扩容

    是指通过增加系统中的资源实例数量来提高系统的处理能力和吞吐量。在计算机领域,水平扩容通常用于应对系统负载的增加或需要处理更多请求的情况。

2.2.无状态

    无状态(stateless)是指系统或组件在处理请求时不依赖于之前的请求或会话信息。换句话说,每个请求都是独立的,系统不会在不同的请求之间保存任何状态或上下文信息。

    在无状态系统中,每个请求被视为一个独立的事件,系统只关注当前请求所包含的信息和参数,而不依赖于之前的请求历史。这使得系统更加简单、可伸缩和易于管理。


三、前置条件

3.1. 已经完成前两篇的学习


四、技术实现

4.1、实现思路

    首先,IM服务是状态的(AI服务是无状态),每个实例中,会缓存用户与WebSocket通道之间的信息。那是否可以采用中间共享存储的方式,将状态信息保存至Redis或外部存储中?答案是:不行。WebScoket的通道信息,无法进行序列化。

    要实现IM服务水平扩容的方式有多种,但目前我们采用以下的方案:

开源模型应用落地-业务优化篇(三)_第1张图片

  1.   每个IM服务保存对应的用户和WS通道的关系;
  2.   每个IM服务对应唯一一个redis队列;
  3.   前置的SLB(App入口)能根据用户标识(哈希)将请求转发至指定的IM服务;
  4.   当某一IM服务出现故障的时候,由App端发起重连,重新建立WebSocket连接。

4.2、调整配置文件

# 每个IM服务实例指定全局唯一的ID,例如:下面指定的node:0

ws:
  server:
    node: 0

PS:具体参数可以在外部指定,作为JVM的运行参数传入

4.3、调整业务逻辑处理类

# 将原有Redis的单一队列名,改为拼接上全局唯一ID的方式

# Redis中缓存的数据如下

开源模型应用落地-业务优化篇(三)_第2张图片

4.4、调整任务处理类

# 将原有Redis的单一队列名,改为拼接上全局唯一ID的方式


五、测试

# 这次换一下测试方式,用离线页面的方式进行测试

5.1.  建立连接

开源模型应用落地-业务优化篇(三)_第3张图片

5.2.  业务初始化

5.3.  业务对话

开源模型应用落地-业务优化篇(三)_第4张图片


六、附带说明

6.1. BusinessHandler完整代码

import com.alibaba.fastjson.JSON;
import io.netty.channel.ChannelHandler;
import lombok.extern.slf4j.Slf4j;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import org.apache.commons.lang3.StringUtils;
import org.redisson.api.RLock;
import org.redisson.api.RedissonClient;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import java.util.concurrent.TimeUnit;


/**
 * @Description: 处理消息的handler
 */
@Slf4j
@ChannelHandler.Sharable
@Component
public class BusinessHandler extends AbstractBusinessLogicHandler {
    public static final String LINE_UP_QUEUE_NAME = "AI-REQ-QUEUE";
    private static final String LINE_UP_LOCK_NAME = "AI-REQ-LOCK";

    private static final int MAX_QUEUE_SIZE = 100;

//    @Autowired
//    private TaskUtils taskExecuteUtils;

    @Autowired
    private RedisUtils redisUtils;
    @Autowired
    private RedissonClient redissonClient;
    @Autowired
    private NettyConfig nettyConfig;


    @Override
    public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
        String channelId = ctx.channel().id().asShortText();
        log.info("add client,channelId:{}", channelId);
    }

    @Override
    public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
        String channelId = ctx.channel().id().asShortText();
        log.info("remove client,channelId:{}", channelId);
    }


    @Override
    protected void channelRead0(ChannelHandlerContext channelHandlerContext, TextWebSocketFrame textWebSocketFrame)
            throws Exception {
        // 获取客户端传输过来的消息
        String content = textWebSocketFrame.text();
        // 兼容在线测试
        if (StringUtils.equals(content, "PING")) {
            buildResponse(channelHandlerContext, ApiRespMessage.builder().code(String.valueOf(StatusCode.SUCCESS.getCode()))
                    .respTime(String.valueOf(System.currentTimeMillis()))
                    .msgType(String.valueOf(MsgType.HEARTBEAT.getCode()))
                    .contents("心跳测试,很高兴收到你的心跳包")
                    .build());

            return;
        }
        log.info("接收到客户端发送的信息: {}", content);

        Long userIdForReq;
        String msgType = "";
        String contents = "";

        try {
            ApiReqMessage apiReqMessage = JSON.parseObject(content, ApiReqMessage.class);
            msgType = apiReqMessage.getMsgType();
            contents = apiReqMessage.getContents();


            userIdForReq = apiReqMessage.getUserId();
            // 用户身份标识校验
            if (null == userIdForReq || (long) userIdForReq <= 10000) {
                ApiRespMessage apiRespMessage = ApiRespMessage.builder().code(String.valueOf(StatusCode.SYSTEM_ERROR.getCode()))
                        .respTime(String.valueOf(System.currentTimeMillis()))
                        .contents("用户身份标识有误!")
                        .msgType(String.valueOf(MsgType.SYSTEM.getCode()))
                        .build();
                buildResponseAndClose(channelHandlerContext, apiRespMessage);
                return;
            }


            if (StringUtils.equals(msgType, String.valueOf(MsgType.CHAT.getCode()))) {
                // 对用户输入的内容进行自定义违规词检测
                // 对用户输入的内容进行第三方在线违规词检测
                // 对用户输入的内容进行组装成Prompt
                // 对Prompt根据业务进行增强(完善prompt的内容)
                // 对history进行裁剪或总结(检测history是否操作模型支持的上下文长度,例如qwen-7b支持的上下文长度为8192)
                // ...

//                通过线程池来处理
//                String messageId = apiReqMessage.getMessageId();
//                List history = apiReqMessage.getHistory();
//                AITaskReqMessage aiTaskReqMessage = AITaskReqMessage.builder().messageId(messageId).userId(userIdForReq).contents(contents).history(history).build();
//                taskExecuteUtils.execute(aiTaskReqMessage);

//                通过队列来缓冲
                boolean flag = true;

                RLock lock = redissonClient.getLock(LINE_UP_LOCK_NAME);
                String queueName = LINE_UP_QUEUE_NAME+"-"+nettyConfig.getNode();

                //尝试获取锁,最多等待3秒,锁的自动释放时间为10秒
                if (lock.tryLock(3, 10, TimeUnit.SECONDS)) {
                    try {
                        if (redisUtils.queueSize(queueName) < MAX_QUEUE_SIZE) {
                            redisUtils.queueAdd(queueName, content);
                            log.info("当前线程为:{}, 添加请求至redis队列",Thread.currentThread().getName());
                        } else {
                            flag = false;
                        }
                    } catch (Throwable e) {
                        log.error("系统处理异常", e);
                    } finally {
                        lock.unlock();
                    }
                } else {
                    flag = false;
                }

                if (!flag) {
                    buildResponse(channelHandlerContext, ApiRespMessage.builder().code(String.valueOf(StatusCode.SUCCESS.getCode()))
                            .respTime(String.valueOf(System.currentTimeMillis()))
                            .msgType(String.valueOf(MsgType.SYSTEM.getCode()))
                            .contents("当前排队人数较多,请稍后再重试!")
                            .build());
                }


            } else if (StringUtils.equals(msgType, String.valueOf(MsgType.INIT.getCode()))) {
                //一、业务黑名单检测(多次违规,永久锁定)

                //二、账户锁定检测(临时锁定)

                //三、多设备登录检测

                //四、剩余对话次数检测

                //检测通过,绑定用户与channel之间关系
                addChannel(channelHandlerContext, userIdForReq);
                String respMessage = "用户标识: " + userIdForReq + " 登录成功";

                buildResponse(channelHandlerContext, ApiRespMessage.builder().code(String.valueOf(StatusCode.SUCCESS.getCode()))
                        .respTime(String.valueOf(System.currentTimeMillis()))
                        .msgType(String.valueOf(MsgType.INIT.getCode()))
                        .contents(respMessage)
                        .build());

            } else if (StringUtils.equals(msgType, String.valueOf(MsgType.HEARTBEAT.getCode()))) {

                buildResponse(channelHandlerContext, ApiRespMessage.builder().code(String.valueOf(StatusCode.SUCCESS.getCode()))
                        .respTime(String.valueOf(System.currentTimeMillis()))
                        .msgType(String.valueOf(MsgType.HEARTBEAT.getCode()))
                        .contents("心跳测试,很高兴收到你的心跳包")
                        .build());
            } else {
                log.info("用户标识: {}, 消息类型有误,不支持类型: {}", userIdForReq, msgType);
            }


        } catch (Exception e) {
            log.warn("【BusinessHandler】接收到请求内容:{},异常信息:{}", content, e.getMessage(), e);
            // 异常返回
            return;
        }

    }


}

6.2. TaskUtils完整代码

import com.alibaba.fastjson.JSON;
import io.netty.channel.ChannelHandlerContext;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.ApplicationArguments;
import org.springframework.boot.ApplicationRunner;
import org.springframework.stereotype.Component;

import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;

@Component
@Slf4j
public class TaskUtils implements ApplicationRunner {
    private static ExecutorService executorService = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors() * 2);

    @Autowired
    private AIChatUtils aiChatUtils;

    @Autowired
    private RedisUtils redisUtils;

    @Autowired
    private NettyConfig nettyConfig;

    @Override
    public void run(ApplicationArguments args) throws Exception {
        while(true){
            String queueName = BusinessHandler.LINE_UP_QUEUE_NAME+"-"+nettyConfig.getNode();
//             执行定时任务的逻辑
            String content = redisUtils.queuePoll(queueName);

            if(StringUtils.isNotEmpty(content) && StringUtils.isNoneBlank(content)){
                try{
                    ApiReqMessage apiReqMessage = JSON.parseObject(content, ApiReqMessage.class);
                    String messageId = apiReqMessage.getMessageId();
                    String contents = apiReqMessage.getContents();
                    Long userIdForReq = apiReqMessage.getUserId();
                    List history = apiReqMessage.getHistory();
                    AITaskReqMessage aiTaskReqMessage = AITaskReqMessage.builder().messageId(messageId).userId(userIdForReq).contents(contents).history(history).build();
                    execute(aiTaskReqMessage);
                }catch (Throwable e){
                    log.error("处理消息出现异常",e);

                    //将请求再次返回去队列
                    //将请求丢弃
                    //其他处理?
                }

            }else{
                TimeUnit.SECONDS.sleep(1);
            }

        }

    }

    public void execute(AITaskReqMessage aiTaskReqMessage) {

        executorService.execute(() -> {
            Long userId = aiTaskReqMessage.getUserId();

            if (null == userId || (long) userId < 10000) {
                log.warn("用户身份标识有误!");
                return;
            }

            ChannelHandlerContext channelHandlerContext = BusinessHandler.getContextByUserId(userId);

            if (channelHandlerContext != null) {
                try {
                    aiChatUtils.chatStream(aiTaskReqMessage);

                } catch (Throwable exception) {
                    exception.printStackTrace();
                }
            }
        });
    }

    public static void destory() {
        executorService.shutdownNow();
        executorService = null;
    }


}

你可能感兴趣的:(开源大语言模型-实际应用落地,深度学习)