netty-websocket 鉴权token及统一请求和响应头(鉴权控制器)

netty-websocket 鉴权token及统一请求和响应头(鉴权控制器)_第1张图片

自己想法和实现,如果有说错的或者有更好的简单的实现方式可以私信交流一下(主要是实现握手时鉴权)

需求实现

  1. 握手鉴权是基于前台请求头 Sec-WebSocket-Protocol的
  2. 本身socket并没有提供自定义请求头,只能自定义 Sec-WebSocket-Protocol的自协议

问题描述

socket握手请求是基于http的,握手成功后会升级为ws

前台传输了 token作为Sec-WebSocket-Protocol的值,后台接收到后总是断开连接,后来网上看了很多博客说的都是大同小异,然后就看了他的源码一步步走的(倔脾气哈哈),终于我看到了端倪,这个问题是因为前后台的Sec-WebSocket-Protocol值不一致,所以会断开,但是我记得websocket好像是不用自己设置请求头的,但是netty我看了源码,好像没有预留设置websocket的response的响应头(这只是我的个人理解)

具体实现

CustomWebSocketProtocolHandler

解释: 自定义替换WebSocketProtocolHandler,复制WebSocketProtocolHandler的内容即可,因为主要是WebSocketServerProtocolHandler自定义会用到

abstract class CustomWebSocketProtocolHandler extends MessageToMessageDecoder<WebSocketFrame> {
    @Override
    protected void decode(ChannelHandlerContext ctx, WebSocketFrame frame, List<Object> out) throws Exception {
        if (frame instanceof PingWebSocketFrame) {
            frame.content().retain();
            ctx.channel().writeAndFlush(new PongWebSocketFrame(frame.content()));
            return;
        }
        if (frame instanceof PongWebSocketFrame) {
            // Pong frames need to get ignored
            return;
        }

        out.add(frame.retain());
    }

    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
        ctx.fireExceptionCaught(cause);
        ctx.close();
    }
}

CustomWebSocketServerProtocolHandler

解释: 自定义WebSocketServerProtocolHandler,实现上面自定义的WebSocketProtocolHandler,具体内容和WebSocketServerProtocolHandler保持一致,只需要将handlerAdded中的类ProtocolHandler改为自己定义的即可
注意:后面监听读写的自定义业务的handler需要实现相应的方法:异常或者事件监听,因为比如异常,如果抛出异常了,是不会有控制器去管的,因为当前的业务控制器就是最后一层,因为上面已经把默认实现改成了自己的实现(其他的控制器都是基于默认handler实现的,如果改了后,去初始化自己改后的handler那便是最后一层),所以要手动去关闭

