实现SpringBoot+netty服务端单端口支持Socket、webSocket协议

基于maven的构建环境:(pom.xml添加netty依赖)



    io.netty
    netty-all
    4.1.32.Final

服务构建主类:

import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.ChannelFuture;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

/**
 * @Author: geyingke
 * @Date: 2020/7/20
 * @Class: NettyServer
 * @Discription: TODO
 **/
public class NettyServer {

    private Logger logger = LogManager.getLogger(NettyServer.class);

    private final int port;

    public NettyServer(int port) {
        this.port = port;
    }

    public void start() throws InterruptedException {
        EventLoopGroup bossGroup = new NioEventLoopGroup();
        EventLoopGroup group = new NioEventLoopGroup();
        try {
            ServerBootstrap serverBootstrap = new ServerBootstrap();
            serverBootstrap
                    .group(bossGroup, group)
                    .channel(NioServerSocketChannel.class)
                    .localAddress(port)
                    //设置server初始化类,在初始化是判断响应的协议,分配到不同的ChannelHandler
                    .childHandler(new NettyServerInitializer());
            ChannelFuture channelFuture = serverBootstrap.bind().sync();
            logger.info(String.format("Netty server started!!!! port: %d", port));
            channelFuture.channel().closeFuture().sync();
        } catch (Exception e) {
            group.shutdownGracefully().sync();
            bossGroup.shutdownGracefully().sync();
        } finally {
            group.shutdownGracefully().sync();
            bossGroup.shutdownGracefully().sync();
        }
    }
}

server初始化类

  • 在初始化时,如果要兼容处理socket请求,socket的处Handler和相应的编码器必须在初始化的时候完成。目前仍在研究如何在一个handler处理两种类型的协议。
  • 如果socket和websocket的Handler处理类不分开处理,websocket的捂手连接不能正常完成,目前正在寻找原因
  • 当前实现tcp的粘包解决方案不使用netty提供的三种解决方案,连接方为c++程序,无包头标记码,因此循环截取bytebuff中的byte数组信息
import com.galaxyeye.icservice.im.parser.SocketUtils;
import com.galaxyeye.icservice.im.socket.NettySocketHandler;
import com.galaxyeye.icservice.im.webSocket.WebSocketHandler;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.codec.ByteToMessageDecoder;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpServerCodec;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import io.netty.handler.stream.ChunkedWriteHandler;
import io.netty.handler.timeout.IdleStateHandler;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.springframework.stereotype.Component;

import java.lang.invoke.MethodHandles;
import java.util.List;

/**
 * @Author: geyingke
 * @Date: 2020/7/21
 * @Class: NettyServerInitializer
 * @Discription: TODO
 **/
@Component
public class NettyServerInitializer extends ChannelInitializer {

    private Logger logger = LogManager.getLogger(MethodHandles.lookup().lookupClass());

    @Override
    protected void initChannel(SocketChannel socketChannel) throws Exception {
        //channel初始化
        socketChannel.pipeline().addLast(new IdleStateHandler(60 * 2, 0, 0));
        /**
         * 注意:
         * 1、netty兼容socket和websocket时,socket的响应处理必须在初始化时完成,否则socket消息后续处理失败
         * 2、SocketParser用户鉴别websocket和socket,和socket消息粘包
         * 3、如果同时兼容websocket和socket,socket消息的解码和编码需要在消息处理中进行,在pipeline后添加编码和解码器
         */
        socketChannel.pipeline().addLast("SocketParser", new SocketParser());
        socketChannel.pipeline().addLast(new NettySocketHandler());
    }

    private class SocketParser extends ByteToMessageDecoder {
        /**
         * WebSocket握手的协议前缀
         */
        private static final String WEBSOCKET_PREFIX = "GET /";
        private final Integer BASE_LENGTH = 14;
        int beginIndex = 0;

        @Override
        protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception {
            String protocol = getBufStart(in);
            in.resetReaderIndex();
            if (protocol.startsWith(WEBSOCKET_PREFIX)) {
                //websocket协议本身是基于http协议的,所以这边也要使用http解编码器
                ctx.pipeline().addLast(new HttpServerCodec());
                //以块的方式来写的处理器
                ctx.pipeline().addLast(new ChunkedWriteHandler());
                ctx.pipeline().addLast(new HttpObjectAggregator(8192));
                ctx.pipeline().addLast(new WebSocketHandler());
                ctx.pipeline().addLast(new WebSocketServerProtocolHandler("/ws", null, true, 65536 * 10));
                //去除socket处理
                ctx.pipeline().remove(NettySocketHandler.class);
                ctx.pipeline().remove(this.getClass());
            } else {
                ByteBuf byteBuf = in.readerIndex(beginIndex);
                int readableBytes = byteBuf.readableBytes();
                if (readableBytes >= BASE_LENGTH) {
                    // 防止socket字节流攻击
                    // 防止,客户端传来的数据过大
                    // 因为,太大的数据,是不合理的
                    if (readableBytes > 2048) {
                        byteBuf.skipBytes(readableBytes);
                    }

                    while (byteBuf.readableBytes() > 0) {
                        int thisReadableBytes = byteBuf.readableBytes();
                        byte[] bytes = new byte[thisReadableBytes];
                        byteBuf.readBytes(bytes);
                        logger.info("send body: " + new String(bytes));

                        // 消息的长度
                        int length = SocketUtils.read_int_le(bytes, 0);
                        logger.info("readableBytes: " + readableBytes + "\t custom decode msg length: " + length);

                        // 判断请求数据包数据是否到齐
                        if (thisReadableBytes < length) {
                            // 还原读指针
                            in = byteBuf.readerIndex(beginIndex);
                            return;
                        }
                        byteBuf.resetReaderIndex();
                        //bytebuff在封装tcp流信息时,前面会多加4位,作为整个消息的长度
                        byte[] msgBytes = new byte[length + 4];
                        byteBuf.readBytes(msgBytes);
                        beginIndex = byteBuf.readerIndex();
                        String parse = SocketUtils.parse(msgBytes);
                        logger.info("full bag body: " + parse);
                        out.add(Unpooled.copiedBuffer(msgBytes));
                        byteBuf.markReaderIndex();
                    }
                    beginIndex = 0;
                }
            }

        }

