分布式集群下WebSocket Session共享解决方案

接上一篇 SpringBoot集成WebSocket进行消息主动推送
分布式集群下WebSocket Session共享解决方案

在实现中需要解决的类变量有两个

private static AtomicInteger online = new AtomicInteger();

private static Map<String, Session> sessionPools = new ConcurrentHashMap<>();

其中online可以用Redis实现存储

Session无法采用Redis进行存储, 因为不能对Session进行序列化

由于session无法实现序列化,不能存储到redis这些中间存储里面,因此这里我们只能把session存储在本地的内存中,那么如果是集群的话,我们如何实现session准确的发送消息呢,其实就是session共享。在websocket中,其实是无法做到session共享的

目前通用的解决方案都是通过消息中间件,实现消息的发布与订阅

向前端推送消息时, 向消息队列中发布消息, 所有的后端服务器订阅该消息, 所有的后端服务器收到消息消费消息时, 都执行推送消息, 本地有Session的即可推送成功(前端只可能跟某一个后端建立Session)

具体实现

1. 引入Redis

此处忽略Redis 的配置及依赖, 直接上封装的服务实现类

IRedisDao

import java.util.List;
import java.util.Map;

/**
 * @Author:
 * @Date:2023/7/3 11:24
 * @Des: IRedisDao
 */
public interface IRedisDao {

    /**
     * 设置增量
     *
     * @param key   键名
     * @param value 增量值
     * @return 增加后的值
     */
    long incryBy(String key, long value);

    /**
     * 添加字符串并设置过期时间
     *
     * @param key        键名
     * @param value      字符串值
     * @param expireTime 过期时间,单位:分钟
     */
    void setString(String key, String value, int expireTime);

    /**
     * 添加字符串
     *
     * @param key   键名
     * @param value 字符串值
     */
    void setString(String key, String value);

    /**
     * 获取字符串值
     *
     * @param key 键名
     * @return 字符串值
     */
    String getString(String key);

    /**
     * 删除字符串
     *
     * @param key 键名
     */
    void delString(String key);

    /**
     * 加锁
     *
     * @param key            键名
     * @param expiredSeconds 过期时间,单位:秒
     * @param lockFlag       锁标志
     *                       锁标志的作用主要有两个方面:
     *                       唯一性检查:在使用connection.set()方法设置锁时,通过指定SET_IF_ABSENT选项,只有当该键在Redis中不存在时才会设置成功。这样可以确保只有一个实体或线程能够成功获取到锁,其他的尝试会被拒绝。
     *                       释放锁时的验证:在释放锁时,可以通过比对锁标志的值来验证是否是持有锁的实体或线程进行释放操作。只有当锁标志匹配时才执行释放操作,以防止其他实体或线程错误释放锁。
     *                       通过使用锁标志,可以实现简单的分布式锁机制,用于控制并发访问共享资源的情况,确保同一时间只有一个实体或线程能够访问该资源。
     * @return true:加锁成功,false:已经加锁
     */
    boolean addLock(String key, int expiredSeconds, String lockFlag);

    /**
     * 释放锁
     *
     * @param key      键名
     * @param lockFlag 锁标志
     * @return true:释放成功,false:释放失败
     */
    boolean releaseLock(String key, String lockFlag);

    /**
     * 向Set集合中添加元素
     *
     * @param key    键名
     * @param member 元素值
     * @return 添加成功的数量
     */
    long sAdd(String key, String member);

    /**
     * 从Set集合中移除元素
     *
     * @param key    键名
     * @param member 元素值
     */
    void sRemove(String key, String member);

    /**
     * 批量从Set集合中移除元素
     *
     * @param key    键名
     * @param member 元素值
     */
    void sRemoveBatch(String key, Object... member);

    /**
     * 判断元素是否存在于Set集合中
     *
     * @param key    键名
     * @param member 元素值
     * @return true:存在,false:不存在
     */
    boolean sIsMember(String key, String member);

    /**
     * 设置键的过期时间
     *
     * @param key            键名
     * @param expiredSeconds 过期时间,单位:秒
     */
    void expire(String key, int expiredSeconds);

    /**
     * 将Map中的键值对保存到Redis的Hash结构中
     *
     * @param key  Hash结构的键名
     * @param data Map类型的数据
     * @return true:保存成功,false:保存失败
     */
    void hmset(String key, Map<String, String> data);

