直接贴代码,不多说了
package com.tom.vv.config;
import cn.hutool.jwt.JWT;
import com.tom.utils.JsonUtil;
import com.tom.vv.auth.MyJwtUtil;
import com.tom.vv.dto.JwtUser;
import com.tom.vv.handler.JwtUserParamResolver;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Configuration;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.messaging.simp.SimpMessageSendingOperations;
import org.springframework.messaging.simp.config.MessageBrokerRegistry;
import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler;
import org.springframework.util.Assert;
import org.springframework.web.method.support.HandlerMethodArgumentResolver;
import org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerAdapter;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker;
import org.springframework.web.socket.config.annotation.StompEndpointRegistry;
import org.springframework.web.socket.config.annotation.WebSocketMessageBrokerConfigurer;
import org.springframework.web.socket.config.annotation.WebSocketTransportRegistration;
import org.springframework.web.socket.handler.WebSocketHandlerDecorator;
import org.springframework.web.socket.server.support.DefaultHandshakeHandler;
import javax.annotation.PostConstruct;
import java.security.Principal;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
@Slf4j
@EnableWebSocketMessageBroker
@Configuration
@RequiredArgsConstructor
public class WebSocketConfig implements WebSocketMessageBrokerConfigurer {
@Autowired
private ObjectProvider simpMessageSendingOperations;
private final ScheduledExecutorService es = Executors.newScheduledThreadPool(4);
@Override
public void registerStompEndpoints(StompEndpointRegistry registry) {
registry.addEndpoint("/wse").setHandshakeHandler(new MyHandleShakeHandler());
// .withSockJS();
}
@Override
public void configureMessageBroker(MessageBrokerRegistry registry) {
ThreadPoolTaskScheduler pool = new ThreadPoolTaskScheduler();
pool.setPoolSize(Runtime.getRuntime().availableProcessors());
pool.setThreadNamePrefix("WsHeart");
pool.initialize();
registry.enableSimpleBroker("/topic/")
.setHeartbeatValue(new long[]{1000 * 60, 1000 * 30})
.setTaskScheduler(pool)
;
registry.setUserDestinationPrefix("/user/");
// registry.enableStompBrokerRelay("/topic/");
// registry.setApplicationDestinationPrefixes("/app/");
}
@Override
public void configureWebSocketTransport(WebSocketTransportRegistration registry) {
registry.addDecoratorFactory(MyWebSocketHandler::new);
}
class MyWebSocketHandler extends WebSocketHandlerDecorator {
private final Map sessions = new ConcurrentHashMap<>(128);
public MyWebSocketHandler(WebSocketHandler delegate) {
super(delegate);
}
@Override
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
Principal p = session.getPrincipal();
if (p instanceof JwtUser) {
WebSocketSession oldSession = sessions.remove(p.getName());
if (Objects.nonNull(oldSession) && !oldSession.getId().equals(session.getId())) {
log.info("User: {} exist old session: {}", p, oldSession);
List oldToken = oldSession.getHandshakeHeaders().get("token");
List newToken = session.getHandshakeHeaders().get("token");
if (Objects.nonNull(oldToken) && Objects.nonNull(newToken)
&& !oldToken.isEmpty() && !newToken.isEmpty()
&& oldToken.get(0).equals(newToken.get(0))) {
// 客户端如果未作点击防抖 点击多次时,同一用户新session关闭,同一不再建立新连接。
log.info("User: {} new session from same client, close newSession: {}, use old!", p, session.getId());
session.close();
return;
} else {
SimpMessageSendingOperations sender = WebSocketConfig.this.simpMessageSendingOperations.getIfAvailable();
if (Objects.nonNull(sender)) {
sender.convertAndSendToUser(p.getName(), "/topic/kick-out", "当前帐号已在其他设备登录!");
es.schedule(() -> {
try {
oldSession.close();
log.info("handle kick out user: {} sessionId: {}", ((JwtUser) p).getId(), oldSession.getId());
} catch (Exception e) {
log.error("Handle kick out error", e);
}
}, 2, TimeUnit.SECONDS);
}
}
}
this.sessions.put(p.getName(), session);
log.info("User: {} connected! session: {}, current online users: {}", p, session.getId(), sessions.size());
}
super.afterConnectionEstablished(session);
}
@Override
public void afterConnectionClosed(WebSocketSession session, CloseStatus closeStatus) throws Exception {
try {
Principal p = session.getPrincipal();
if (Objects.nonNull(p) && p instanceof JwtUser) {
WebSocketSession sess = this.sessions.get(p.getName());
if (Objects.nonNull(sess) && sess.getId().equals(session.getId())) {
this.sessions.remove(p.getName());
log.info("User: {} disconnect!", p);
}
}
} catch (Exception e) {
log.warn("Web socket disconnect error!", e);
}
super.afterConnectionClosed(session, closeStatus);
}
}
static class MyHandleShakeHandler extends DefaultHandshakeHandler {
@Override
protected Principal determineUser(ServerHttpRequest request, WebSocketHandler wsHandler, Map attributes) {
if (log.isDebugEnabled()) {
log.debug("request handshake: {} remote: {}, headers: {}", request.getURI(), request.getRemoteAddress(), JsonUtil.toJsonStringQuit(request.getHeaders()));
}
List tkHeaders = request.getHeaders().get("token");
if (Objects.nonNull(tkHeaders) && !tkHeaders.isEmpty()) {
String tk = tkHeaders.get(0);
JWT jwtAuthToken = MyJwtUtil.parseToken(tk);
if (Objects.isNull(jwtAuthToken)) {
log.error("handshake token not parsable: {}", tk);
} else {
return MyJwtUtil.extractJwtUser(jwtAuthToken);
}
}
return super.determineUser(request, wsHandler, attributes);
}
}
}