        private String getBufStart(ByteBuf in) {
            int length = in.readableBytes();
            // 标记读位置
            in.markReaderIndex();
            byte[] content = new byte[length];
            in.readBytes(content);
            return new String(content);
        }
    }
}
 
 

Channel连接池实体,存储连接信息:channel不可序列化,因此不可存储到redis缓存中,连接的分布式共享不能基于简单的流共享方式。当前实现channel存储到本地缓存

import com.alibaba.fastjson.JSON;
import com.galaxyeye.icservice.conf.SpringContextBean;
import com.galaxyeye.icservice.utils.RedisTempleUtil;
import io.netty.channel.Channel;
import io.netty.channel.ChannelId;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.springframework.util.Assert;

import java.lang.invoke.MethodHandles;
import java.util.*;

/**
 * @Author: geyingke
 * @Date: 2020/7/20
 * @Class: MyChannelHandlePool
 * @Discription: TODO
 **/
public class MyChannelHandlePool {

    public static Map CHANNEL_MAP = new HashMap<>();

    public static List AUTH_CHANNEL = new ArrayList<>();

    /**
     * 通道及索引信息存储流程:
     * 1、客户端通道建立:
     * websocket:存入(CHANNEL_GROUP-channel)-->(id@appId-channelId)-->(CHANNEL_KEY-channelId)-->(USER_KEY:appId-id@appId)-->WS_ONLINE_APP_ID
     * socket:(CHANNEL_GROUP-channel)-->(CHANNEL_KEY-channelId)
     * 2、socket用户通道关系建立在接入会话时完成:(id@appId-channelId)-->(USER_KEY:appId-id@appId)
     */

    public MyChannelHandlePool() {
    }

    public static RedisTempleUtil redisTempleUtil = SpringContextBean.getBean(RedisTempleUtil.class);

    private static Logger logger = LogManager.getLogger(MethodHandles.lookup().lookupClass());

    /**
     * 通道组
     */
    public static final String CHANNEL_GROUP = "CHANNEL_GROUP";
    //离线的websocket记录
    public static final String OUTLINE_CACHE = "OUTLINE_CACHE";
    public static final String USER_KEY = "USER_KEY";
    public static final String CHANNEL_KEY = "CHANNEL_KEY";
    //排队队伍
    public static final String QUEUE_MARK = "QUEUE_MARK";
    //排队总的appId索引标识
    public static final String QUEUE_APP_ID = "QUEUE_APP_ID";
    //通道类型:websocket
    public static final Integer WS_CHANNEL_TYPE = 1;
    //通道类型:websocket
    public static final Integer SOCKET_CHANNEL_TYPE = 2;
    //websocket在线appId索引标识
    public static final String WS_ONLINE_APP_ID = "WS_ONLINE_APP_ID";
    //socket在线appId索引标识
    public static final String SOCKET_ONLINE_APP_ID = "SOCKET_ONLINE_APP_ID";
    //key分割符
    public static final String KEY_SPLIT = ":";

    /**
     * 服务器重启,清除所有channel连接通道
     *
     * @return
     */
    public static Long clearAllChannel() {
        logger.info("============================服务启动初始化,清空所有连接相关数据===========================");
        logger.info("============================(不清除排队队列,保证服务重启后,排队能够正常进行)===========================");
        //清空连接通道组
//        redisTempleUtil.del(CHANNEL_GROUP);
        //删除ws用户索引下的所有关系数据
        Set wsAppIdKeySet = redisTempleUtil.sGet(WS_ONLINE_APP_ID);
        if (wsAppIdKeySet != null && wsAppIdKeySet.size() > 0) {
            logger.info("待清空的websocket在线appId索引:" + JSON.toJSONString(wsAppIdKeySet));
            for (Object o : wsAppIdKeySet) {
                String appId = (String) o;
                Set userKeys = redisTempleUtil.sGet(generateUserKey(appId, WS_CHANNEL_TYPE));
                if (userKeys != null && userKeys.size() > 0) {
                    logger.info("待清空的websocket用户索引" + JSON.toJSONString(userKeys));
                    redisTempleUtil.del(userKeys.toArray(new String[userKeys.size()]));
                }
            }
            redisTempleUtil.del(WS_ONLINE_APP_ID);
        }
        Set socketAppIdKeySet = redisTempleUtil.sGet(SOCKET_ONLINE_APP_ID);
        if (socketAppIdKeySet != null && socketAppIdKeySet.size() > 0) {
            logger.info("待清空的socket在线appId索引:" + JSON.toJSONString(socketAppIdKeySet));
            for (Object o : socketAppIdKeySet) {
                String appId = (String) o;
                Set userKeys = redisTempleUtil.sGet(generateUserKey(appId, SOCKET_CHANNEL_TYPE));
                if (userKeys != null && userKeys.size() > 0) {
                    logger.info("待清空的socket用户索引" + JSON.toJSONString(userKeys));
                    redisTempleUtil.del(userKeys.toArray(new String[userKeys.size()]));
                }
            }
            redisTempleUtil.del(SOCKET_ONLINE_APP_ID);
        }
        //删除通道索引下的所有关系数据
        Set channelKeys = redisTempleUtil.sGet(CHANNEL_KEY);
        logger.info("待清空的通道id索引" + JSON.toJSONString(channelKeys));
        redisTempleUtil.del(channelKeys.toArray(new String[channelKeys.size()]));
        //清空索引
        redisTempleUtil.del(CHANNEL_KEY);
        logger.info("====================================连接数据初始化完成===================================");
        return 1L;
    }