    /**
     * 存hash结构中的键值对
     *
     * @param redisKey Hash结构的键名
     * @param hashKey
     * @param value
     */
    void hput(String redisKey, String hashKey, String value);

    /**
     * hash结构删除键值对
     * @param redisKey
     * @param hashKey
     */
    void hdel(String redisKey, String hashKey);


    String hget(String redisKey, String hashKey);

    /**
     * 判断键是否存在
     *
     * @param key 键名
     * @return true:存在,false:不存在
     */
    boolean hasKey(String key);

    /**
     * 将指定的field和value添加到Redis哈希结构中的key中,仅当该field在哈希结构中不存在时才执行添加操作。
     *
     * @param key   键名
     * @param field 字段名
     * @param value 字段值
     * @return true:添加成功,false:字段已存在,添加失败
     */
    public boolean putIfAbsentDB(String key, String field, String value);

    /**
     * 将值从左侧压入列表
     *
     * @param key   键名
     * @param value 值
     * @return 列表的长度
     */
    public Long lPushDB(String key, String value);

    /**
     * 返回列表中指定范围的元素
     *
     * @param key   键名
     * @param start 起始索引
     * @param end   结束索引
     * @return 指定范围内的元素列表
     */
    public List<String> rRangeDB(String key, long start, long end);

    /**
     * 修剪列表,只保留指定范围内的元素
     *
     * @param key   键名
     * @param start 起始索引
     * @param end   结束索引
     */
    public void lTrimDB(String key, long start, long end);

    /**
     * 获取列表的长度
     *
     * @param key 键名
     * @return 列表的长度
     */
    public Long lSizeDB(String key);
}
RedisDaoImpl

import com.sinotrans.gtp.exception.AppCodeMsg;
import com.sinotrans.gtp.exception.AppException;
import io.lettuce.core.api.sync.RedisCommands;
import lombok.extern.slf4j.Slf4j;
import org.osgi.framework.ServiceException;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.connection.RedisStringCommands;
import org.springframework.data.redis.core.RedisCallback;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.data.redis.core.types.Expiration;
import org.springframework.stereotype.Repository;

import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;

/**
 * @Author:
 * @Date:2023/7/3 11:25
 * @Des: RedisDaoImpl Redis操作封装类
 */
@Repository("redisDao")
@Slf4j
public class RedisDaoImpl implements IRedisDao {

    private static final int SECOND = 60;

    private static final int MINUTE = SECOND * 60;

    private static final int HOUR = MINUTE * 60;

    private static final int DAY = HOUR * 24;

    private static final int WEEK = DAY * 7;


    @Autowired
    private RedisTemplate<String, String> redisTemplate;

    @Override
    public long incryBy(String key, long increment) {
        Long value = redisTemplate.opsForValue().increment(key, increment);
        return value == null ? 0L : value.longValue();
    }

    @Override
    public void setString(String key, String value, int expireTime) {
        redisTemplate.opsForValue()
                .set(key, value, expireTime, TimeUnit.MINUTES);
    }

    @Override
    public void setString(String key, String value) {
        redisTemplate.opsForValue()
                .set(key, value);
    }

    @Override
    public String getString(String key) {
        String value = redisTemplate.opsForValue().get(key);
        return value;
    }

    @Override
    public void delString(String key) {
        redisTemplate.delete(key);
    }

    @Override
    public boolean addLock(String key, int expiredSeconds, String lockFlag) {

        return (Boolean) redisTemplate.execute((RedisCallback) connection -> {
            Boolean set = connection.set(key.getBytes(StandardCharsets.UTF_8), lockFlag.getBytes(StandardCharsets.UTF_8), Expiration.seconds(expiredSeconds), RedisStringCommands.SetOption.SET_IF_ABSENT);
            if (set == null) {
                return false;
            }
            return set;
        });

    }

    @Override
    public boolean releaseLock(String key, String lockFlag) {
        DefaultRedisScript<Boolean> releaseScript = new DefaultRedisScript<>(
                "if redis.call('GET', KEYS[1]) == ARGV[1] then " +
                        "   return redis.call('DEL', KEYS[1]) " +
                        "else " +
                        "   return 0 " +
                        "end",
                Boolean.class
        );
        List<String> keys = Collections.singletonList(key);
        Boolean release = redisTemplate.execute(releaseScript, keys, lockFlag);
        return release != null && release;
    }


