自己想法和实现,如果有说错的或者有更好的简单的实现方式可以私信交流一下(主要是实现握手时鉴权)
- 握手鉴权是基于前台请求头 Sec-WebSocket-Protocol的
- 本身socket并没有提供自定义请求头,只能自定义 Sec-WebSocket-Protocol的自协议
socket握手请求是基于http的,握手成功后会升级为ws
前台传输了 token作为Sec-WebSocket-Protocol的值,后台接收到后总是断开连接,后来网上看了很多博客说的都是大同小异,然后就看了他的源码一步步走的(倔脾气哈哈),终于我看到了端倪,这个问题是因为前后台的Sec-WebSocket-Protocol值不一致,所以会断开,但是我记得websocket好像是不用自己设置请求头的,但是netty我看了源码,好像没有预留设置websocket的response的响应头(这只是我的个人理解)
解释: 自定义替换WebSocketProtocolHandler,复制WebSocketProtocolHandler的内容即可,因为主要是WebSocketServerProtocolHandler自定义会用到
abstract class CustomWebSocketProtocolHandler extends MessageToMessageDecoder<WebSocketFrame> {
@Override
protected void decode(ChannelHandlerContext ctx, WebSocketFrame frame, List<Object> out) throws Exception {
if (frame instanceof PingWebSocketFrame) {
frame.content().retain();
ctx.channel().writeAndFlush(new PongWebSocketFrame(frame.content()));
return;
}
if (frame instanceof PongWebSocketFrame) {
// Pong frames need to get ignored
return;
}
out.add(frame.retain());
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
ctx.fireExceptionCaught(cause);
ctx.close();
}
}
解释: 自定义WebSocketServerProtocolHandler,实现上面自定义的WebSocketProtocolHandler,具体内容和WebSocketServerProtocolHandler保持一致,只需要将handlerAdded中的类ProtocolHandler改为自己定义的即可
注意:后面监听读写的自定义业务的handler需要实现相应的方法:异常或者事件监听,因为比如异常,如果抛出异常了,是不会有控制器去管的,因为当前的业务控制器就是最后一层,因为上面已经把默认实现改成了自己的实现(其他的控制器都是基于默认handler实现的,如果改了后,去初始化自己改后的handler那便是最后一层),所以要手动去关闭
ublic class CustomWebSocketServerProtocolHandler extends CustomWebSocketProtocolHandler {
/**
* Events that are fired to notify about handshake status
*/
public enum ServerHandshakeStateEvent {
/**
* The Handshake was completed successfully and the channel was upgraded to websockets.
*
* @deprecated in favor of {@link WebSocketServerProtocolHandler.HandshakeComplete} class,
* it provides extra information about the handshake
*/
@Deprecated
HANDSHAKE_COMPLETE
}
/**
* The Handshake was completed successfully and the channel was upgraded to websockets.
*/
public static final class HandshakeComplete {
private final String requestUri;
private final HttpHeaders requestHeaders;
private final String selectedSubprotocol;
public HandshakeComplete(String requestUri, HttpHeaders requestHeaders, String selectedSubprotocol) {
this.requestUri = requestUri;
this.requestHeaders = requestHeaders;
this.selectedSubprotocol = selectedSubprotocol;
}
public String requestUri() {
return requestUri;
}
public HttpHeaders requestHeaders() {
return requestHeaders;
}
public String selectedSubprotocol() {
return selectedSubprotocol;
}
}
private static final AttributeKey<WebSocketServerHandshaker> HANDSHAKER_ATTR_KEY =
AttributeKey.valueOf(WebSocketServerHandshaker.class, "HANDSHAKER");
private final String websocketPath;
private final String subprotocols;
private final boolean allowExtensions;
private final int maxFramePayloadLength;
private final boolean allowMaskMismatch;
private final boolean checkStartsWith;
public CustomWebSocketServerProtocolHandler(String websocketPath) {
this(websocketPath, null, false);
}
public CustomWebSocketServerProtocolHandler(String websocketPath, boolean checkStartsWith) {
this(websocketPath, null, false, 65536, false, checkStartsWith);
}
public CustomWebSocketServerProtocolHandler(String websocketPath, String subprotocols) {
this(websocketPath, subprotocols, false);
}
public CustomWebSocketServerProtocolHandler(String websocketPath, String subprotocols, boolean allowExtensions) {
this(websocketPath, subprotocols, allowExtensions, 65536);
}
public CustomWebSocketServerProtocolHandler(String websocketPath, String subprotocols,
boolean allowExtensions, int maxFrameSize) {
this(websocketPath, subprotocols, allowExtensions, maxFrameSize, false);
}
public CustomWebSocketServerProtocolHandler(String websocketPath, String subprotocols,
boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch) {
this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, false);
}
public CustomWebSocketServerProtocolHandler(String websocketPath, String subprotocols,
boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch, boolean checkStartsWith) {
this.websocketPath = websocketPath;
this.subprotocols = subprotocols;
this.allowExtensions = allowExtensions;
maxFramePayloadLength = maxFrameSize;
this.allowMaskMismatch = allowMaskMismatch;
this.checkStartsWith = checkStartsWith;
}
@Override
public void handlerAdded(ChannelHandlerContext ctx) {
ChannelPipeline cp = ctx.pipeline();
if (cp.get(CustomWebSocketServerProtocolHandler.class) == null) {
// Add the WebSocketHandshakeHandler before this one.
ctx.pipeline().addBefore(ctx.name(), CustomWebSocketServerProtocolHandler.class.getName(),
new CustomWebSocketServerProtocolHandler(websocketPath, subprotocols,
allowExtensions, maxFramePayloadLength, allowMaskMismatch, checkStartsWith));
}
if (cp.get(Utf8FrameValidator.class) == null) {
// Add the UFT8 checking before this one.
ctx.pipeline().addBefore(ctx.name(), Utf8FrameValidator.class.getName(),
new Utf8FrameValidator());
}
}
@Override
protected void decode(ChannelHandlerContext ctx, WebSocketFrame frame, List<Object> out) throws Exception {
if (frame instanceof CloseWebSocketFrame) {
WebSocketServerHandshaker handshaker = getHandshaker(ctx.channel());
if (handshaker != null) {
frame.retain();
handshaker.close(ctx.channel(), (CloseWebSocketFrame) frame);
} else {
ctx.writeAndFlush(Unpooled.EMPTY_BUFFER).addListener(ChannelFutureListener.CLOSE);
}
return;
}
super.decode(ctx, frame, out);
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
if (cause instanceof WebSocketHandshakeException) {
FullHttpResponse response = new DefaultFullHttpResponse(
HTTP_1_1, HttpResponseStatus.BAD_REQUEST, Unpooled.wrappedBuffer(cause.getMessage().getBytes()));
ctx.channel().writeAndFlush(response).addListener(ChannelFutureListener.CLOSE);
} else {
ctx.fireExceptionCaught(cause);
ctx.close();
}
}
static WebSocketServerHandshaker getHandshaker(Channel channel) {
return channel.attr(HANDSHAKER_ATTR_KEY).get();
}
public static void setHandshaker(Channel channel, WebSocketServerHandshaker handshaker) {
channel.attr(HANDSHAKER_ATTR_KEY).set(handshaker);
}
public static ChannelHandler forbiddenHttpRequestResponder() {
return new ChannelInboundHandlerAdapter() {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (msg instanceof FullHttpRequest) {
((FullHttpRequest) msg).release();
FullHttpResponse response =
new DefaultFullHttpResponse(HTTP_1_1, HttpResponseStatus.FORBIDDEN);
ctx.channel().writeAndFlush(response);
} else {
ctx.fireChannelRead(msg);
}
}
};
}
}
用SecurityServerHandler自定义的入站控制器替换原有默认的控制器WebSocketServerProtocolHandshakeHandler
这一步最关键了,因为在这一步就要将头设置进去,前面两步只是为这一步做铺垫,因为netty包中的类不能外部引用也没有提供修改方法,所以才有了上面的自定义类,此类中需要调整握手逻辑,添加握手响应头,然后将WebSocketServerProtocolHandler改为CustomWebSocketServerProtocolHandler,其他的实现类也是一样的去改
public class SecurityServerHandler extends ChannelInboundHandlerAdapter {
private final String websocketPath;
private final String subprotocols;
private final boolean allowExtensions;
private final int maxFramePayloadSize;
private final boolean allowMaskMismatch;
private final boolean checkStartsWith;
/**
* 自定义属性 token头key
*/
private final String tokenHeader;
/**
* 自定义属性 token
*/
private final boolean hasToken;
public SecurityServerHandler(String websocketPath, String subprotocols,
boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch, String tokenHeader, boolean hasToken) {
this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, false,tokenHeader,hasToken);
}
SecurityServerHandler(String websocketPath, String subprotocols,
boolean allowExtensions, int maxFrameSize,
boolean allowMaskMismatch,
boolean checkStartsWith,
String tokenHeader,
boolean hasToken) {
this.websocketPath = websocketPath;
this.subprotocols = subprotocols;
this.allowExtensions = allowExtensions;
maxFramePayloadSize = maxFrameSize;
this.allowMaskMismatch = allowMaskMismatch;
this.checkStartsWith = checkStartsWith;
this.tokenHeader = tokenHeader;
this.hasToken = hasToken;
}
@Override
public void channelRead(final ChannelHandlerContext ctx, Object msg) throws Exception {
final FullHttpRequest req = (FullHttpRequest) msg;
if (isNotWebSocketPath(req)) {
ctx.fireChannelRead(msg);
return;
}
try {
// 具体的鉴权逻辑
HttpHeaders headers = req.headers();
String token = Objects.requireNonNull(headers.get(tokenHeader));
if(hasToken){
// 开启鉴权 认证
//extracts device information headers
LoginUser loginUser = SecurityUtils.getLoginUser(token);
if(null == loginUser){
refuseChannel(ctx);
return;
}
Long userId = loginUser.getUserId();
//check ......
SecurityCheckComplete complete = new SecurityCheckComplete(String.valueOf(userId),tokenHeader,hasToken);
ctx.channel().attr(SECURITY_CHECK_COMPLETE_ATTRIBUTE_KEY).set(complete);
ctx.fireUserEventTriggered(complete);
}else {
// 不开启鉴权 / 认证
SecurityCheckComplete complete = new SecurityCheckComplete(null,tokenHeader,hasToken);
ctx.channel().attr(SECURITY_CHECK_COMPLETE_ATTRIBUTE_KEY).set(complete);
}
if (req.method() != GET) {
sendHttpResponse(ctx, req, new DefaultFullHttpResponse(HTTP_1_1, FORBIDDEN));
return;
}
final WebSocketServerHandshakerFactory wsFactory = new WebSocketServerHandshakerFactory(
getWebSocketLocation(ctx.pipeline(), req, websocketPath), subprotocols,
allowExtensions, maxFramePayloadSize, allowExtensions);
final WebSocketServerHandshaker handshaker = wsFactory.newHandshaker(req);
if (handshaker == null) {
WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel());
} else {
// 此处将具体的头加入http中,因为这个头会传递个netty底层设置响应头的方法中,默认实现是传的null
HttpHeaders httpHeaders = new DefaultHttpHeaders().add(tokenHeader,token);
// 此处便是构造握手相应头的关键步骤
final ChannelFuture handshakeFuture = handshaker.handshake(ctx.channel(), req,httpHeaders,ctx.channel().newPromise());
handshakeFuture.addListener((ChannelFutureListener) future -> {
if (!future.isSuccess()) {
ctx.fireExceptionCaught(future.cause());
} else {
// Kept for compatibility
ctx.fireUserEventTriggered(
CustomWebSocketServerProtocolHandler.ServerHandshakeStateEvent.HANDSHAKE_COMPLETE);
ctx.fireUserEventTriggered(
new CustomWebSocketServerProtocolHandler.HandshakeComplete(
req.uri(), req.headers(), handshaker.selectedSubprotocol()));
}
});
CustomWebSocketServerProtocolHandler.setHandshaker(ctx.channel(), handshaker);
ctx.pipeline().replace(this, "WS403Responder",
CustomWebSocketServerProtocolHandler.forbiddenHttpRequestResponder());
}
}catch (Exception e){
e.printStackTrace();
}finally {
req.release();
}
}
public static final class HandshakeComplete {
private final String requestUri;
private final HttpHeaders requestHeaders;
private final String selectedSubprotocol;
HandshakeComplete(String requestUri, HttpHeaders requestHeaders, String selectedSubprotocol) {
this.requestUri = requestUri;
this.requestHeaders = requestHeaders;
this.selectedSubprotocol = selectedSubprotocol;
}
public String requestUri() {
return requestUri;
}
public HttpHeaders requestHeaders() {
return requestHeaders;
}
public String selectedSubprotocol() {
return selectedSubprotocol;
}
}
private boolean isNotWebSocketPath(FullHttpRequest req) {
return checkStartsWith ? !req.uri().startsWith(websocketPath) : !req.uri().equals(websocketPath);
}
private static void sendHttpResponse(ChannelHandlerContext ctx, HttpRequest req, HttpResponse res) {
ChannelFuture f = ctx.channel().writeAndFlush(res);
if (!isKeepAlive(req) || res.status().code() != 200) {
f.addListener(ChannelFutureListener.CLOSE);
}
}
private static String getWebSocketLocation(ChannelPipeline cp, HttpRequest req, String path) {
String protocol = "ws";
if (cp.get(SslHandler.class) != null) {
// SSL in use so use Secure WebSockets
protocol = "wss";
}
String host = req.headers().get(HttpHeaderNames.HOST);
return protocol + "://" + host + path;
}
private void refuseChannel(ChannelHandlerContext ctx) {
ctx.channel().writeAndFlush(new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.UNAUTHORIZED));
ctx.channel().close();
}
private static void send100Continue(ChannelHandlerContext ctx,String tokenHeader,String token) {
DefaultFullHttpResponse response = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.CONTINUE);
response.headers().set(tokenHeader,token);
ctx.writeAndFlush(response);
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
System.out.println("channel 捕获到异常了,关闭了");
super.exceptionCaught(ctx, cause);
}
@Getter
@AllArgsConstructor
public static final class SecurityCheckComplete {
private String userId;
private String tokenHeader;
private Boolean hasToken;
}
}
其他的类需要自己实现或者引用,其他的就是无关紧要的,不用去处理的类
@Override
protected void initChannel(SocketChannel ch){
log.info("有新的连接");
//获取工人所要做的工程(管道器==管道器对应的便是管道channel)
ChannelPipeline pipeline = ch.pipeline();
//为工人的工程按顺序添加工序/材料 (为管道器设置对应的handler也就是控制器)
//1.设置心跳机制
pipeline.addLast("idle-state",new IdleStateHandler(
nettyWebSocketProperties.getReaderIdleTime(),
0,
0,
TimeUnit.SECONDS));
//2.出入站时的控制器,大部分用于针对心跳机制
pipeline.addLast("change-duple",new WsChannelDupleHandler(nettyWebSocketProperties.getReaderIdleTime()));
//3.加解码
pipeline.addLast("http-codec",new HttpServerCodec());
//3.打印控制器,为工人提供明显可见的操作结果的样式
pipeline.addLast("logging", new LoggingHandler(LogLevel.INFO));
pipeline.addLast("aggregator",new HttpObjectAggregator(8192));
// 将自己的授权handler替换原有的handler
pipeline.addLast("auth",new SecurityServerHandler(
// 此处我是用的yaml配置的,换成自己的即可
nettyWebSocketProperties.getWebsocketPath(),
nettyWebSocketProperties.getSubProtocols(),
nettyWebSocketProperties.getAllowExtensions(),
nettyWebSocketProperties.getMaxFrameSize(),
//todo
false,
nettyWebSocketProperties.getTokenHeader(),
nettyWebSocketProperties.getHasToken()
));
pipeline.addLast("http-chunked",new ChunkedWriteHandler());
// 将自己的协议控制器替换原有的协议控制器
pipeline.addLast("websocket",
new CustomWebSocketServerProtocolHandler(
nettyWebSocketProperties.getWebsocketPath(),
nettyWebSocketProperties.getSubProtocols(),
nettyWebSocketProperties.getAllowExtensions(),
nettyWebSocketProperties.getMaxFrameSize())
);
//7.自定义的handler针对业务
pipeline.addLast("chat-handler",new ChatHandler());
}
/**
* @author qb
* @version 1.0
* @since 2023/3/7 11:56
*/
@Slf4j
public class ChatHandler extends SimpleChannelInboundHandler<TextWebSocketFrame> {
/**
* 连接时
* @param ctx 上下文
* @throws Exception /
*/
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
log.info("与客户端建立连接,通道开启!");
// 添加到channelGroup通道组
ChannelHandlerPool.pool().addChannel(ctx.channel());
}
/**
* 断开连接时
* @param ctx 上下文
* @throws Exception /
*/
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
log.info("与客户端断开连接,通道关闭!");
// 从channelGroup通道组移除
ChannelHandlerPool.pool().removeChannel(ctx.channel());
String useridQuit = ctx.channel().attr(SECURITY_CHECK_COMPLETE_ATTRIBUTE_KEY).get().getUserId();
ChannelHandlerPool.pool().removeChannelId(useridQuit);
log.info("断开的用户id为:{}",useridQuit);
}
/**
* 获取消息时
* @param ctx 上下文
* @param msg 消息
*/
@Override
protected void channelRead0(ChannelHandlerContext ctx, TextWebSocketFrame msg) {
log.info("msg.text():{}",msg.text());
Message message = JSON.parseObject(msg.text(),Message.class);
if("0".equals(message.getType())){
log.info("消息接收的类型是绑定channel,userId:{}",message.getFromUserId());
Boolean hasToken = ctx.channel().attr(SECURITY_CHECK_COMPLETE_ATTRIBUTE_KEY).get().getHasToken();
log.info("hasToken: {}",hasToken);
if (!hasToken){
log.info("binding channel...");
// 没有鉴权就使用消息方式绑定channel
binding(ctx,message);
}
}else{
if(StringUtils.isNotBlank(message.getToUserId())){
//私聊
sendMsg(message);
}else{
// 发送给除了自己的其他人
sendOther(ctx,message);
}
}
}
/**
* 添加channel 回调方法
* @param ctx /
*/
@Override
public void handlerAdded(ChannelHandlerContext ctx) {
//打印出channel唯一值,asLongText方法是channel的id的全名
log.info("handlerAdded :{}",ctx.channel().id().asLongText());
}
/**
* 删除channel 回调方法
* @param ctx /
*/
@Override
public void handlerRemoved(ChannelHandlerContext ctx) {
log.info("handlerRemoved :{}",ctx.channel().id().asLongText());
}
/**
* 时间监听器
* @param ctx /
* @param evt /
* @throws Exception /
*/
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
if (evt instanceof SecurityServerHandler.SecurityCheckComplete){
log.info("Security check has passed");
// 鉴权成功后的逻辑 暂不添加
}
else if (evt instanceof CustomWebSocketServerProtocolHandler.HandshakeComplete) {
log.info("Handshake has completed");
SecurityServerHandler.SecurityCheckComplete securityCheckComplete = ctx.channel().attr(SECURITY_CHECK_COMPLETE_ATTRIBUTE_KEY).get();
Boolean hasToken = securityCheckComplete.getHasToken();
log.info("Handshake has completed after check hasToken:{}",hasToken);
// 握手成功后的逻辑 如果鉴权了就绑定channel
if(hasToken){
log.info("Handshake has completed after binding channel");
binding(ctx,securityCheckComplete.getUserId());
}
}
super.userEventTriggered(ctx, evt);
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
log.info("exceptionCaught 异常:{}",cause.getMessage());
cause.printStackTrace();
Channel channel = ctx.channel();
//……
if(channel.isActive()){
log.info("手动关闭通道");
ctx.close();
};
}
/**
* 群发所有人
* @param message 消息
*/
private void sendAllMessage(Message message){
//收到信息后,群发给所有channel
ChannelHandlerPool.pool().writeAndFlush(message);
}
/**
* 发送消息
* @param ctx 上下文
* @param message 消息
*/
private void sendMsg(ChannelHandlerContext ctx, Message message){
//给自己发自己的消息
ChannelHandlerPool.pool().writeAndFlush(ctx.channel().id(),message);
}
/**
* 绑定channel与userid
* @param ctx 上下文
* @param message 消息
*/
public void binding(ChannelHandlerContext ctx,Message message){
ChannelId channelId = ctx.channel().id();
Channel channel = ChannelHandlerPool.pool().getChannel(channelId);
// 查看是否存在当前channel,不存在便重新插入
if(null == channel){
ChannelHandlerPool.pool().addChannel(ctx.channel());
}
try {
//绑定userid 与 channel
ChannelHandlerPool.pool().putChannelId(message.getFromUserId(), channelId);
}catch (Exception e){
log.info("主动断开");
e.printStackTrace();
// 发生异常断开连接
ctx.close();
}
}
/**
* 绑定channel与userid
* @param ctx 上下文
* @param userId 用户id
*/
public void binding(ChannelHandlerContext ctx,String userId){
ChannelId channelId = ctx.channel().id();
Channel channel = ChannelHandlerPool.pool().getChannel(channelId);
// 查看是否存在当前channel,不存在便重新插入
if(null == channel){
ChannelHandlerPool.pool().addChannel(ctx.channel());
}
try {
SecurityServerHandler.SecurityCheckComplete complete = ctx.channel().attr(SECURITY_CHECK_COMPLETE_ATTRIBUTE_KEY).get();
//绑定userid 与 channel
ChannelHandlerPool.pool().putChannelId(complete.getHasToken() ? complete.getUserId() : userId, channelId);
}catch (Exception e){
log.info("主动断开");
e.printStackTrace();
// 发生异常断开连接
ctx.close();
}
}
public void sendMsg(Message message){
//私聊
ChannelId channelId = ChannelHandlerPool.pool().getChannelId(message.getToUserId());
if(null == channelId){
log.info("用户: {},已经下线!",message.getToUserId());
//下线操作 存库
return;
}
Channel channel = ChannelHandlerPool.pool().getChannel(channelId);
if(null == channel){
log.info("清除用户:{}在mapper存的channelId",message.getToUserId());
//特殊下线两个静态变量值不对称处理
ChannelHandlerPool.pool().removeChannelId(message.getFromUserId());
return;
}
ChannelHandlerPool.pool().writeAndFlush(channel,message);
log.info("channel中的userId:{}",channel.attr(SECURITY_CHECK_COMPLETE_ATTRIBUTE_KEY).get().getUserId());
}
/**
* 发送给除了自己的其他人
* @param ctx 上下文
* @param message 消息
*/
public void sendOther(ChannelHandlerContext ctx, Message message){
for (Channel channel : ChannelHandlerPool.pool().getChannelGroup()) {
//给除自己外的人发消息
if(channel != ctx.channel()){
log.info("发送消息:{}",message);
channel.writeAndFlush(new TextWebSocketFrame(JSONObject.toJSONString(message)));
}
}
}
public class AttributeKeyUtils {
/**
* 为channel添加属性 将userid设置为属性,避免客户端特殊情况退出时获取不到userid
*/
public static final AttributeKey<String> USER_ID = AttributeKey.valueOf("userid");
public static final AttributeKey<SecurityServerHandler.SecurityCheckComplete> SECURITY_CHECK_COMPLETE_ATTRIBUTE_KEY =
AttributeKey.valueOf("SECURITY_CHECK_COMPLETE_ATTRIBUTE_KEY");
}
package com.edu.message.handler.pool;
import com.alibaba.fastjson2.JSONObject;
import com.edu.common.utils.spring.SpringUtils;
import com.edu.message.converter.MessageConverter;
import com.edu.message.domain.vo.Message;
import com.edu.message.service.IMessageSocketService;
import io.netty.channel.Channel;
import io.netty.channel.ChannelId;
import io.netty.channel.group.ChannelGroup;
import io.netty.channel.group.DefaultChannelGroup;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import io.netty.util.concurrent.GlobalEventExecutor;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import static com.edu.common.constant.ThreadPoolConstants.POOL_NAME;
/**
* 保存用户及相关信息
* @author qb
* @version 1.0
* @since 2023/3/7 11:53
*/
@Slf4j
public class ChannelHandlerPool {
private final IMessageSocketService iSocketService = SpringUtils.getBean(IMessageSocketService.class);
private final MessageConverter messageConverter = SpringUtils.getBean(MessageConverter.class);
private ChannelHandlerPool(){}
private ThreadPoolTaskExecutor executor = SpringUtils.getBean(POOL_NAME);
private static final ChannelHandlerPool POOL = new ChannelHandlerPool();
public static ChannelHandlerPool pool(){
return POOL;
}
public void execute(Message message){
executor.submit(() -> {
String name = Thread.currentThread().getName();
Long id = Thread.currentThread().getId();
log.info("message:{},线程名称:{},id:{}",message,name,id);
iSocketService.save(messageConverter.messageToMsgSocket(message));
});
}
/**
*
* map: userId,ChannelId
*/
private static final Map<String, ChannelId> CHANNEL_ID_MAP = new ConcurrentHashMap<>();
/**
* 管道组
*/
private static final ChannelGroup CHANNEL_GROUP = new DefaultChannelGroup(GlobalEventExecutor.INSTANCE);
public Map<String, ChannelId> getChannelIdMap(){
return CHANNEL_ID_MAP;
}
public ChannelGroup getChannelGroup(){
return CHANNEL_GROUP;
}
/**
* 获取channelId
* @param userId 用户id
* @return /
*/
public ChannelId getChannelId(String userId){
return CHANNEL_ID_MAP.get(userId);
}
/**
* 获取channel
* @param channelId /
* @return /
*/
public Channel getChannel(ChannelId channelId){
return CHANNEL_GROUP.find(channelId);
}
public Channel getChannel(String userId){
return CHANNEL_GROUP.find(getChannelId(userId));
}
public void removeChannelId(String userid){
if(StringUtils.isNotBlank(userid)){
CHANNEL_ID_MAP.remove(userid);
}
}
public void removeChannel(Channel channel){
CHANNEL_GROUP.remove(channel);
}
public void addChannel(Channel channel){
CHANNEL_GROUP.add(channel);
}
/**
* 群发
* @param message 消息内容
*/
public void writeAndFlush(Message message){
saveBase(message);
CHANNEL_GROUP.writeAndFlush( new TextWebSocketFrame(JSONObject.toJSONString(message)));
}
/**
* 私发
* @param channel channel
* @param message 内容
*/
public void writeAndFlush(Channel channel,Message message){
saveBase(message);
channel.writeAndFlush( new TextWebSocketFrame(JSONObject.toJSONString(message)));
}
/**
* 私发
* @param channelId channelId
* @param message 消息
*/
public void writeAndFlush(ChannelId channelId,Message message){
saveBase(message);
findChannel(channelId).writeAndFlush( new TextWebSocketFrame(JSONObject.toJSONString(message)));
}
public Channel findChannel(ChannelId channelId){
return CHANNEL_GROUP.find(channelId);
}
public void putChannelId(String userid,ChannelId channelId){
CHANNEL_ID_MAP.put(userid,channelId);
}
private void saveBase(Message message){
ChannelHandlerPool.pool().execute(message);
}
}
调整为自定义请求头解析,但不去替换其他handler
package com.edu.message.handler.security;
import com.edu.common.utils.SecurityUtils;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.http.FullHttpMessage;
import io.netty.handler.codec.http.HttpHeaders;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import java.util.Objects;
import static com.edu.message.handler.attributeKey.AttributeKeyUtils.SECURITY_CHECK_COMPLETE_ATTRIBUTE_KEY;
/**
* @author Administrator
*/
@Slf4j
public class SecurityServerHandler extends ChannelInboundHandlerAdapter {
private String tokenHeader;
private Boolean hasToken;
public SecurityServerHandler(String tokenHeader,Boolean hasToken){
this.tokenHeader = tokenHeader;
this.hasToken = hasToken;
}
private SecurityServerHandler(){}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if(msg instanceof FullHttpMessage){
FullHttpMessage httpMessage = (FullHttpMessage) msg;
HttpHeaders headers = httpMessage.headers();
String token = Objects.requireNonNull(headers.get(tokenHeader));
if(hasToken){
// 开启鉴权 认证
//extracts device information headers
Long userId = 12345L;//SecurityUtils.getLoginUser(token).getUserId();
//check ......
SecurityCheckComplete complete = new SecurityCheckComplete(String.valueOf(userId),tokenHeader,hasToken);
ctx.channel().attr(SECURITY_CHECK_COMPLETE_ATTRIBUTE_KEY).set(complete);
ctx.fireUserEventTriggered(complete);
}else {
// 不开启鉴权 / 认证
SecurityCheckComplete complete = new SecurityCheckComplete(null,tokenHeader,hasToken);
ctx.channel().attr(SECURITY_CHECK_COMPLETE_ATTRIBUTE_KEY).set(complete);
}
}
//other protocols
super.channelRead(ctx, msg);
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
System.out.println("channel 捕获到异常了,关闭了");
super.exceptionCaught(ctx, cause);
}
@Getter
@AllArgsConstructor
public static final class SecurityCheckComplete {
private String userId;
private String tokenHeader;
private Boolean hasToken;
}
}
改为使用默认实现
@Override
protected void initChannel(SocketChannel ch){
log.info("有新的连接");
//获取工人所要做的工程(管道器==管道器对应的便是管道channel)
ChannelPipeline pipeline = ch.pipeline();
//为工人的工程按顺序添加工序/材料 (为管道器设置对应的handler也就是控制器)
//1.设置心跳机制
pipeline.addLast("idle-state",new IdleStateHandler(
nettyWebSocketProperties.getReaderIdleTime(),
0,
0,
TimeUnit.SECONDS));
//2.出入站时的控制器,大部分用于针对心跳机制
pipeline.addLast("change-duple",new WsChannelDupleHandler(nettyWebSocketProperties.getReaderIdleTime()));
//3.加解码
pipeline.addLast("http-codec",new HttpServerCodec());
//3.打印控制器,为工人提供明显可见的操作结果的样式
pipeline.addLast("logging", new LoggingHandler(LogLevel.INFO));
pipeline.addLast("aggregator",new HttpObjectAggregator(8192));
pipeline.addLast("auth",new SecurityServerHandler(
nettyWebSocketProperties.getTokenHeader(),
nettyWebSocketProperties.getHasToken()
));
pipeline.addLast("http-chunked",new ChunkedWriteHandler());
// pipeline.addLast("websocket",
// new CustomWebSocketServerProtocolHandler(
// nettyWebSocketProperties.getWebsocketPath(),
// nettyWebSocketProperties.getSubProtocols(),
// nettyWebSocketProperties.getAllowExtensions(),
// nettyWebSocketProperties.getMaxFrameSize())
// );
pipeline.addLast("websocket",
new WebSocketServerProtocolHandler(
nettyWebSocketProperties.getWebsocketPath(),
nettyWebSocketProperties.getSubProtocols(),
nettyWebSocketProperties.getAllowExtensions(),
nettyWebSocketProperties.getMaxFrameSize())
);
//7.自定义的handler针对业务
pipeline.addLast("chat-handler",new ChatHandler());
}
第一步走到了自己定义的鉴权控制器(入站控制器),执行channelRead方法
此处便是走到了默认协议控制器的channelRead方法,需要注意handshaker.handshake(ctx.channel(), req) 这个方法,这是处理握手的方法,打个断点进去
可以看到handshake 方法传的 HttpHeaders是null,这里就是核心的握手逻辑可以看到并没有提供相应的头处理器
newHandshakeResponse(req, responseHeaders) 就是构建响应结果,可以看到头是null
可以看到有回到了自定义handler的业务控制器 中的时间监听方法
此时只要放行这一步便会在控制台打印出响应头,可以看出并没有设置我们自己的响应头,还是null
最后统一返回,连接中断,自协议头不一致所导致