    /**
     * 确认组中是否存在以channelId为key的值
     *
     * @param channelId
     * @return
     */
    public static boolean hasChannel(ChannelId channelId) {
        return CHANNEL_MAP.containsKey(getStrChannelId(channelId));
    }

    /**
     * 记录连接的通道
     *
     * @param channelId
     * @param channel
     * @return
     */
    public static boolean addChannel(ChannelId channelId, Channel channel) {
        boolean res = true;
//        boolean hset = redisTempleUtil.hset(CHANNEL_GROUP, getStrChannelId(channelId), channel);
        String strChannelId = getStrChannelId(channelId);
        if (!CHANNEL_MAP.containsKey(strChannelId)) {
            CHANNEL_MAP.putIfAbsent(strChannelId, channel);
        } else {
            res = false;
        }
        if (res) {
            //创建通道id key索引
            if (!redisTempleUtil.sHasKey(CHANNEL_KEY, strChannelId)) {
                redisTempleUtil.sSet(CHANNEL_KEY, strChannelId);
            }
        }
        return res;
    }

    /**
     * 获取当前所有连接通道的数量
     *
     * @return
     */
    public static Long getChannelGroupSize() {
        return Long.valueOf(CHANNEL_MAP.size());
    }

    /**
     * 将ChannelId转为string
     *
     * @param channelId
     * @return
     */
    public static String getStrChannelId(ChannelId channelId) {
        return channelId.asLongText();
    }

    /**
     * 删除连接的通道
     *
     * @param channel
     * @return
     */
    public static boolean delChannel(Channel channel) {
        return CHANNEL_MAP.remove(getStrChannelId(channel.id()), channel);
    }

    /**
     * 根据channelId获取通道
     *
     * @param channelId
     * @return
     */
    public static Channel getChannel(String channelId) {
        return (Channel) CHANNEL_MAP.get(channelId);
    }

    /**
     * 创建用户id和通道的关系,1:n
     *
     * @param userAppInfo
     * @param channelId
     * @return
     */
    public static boolean createUserChannelRelation(String appId, String userAppInfo, ChannelId channelId, Integer channelType) {
        //清除离线缓存
        delOutlineQueue(userAppInfo);
        //添加新的关系
        if (redisTempleUtil.sHasKey(userAppInfo, getStrChannelId(channelId))) {
            return true;
        } else {
            //创建用户-通道id关系
            long l = redisTempleUtil.sSet(userAppInfo, getStrChannelId(channelId));
            //创建用户-通道id的key索引
            if (channelType == WS_CHANNEL_TYPE) {
                if (!redisTempleUtil.sHasKey(WS_ONLINE_APP_ID, appId)) {
                    redisTempleUtil.sSet(WS_ONLINE_APP_ID, appId);
                }
                if (!redisTempleUtil.sHasKey(generateUserKey(appId, channelType), userAppInfo)) {
                    redisTempleUtil.sSet(generateUserKey(appId, channelType), userAppInfo);
                }
            } else if (channelType == SOCKET_CHANNEL_TYPE) {
                if (!redisTempleUtil.sHasKey(SOCKET_ONLINE_APP_ID, appId)) {
                    redisTempleUtil.sSet(SOCKET_ONLINE_APP_ID, appId);
                }
                if (!redisTempleUtil.sHasKey(generateUserKey(appId, channelType), userAppInfo)) {
                    redisTempleUtil.sSet(generateUserKey(appId, channelType), userAppInfo);
                }
            }
            //创建通道id-用户索引
            if (!redisTempleUtil.sHasKey(getStrChannelId(channelId), userAppInfo)) {
                redisTempleUtil.sSet(getStrChannelId(channelId), userAppInfo);
            }
            return l > 0 ? true : false;
        }
    }

    private static String generateUserKey(String appId, Integer channelType) {
        return USER_KEY + KEY_SPLIT + appId + KEY_SPLIT + channelType;
    }

    /**
     * 添加websocket离线队列
     *
     * @param userInfoId:staffId@appId
     * @param timeStamp:当前时间的毫秒值
     * @return
     */
    public static boolean addOutlineQueue(String userInfoId, long timeStamp) {
        return redisTempleUtil.hset(OUTLINE_CACHE, userInfoId, timeStamp);
    }

    /**
     * 获取离线websocket队列集合
     * @return
     */
    public static Map getOutlineEntrys() {
        return redisTempleUtil.hEntrys(OUTLINE_CACHE);
    }

    /**
     * 删除websocket离线队列
     *
     * @param userInfoId:staffId@appId
     * @return
     */
    public static Long delOutlineQueue(String userInfoId) {
        if (redisTempleUtil.hHasKey(OUTLINE_CACHE, userInfoId)) {
            return redisTempleUtil.hdel(OUTLINE_CACHE, userInfoId);
        } else {
            return 0L;
        }
    }