ublic class CustomWebSocketServerProtocolHandler extends CustomWebSocketProtocolHandler {

    /**
     * Events that are fired to notify about handshake status
     */
    public enum ServerHandshakeStateEvent {
        /**
         * The Handshake was completed successfully and the channel was upgraded to websockets.
         *
         * @deprecated in favor of {@link WebSocketServerProtocolHandler.HandshakeComplete} class,
         * it provides extra information about the handshake
         */
        @Deprecated
        HANDSHAKE_COMPLETE
    }

    /**
     * The Handshake was completed successfully and the channel was upgraded to websockets.
     */
    public static final class HandshakeComplete {
        private final String requestUri;
        private final HttpHeaders requestHeaders;
        private final String selectedSubprotocol;

       public HandshakeComplete(String requestUri, HttpHeaders requestHeaders, String selectedSubprotocol) {
            this.requestUri = requestUri;
            this.requestHeaders = requestHeaders;
            this.selectedSubprotocol = selectedSubprotocol;
        }

        public String requestUri() {
            return requestUri;
        }

        public HttpHeaders requestHeaders() {
            return requestHeaders;
        }

        public String selectedSubprotocol() {
            return selectedSubprotocol;
        }
    }

    private static final AttributeKey<WebSocketServerHandshaker> HANDSHAKER_ATTR_KEY =
            AttributeKey.valueOf(WebSocketServerHandshaker.class, "HANDSHAKER");

    private final String websocketPath;
    private final String subprotocols;
    private final boolean allowExtensions;
    private final int maxFramePayloadLength;
    private final boolean allowMaskMismatch;
    private final boolean checkStartsWith;

    public CustomWebSocketServerProtocolHandler(String websocketPath) {
        this(websocketPath, null, false);
    }

    public CustomWebSocketServerProtocolHandler(String websocketPath, boolean checkStartsWith) {
        this(websocketPath, null, false, 65536, false, checkStartsWith);
    }

    public CustomWebSocketServerProtocolHandler(String websocketPath, String subprotocols) {
        this(websocketPath, subprotocols, false);
    }

    public CustomWebSocketServerProtocolHandler(String websocketPath, String subprotocols, boolean allowExtensions) {
        this(websocketPath, subprotocols, allowExtensions, 65536);
    }

    public CustomWebSocketServerProtocolHandler(String websocketPath, String subprotocols,
                                          boolean allowExtensions, int maxFrameSize) {
        this(websocketPath, subprotocols, allowExtensions, maxFrameSize, false);
    }

    public CustomWebSocketServerProtocolHandler(String websocketPath, String subprotocols,
                                          boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch) {
        this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, false);
    }

    public CustomWebSocketServerProtocolHandler(String websocketPath, String subprotocols,
                                          boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch, boolean checkStartsWith) {
        this.websocketPath = websocketPath;
        this.subprotocols = subprotocols;
        this.allowExtensions = allowExtensions;
        maxFramePayloadLength = maxFrameSize;
        this.allowMaskMismatch = allowMaskMismatch;
        this.checkStartsWith = checkStartsWith;
    }

    @Override
    public void handlerAdded(ChannelHandlerContext ctx) {
        ChannelPipeline cp = ctx.pipeline();
        if (cp.get(CustomWebSocketServerProtocolHandler.class) == null) {
            // Add the WebSocketHandshakeHandler before this one.
            ctx.pipeline().addBefore(ctx.name(), CustomWebSocketServerProtocolHandler.class.getName(),
                    new CustomWebSocketServerProtocolHandler(websocketPath, subprotocols,
                            allowExtensions, maxFramePayloadLength, allowMaskMismatch, checkStartsWith));
        }
        if (cp.get(Utf8FrameValidator.class) == null) {
            // Add the UFT8 checking before this one.
            ctx.pipeline().addBefore(ctx.name(), Utf8FrameValidator.class.getName(),
                    new Utf8FrameValidator());
        }
    }

    @Override
    protected void decode(ChannelHandlerContext ctx, WebSocketFrame frame, List<Object> out) throws Exception {
        if (frame instanceof CloseWebSocketFrame) {
            WebSocketServerHandshaker handshaker = getHandshaker(ctx.channel());
            if (handshaker != null) {
                frame.retain();
                handshaker.close(ctx.channel(), (CloseWebSocketFrame) frame);
            } else {
                ctx.writeAndFlush(Unpooled.EMPTY_BUFFER).addListener(ChannelFutureListener.CLOSE);
            }
            return;
        }
        super.decode(ctx, frame, out);
    }

    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
        if (cause instanceof WebSocketHandshakeException) {
            FullHttpResponse response = new DefaultFullHttpResponse(
                    HTTP_1_1, HttpResponseStatus.BAD_REQUEST, Unpooled.wrappedBuffer(cause.getMessage().getBytes()));
            ctx.channel().writeAndFlush(response).addListener(ChannelFutureListener.CLOSE);
        } else {
            ctx.fireExceptionCaught(cause);
            ctx.close();
        }
    }

    static WebSocketServerHandshaker getHandshaker(Channel channel) {
        return channel.attr(HANDSHAKER_ATTR_KEY).get();
    }

    public static void setHandshaker(Channel channel, WebSocketServerHandshaker handshaker) {
        channel.attr(HANDSHAKER_ATTR_KEY).set(handshaker);
    }

    public static ChannelHandler forbiddenHttpRequestResponder() {
        return new ChannelInboundHandlerAdapter() {
            @Override
            public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
                if (msg instanceof FullHttpRequest) {
                    ((FullHttpRequest) msg).release();
                    FullHttpResponse response =
                            new DefaultFullHttpResponse(HTTP_1_1, HttpResponseStatus.FORBIDDEN);
                    ctx.channel().writeAndFlush(response);
                } else {
                    ctx.fireChannelRead(msg);
                }
            }
        };
    }
}

SecurityServerHandler

用SecurityServerHandler自定义的入站控制器替换原有默认的控制器WebSocketServerProtocolHandshakeHandler
这一步最关键了,因为在这一步就要将头设置进去,前面两步只是为这一步做铺垫,因为netty包中的类不能外部引用也没有提供修改方法,所以才有了上面的自定义类,此类中需要调整握手逻辑,添加握手响应头,然后将WebSocketServerProtocolHandler改为CustomWebSocketServerProtocolHandler,其他的实现类也是一样的去改

public class SecurityServerHandler extends ChannelInboundHandlerAdapter {