    @Override
    public long sAdd(String key, String member) {
        Long rs = redisTemplate.opsForSet().add(key, member);
        long endTime = System.currentTimeMillis();
        return rs == null ? 0L : rs.longValue();
    }

    @Override
    public void sRemove(String key, String member) {
        redisTemplate.opsForSet().remove(key, member);
        long endTime = System.currentTimeMillis();
    }

    @Override
    public void sRemoveBatch(String key, Object... member) {
        redisTemplate.opsForSet().remove(key, member);
        long endTime = System.currentTimeMillis();
    }

    @Override
    public boolean sIsMember(String key, String member) {
        boolean isMember = redisTemplate.opsForSet().isMember(key, member);
        return isMember;
    }

    @Override
    public void expire(String key, int expiredSeconds) {
        redisTemplate.expire(key, expiredSeconds, TimeUnit.SECONDS);
    }

    @Override
    public void hmset(String redisKey, Map<String, String> data) {
        redisTemplate.opsForHash().putAll(redisKey, data);
//            redisTemplate.expire(redisKey, 52 * WEEK, TimeUnit.SECONDS);
    }

    @Override
    public void hput(String redisKey, String hashKey, String value) {
        redisTemplate.opsForHash().put(redisKey, hashKey, value);
    }

    @Override
    public void hdel(String redisKey, String hashKey) {
        redisTemplate.opsForHash().delete(redisKey, hashKey);
    }

    public String hget(String redisKey, String hashKey) {
        return (String) redisTemplate.opsForHash().get(redisKey, hashKey);
    }

    @Override
    public boolean hasKey(String key) {
        return Boolean.TRUE.equals(redisTemplate.hasKey(key));
    }


    /**
     * hash putIfAbsent
     *
     * @param key
     * @param field
     * @param value
     * @return
     */
    public boolean putIfAbsentDB(String key, String field, String value) {
        Boolean result = redisTemplate.opsForHash().putIfAbsent(key, field, value);
        return result;
    }

    public Long lPushDB(String key, String value) {
        Long result = redisTemplate.opsForList().leftPush(key, value);
        return result;
    }

    public List<String> rRangeDB(String key, long start, long end) {
        List<String> list = redisTemplate.opsForList().range(key, start, end);
        return list;
    }

    public void lTrimDB(String key, long start, long end) {
        redisTemplate.opsForList().trim(key, start, end);
        long endTime = System.currentTimeMillis();
        return;
    }

    public Long lSizeDB(String key) {
        Long size = redisTemplate.opsForList().size(key);
        return size;
    }


}

2. 改写在线人数的实现

该处人数统计并不是最终解决方案

此处的解决是, 将建立的客户端标识存储至Redis的String数据结构中, 用固定的前缀拼接

设置过期时间1天, 最终保持一致性

通过key通配符的查询方式获取人数

最终解决方案需要心跳机制(此处暂未实现)

通过后端定时任务去推送一段文本随意一段即可,

存储至Redis 的前端客户端标识中带有当前后端的host+port, 并设置10分钟超时时间

后端定时任务(5分钟一跑)业务逻辑中从Redis拿到属于当前后端的前端客户端标识, 去一一发送心跳, 判断是否发送成功, 成功则继续将该标识续命为10分钟超时时间, 如果服务运转正常, 将会成功发送, 如果发送失败删除标识即可, 如果服务挂掉, 标识10分钟将自动过期无法续命


import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Repository;

import java.util.Set;


/**
 * @Author:
 * @Date:2023/7/3 14:41
 * @Des: WebSocketRedisDao WebSocketRedis存储类
 */
@Repository
public class WebSocketRedisDao {

    
    private static final String ONLINE_CLIENT_KEY_PRE = "ONLINE_#_CLIENT_ID_";

    @Autowired
    private IRedisDao redisDao;