    /**
     * 根据用户id,和channelId移除关系
     *
     * @param userAppInfo
     * @param channelId
     * @return
     */
    public static boolean removeUserChannelRelation(String appId, String userAppInfo, ChannelId channelId, Integer channelType) {
        if (userAppInfo != null && channelId != null) {
            //删除用户-通道关系
            long l = redisTempleUtil.setRemove(userAppInfo, getStrChannelId(channelId));
            //校验当前用户下的通道是否全部被删除,是,则删除后续索引
            if (redisTempleUtil.sGetSetSize(userAppInfo) <= 0) {
                //删除用户id索引
                if (channelType == WS_CHANNEL_TYPE) {
                    long l1 = redisTempleUtil.setRemove(generateUserKey(appId, channelType), userAppInfo);
                    //校验当前产品下的用户id索引数是否为0,如果为0,则删除WS_ONLINE_APP_ID的产品索引
                    if (redisTempleUtil.sGetSetSize(generateUserKey(appId, channelType)) <= 0) {
                        long l2 = redisTempleUtil.setRemove(WS_ONLINE_APP_ID, appId);
                    }
                } else if (channelType == SOCKET_CHANNEL_TYPE) {
                    long l1 = redisTempleUtil.setRemove(generateUserKey(appId, channelType), userAppInfo);
                    //校验当前产品下的用户id索引数是否为0,如果为0,则删除WS_ONLINE_APP_ID的产品索引
                    if (redisTempleUtil.sGetSetSize(generateUserKey(appId, channelType)) <= 0) {
                        long l2 = redisTempleUtil.setRemove(WS_ONLINE_APP_ID, appId);
                    }
                }
            }
            return l > 0 ? true : false;
        }
        return false;
    }

    /**
     * set检查key、value是否存在
     *
     * @param key
     * @param value
     * @return
     */
    public static boolean hasSetIndex(String key, String value) {
        return redisTempleUtil.sHasKey(key, value);
    }

    /**
     * set检查key、value是否存在
     *
     * @param values
     * @param key
     * @return
     */
    public static Set removeChannelRelation(String key, Integer channelKey, String... values) {
        long l = redisTempleUtil.setRemove(key, values);
        switch (key) {
            case CHANNEL_KEY:
                return removeChannelKeyAssociation(values, channelKey);
            default:
                return null;
        }
    }

    /**
     * set检查key、value是否存在
     *
     * @param key
     * @return
     */
    public static Set getSetValue(String key) {
        return redisTempleUtil.sGet(key);
    }

    public static Set getUserKeySet(String appId, Integer channelType) {
        return redisTempleUtil.sGet(generateUserKey(appId, channelType));
    }

    public static Set getUserKeyChannelIdSet(String appId, String userId) {
        return redisTempleUtil.sGet(generateSetKey(userId, appId));
    }

    private static Set removeChannelKeyAssociation(String[] channelKeys, Integer channelType) {
        Set result = new HashSet<>();
        Assert.notNull(channelKeys, "channelKeys must not be null");
        if (channelKeys.length > 0) {
            for (String indexKey : channelKeys) {
                Set userAppKeySet = redisTempleUtil.sGet(indexKey);
                if (userAppKeySet != null && userAppKeySet.size() > 0) {
                    for (Object nexIndexKey : userAppKeySet) {
                        String userAppInfo = (String) nexIndexKey;
                        redisTempleUtil.setRemove(userAppInfo, indexKey);
                        //移除用户产品索引
                        String appId = userAppInfo.split(DEFAULT_MX)[1];
                        if (channelType == WS_CHANNEL_TYPE) {
                            redisTempleUtil.setRemove(generateUserKey(appId, channelType), userAppInfo);
                            redisTempleUtil.setRemove(WS_ONLINE_APP_ID, appId);
                        } else if (channelType == SOCKET_CHANNEL_TYPE) {
                            redisTempleUtil.setRemove(generateUserKey(appId, channelType), userAppInfo);
                            redisTempleUtil.setRemove(SOCKET_ONLINE_APP_ID, appId);
                        }
                    }
                    result.addAll(userAppKeySet);
                }
                //移除自身通道的索引
                CHANNEL_MAP.remove(indexKey);
            }
        }
        return result;
    }

    public static final String DEFAULT_MX = "@";

    /**
     * 生成用户-通道关系的key
     *
     * @param userInfoId
     * @param appId
     * @return
     */
    public static String generateSetKey(String userInfoId, String appId) {
        return userInfoId + DEFAULT_MX + appId;
    }

    /**
     * 解析用户-通道关系的key
     *
     * @param userAppKey
     * @return
     */
    public static String[] decodeUserAppSetKey(String userAppKey) {
        if (userAppKey.contains(DEFAULT_MX)) {
            return userAppKey.split(DEFAULT_MX);
        } else {
            return null;
        }
    }

    /**
     * 获取排队序列索引set集合
     *
     * @return
     */
    public static Set getQueueIndexSet() {
        Set queueIndexSet = redisTempleUtil.sGet(QUEUE_APP_ID);
        return queueIndexSet;
    }

