@Configuration
@EnableWebSocket
public class WebSocketConfig implements WebSocketConfigurer {
@Autowired
private WebSocketInterceptor webSocketInterceptor;
@Autowired
private TextWebSocketHandler textWebSocketHandler;
@Override
public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
registry.addHandler(textWebSocketHandler, "/forward/**")
.setAllowedOrigins("*")
.addInterceptors(webSocketInterceptor);
}
}
@Component
@Slf4j
public class WebSocketInterceptor implements HandshakeInterceptor {
@Override
public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response,
WebSocketHandler wsHandler, Map attributes) throws Exception {
log.info("websocket before hand.");
if (!(request instanceof ServletServerHttpRequest)) {
return true;
}
return true;
}
@Override
public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response,
WebSocketHandler wsHandler, Exception exception) {
log.info("websocket after hand.");
}
}
@Slf4j
@Component
public class TextWebSocketHandler implements WebSocketHandler {
@Autowired
private WebSocketService webSocketService;
@Override
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
log.info("socket path :{}", session.getUri().getPath());
String topic = matchTopicPath(session.getUri().getPath());
webSocketService.clientOnline(topic, session);
}
/*
* 根据路径匹配topic
*/
private String matchTopicPath(String path) {
return path;
}
@Override
public void handleMessage(WebSocketSession session, WebSocketMessage> message) throws Exception {
// every msg begin a trace instance
String msg = message.getPayload().toString();
log.info("receive message:{}", message);
String topic = matchTopicPath(session.getUri().getPath());
webSocketService.messageForward(topic, msg);
try {
if ("ping".equals(msg)) {
session.sendMessage(new TextMessage("pong"));
} else {
session.sendMessage(new TextMessage("ok"));
}
} catch (IOException ex) {
handleTransportError(session, ex);
}
}
@Override
public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
log.error("websocket handleTransportError :", exception);
String topic = matchTopicPath(session.getUri().getPath());
webSocketService.clientOffLine(topic, session);
if (session.isOpen()) {
session.close();
}
}
@Override
public void afterConnectionClosed(WebSocketSession session, CloseStatus closeStatus) throws Exception {
log.error("websocket afterConnectionClosed :");
String topic = matchTopicPath(session.getUri().getPath());
webSocketService.clientOffLine(topic, session);
}
@Override
public boolean supportsPartialMessages() {
return false;
}
}
/**
* websocket 推送消息
*/
@Slf4j
@Service
public class WebSocketServiceImpl implements WebSocketService {
//topic 注册上的client
private static final ConcurrentHashMap> TOPIC_CLIENT_MAP = new ConcurrentHashMap<>();
//topic 上注册的回调方法
private static final ConcurrentHashMap>> CALL_BACK_FUN_MAP = new ConcurrentHashMap<>();
@Override
public void pushMessage(String topic, T data) {
String msg = JSON.toJSONString(data);
ConcurrentHashMap topicClient = TOPIC_CLIENT_MAP.get(topic);
if (topicClient == null || topicClient.size() == 0) {
log.error("[pushMessage] did not find websocket client,topic{}", topic);
return;
}
for (WebSocketSession session : topicClient.values()) {
try {
session.sendMessage(new TextMessage(msg));
} catch (IOException ex) {
log.error("[pushMessage] topic:{} throw:", topic, ex);
clientOffLine(topic, session);
}
}
}
@Override
public void clientOnline(String topic, WebSocketSession session) {
log.info("[clientOnline] client online. topic:{} id:{},ip:{}", topic, session.getId(),
session.getRemoteAddress());
ConcurrentHashMap topicClient = TOPIC_CLIENT_MAP.get(topic);
if (topicClient == null) {
TOPIC_CLIENT_MAP.putIfAbsent(topic, new ConcurrentHashMap<>());
}
TOPIC_CLIENT_MAP.get(topic).put(session.getId(), session);
}
@Override
public void clientOffLine(String topic, WebSocketSession session) {
log.info("[clientOffLine] client offline.topic:{}, id:{},ip:{}", topic, session.getId(),
session.getRemoteAddress());
ConcurrentHashMap topicClient = TOPIC_CLIENT_MAP.get(topic);
if (topicClient == null) {
TOPIC_CLIENT_MAP.putIfAbsent(topic, new ConcurrentHashMap<>());
}
TOPIC_CLIENT_MAP.get(topic).remove(session.getId());
}
@Override
public void registerHandleFun(String topic, String signature, Function messageHandle) {
ConcurrentHashMap> messageHandleMap = CALL_BACK_FUN_MAP.get(topic);
if (messageHandleMap == null) {
CALL_BACK_FUN_MAP.putIfAbsent(topic, new ConcurrentHashMap<>());
}
CALL_BACK_FUN_MAP.get(topic).put(signature, messageHandle);
}
@Async
@Override
public void messageForward(String topic, String message) {
log.info("receive websocket topic:{}, message:{}", topic, message);
ConcurrentHashMap> messageHandleMap = CALL_BACK_FUN_MAP.get(topic);
if (messageHandleMap == null) {
CALL_BACK_FUN_MAP.putIfAbsent(topic, new ConcurrentHashMap<>());
}
for (Entry> stringFunctionEntry : CALL_BACK_FUN_MAP.get(topic).entrySet()) {
boolean rs = stringFunctionEntry.getValue().apply(message);
log.info("topic:{},signature:{} message:{} handle rs:{}", topic, stringFunctionEntry.getKey(), message, rs);
}
}
}