    private final String websocketPath;
    private final String subprotocols;
    private final boolean allowExtensions;
    private final int maxFramePayloadSize;
    private final boolean allowMaskMismatch;
    private final boolean checkStartsWith;
	
	  /**
     * 自定义属性 token头key
     */
    private final String tokenHeader;
	/**
     * 自定义属性 token
     */
    private final boolean hasToken;


    public SecurityServerHandler(String websocketPath, String subprotocols,
                                 boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch, String tokenHeader, boolean hasToken) {
        this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, false,tokenHeader,hasToken);
    }

    SecurityServerHandler(String websocketPath, String subprotocols,
                                            boolean allowExtensions, int maxFrameSize,
                          boolean allowMaskMismatch,
                          boolean checkStartsWith,
                          String tokenHeader,
                          boolean hasToken) {
        this.websocketPath = websocketPath;
        this.subprotocols = subprotocols;
        this.allowExtensions = allowExtensions;
        maxFramePayloadSize = maxFrameSize;
        this.allowMaskMismatch = allowMaskMismatch;
        this.checkStartsWith = checkStartsWith;
        this.tokenHeader = tokenHeader;
        this.hasToken = hasToken;
    }

    @Override
    public void channelRead(final ChannelHandlerContext ctx, Object msg) throws Exception {
        final FullHttpRequest req = (FullHttpRequest) msg;
        if (isNotWebSocketPath(req)) {
            ctx.fireChannelRead(msg);
            return;
        }
        try {
        	// 具体的鉴权逻辑
            HttpHeaders headers = req.headers();
            String token = Objects.requireNonNull(headers.get(tokenHeader));
            if(hasToken){
                // 开启鉴权 认证
                //extracts device information headers
                LoginUser loginUser = SecurityUtils.getLoginUser(token);
                if(null == loginUser){
                    refuseChannel(ctx);
                    return;
                }
                Long userId = loginUser.getUserId();
                //check ......
                SecurityCheckComplete complete = new SecurityCheckComplete(String.valueOf(userId),tokenHeader,hasToken);
                ctx.channel().attr(SECURITY_CHECK_COMPLETE_ATTRIBUTE_KEY).set(complete);
                ctx.fireUserEventTriggered(complete);
            }else {
                // 不开启鉴权 / 认证
                SecurityCheckComplete complete = new SecurityCheckComplete(null,tokenHeader,hasToken);
                ctx.channel().attr(SECURITY_CHECK_COMPLETE_ATTRIBUTE_KEY).set(complete);
            }
            if (req.method() != GET) {
                sendHttpResponse(ctx, req, new DefaultFullHttpResponse(HTTP_1_1, FORBIDDEN));
                return;
            }
            final WebSocketServerHandshakerFactory wsFactory = new WebSocketServerHandshakerFactory(
                    getWebSocketLocation(ctx.pipeline(), req, websocketPath), subprotocols,
                    allowExtensions, maxFramePayloadSize, allowExtensions);
            final WebSocketServerHandshaker handshaker = wsFactory.newHandshaker(req);
            if (handshaker == null) {
                WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel());
            } else {
            	// 此处将具体的头加入http中,因为这个头会传递个netty底层设置响应头的方法中,默认实现是传的null
                HttpHeaders httpHeaders = new DefaultHttpHeaders().add(tokenHeader,token);
                // 此处便是构造握手相应头的关键步骤
                final ChannelFuture handshakeFuture = handshaker.handshake(ctx.channel(), req,httpHeaders,ctx.channel().newPromise());
                handshakeFuture.addListener((ChannelFutureListener) future -> {
                    if (!future.isSuccess()) {
                        ctx.fireExceptionCaught(future.cause());
                    } else {
                        // Kept for compatibility
                        ctx.fireUserEventTriggered(
                                CustomWebSocketServerProtocolHandler.ServerHandshakeStateEvent.HANDSHAKE_COMPLETE);
                        ctx.fireUserEventTriggered(
                                new CustomWebSocketServerProtocolHandler.HandshakeComplete(
                                        req.uri(), req.headers(), handshaker.selectedSubprotocol()));
                    }
                });
                CustomWebSocketServerProtocolHandler.setHandshaker(ctx.channel(), handshaker);
                ctx.pipeline().replace(this, "WS403Responder",
                        CustomWebSocketServerProtocolHandler.forbiddenHttpRequestResponder());
            }
        }catch (Exception e){
            e.printStackTrace();
        }finally {
            req.release();
        }
    }

    public static final class HandshakeComplete {
        private final String requestUri;
        private final HttpHeaders requestHeaders;
        private final String selectedSubprotocol;

        HandshakeComplete(String requestUri, HttpHeaders requestHeaders, String selectedSubprotocol) {
            this.requestUri = requestUri;
            this.requestHeaders = requestHeaders;
            this.selectedSubprotocol = selectedSubprotocol;
        }

        public String requestUri() {
            return requestUri;
        }

        public HttpHeaders requestHeaders() {
            return requestHeaders;
        }

        public String selectedSubprotocol() {
            return selectedSubprotocol;
        }
    }



    private boolean isNotWebSocketPath(FullHttpRequest req) {
        return checkStartsWith ? !req.uri().startsWith(websocketPath) : !req.uri().equals(websocketPath);
    }


    private static void sendHttpResponse(ChannelHandlerContext ctx, HttpRequest req, HttpResponse res) {
        ChannelFuture f = ctx.channel().writeAndFlush(res);
        if (!isKeepAlive(req) || res.status().code() != 200) {
            f.addListener(ChannelFutureListener.CLOSE);
        }
    }

    private static String getWebSocketLocation(ChannelPipeline cp, HttpRequest req, String path) {
        String protocol = "ws";
        if (cp.get(SslHandler.class) != null) {
            // SSL in use so use Secure WebSockets
            protocol = "wss";
        }
        String host = req.headers().get(HttpHeaderNames.HOST);
        return protocol + "://" + host + path;
    }

    private void refuseChannel(ChannelHandlerContext ctx) {
        ctx.channel().writeAndFlush(new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.UNAUTHORIZED));
        ctx.channel().close();
    }

    private static void send100Continue(ChannelHandlerContext ctx,String tokenHeader,String token) {
        DefaultFullHttpResponse response = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.CONTINUE);
        response.headers().set(tokenHeader,token);
        ctx.writeAndFlush(response);
    }

    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
        System.out.println("channel 捕获到异常了,关闭了");
        super.exceptionCaught(ctx, cause);
    }
    @Getter
    @AllArgsConstructor
    public static final class SecurityCheckComplete {

        private String userId;

        private String tokenHeader;

        private Boolean hasToken;

    }
}

