Springboot中添加原生websocket支持

1、添加配置

@Configuration
@EnableWebSocket
public class WebSocketConfig implements WebSocketConfigurer {
    @Override
    public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
        // 注册WebSocket处理器,并允许所有来源的连接(在生产环境中应限制来源)
        registry.addHandler(new WebSocketHandler(), "/ws/[请求的地址]")
                .setAllowedOrigins("*").addInterceptors(new WebSocketSecurityTokenInterceptor());
    }
}

2、添加Handler对请求进行处理

@Component
@Slf4j
public class WebSocketHandler extends TextWebSocketHandler {

    private static final CopyOnWriteArrayList sessions = new CopyOnWriteArrayList<>();

    private static final ScheduledExecutorService scheduledThreadPool = Executors.newScheduledThreadPool(Runtime.getRuntime().availableProcessors());

    private static final Map lastPongTimes = new ConcurrentHashMap<>();

    private static final String PING = "ping";

    private static final String PONG = "pong";

    private static final String GET_DATA = "getData";
    /**
     * 检测心跳是否正常的周期时间
     */
    private static final Integer heartbeatInterval = 30_000;

    /**
     * 检测客户端连接心跳保持时间是否超时的时间
     */
    private static final Integer heartbeatTimeout = 60_000;

    @Resource
    private ReportDashboardService reportDashboardService;
    // 使用Guava 弱引用缓存数据
    private static final Cache CACHE = CacheBuilder.newBuilder().softValues().expireAfterWrite(3, TimeUnit.SECONDS).build();

    @PostConstruct
    public void init() {
        scheduledThreadPool.scheduleWithFixedDelay(() -> {
            try {
                checkHeartbeats();
            } catch (IOException e) {
                throw new RuntimeException(e);
            }
        },heartbeatInterval,heartbeatInterval,TimeUnit.SECONDS);
    }

    @Override
    protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
        // 广播消息给所有已连接的客户端
        String payload = message.getPayload();
        if (PING.equals(payload)) {
            session.sendMessage(new TextMessage(PONG));
        } else if (GET_DATA.equals(payload)) {
            sendData(session);
        } else {
            sendDataByPayload(session,payload);
            //broadcast(payload);
        }
        recordPong(session.getId());
    }

    @Override
    public void afterConnectionEstablished(WebSocketSession session) throws Exception {
        // 当新连接建立时添加到列表
        sessions.add(session);
        //session.sendMessage(new TextMessage(PONG));
        recordPong(session.getId());
        sendData(session);
    }

    @Override
    public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
        // 当连接关闭时从列表中移除
        sessions.remove(session);
        log.info("Connection closed.sessionId={},status={}",session.getId(),status);
    }

    private void sendData(WebSocketSession sess) {
        try {
            if (sess.isOpen()) {
                String query = sess.getUri().getQuery();
                String[] split = query.split("&");
                ReqObj req = new ReqObj();
                String path = "";
                for (String s : split) {
                    String[] arr = s.split("=");
                    if ("path".equals(arr[0])) {
                        path = arr[1];
                    } else if ("deviceCode".equals(arr[0])) {
                        req.setDeviceCode(arr[1]);
                    } else if ("pointType".equals(arr[0])) {
                        req.setPointType(arr[1]);
                    } else if ("gap".equals(arr[0])) {
                        if (arr[1] != null) {
                            req.setGap(Integer.parseInt(arr[1]));
                        }
                    }
                }
                sess.sendMessage(new TextMessage(sendDataByPath(path,req)));
            }
        } catch (Exception e) {
            System.err.println("Failed to send message: " + e.getMessage());
        }
    }

    private boolean closeTimeoutSession(String sessionId) throws IOException {
        WebSocketSession s = null;
        for (WebSocketSession sess : sessions) {
            if (sess.isOpen() && sess.getId().equals(sessionId)) {
                sess.sendMessage(new TextMessage("当前连接1分钟内未发送心跳消息,即将关闭"));
                s = sess;
            }
        }
        log.info("关闭心跳超过的连接,sessionId={}",sessionId);
        return s != null && sessions.remove(s);
    }

    private void recordPong(String sessionId) {
        lastPongTimes.put(sessionId,System.currentTimeMillis());
    }

    private boolean isClientAlive(String sessionId) {
        Long lastPongTime = lastPongTimes.get(sessionId);
        if (lastPongTime == null){
            return false;
        }
        return System.currentTimeMillis() - lastPongTime <= heartbeatTimeout;
    }

    private void checkHeartbeats() throws IOException {
        log.info("开始检查连接的心跳是否超时......");
        Set> entries = lastPongTimes.entrySet();
        for (Map.Entry entry : entries) {
            String sessionId = entry.getKey();
            log.info("sessionId = {}",sessionId);
            if (!isClientAlive(sessionId)) {
                closeTimeoutSession(sessionId);
            }
        }
    }

    private void sendDataByPayload(WebSocketSession sess,String payload){
        try {
            if (sess.isOpen()) {
                ChartDto dto = null;
                String cacheKey = null;
                try {
                     cacheKey = MD5Util.encrypt(payload);
                    dto = JSON.parseObject(payload, ChartDto.class);
                } catch (Exception e) {
                    log.error("将payload转为ChartDto对象失败");
                }
                if (dto == null) {
                    JSONObject jsonObject = JSON.parseObject(payload);
                    String path = jsonObject.getString("path");
                    ReqObj req = new ReqObj();
                    req.setGap(jsonObject.getIntValue("gap",0));
                    req.setDeviceCode(jsonObject.getString("deviceCode"));
                    req.setPointType(jsonObject.getString("pointType"));
                    String s = sendDataByPath(path, req);
                    sess.sendMessage(new TextMessage("{\"path\":\""+path+"\",\"data\":"+s+"}"));
                } else {
                    if (reportDashboardService == null) {
                        reportDashboardService = SpringUtil.getBean(ReportDashboardService.class);
                    }
                    synchronized (Thread.currentThread()) {
                        Object data = CACHE.getIfPresent(cacheKey);
                        if (data == null) {
                            data = reportDashboardService.getChartData(dto);
                            if (data != null) {
                                CACHE.put(cacheKey,data);
                            }
                        }
                        String s = JSON.toJSONString(R.success(data,"success",dto.getId()));
                        sess.sendMessage(new TextMessage(s));
                    }
                }
            }
        } catch (Exception e) {
            System.err.println("Failed to send message: " + e.getMessage());
        }
    }

    private String sendDataByPath(String path,ReqObj req) {
        return "{}";
    }

    private void broadcast(String message) {
        for (WebSocketSession sess : sessions) {
            try {
                if (sess.isOpen()) {
                    String data = "原始数据:"+message+",翻转后的数据:"+new StringBuilder(message).reverse();
                    sess.sendMessage(new TextMessage(data));
                }
            } catch (Exception e) {
                System.err.println("Failed to send message: " + e.getMessage());
            }
        }
    }
}