    /**
     * 向队列中右侧添加排队信息
     *
     * @param appId
     * @param uId
     * @param chatPackSeq
     * @return value在队列中的索引
     */
    public static long rightPushQueue(String appId, String uId, String chatPackSeq) {
        String key = QUEUE_MARK + ":" + appId;
        String value = encodeQueueValue(chatPackSeq, appId, uId);
        long valueIndex = redisTempleUtil.lhasKeyAndValue(key, value);
        if (valueIndex > -1) {
            return valueIndex;
        } else {
            //存储排队序列索引appid
            if (!redisTempleUtil.sHasKey(QUEUE_APP_ID, appId)) {
                redisTempleUtil.sSet(QUEUE_APP_ID, appId);
            }
            return redisTempleUtil.lrightSet(key, value);
        }
    }

    /**
     * 校验并且返回当前排队信息的索引
     *
     * @param appId
     * @param uId
     * @param chatPackSeq
     * @return value在队列中的索引
     */
    public static long checkAndReturnIndex(String appId, String uId, String chatPackSeq) {
        String key = QUEUE_MARK + ":" + appId;
        String value = encodeQueueValue(chatPackSeq, appId, uId);
        long valueIndex = redisTempleUtil.lhasKeyAndValue(key, value);
        return valueIndex;
    }

    /**
     * 校验并且返回当前排队信息的索引
     *
     * @param appId
     * @return value在队列中的索引
     */
    public static long getQueueSizeByAppId(String appId) {
        String key = QUEUE_MARK + ":" + appId;
        long size = redisTempleUtil.lGetListSize(key);
        return size;
    }

    /**
     * 从队列左侧拿出排队信息
     *
     * @param appId
     * @return
     */
    public static Object leftPopQueue(String appId) {
        Object value = redisTempleUtil.leftPopListValue(QUEUE_MARK + ":" + appId);
        return value;
    }

    /**
     * 移除一个排队信息
     *
     * @param appId
     * @return
     */
    public static long removeOneInQueue(String appId, String uId, String chatPackSeq) {
        String key = QUEUE_MARK + ":" + appId;
        String value = encodeQueueValue(chatPackSeq, appId, uId);
        long l = redisTempleUtil.lRemove(key, 1, value);
        if (redisTempleUtil.lGetListSize(key) == 0) {
            redisTempleUtil.setRemove(QUEUE_APP_ID, appId);
        }
        return l;
    }

    /**
     * 编码排队的value
     *
     * @param chatPackSeq
     * @param appId
     * @param uId
     * @return
     */
    public static String encodeQueueValue(String chatPackSeq, String appId, String uId) {
        return new StringBuffer(chatPackSeq).append(DEFAULT_MX).append(appId).append(DEFAULT_MX).append(uId).toString();
    }

    /**
     * 解码排队的value
     *
     * @param value
     * @return
     */
    public static String[] decodeQueueValue(String value) {
        if (value.contains(DEFAULT_MX)) {
            return value.split(DEFAULT_MX);
        } else {
            return null;
        }
    }
}
 
 

websocket处理类(Handler)

import com.alibaba.fastjson.JSON;
import com.galaxyeye.icservice.im.MyChannelHandlePool;
import io.netty.channel.*;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import io.netty.handler.timeout.IdleState;
import io.netty.handler.timeout.IdleStateEvent;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.springframework.stereotype.Component;

import java.net.InetSocketAddress;
import java.util.HashMap;
import java.util.Map;

/**
 * @Author: geyingke
 * @Date: 2020/7/20
 * @Class: MyWebSocketHandler
 * @Discription: websocket请求处理类
 **/
@Component
@ChannelHandler.Sharable
public class WebSocketHandler extends SimpleChannelInboundHandler {

    private Logger logger = LogManager.getLogger(WebSocketHandler.class);

    @Override
    protected void channelRead0(ChannelHandlerContext channelHandlerContext, TextWebSocketFrame msg) throws Exception {
        //do nothing
        logger.info("come here~!");
    }

    @Override
    public void channelActive(ChannelHandlerContext ctx) throws Exception {
        InetSocketAddress socketAddress = (InetSocketAddress) ctx.channel().remoteAddress();
        String clientIp = socketAddress.getAddress().getHostAddress();
        int clientPort = socketAddress.getPort();
        ChannelId channelId = ctx.channel().id();
        if (MyChannelHandlePool.channelGroup.containsKey(channelId)) {
            logger.info(String.format("websocket客户端【%s】是连接状态,连接通道数量:%d", channelId, MyChannelHandlePool.channelGroup.size()));
        } else {
            //将channel添加到组
            MyChannelHandlePool.channelGroup.put(channelId, ctx.channel());
            logger.info(String.format("websocket客户端【%s】连接netty服务器[IP:%s--->PORT:%d]", channelId, clientIp, clientPort));
        }
    }