initChannel方法去初始化自己的实现类

其他的类需要自己实现或者引用,其他的就是无关紧要的,不用去处理的类


 @Override
    protected void initChannel(SocketChannel ch){
        log.info("有新的连接");
        //获取工人所要做的工程(管道器==管道器对应的便是管道channel)
        ChannelPipeline pipeline = ch.pipeline();
        //为工人的工程按顺序添加工序/材料 (为管道器设置对应的handler也就是控制器)
        //1.设置心跳机制
        pipeline.addLast("idle-state",new IdleStateHandler(
                nettyWebSocketProperties.getReaderIdleTime(),
                0,
                0,
                TimeUnit.SECONDS));
        //2.出入站时的控制器,大部分用于针对心跳机制
        pipeline.addLast("change-duple",new WsChannelDupleHandler(nettyWebSocketProperties.getReaderIdleTime()));
        //3.加解码
        pipeline.addLast("http-codec",new HttpServerCodec());
        //3.打印控制器,为工人提供明显可见的操作结果的样式
        pipeline.addLast("logging", new LoggingHandler(LogLevel.INFO));
        pipeline.addLast("aggregator",new HttpObjectAggregator(8192));
        // 将自己的授权handler替换原有的handler
        pipeline.addLast("auth",new SecurityServerHandler(
        		// 此处我是用的yaml配置的,换成自己的即可
                nettyWebSocketProperties.getWebsocketPath(),
                nettyWebSocketProperties.getSubProtocols(),
                nettyWebSocketProperties.getAllowExtensions(),
                nettyWebSocketProperties.getMaxFrameSize(),
                //todo
                false,
                nettyWebSocketProperties.getTokenHeader(),
                nettyWebSocketProperties.getHasToken()
        ));
        pipeline.addLast("http-chunked",new ChunkedWriteHandler());
        // 将自己的协议控制器替换原有的协议控制器
        pipeline.addLast("websocket",
                new CustomWebSocketServerProtocolHandler(
                nettyWebSocketProperties.getWebsocketPath(),
                nettyWebSocketProperties.getSubProtocols(),
                nettyWebSocketProperties.getAllowExtensions(),
                nettyWebSocketProperties.getMaxFrameSize())
        );
        //7.自定义的handler针对业务
        pipeline.addLast("chat-handler",new ChatHandler());
    }

附加:我自己的业务实现类(chatHandler)及相应工具类

chatHandler