    /*人数相关操作---弃用 : 会有一致性问题, 通过下面存客户端标识实现获取人数*/
    /*
    public void addOnlineCount() {
        redisDao.incryBy(ONLINE_COUNT_KEY, 1);
    }

    public void subOnlineCount() {
        redisDao.incryBy(ONLINE_COUNT_KEY, -1);
    }

    public int getOnlineCount() {
        Integer count = Integer.valueOf(redisDao.getString(ONLINE_COUNT_KEY));
        return count != null ? count : 0;
    }*/


    /*前端客户端标识相关操作*/
    public void addClientId(String clientId) {
        // 设置过期时间1天(在线人数 假如SpringBoot程序被不正常打断, 会导致Redis没有删除活跃的客户端标识, 最终通过过期删除, 最终保持一致性)
        redisDao.setString(ONLINE_CLIENT_KEY_PRE + clientId, "0", 60 * 24);
    }

    public void removeClientId(String clientId) {
        redisDao.delString(ONLINE_CLIENT_KEY_PRE + clientId);
    }

    /**
     * 存储的clientId 非用户名
     *
     * @return
     */
    public Set<String> getClientsSet() {
        return redisDao.getKeysByPattern(ONLINE_CLIENT_KEY_PRE + "*");
    }


}

websocket无法通过注解方式注入bean的解决办法
引入ApplicationContextUtils类

import org.springframework.beans.BeansException;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.stereotype.Component;

/**
 * @Author:
 * @Date:2023/7/4 10:57
 * @Des: ApplicationContextUtils 用来获取SpringBoot创建好的工厂
 */
@Component
public class ApplicationContextUtils implements ApplicationContextAware {

    // 保留下来工厂
    private static ApplicationContext applicationContext;

    // 将创建好的工厂以参数的形式传递给这个类
    @Override
    public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
        this.applicationContext = applicationContext;
    }

    // 提供在工厂中获取对象的方法 // RedisTemplate redisTemplate
    public static Object getBeanByName(String beanName) {
        return applicationContext.getBean(beanName);
    }

    public static <T> T getBeanByClazz(Class<T> clazz) {

        return applicationContext.getBean(clazz);
    }

}

注入方式
private WebSocketRedisDao webSocketRedisDao = ApplicationContextUtils.getBeanByClazz(WebSocketRedisDao.class);
3. 使用Redis实现发布订阅模型

参考Redis.22

4. session共享

此时调用发送消息时, 由之前的直接调用WebSocket的消息发送方法, 改为往消息队列中发布消息

在消息队列的订阅方法中, 再进行调用WebSocket的消息发送方法即可

5. 其他
5.1. OnClose方法中存在问题

当SpringBoot程序关闭时, 主动触发OnClose注解所在方法执行人数扣减操作

此处需要手动在方法里面获取webSocketRedisDao, 防止已经结束生命周期无法操作

/**
     * 关闭连接时调用
     *
     * @param userName 关闭连接的客户端的姓名
     */
@OnClose
public void onClose(@PathParam(value = "name") String userName) {
    WebSocketRedisDao webSocketRedisDao = ApplicationContextUtils.getBeanByClazz(WebSocketRedisDao.class);
    sessionPools.remove(userName);
    webSocketRedisDao.removeClientId(userName);
    subOnlineCount();
    log.info(userName + "断开webSocket连接!当前SpringBoot实例活跃前端客户端数为" + currentSpringBootOnline.get());
    log.info(userName + "断开webSocket连接!当前整个服务集群总活跃前端客户端数为" + webSocketRedisDao.getClientsSet().size());
}
5.2. 暂未实现判断是否在线
6. WebSocketServer

升级后的WebSocketServer


import com.wd.gtp.component.ApplicationContextUtils;
import com.wd.gtp.dao.redis.WebSocketRedisDao;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;

import javax.websocket.*;
import javax.websocket.server.PathParam;
import javax.websocket.server.ServerEndpoint;
import java.io.IOException;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;

/**
 * @Author:
 * @Date:2023/6/26 10:11
 * @Des: WebSocketServer WebSocket服务端代码,包含接收消息,推送消息等接口
 */
@Component
@Slf4j
@ServerEndpoint(value = "/socket/{name}")
public class WebSocketServer {

    private WebSocketRedisDao webSocketRedisDao = ApplicationContextUtils.getBeanByClazz(WebSocketRedisDao.class);

