文章源码:gitee
源码部分可以看上一篇文章中的源码分析netty-websocket 鉴权token及统一请求和响应头(鉴权控制器)
最近刚好没事,看到有朋友说自定义协议好搞,我就想了想,发现上面那种方式实现确实麻烦,而且兼容性还不行,后来我对照着WebSocketServerProtocolHandler试了试扩展一下,将WebSocketServerProtocolHandler中handlerAdded添加的握手逻辑换成自己的,终于测通了,我用postman测试时,请求头也可以自定义,下面上代码
1.(userEventTriggered): 鉴权成功后可以抛出自定义事件,业务channel中实现 事件监听器userEventTriggered,这样就可以在鉴权成功后,握手成功前执行某个方法,比如验证权限啥的,具体可看SecurityHandler中的例子
2. (exceptionCaught): 异常捕获
3. channel设置attr实现channel上下文的数据属性
4. …等等
这个协议有很多私有方法外部引用不了,所以只能copy一份出来,主要是把handlerAdded这个方法重写了,将原有的‘WebSocketServerProtocolHandshakeHandler’替换为‘自己的(SecurityHandler)’
package com.chat.nettywebsocket.handler.test;
import io.netty.channel.*;
import io.netty.handler.codec.http.*;
import io.netty.handler.codec.http.websocketx.Utf8FrameValidator;
import io.netty.handler.codec.http.websocketx.WebSocketServerHandshaker;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import io.netty.util.AttributeKey;
import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1;
/***
*
* @author qb
* @date 2024/2/5 8:53
* @version 1.0
*/
public class CustomWebSocketServerProtocolHandler extends WebSocketServerProtocolHandler {
public enum ServerHandshakeStateEvent {
/**
* The Handshake was completed successfully and the channel was upgraded to websockets.
*
* @deprecated in favor of {@link WebSocketServerProtocolHandler.HandshakeComplete} class,
* it provides extra information about the handshake
*/
@Deprecated
HANDSHAKE_COMPLETE
}
/**
* The Handshake was completed successfully and the channel was upgraded to websockets.
*/
public static final class HandshakeComplete {
private final String requestUri;
private final HttpHeaders requestHeaders;
private final String selectedSubprotocol;
HandshakeComplete(String requestUri, HttpHeaders requestHeaders, String selectedSubprotocol) {
this.requestUri = requestUri;
this.requestHeaders = requestHeaders;
this.selectedSubprotocol = selectedSubprotocol;
}
public String requestUri() {
return requestUri;
}
public HttpHeaders requestHeaders() {
return requestHeaders;
}
public String selectedSubprotocol() {
return selectedSubprotocol;
}
}
public CustomWebSocketServerProtocolHandler(String websocketPath, String subprotocols, boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch, boolean checkStartsWith) {
super(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, checkStartsWith);
this.websocketPath = websocketPath;
this.subprotocols = subprotocols;
this.allowExtensions = allowExtensions;
maxFramePayloadLength = maxFrameSize;
this.allowMaskMismatch = allowMaskMismatch;
this.checkStartsWith = checkStartsWith;
}
private final String websocketPath;
private final String subprotocols;
private final boolean allowExtensions;
private final int maxFramePayloadLength;
private final boolean allowMaskMismatch;
private final boolean checkStartsWith;
@Override
public void handlerAdded(ChannelHandlerContext ctx) {
System.err.println("handlerAdded");
ChannelPipeline cp = ctx.pipeline();
if (cp.get(SecurityHandler.class) == null) {
// Add the WebSocketHandshakeHandler before this one.
// 增加协议实现handler
ctx.pipeline().addBefore(ctx.name(), SecurityHandler.class.getName(),
new SecurityHandler(websocketPath, subprotocols,
allowExtensions, maxFramePayloadLength, allowMaskMismatch, checkStartsWith));
}
if (cp.get(Utf8FrameValidator.class) == null) {
// Add the UFT8 checking before this one.
ctx.pipeline().addBefore(ctx.name(), Utf8FrameValidator.class.getName(),
new Utf8FrameValidator());
}
}
private static final AttributeKey<WebSocketServerHandshaker> HANDSHAKER_ATTR_KEY =
AttributeKey.valueOf(WebSocketServerHandshaker.class, "HANDSHAKER");
static WebSocketServerHandshaker getHandshaker(Channel channel) {
return channel.attr(HANDSHAKER_ATTR_KEY).get();
}
static void setHandshaker(Channel channel, WebSocketServerHandshaker handshaker) {
channel.attr(HANDSHAKER_ATTR_KEY).set(handshaker);
}
static ChannelHandler forbiddenHttpRequestResponder() {
return new ChannelInboundHandlerAdapter() {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (msg instanceof FullHttpRequest) {
((FullHttpRequest) msg).release();
FullHttpResponse response =
new DefaultFullHttpResponse(HTTP_1_1, HttpResponseStatus.FORBIDDEN);
ctx.channel().writeAndFlush(response);
} else {
ctx.fireChannelRead(msg);
}
}
};
}
}
复制的WebSocketServerProtocolHandshakeHandler的方法,就是改了请求头逻辑和发布事件的相关类调整
package com.chat.nettywebsocket.handler.test;
import com.chat.nettywebsocket.handler.test.CustomWebSocketServerProtocolHandler;
import io.netty.channel.*;
import io.netty.handler.codec.http.*;
import io.netty.handler.codec.http.websocketx.WebSocketServerHandshaker;
import io.netty.handler.codec.http.websocketx.WebSocketServerHandshakerFactory;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import io.netty.handler.ssl.SslHandler;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.StringUtils;
import static com.chat.nettywebsocket.handler.AttributeKeyUtils.SECURITY_CHECK_COMPLETE_ATTRIBUTE_KEY;
import static io.netty.handler.codec.http.HttpMethod.GET;
import static io.netty.handler.codec.http.HttpResponseStatus.FORBIDDEN;
import static io.netty.handler.codec.http.HttpUtil.isKeepAlive;
import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1;
/***
*
* @author qb
* @date 2024/2/5 8:37
* @version 1.0
*/
@Slf4j
@ChannelHandler.Sharable
public class SecurityHandler extends ChannelInboundHandlerAdapter {
private final String websocketPath;
private final String subprotocols;
private final boolean allowExtensions;
private final int maxFramePayloadSize;
private final boolean allowMaskMismatch;
private final boolean checkStartsWith;
SecurityHandler(String websocketPath, String subprotocols,
boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch) {
this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, false);
}
SecurityHandler(String websocketPath, String subprotocols,
boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch, boolean checkStartsWith) {
this.websocketPath = websocketPath;
this.subprotocols = subprotocols;
this.allowExtensions = allowExtensions;
maxFramePayloadSize = maxFrameSize;
this.allowMaskMismatch = allowMaskMismatch;
this.checkStartsWith = checkStartsWith;
}
@Override
public void channelRead(final ChannelHandlerContext ctx, Object msg) throws Exception {
final FullHttpRequest req = (FullHttpRequest) msg;
if (isNotWebSocketPath(req)) {
ctx.fireChannelRead(msg);
return;
}
try {
if (req.method() != GET) {
sendHttpResponse(ctx, req, new DefaultFullHttpResponse(HTTP_1_1, FORBIDDEN));
return;
}
// 比如 此处极权成功就抛出成功事件
SecurityCheckComplete complete = new SecurityHandler.SecurityCheckComplete(true);
// 设置 channel属性,相当于channel固定的上下文属性
ctx.channel().attr(SECURITY_CHECK_COMPLETE_ATTRIBUTE_KEY).set(complete);
ctx.fireUserEventTriggered(complete);
final WebSocketServerHandshakerFactory wsFactory = new WebSocketServerHandshakerFactory(
getWebSocketLocation(ctx.pipeline(), req, websocketPath), subprotocols,
allowExtensions, maxFramePayloadSize, allowMaskMismatch);
final WebSocketServerHandshaker handshaker = wsFactory.newHandshaker(req);
if (handshaker == null) {
WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel());
} else {
String s = req.headers().get("Sec-token");
HttpHeaders httpHeaders = null;
if(StringUtils.hasText(s)){
httpHeaders = new DefaultHttpHeaders().add("Sec-token",s);
}else {
httpHeaders = new DefaultHttpHeaders();
}
// 设置请求头
final ChannelFuture handshakeFuture = handshaker.handshake(ctx.channel(),req, httpHeaders,ctx.channel().newPromise());
System.err.println("handshakeFuture: "+handshakeFuture.isSuccess());
handshakeFuture.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
if (!future.isSuccess()) {
ctx.fireExceptionCaught(future.cause());
} else {
// Kept for compatibility
ctx.fireUserEventTriggered(
CustomWebSocketServerProtocolHandler.ServerHandshakeStateEvent.HANDSHAKE_COMPLETE);
ctx.fireUserEventTriggered(
new CustomWebSocketServerProtocolHandler.HandshakeComplete(
req.uri(), req.headers(), handshaker.selectedSubprotocol()));
}
}
});
CustomWebSocketServerProtocolHandler.setHandshaker(ctx.channel(), handshaker);
ctx.pipeline().replace(this, "WS403Responder",
CustomWebSocketServerProtocolHandler.forbiddenHttpRequestResponder());
}
} finally {
req.release();
}
}
private boolean isNotWebSocketPath(FullHttpRequest req) {
return checkStartsWith ? !req.uri().startsWith(websocketPath) : !req.uri().equals(websocketPath);
}
private static void sendHttpResponse(ChannelHandlerContext ctx, HttpRequest req, HttpResponse res) {
ChannelFuture f = ctx.channel().writeAndFlush(res);
if (!isKeepAlive(req) || res.status().code() != 200) {
f.addListener(ChannelFutureListener.CLOSE);
}
}
private static String getWebSocketLocation(ChannelPipeline cp, HttpRequest req, String path) {
String protocol = "ws";
if (cp.get(SslHandler.class) != null) {
// SSL in use so use Secure WebSockets
protocol = "wss";
}
String host = req.headers().get(HttpHeaderNames.HOST);
return protocol + "://" + host + path;
}
// 自定义事件实体
@Getter
@AllArgsConstructor
public static final class SecurityCheckComplete {
private Boolean isLogin;
}
}
package com.chat.nettywebsocket.handler;
import com.alibaba.fastjson.JSONObject;
import com.chat.nettywebsocket.domain.Message;
import com.chat.nettywebsocket.handler.test.CustomWebSocketServerProtocolHandler;
import com.chat.nettywebsocket.handler.test.SecurityHandler;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelId;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import io.netty.util.AttributeKey;
import lombok.extern.slf4j.Slf4j;
import java.nio.charset.StandardCharsets;
/**
* 自定义控制器
* @author qubing
* @date 2021/8/16 9:26
*/
@Slf4j
public class ChatHandler extends SimpleChannelInboundHandler<TextWebSocketFrame> {
/**
* 为channel添加属性 将userid设置为属性,避免客户端特殊情况退出时获取不到userid
*/
AttributeKey<Integer> userid = AttributeKey.valueOf("userid");
/**
* 连接时
* @param ctx 上下文
* @throws Exception /
*/
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
log.info("与客户端建立连接,通道开启!");
// 添加到channelGroup通道组
MyChannelHandlerPool.channelGroup.add(ctx.channel());
}
/**
* 断开连接时
* @param ctx /
* @throws Exception /
*/
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
log.info("与客户端断开连接,通道关闭!");
// 从channelGroup通道组移除
// MyChannelHandlerPool.channelGroup.remove(ctx.channel());
// Integer useridQuit = ctx.channel().attr(userid).get();
// MyChannelHandlerPool.channelIdMap.remove(useridQuit);
log.info("断开的用户id为");
}
// 监听事件
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
// 自定义鉴权成功事件
if (evt instanceof SecurityHandler.SecurityCheckComplete){
// 鉴权成功后的逻辑
log.info("鉴权成功 SecurityHandler.SecurityCheckComplete");
}
// 握手成功
else if (evt instanceof CustomWebSocketServerProtocolHandler.HandshakeComplete) {
log.info("Handshake has completed");
// 握手成功后的逻辑 鉴权和不鉴权模式都绑定channel
log.info("Handshake has completed after binding channel");
}
super.userEventTriggered(ctx, evt);
}
/**
* 获取消息时
* @param ctx /
* @param msg 消息
* @throws Exception /
*/
@Override
protected void channelRead0(ChannelHandlerContext ctx, TextWebSocketFrame msg) throws Exception {
String mssage = msg.content().toString(StandardCharsets.UTF_8);
ctx.channel().writeAndFlush(mssage);
System.err.println(mssage);
}
/**
* 群发所有人
*/
private void sendAllMessage(String message){
//收到信息后,群发给所有channel
MyChannelHandlerPool.channelGroup.writeAndFlush( new TextWebSocketFrame(message));
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
log.info("exceptionCaught 异常:{}",cause.getMessage());
cause.printStackTrace();
Channel channel = ctx.channel();
//……
if(channel.isActive()){
log.info("手动关闭通道");
ctx.close();
};
}
}
public class AttributeKeyUtils {
/**
* 为channel添加属性 将userid设置为属性,避免客户端特殊情况退出时获取不到userid
*/
public static final AttributeKey<String> USER_ID = AttributeKey.valueOf("userid");
public static final AttributeKey<SecurityHandler.SecurityCheckComplete> SECURITY_CHECK_COMPLETE_ATTRIBUTE_KEY =
AttributeKey.valueOf("SECURITY_CHECK_COMPLETE_ATTRIBUTE_KEY");
}
@Slf4j
@ChannelHandler.Sharable
public class WsServerInitializer extends ChannelInitializer<SocketChannel> {
// @Override
// protected void initChannel(SocketChannel socketChannel) throws Exception {
// log.info("有新的连接");
// ChannelPipeline pipeline = socketChannel.pipeline();
// //netty 自带的http解码器
// pipeline.addLast(new HttpServerCodec());
// //http聚合器
// pipeline.addLast(new HttpObjectAggregator(8192));
// pipeline.addLast(new ChunkedWriteHandler());
// //压缩协议
// pipeline.addLast(new WebSocketServerCompressionHandler());
// //http处理器 用来握手和执行进一步操作
pipeline.addLast(new NettyWebsocketHttpHandler(config, listener));
//
// }
@Override
protected void initChannel(SocketChannel ch) throws Exception {
log.info("有新的连接");
//获取工人所要做的工程(管道器==管道器对应的便是管道channel)
ChannelPipeline pipeline = ch.pipeline();
//为工人的工程按顺序添加工序/材料 (为管道器设置对应的handler也就是控制器)
//1.设置心跳机制
pipeline.addLast(new IdleStateHandler(5,0,0, TimeUnit.SECONDS));
//2.出入站时的控制器,大部分用于针对心跳机制
pipeline.addLast(new WsChannelDupleHandler());
//3.加解码
pipeline.addLast(new HttpServerCodec());
//3.打印控制器,为工人提供明显可见的操作结果的样式
pipeline.addLast("logging", new LoggingHandler(LogLevel.INFO));
pipeline.addLast(new ChunkedWriteHandler());
pipeline.addLast(new HttpObjectAggregator(8192));
// 扩展的websocket协议
pipeline.addLast(new CustomWebSocketServerProtocolHandler(
"/ws","websocket",true,65536 * 10,false,true));
//7.自定义的handler针对业务
pipeline.addLast(new ChatHandler());
}
}
postman链接成功
根据日志可以看出,链接成功并且相应和请求的头是一致的