    @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
        //首次请求为FullHttpRequest
        if (null != msg && msg instanceof FullHttpRequest) {
            FullHttpRequest fullHttpRequest = (FullHttpRequest) msg;
            String uri = fullHttpRequest.uri();
            Map paramMap = getUrlParams(uri);
            logger.info("received msg ==>" + JSON.toJSONString(paramMap));
            //如果url包含参数,需要处理
            if (uri.contains("?")) {
                String newUri = uri.substring(0, uri.indexOf("?"));
                fullHttpRequest.setUri(newUri);
            } else {
                fullHttpRequest.setUri("/ws");
            }
        } else if (msg instanceof TextWebSocketFrame) {
            TextWebSocketFrame textWebSocketFrame = (TextWebSocketFrame) msg;
            logger.info(String.format("服务端接收到的消息:%s", textWebSocketFrame.text()));
            //todo:消息处理
            sendMessage(ctx.channel(), textWebSocketFrame.text());
        }
        super.channelRead(ctx, msg);
    }

    private void sendMessage(Channel channel, String text) {
        sendAllMessage(channel, text);
    }

    private void sendAllMessage(Channel channel, String message) {
        //收到信息后,群发给所有channel
        channel.writeAndFlush(new TextWebSocketFrame(message));
    }

    @Override
    public void channelInactive(ChannelHandlerContext ctx) throws Exception {
        InetSocketAddress socketAddress = (InetSocketAddress) ctx.channel().remoteAddress();
        String clientIp = socketAddress.getAddress().getHostAddress();
        int clientPort = socketAddress.getPort();
        ChannelId channelId = ctx.channel().id();
        if (MyChannelHandlePool.channelGroup.containsKey(channelId)) {
            Channel remove = MyChannelHandlePool.channelGroup.remove(channelId);
            if (remove != null) {
                logger.info(String.format("websocket客户端【%s】成功下线![IP:%s--> PORT:%d]", channelId, clientIp, clientPort));
                logger.info(String.format("websocket连接通道数量:%d", MyChannelHandlePool.channelGroup.size()));
            } else {
                logger.error(String.format("websocket客服端【%s】下线失败![IP:%s--> PORT:%d]", channelId, clientIp, clientPort));
            }
        }
    }

    @Override
    public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
        logger.info("channelReadComplete");
        ctx.flush();
    }

    @Override
    public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
        String socketString = ctx.channel().remoteAddress().toString();

        if (evt instanceof IdleStateEvent) {
            IdleStateEvent event = (IdleStateEvent) evt;
            if (event.state() == IdleState.READER_IDLE) {
                logger.info(String.format("Client: %s READER_IDLE 读超时", socketString));
                ctx.disconnect();
            } else if (event.state() == IdleState.WRITER_IDLE) {
                logger.info(String.format("Client: %s WRITER_IDLE 写超时", socketString));
                ctx.disconnect();
            } else if (event.state() == IdleState.ALL_IDLE) {
                logger.info(String.format("Client: %s ALL_IDLE 总超时", socketString));
                ctx.disconnect();
            }
        } else {
            super.userEventTriggered(ctx, evt);
        }
    }

    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
        logger.error("websocket消息处理异常!");
        if (cause != null) cause.printStackTrace();
        if (ctx != null) ctx.close();
    }

    private static Map getUrlParams(String url) {
        Map map = new HashMap<>();
        url = url.replace("?", ";");
        if (!url.contains(";")) {
            return map;
        }
        if (url.split(";").length > 0) {
            String[] arr = url.split(";")[1].split("&");
            for (String s : arr) {
                String key = s.split("=")[0];
                String value = s.split("=")[1];
                map.put(key, value);
            }
            return map;

        } else {
            return map;
        }
    }
}

socket处理类(Handler)

继承SimpleChannelInboundHandler,泛型不能跟websocket的handler的泛型相同,否则需要提取离线和在线的处理,不可分辨channel类型

import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import com.galaxyeye.icservice.conf.RedisOperator;
import com.galaxyeye.icservice.conf.SpringContextBean;
import com.galaxyeye.icservice.conf.myException.CatchedReturnException;
import com.galaxyeye.icservice.conf.myException.DataBaseException;
import com.galaxyeye.icservice.constant.ReturnEnum;
import com.galaxyeye.icservice.entity.ValidatorVo;
import com.galaxyeye.icservice.im.MyChannelHandlePool;
import com.galaxyeye.icservice.im.parser.SocketUtils;
import com.galaxyeye.icservice.im.protocol.SocketProtocol;
import com.galaxyeye.icservice.im.protocol.WebSocketProtocol;
import com.galaxyeye.icservice.service.im.socket.SocketHandlerService;
import com.galaxyeye.icservice.utils.RandomUtils;
import com.galaxyeye.icservice.utils.WSMsgUtil;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelId;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import io.netty.handler.timeout.IdleState;
import io.netty.handler.timeout.IdleStateEvent;
import io.netty.util.ReferenceCountUtil;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import java.lang.invoke.MethodHandles;
import java.net.InetSocketAddress;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;

/**
 * @Author: geyingke
 * @Date: 2020/7/21
 * @Class: NettyWebSocketHandler
 * @Discription: TODO
 **/
public class NettySocketHandler extends SimpleChannelInboundHandler {

    private Logger logger = LogManager.getLogger(MethodHandles.lookup().lookupClass());

    private RedisOperator redisOperator = SpringContextBean.getBean(RedisOperator.class);

    private SocketHandlerService socketHandlerService = SpringContextBean.getBean(SocketHandlerService.class);

    @Override
    public void channelActive(ChannelHandlerContext ctx) throws Exception {
        InetSocketAddress socketAddress = (InetSocketAddress) ctx.channel().remoteAddress();
        String clientIp = socketAddress.getAddress().getHostAddress();
        int clientPort = socketAddress.getPort();
        ChannelId channelId = ctx.channel().id();
        if (MyChannelHandlePool.hasChannel(channelId)) {
            logger.info(String.format("socket客户端【%s】是连接状态,连接通道数量:%d", channelId, MyChannelHandlePool.getChannelGroupSize()));
        } else {
            //将channel添加到组
            MyChannelHandlePool.addChannel(channelId, ctx.channel());
            logger.info(String.format("socket客户端【%s】连接netty服务器[IP:%s--->PORT:%d]", channelId, clientIp, clientPort));
            logger.info(String.format("客户端连接通道数量:%d", MyChannelHandlePool.getChannelGroupSize()));
        }
    }