/**
 * @author qb
 * @version 1.0
 * @since 2023/3/7 11:56
 */
@Slf4j
public class ChatHandler extends SimpleChannelInboundHandler<TextWebSocketFrame> {

    /**
     * 连接时
     * @param ctx 上下文
     * @throws Exception /
     */
    @Override
    public void channelActive(ChannelHandlerContext ctx) throws Exception {
        log.info("与客户端建立连接,通道开启!");
        // 添加到channelGroup通道组
        ChannelHandlerPool.pool().addChannel(ctx.channel());
    }

    /**
     * 断开连接时
     * @param ctx 上下文
     * @throws Exception /
     */
    @Override
    public void channelInactive(ChannelHandlerContext ctx) throws Exception {
        log.info("与客户端断开连接,通道关闭!");
        // 从channelGroup通道组移除
        ChannelHandlerPool.pool().removeChannel(ctx.channel());
        String useridQuit = ctx.channel().attr(SECURITY_CHECK_COMPLETE_ATTRIBUTE_KEY).get().getUserId();
        ChannelHandlerPool.pool().removeChannelId(useridQuit);
        log.info("断开的用户id为:{}",useridQuit);
    }

    /**
     * 获取消息时
     * @param ctx 上下文
     * @param msg 消息
     */
    @Override
    protected void channelRead0(ChannelHandlerContext ctx, TextWebSocketFrame msg) {
        log.info("msg.text():{}",msg.text());
        Message message = JSON.parseObject(msg.text(),Message.class);
        if("0".equals(message.getType())){
            log.info("消息接收的类型是绑定channel,userId:{}",message.getFromUserId());
            Boolean hasToken = ctx.channel().attr(SECURITY_CHECK_COMPLETE_ATTRIBUTE_KEY).get().getHasToken();
            log.info("hasToken: {}",hasToken);
            if (!hasToken){
                log.info("binding channel...");
                // 没有鉴权就使用消息方式绑定channel
                binding(ctx,message);
            }
        }else{
            if(StringUtils.isNotBlank(message.getToUserId())){
                //私聊
                sendMsg(message);
            }else{
                // 发送给除了自己的其他人
                sendOther(ctx,message);
            }
        }
    }

    /**
     * 添加channel 回调方法
     * @param ctx /
     */
    @Override
    public void handlerAdded(ChannelHandlerContext ctx) {
        //打印出channel唯一值,asLongText方法是channel的id的全名
        log.info("handlerAdded :{}",ctx.channel().id().asLongText());
    }

    /**
     * 删除channel 回调方法
     * @param ctx /
     */
    @Override
    public void handlerRemoved(ChannelHandlerContext ctx) {
        log.info("handlerRemoved :{}",ctx.channel().id().asLongText());
    }