3、拦截器握手时进行校验token

@Getter
@Slf4j
@Component
public class WebSocketSecurityTokenInterceptor implements HandshakeInterceptor {

    private TokenAcquireHandler tokenAcquireHandler;

    private TokenAnalysisHandler tokenAnalysisHandler;

    {
        tokenAcquireHandler = SpringUtil.getOrDefault( TokenAcquireHandler.class, new DefaultTokenAcquireHandler() );
        tokenAnalysisHandler = SpringUtil.getOrDefault( TokenAnalysisHandler.class, new DefaultTokenAnalysisHandler() );
    }

    @Override
    public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map attributes) throws Exception {
        //  放开的路径直接放行
        if ( FilterContextHandler.getContext().isExclude() ) {
            //  如果已经手动setContext 此处不再赋值Empty
            if ( SecurityContextHandler.getContext() == null ) {
                SecurityContextHandler.setContext( SecurityContext.EMPTY );
            }
            return true;
        }
        String token = getToken(request);
        if ( !StringUtils.hasText( token ) ) {
            throw new TokenNotFoundException( "token not found" );
        }
        UserDetails userDetails = tokenAnalysisHandler.analysisToken( token );
        checkUserDetails( token, userDetails );
        SecurityContextHandler.setContext( new SecurityContext( token, userDetails ) );
        return true;
    }

    /**
     * 校验用户信息
     */
    private void checkUserDetails( String token, UserDetails userDetails ) {
        //  解析的UserDetails不能为空
        if ( userDetails == null ) {
            throw new TokenAnalysisException( "token analysis userDetails cannot be empty" );
        }
        //  判断用户是否启用
        if ( !userDetails.isEnabled() ) {
            throw new TokenAnalysisException();
        }
        //  判断用户是否过期
        if ( userDetails.isAccountNonExpired() ) {
            throw new UserDetailsExpiredException();
        }
        //  判断用户是否锁定
        if ( userDetails.isAccountNonLocked() ) {
            throw new UserLockException();
        }
        //  判断Token是否过期
        if ( userDetails.isCredentialsNonExpired( token ) ) {
            throw new TokenExpiredException();
        }
    }

    @Override
    public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Exception exception) {
        // 握手完成后进行一些初始化工作
        //log.info("握手完成......");
    }

    private String getToken( ServerHttpRequest req ) {
        List< String > headerList = req.getHeaders().get( HttpHeaders.AUTHORIZATION );
        String token = CollectionUtils.isEmpty( headerList ) ? "" : headerList.get( 0 );
        if ( StrUtil.isNotBlank( token ) ) {
//            req.setAttribute( HttpHeaders.AUTHORIZATION, token );
            return token;
        }
        List< String > cookies = req.getHeaders().get( HttpHeaders.COOKIE );
        for (String cookieStr : Optional.ofNullable(cookies).orElse(Collections.emptyList())) {
            HttpCookie cookie = parseAuthCookie(cookieStr);
            if ( cookie != null ){
                return cookie.getValue();
            }
        }
        return null;
    }

    private HttpCookie parseAuthCookie(String cookieStr) {
        if (!StringUtils.hasText(cookieStr)){
            return null;
        }
        List cookieList = Arrays.stream(cookieStr.split(";")).map(this::parseCookie).filter(Objects::nonNull).collect(Collectors.toList());
        for (HttpCookie cookie : cookieList) {
            if ( HttpHeaders.AUTHORIZATION.equals( cookie.getName() ) ) {
                return cookie;
            }
        }
        return null;
    }

    private HttpCookie parseCookie(String cookieStr) {
        try {
            List cookies = HttpCookie.parse(cookieStr);
            return CollectionUtils.isEmpty(cookies) ? null : cookies.get(0);
        }catch (Exception e){
            return null;
        }
    }
}

你可能感兴趣的:(spring,boot,websocket,java)