    @Override
    public void channelInactive(ChannelHandlerContext ctx) throws Exception {
        logger.info("---------------------socket断线检测-------------------------");
        InetSocketAddress socketAddress = (InetSocketAddress) ctx.channel().remoteAddress();
        String clientIp = socketAddress.getAddress().getHostAddress();
        int clientPort = socketAddress.getPort();
        ChannelId channelId = ctx.channel().id();
        if (MyChannelHandlePool.hasChannel(channelId)) {
            boolean aLong = MyChannelHandlePool.delChannel(ctx.channel());
            //校验通道id,并下线
            if (MyChannelHandlePool.hasSetIndex(MyChannelHandlePool.CHANNEL_KEY, MyChannelHandlePool.getStrChannelId(channelId))) {
                MyChannelHandlePool.removeChannelRelation(MyChannelHandlePool.CHANNEL_KEY, MyChannelHandlePool.SOCKET_CHANNEL_TYPE, MyChannelHandlePool.getStrChannelId(channelId));
            }
            if (aLong) {
                logger.info(String.format("socket客户端【%s】成功下线![IP:%s--> PORT:%d]", channelId, clientIp, clientPort));
                logger.info(String.format("连接通道数量:%d", MyChannelHandlePool.getChannelGroupSize()));
            } else {
                logger.error(String.format("socket客户端【%s】下线失败![IP:%s--> PORT:%d]", channelId, clientIp, clientPort));
            }
        }
    }

    @Override
    protected void channelRead0(ChannelHandlerContext ctx, ByteBuf msg) throws Exception {
        logger.info("come here");
    }

    @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
        try {
            ByteBuf byteBuf = (ByteBuf) msg;
            Map reqMap = SocketUtils.parseByteBuffMap(byteBuf);
            String msgBody = (String) reqMap.get(SocketUtils.BODY);
            int cmd = (int) reqMap.get(SocketUtils.CMD);
            logger.info("msg coverted : " + msgBody);
            JSONObject receiveMsg = JSON.parseObject(msgBody);
            if (receiveMsg.containsKey(SocketProtocol.TYPE)) {
                String msgType = receiveMsg.getString(SocketProtocol.TYPE);
                if (MyChannelHandlePool.AUTH_CHANNEL.contains(ctx.channel())) {
                    switch (msgType) {
                        case SocketProtocol.TRANS_IC:
                            tranIc(ctx, receiveMsg);
                            break;
                        case SocketProtocol.GET_QUEUE_INFO:
                            getQueueInfo(ctx, receiveMsg);
                            break;
                        case SocketProtocol.EXIT_QUEUE:
                            exitQueue(ctx, receiveMsg);
                            break;
                        case SocketProtocol.CHAT_SEND:
                            chatSend(ctx, receiveMsg);
                            break;
                        case SocketProtocol.FEEDBACK:
                            feedBack(ctx, receiveMsg);
                            break;
                        case SocketProtocol.CHAT_OFFLINE:
                            chatOffline(ctx, receiveMsg);
                            break;
                        case SocketProtocol.CHAT_EXIT:
                            chatExit(ctx, receiveMsg);
                            break;
                        case SocketProtocol.CHAT_RECONNECT:
                            reconnect(ctx, receiveMsg);
                            break;
                        default:
                            break;
                    }
                } else {
                    JSONObject res = new JSONObject() {{
                        put("type", "error");
                        put("retCode", ReturnEnum.UNAUTHED_CHANNEL.getRet_msg());
                        put("retMsg", ReturnEnum.UNAUTHED_CHANNEL.getRet_msg());
                    }};
                    writeBack(ctx, res.toJSONString());
                }
            } else {
                if (cmd == 201) {
                    Integer servType = receiveMsg.getInteger("servType");
                    String appid = receiveMsg.getString("appid");
                    handle201msg(ctx, appid, servType);
                } else if (cmd == 202) {
                    String appid = receiveMsg.getString("appid");
                    String sign = receiveMsg.getString("sign");
                    handle202msg(ctx, appid, sign);
                } else {
                    //默认的心跳包处理,直接返回
                    logger.info("return msg: " + msgBody);
                    writeHeartBeatBack(ctx, msgBody);
                }
            }

        } catch (CatchedReturnException ce) {
            logger.error("消息内容不合法!", ce);
            String dispose = JSON.toJSONString(new HashMap() {{
                put(SocketProtocol.TYPE, SocketProtocol.ERROR);
                put(SocketProtocol.DATA, new StringBuffer("消息内容不合法!:").append(ce.getMessage()));
            }});
            writeBack(ctx, dispose);
        } catch (Exception e) {
            logger.error("消息处理异常:", e);
            String dispose = JSON.toJSONString(new HashMap() {{
                put(SocketProtocol.TYPE, SocketProtocol.ERROR);
                put(SocketProtocol.DATA, new StringBuffer("消息内容不合法!:").append(e.getMessage()));
            }});
            writeBack(ctx, dispose);
        } finally {
            //释放内存
            ReferenceCountUtil.release(msg);
        }
    }


    /**
     * 处理返回信息
     *
     * @param ctx
     * @param dispose
     */
    private void writeBack(ChannelHandlerContext ctx, String dispose) {
        ByteBuf resp = Unpooled.copiedBuffer(SocketUtils.pack(dispose.getBytes(), 0));
        ctx.writeAndFlush(resp);
    }

    /**
     * 处理返回信息
     *
     * @param ctx
     * @param dispose
     */
    private void writeHeartBeatBack(ChannelHandlerContext ctx, String dispose) {
        ByteBuf resp = Unpooled.copiedBuffer(SocketUtils.pack(dispose.getBytes(), 9999));
        ctx.writeAndFlush(resp);
    }

    @Override
    public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
        String socketString = ctx.channel().remoteAddress().toString();
        if (evt instanceof IdleStateEvent) {
            IdleStateEvent event = (IdleStateEvent) evt;
            if (event.state() == IdleState.READER_IDLE) {
                logger.info(String.format("Client: %s READER_IDLE 读超时", socketString));
                ctx.disconnect();
            } else if (event.state() == IdleState.WRITER_IDLE) {
                logger.info(String.format("Client: %s WRITER_IDLE 写超时", socketString));
                ctx.disconnect();
            } else if (event.state() == IdleState.ALL_IDLE) {
                logger.info(String.format("Client: %s ALL_IDLE 总超时", socketString));
                ctx.disconnect();
            }
        } else {
            super.userEventTriggered(ctx, evt);
        }
    }
}