    /**
     * 时间监听器
     * @param ctx       /
     * @param evt       /
     * @throws Exception    /
     */
    @Override
    public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
        if (evt instanceof SecurityServerHandler.SecurityCheckComplete){
            log.info("Security check has passed");
            // 鉴权成功后的逻辑 暂不添加
        }
        else if (evt instanceof CustomWebSocketServerProtocolHandler.HandshakeComplete) {
            log.info("Handshake has completed");
            SecurityServerHandler.SecurityCheckComplete securityCheckComplete = ctx.channel().attr(SECURITY_CHECK_COMPLETE_ATTRIBUTE_KEY).get();
            Boolean hasToken = securityCheckComplete.getHasToken();
            log.info("Handshake has completed after check hasToken:{}",hasToken);
            // 握手成功后的逻辑  如果鉴权了就绑定channel
            if(hasToken){
                log.info("Handshake has completed after binding channel");
                binding(ctx,securityCheckComplete.getUserId());
            }

        }
        super.userEventTriggered(ctx, evt);
    }

    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
        log.info("exceptionCaught 异常:{}",cause.getMessage());
        cause.printStackTrace();
        Channel channel = ctx.channel();
        //……
        if(channel.isActive()){
            log.info("手动关闭通道");
            ctx.close();
        };
    }

    /**
     * 群发所有人
     * @param message 消息
     */
    private void sendAllMessage(Message message){
        //收到信息后,群发给所有channel
        ChannelHandlerPool.pool().writeAndFlush(message);
    }

    /**
     * 发送消息
     * @param ctx       上下文
     * @param message   消息
     */
    private void sendMsg(ChannelHandlerContext ctx, Message message){
        //给自己发自己的消息
        ChannelHandlerPool.pool().writeAndFlush(ctx.channel().id(),message);
    }

    /**
     * 绑定channel与userid
     * @param ctx       上下文
     * @param message   消息
     */
    public void binding(ChannelHandlerContext ctx,Message message){
        ChannelId channelId = ctx.channel().id();
        Channel channel = ChannelHandlerPool.pool().getChannel(channelId);
        // 查看是否存在当前channel,不存在便重新插入
        if(null == channel){
            ChannelHandlerPool.pool().addChannel(ctx.channel());
        }
        try {
            //绑定userid 与 channel
            ChannelHandlerPool.pool().putChannelId(message.getFromUserId(), channelId);
        }catch (Exception e){
            log.info("主动断开");
            e.printStackTrace();
            // 发生异常断开连接
            ctx.close();
        }
    }

    /**
     * 绑定channel与userid
     * @param ctx       上下文
     * @param userId   用户id
     */
    public void binding(ChannelHandlerContext ctx,String userId){
        ChannelId channelId = ctx.channel().id();
        Channel channel =  ChannelHandlerPool.pool().getChannel(channelId);
        // 查看是否存在当前channel,不存在便重新插入
        if(null == channel){
            ChannelHandlerPool.pool().addChannel(ctx.channel());
        }
        try {
            SecurityServerHandler.SecurityCheckComplete complete = ctx.channel().attr(SECURITY_CHECK_COMPLETE_ATTRIBUTE_KEY).get();
            //绑定userid 与 channel
            ChannelHandlerPool.pool().putChannelId(complete.getHasToken() ? complete.getUserId() : userId, channelId);
        }catch (Exception e){
            log.info("主动断开");
            e.printStackTrace();
            // 发生异常断开连接
            ctx.close();
        }
    }

    public void sendMsg(Message message){
        //私聊
        ChannelId channelId = ChannelHandlerPool.pool().getChannelId(message.getToUserId());
        if(null == channelId){
            log.info("用户: {},已经下线!",message.getToUserId());
            //下线操作 存库
            return;
        }
        Channel channel = ChannelHandlerPool.pool().getChannel(channelId);
        if(null == channel){
            log.info("清除用户:{}在mapper存的channelId",message.getToUserId());
            //特殊下线两个静态变量值不对称处理
            ChannelHandlerPool.pool().removeChannelId(message.getFromUserId());
            return;
        }
        ChannelHandlerPool.pool().writeAndFlush(channel,message);
        log.info("channel中的userId:{}",channel.attr(SECURITY_CHECK_COMPLETE_ATTRIBUTE_KEY).get().getUserId());
    }

    /**
     * 发送给除了自己的其他人
     * @param ctx       上下文
     * @param message   消息
     */
    public void sendOther(ChannelHandlerContext ctx, Message message){
        for (Channel channel : ChannelHandlerPool.pool().getChannelGroup()) {
            //给除自己外的人发消息
            if(channel != ctx.channel()){
                log.info("发送消息:{}",message);
                channel.writeAndFlush(new TextWebSocketFrame(JSONObject.toJSONString(message)));
            }
        }
    }

AttributeKeyUtils

public class AttributeKeyUtils {

    /**
     * 为channel添加属性  将userid设置为属性,避免客户端特殊情况退出时获取不到userid
     */
    public static final AttributeKey<String> USER_ID = AttributeKey.valueOf("userid");

    public static final AttributeKey<SecurityServerHandler.SecurityCheckComplete> SECURITY_CHECK_COMPLETE_ATTRIBUTE_KEY =
            AttributeKey.valueOf("SECURITY_CHECK_COMPLETE_ATTRIBUTE_KEY");


}

ChannelHandlerPool

package com.edu.message.handler.pool;

import com.alibaba.fastjson2.JSONObject;
import com.edu.common.utils.spring.SpringUtils;
import com.edu.message.converter.MessageConverter;
import com.edu.message.domain.vo.Message;
import com.edu.message.service.IMessageSocketService;
import io.netty.channel.Channel;
import io.netty.channel.ChannelId;
import io.netty.channel.group.ChannelGroup;
import io.netty.channel.group.DefaultChannelGroup;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import io.netty.util.concurrent.GlobalEventExecutor;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

import static com.edu.common.constant.ThreadPoolConstants.POOL_NAME;

/**
 * 保存用户及相关信息
 * @author qb
 * @version 1.0
 * @since 2023/3/7 11:53
 */
@Slf4j
public class ChannelHandlerPool {

    private final IMessageSocketService iSocketService = SpringUtils.getBean(IMessageSocketService.class);

    private final MessageConverter messageConverter = SpringUtils.getBean(MessageConverter.class);
    private ChannelHandlerPool(){}