    //静态变量,用来记录当前在线连接数。应该把它设计成线程安全的。
    private static AtomicInteger currentSpringBootOnline = new AtomicInteger(); // 当前实例链接人数

    //concurrent包的线程安全Set,用来存放每个客户端对应的WebSocketServer对象。
    private static Map<String, Session> sessionPools = new ConcurrentHashMap<>();

    /**
     * 发送消息方法
     *
     * @param session 客户端与socket建立的会话
     * @param message 消息
     * @throws IOException
     */
    public void sendMessage(Session session, String message) throws IOException {
        if (session != null) {
            log.info("Session获取非空, 消息推送成功");
            session.getBasicRemote().sendText(message);
        } else {
            log.info("Session获取为空, 消息未推送");
        }
    }

    /**
     * 连接建立成功调用
     *
     * @param session  客户端与socket建立的会话
     * @param userName 客户端的userName
     */
    @OnOpen
    public void onOpen(Session session, @PathParam(value = "name") String userName) {
        sessionPools.put(userName, session);
        webSocketRedisDao.addClientId(userName);
        addOnlineCount();
        log.info(userName + "打开webSocket连接!当前SpringBoot实例活跃前端客户端数为" + currentSpringBootOnline.get());
        log.info(userName + "打开webSocket连接!当前整个服务集群总活跃前端客户端数为" + webSocketRedisDao.getClientsSet().size());
        try {
            sendMessage(session, "欢迎" + userName + "加入连接!");
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    /**
     * 关闭连接时调用
     *
     * @param userName 关闭连接的客户端的姓名
     */
    @OnClose
    public void onClose(@PathParam(value = "name") String userName) {
        WebSocketRedisDao webSocketRedisDao = ApplicationContextUtils.getBeanByClazz(WebSocketRedisDao.class);
        sessionPools.remove(userName);
        webSocketRedisDao.removeClientId(userName);
        subOnlineCount();
        log.info(userName + "断开webSocket连接!当前SpringBoot实例活跃前端客户端数为" + currentSpringBootOnline.get());
        log.info(userName + "断开webSocket连接!当前整个服务集群总活跃前端客户端数为" + webSocketRedisDao.getClientsSet().size());
    }

    /**
     * 发生错误时候
     *
     * @param session
     * @param throwable
     */
    @OnError
    public void onError(Session session, Throwable throwable) {
        log.error("异常", throwable);
    }

    /**
     * 给指定前端客户端发送消息
     *
     * @param clientId 前端客户端标识
     * @param message  消息
     */
    public void sendInfoClient(String clientId, String message) {
        Session session = sessionPools.get(clientId);
        try {
            sendMessage(session, message);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    /**
     * 给指定用户发送消息
     * 该情况下考虑到了 同一用户多次登录, 导致只会向一个发送
     * map中存的key:  (用户名:浏览器标识)  [GTM.ADMIN]:[浏览器标识]  其中浏览器标识最好客户端唯一 每次发起请求都是一样的每个客户端都是唯一的
     *
     * @param id      用户名
     * @param message 消息
     */
    public void sendInfoUser(String id, String message) {
        Set<String> keySet = sessionPools.keySet();
        try {
            for (String key : keySet) {
                if (prefix(key).equals(id)) {
                    sendMessage(sessionPools.get(key), message);
                }
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    /**
     * 收到客户端消息时触发(群发)
     *
     * @param message
     * @throws IOException
     */
    @OnMessage
    public void onMessage(String message) {
        for (Session session : sessionPools.values()) {
            try {
                sendMessage(session, message);
            } catch (Exception e) {
                e.printStackTrace();
                continue;
            }
        }
    }

    private String prefix(String key) {
        return key.substring(0, key.indexOf(":"));
    }

    // 判断用户是否在线
    public boolean isOnline(String userId) {
        Set<String> keySet = webSocketRedisDao.getClientsSet();
        for (String key : keySet) {
            if (prefix(key).equals(userId)) {
                return true;
            }
        }
        return false;
    }


    public static void addOnlineCount() {
        currentSpringBootOnline.incrementAndGet();
    }

    public static void subOnlineCount() {
        currentSpringBootOnline.decrementAndGet();
    }

}


你可能感兴趣的:(分布式,websocket,网络协议)