socket消息封包、解包处理类:

import io.netty.buffer.ByteBuf;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;

/**
 * @Author: geyingke
 * @Date: 2020/8/4
 * @Class: SocketUtils
 * @Discription: TODO
 **/
public class SocketUtils {
    static final int HEAD_SIZE = 10;
    static final int TOTAL_SIZE = 14;
    static int cmd;

    private static final Logger logger = LogManager.getLogger(SocketUtils.class);


    static void write_short_le(byte[] buf, int offset, short value) {
        buf[offset + 1] = (byte) ((value >> 8) & 0xff);//说明一
        buf[offset + 0] = (byte) ((value) & 0xff);
    }

    static void write_int_le(byte[] buf, int offset, int value) {
        buf[offset + 3] = (byte) ((value >> 24) & 0xff);//说明一
        buf[offset + 2] = (byte) ((value >> 16) & 0xff);
        buf[offset + 1] = (byte) ((value >> 8) & 0xff);
        buf[offset + 0] = (byte) (value & 0xff);
    }

    static void write_bytes(byte[] src, int src_offset, byte[] dst, int dst_offset) {
        for (int i = 0; i < src.length - src_offset; ++i) {
            dst[dst_offset + i] = src[src_offset + i];
        }
    }

    static short read_short_le(byte[] data, int offset) {
        int ret = (data[offset] | (data[offset + 1] << 8)) & 0xFF;
        return (short) ret;
    }

    public static int read_int_le(byte[] data, int offset) {
        int ret = ((data[offset] & 0xFF) | ((data[offset + 1] & 0xFF) << 8) | ((data[offset + 2] & 0xFF << 16)) | ((data[offset + 3] & 0xFF << 24)));
        return ret;
    }

    /**
     * 解析byteBuf内容
     *
     * @param msg
     * @return
     */
    public static String parseByteBuff(ByteBuf msg) {
        byte[] bytes = new byte[msg.readableBytes()];
        msg.readBytes(bytes);
        logger.debug("msg before covert: " + new String(bytes));
        return parse(bytes);
    }

    /**
     * 解析tcp body
     *
     * @param bytes
     * @return
     */
    public static String parse(byte[] bytes) {

        int offset = 0;
        int plen = read_int_le(bytes, offset);
        offset += 4;//pkgLen
        offset += 4;//checkSum
        cmd = read_short_le(bytes, offset);
        offset += 2;//cmd
        offset += 2;//target
        offset += 2;//retCode
        int content_size = (plen - HEAD_SIZE);
        byte[] content_buf = new byte[content_size];
        write_bytes(bytes, offset, content_buf, 0);
        return new String(content_buf);
    }

    public static final String CMD = "CMD";
    public static final String BODY = "BODY";

    /**
     * 解析byteBuf内容
     *
     * @param msg
     * @return
     */
    public static Map parseByteBuffMap(ByteBuf msg) {
        byte[] bytes = new byte[msg.readableBytes()];
        msg.readBytes(bytes);
        logger.debug("msg before covert: " + new String(bytes));
        return parseMap(bytes);
    }

    /**
     * 解析tcp body
     *
     * @param bytes
     * @return
     */
    public static Map parseMap(byte[] bytes) {

        int offset = 0;
        int plen = read_int_le(bytes, offset);
        offset += 4;//pkgLen
        offset += 4;//checkSum
        int cmdx = read_short_le(bytes, offset);
        offset += 2;//cmd
        offset += 2;//target
        offset += 2;//retCode
        int content_size = (plen - HEAD_SIZE);
        byte[] content_buf = new byte[content_size];
        write_bytes(bytes, offset, content_buf, 0);
        return new HashMap() {{
            put(CMD, cmdx);
            put(BODY, new String(content_buf));
        }};
    }

    /**
     * 打包tcp body
     *
     * @param content
     * @param cmd
     * @return
     */
    public static byte[] pack(byte[] content, int cmd) {
        int total_size = content.length + TOTAL_SIZE;
        int pkgLen = total_size - 4;
        int offset = 0;
        byte[] msg = new byte[total_size];

        write_int_le(msg, offset, pkgLen);
        offset += 4;//pkgLen
        write_int_le(msg, offset, 0);
        offset += 4;//checkSum
        write_short_le(msg, offset, (short) cmd);
        offset += 2;//cmd
        write_short_le(msg, offset, (short) 0);
        offset += 2;//target
        write_short_le(msg, offset, (short) 0);
        offset += 2;//retCode
        write_bytes(content, 0, msg, offset);

        return msg;
    }
}

你可能感兴趣的:(实现SpringBoot+netty服务端单端口支持Socket、webSocket协议)