    private ThreadPoolTaskExecutor executor = SpringUtils.getBean(POOL_NAME);

    private static final ChannelHandlerPool POOL = new ChannelHandlerPool();

    public static ChannelHandlerPool pool(){
        return POOL;
    }

    public void execute(Message message){
        executor.submit(() -> {
            String name = Thread.currentThread().getName();
            Long id = Thread.currentThread().getId();
            log.info("message:{},线程名称:{},id:{}",message,name,id);
            iSocketService.save(messageConverter.messageToMsgSocket(message));
        });
    }

    /**
     *
     * map: userId,ChannelId
     */
    private static final Map<String, ChannelId> CHANNEL_ID_MAP = new ConcurrentHashMap<>();

    /**
     * 管道组
     */
    private static final ChannelGroup CHANNEL_GROUP = new DefaultChannelGroup(GlobalEventExecutor.INSTANCE);

    public Map<String, ChannelId> getChannelIdMap(){
        return CHANNEL_ID_MAP;
    }

    public ChannelGroup getChannelGroup(){
        return CHANNEL_GROUP;
    }


    /**
     * 获取channelId
     * @param userId 用户id
     * @return  /
     */
    public ChannelId getChannelId(String userId){
        return CHANNEL_ID_MAP.get(userId);
    }


    /**
     * 获取channel
     * @param channelId /
     * @return /
     */
    public Channel getChannel(ChannelId channelId){
        return CHANNEL_GROUP.find(channelId);
    }

    public Channel getChannel(String userId){
        return CHANNEL_GROUP.find(getChannelId(userId));
    }

    public void removeChannelId(String userid){
        if(StringUtils.isNotBlank(userid)){
            CHANNEL_ID_MAP.remove(userid);
        }
    }

    public void removeChannel(Channel channel){
        CHANNEL_GROUP.remove(channel);
    }

    public void addChannel(Channel channel){
        CHANNEL_GROUP.add(channel);
    }

    /**
     * 群发
     * @param message 消息内容
     */
    public void writeAndFlush(Message message){
        saveBase(message);
        CHANNEL_GROUP.writeAndFlush( new TextWebSocketFrame(JSONObject.toJSONString(message)));
    }

    /**
     * 私发
     * @param channel   channel
     * @param message   内容
     */
    public void writeAndFlush(Channel channel,Message message){
        saveBase(message);
        channel.writeAndFlush( new TextWebSocketFrame(JSONObject.toJSONString(message)));
    }

    /**
     * 私发
     * @param channelId channelId
     * @param message   消息
     */
    public void writeAndFlush(ChannelId channelId,Message message){
        saveBase(message);
        findChannel(channelId).writeAndFlush( new TextWebSocketFrame(JSONObject.toJSONString(message)));
    }

    public Channel findChannel(ChannelId channelId){
        return CHANNEL_GROUP.find(channelId);
    }

    public void putChannelId(String userid,ChannelId channelId){
        CHANNEL_ID_MAP.put(userid,channelId);
    }

    private void saveBase(Message message){
        ChannelHandlerPool.pool().execute(message);
    }
}

效果截图

源码跟踪

SecurityServerHandler 调整

调整为自定义请求头解析,但不去替换其他handler

package com.edu.message.handler.security;

import com.edu.common.utils.SecurityUtils;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.http.FullHttpMessage;
import io.netty.handler.codec.http.HttpHeaders;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;

import java.util.Objects;

import static com.edu.message.handler.attributeKey.AttributeKeyUtils.SECURITY_CHECK_COMPLETE_ATTRIBUTE_KEY;

/**
 * @author Administrator
 */
@Slf4j
public class SecurityServerHandler extends ChannelInboundHandlerAdapter {

    private String tokenHeader;

    private Boolean hasToken;

    public SecurityServerHandler(String tokenHeader,Boolean hasToken){
        this.tokenHeader = tokenHeader;
        this.hasToken = hasToken;
    }

    private SecurityServerHandler(){}

    @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
        if(msg instanceof FullHttpMessage){
            FullHttpMessage httpMessage = (FullHttpMessage) msg;
            HttpHeaders headers = httpMessage.headers();
            String token = Objects.requireNonNull(headers.get(tokenHeader));
            if(hasToken){
                // 开启鉴权 认证
                //extracts device information headers
                Long userId = 12345L;//SecurityUtils.getLoginUser(token).getUserId();
                //check ......
                SecurityCheckComplete complete = new SecurityCheckComplete(String.valueOf(userId),tokenHeader,hasToken);
                ctx.channel().attr(SECURITY_CHECK_COMPLETE_ATTRIBUTE_KEY).set(complete);
                ctx.fireUserEventTriggered(complete);
            }else {
                // 不开启鉴权 / 认证
                SecurityCheckComplete complete = new SecurityCheckComplete(null,tokenHeader,hasToken);
                ctx.channel().attr(SECURITY_CHECK_COMPLETE_ATTRIBUTE_KEY).set(complete);
            }
        }
        //other protocols
        super.channelRead(ctx, msg);
    }

    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
        System.out.println("channel 捕获到异常了,关闭了");
        super.exceptionCaught(ctx, cause);
    }
    @Getter
    @AllArgsConstructor
    public static final class SecurityCheckComplete {

        private String userId;

        private String tokenHeader;

        private Boolean hasToken;

    }
}

initChannel方法调整

改为使用默认实现

@Override
    protected void initChannel(SocketChannel ch){
        log.info("有新的连接");
        //获取工人所要做的工程(管道器==管道器对应的便是管道channel)
        ChannelPipeline pipeline = ch.pipeline();
        //为工人的工程按顺序添加工序/材料 (为管道器设置对应的handler也就是控制器)
        //1.设置心跳机制
        pipeline.addLast("idle-state",new IdleStateHandler(
                nettyWebSocketProperties.getReaderIdleTime(),
                0,
                0,
                TimeUnit.SECONDS));
        //2.出入站时的控制器,大部分用于针对心跳机制
        pipeline.addLast("change-duple",new WsChannelDupleHandler(nettyWebSocketProperties.getReaderIdleTime()));
        //3.加解码
        pipeline.addLast("http-codec",new HttpServerCodec());
        //3.打印控制器,为工人提供明显可见的操作结果的样式
        pipeline.addLast("logging", new LoggingHandler(LogLevel.INFO));
        pipeline.addLast("aggregator",new HttpObjectAggregator(8192));
        pipeline.addLast("auth",new SecurityServerHandler(
                nettyWebSocketProperties.getTokenHeader(),
                nettyWebSocketProperties.getHasToken()
        ));
        pipeline.addLast("http-chunked",new ChunkedWriteHandler());
//        pipeline.addLast("websocket",
//                new CustomWebSocketServerProtocolHandler(
//                nettyWebSocketProperties.getWebsocketPath(),
//                nettyWebSocketProperties.getSubProtocols(),
//                nettyWebSocketProperties.getAllowExtensions(),
//                nettyWebSocketProperties.getMaxFrameSize())
//        );
        pipeline.addLast("websocket",
                new WebSocketServerProtocolHandler(
                nettyWebSocketProperties.getWebsocketPath(),
                nettyWebSocketProperties.getSubProtocols(),
                nettyWebSocketProperties.getAllowExtensions(),
                nettyWebSocketProperties.getMaxFrameSize())
        );
        //7.自定义的handler针对业务
        pipeline.addLast("chat-handler",new ChatHandler());
    }

启动项目–流程截图

断点截图

netty-websocket 鉴权token及统一请求和响应头(鉴权控制器)_第2张图片

1. SecurityServerHandler

第一步走到了自己定义的鉴权控制器(入站控制器),执行channelRead方法
netty-websocket 鉴权token及统一请求和响应头(鉴权控制器)_第3张图片

2.userEventTriggered

自定义业务handler中的事件方法
netty-websocket 鉴权token及统一请求和响应头(鉴权控制器)_第4张图片

3.WebSocketServerProtocolHandshakeHandler

此处便是走到了默认协议控制器的channelRead方法,需要注意handshaker.handshake(ctx.channel(), req) 这个方法,这是处理握手的方法,打个断点进去

4.WebSocketServerHandshaker

可以看到handshake 方法传的 HttpHeaders是null这里就是核心的握手逻辑可以看到并没有提供相应的头处理器
netty-websocket 鉴权token及统一请求和响应头(鉴权控制器)_第5张图片

5. WebSocketServerHandshaker

newHandshakeResponse(req, responseHeaders) 就是构建响应结果,可以看到头是null

6. 最后的封装返回

可以看到有回到了自定义handler的业务控制器 中的时间监听方法

此时只要放行这一步便会在控制台打印出响应头,可以看出并没有设置我们自己的响应头,还是null
netty-websocket 鉴权token及统一请求和响应头(鉴权控制器)_第6张图片
最后统一返回,连接中断,自协议头不一致所导致
netty-websocket 鉴权token及统一请求和响应头(鉴权控制器)_第7张图片

你可能感兴趣的:(java排坑之路,websocket,网络,java)