webflux的websocket连接与生命周期

1、配置入口:

import com.mti.handler.MessageHandler;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.Ordered;
import org.springframework.web.reactive.HandlerMapping;
import org.springframework.web.reactive.handler.SimpleUrlHandlerMapping;
import org.springframework.web.reactive.socket.WebSocketHandler;
import org.springframework.web.reactive.socket.server.support.WebSocketHandlerAdapter;

import java.util.HashMap;
import java.util.Map;
/**
 * ThreadConfig class
 *
 * @author zhaoyj
 * @date 2019/3/12
 */
@Configuration
public class WebSocketConfiguration {
    @Autowired
    @Bean
    public HandlerMapping webSocketMapping(final MessageHandler echoHandler) {
        final Map map = new HashMap<>();
        map.put("/echo", echoHandler);
        final SimpleUrlHandlerMapping mapping = new SimpleUrlHandlerMapping();
        mapping.setOrder(Ordered.HIGHEST_PRECEDENCE);
        mapping.setUrlMap(map);
        return mapping;
    }

    @Bean
    public WebSocketHandlerAdapter handlerAdapter() {
        return new WebSocketHandlerAdapter();
    }
}

2、配置Handler

import com.alibaba.fastjson.JSONObject;
import com.google.protobuf.InvalidProtocolBufferException;
import com.mti.configuration.Systemconfig;
import com.mti.enums.ReferenceMsgType;
import com.mti.exception.BusinessException;
import com.mti.handler.up.StreamReferenceReq;
import com.mti.proto.Linkproto;
import com.mti.vo.Message;
import com.mti.websocket.SocketSessionRegistry;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.core.io.buffer.DataBufferFactory;
import org.springframework.core.task.TaskExecutor;
import org.springframework.stereotype.Component;
import org.springframework.web.reactive.socket.WebSocketHandler;
import org.springframework.web.reactive.socket.WebSocketMessage;
import org.springframework.web.reactive.socket.WebSocketSession;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import java.util.Optional;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import java.util.concurrent.TimeUnit;

/**
 * MessageHandler class
 *
 * @author zhaoyj
 * @date 2019/3/12
 */
@Component
@Slf4j
public class MessageHandler implements WebSocketHandler {

    @Autowired
    private SocketSessionRegistry sessionRegistry;
    @Autowired
    private ScheduledThreadPoolExecutor executor;
    @Autowired
    private Systemconfig systemconfig;

    @Autowired
    DispatchFactory dispatchFactory;

    @Autowired
    TaskExecutor taskExecutor;

    @Override
    public Mono handle(WebSocketSession session) {

        return session.receive().doOnSubscribe(s -> {
            log.info("发起连接:{}",s);
            /**
             * 你有10秒时间登陆,不登陆就关掉连接;并且不给任何错误信息
             */
            if(systemconfig.getLoginInterval() != 0){
                executor.schedule(() -> sessionRegistry.checkAndRemove(session),systemconfig.getLoginInterval(),TimeUnit.SECONDS);
            }
            if(systemconfig.getPingInterval() != 0){
                executor.schedule(() ->  session.send(Flux.just(session.pingMessage(DataBufferFactory::allocateBuffer))).toProcessor(),systemconfig.getPingInterval(), TimeUnit.SECONDS);
            }
        }).doOnTerminate(() -> {
           sessionRegistry.unregisterSession(session);
            StreamReferenceReq req = (StreamReferenceReq) dispatchFactory.getCommand(ReferenceMsgType.SEND_VALUE);
            taskExecutor.execute(() -> Optional.ofNullable(req.removeSession(session)).ifPresent(list -> list.forEach(req::sendStopStreamConfig)));
           log.info("doOnTerminate");
        }).doOnComplete(() -> {
            log.info("doOnComplete");
        }).doOnCancel(() -> {
            log.info("doOnCancel");
        }).doOnNext(message -> {
            if(message.getType().equals(WebSocketMessage.Type.BINARY)){
                log.info("收到二进制消息");
                Linkproto.LinkCmd linkCmd = null;
                try {
                    linkCmd = Optional.ofNullable(Linkproto.LinkCmd.parseFrom(message.getPayload().asByteBuffer())).orElseThrow(() -> new BusinessException(500,"解析出错了"));
                    BaseDispatch dispatch = dispatchFactory.getCommand(linkCmd.getTypeValue());
                    log.info("处理session,{},消息实体,{},类型,{},dispatch:{}",session,linkCmd,linkCmd.getTypeValue(),dispatch);
                    dispatch.excuted(session, linkCmd);
                } catch (InvalidProtocolBufferException e) {
                    e.printStackTrace();
                }
            }else if(message.getType().equals(WebSocketMessage.Type.TEXT)){
                String content = message.getPayloadAsText();
                log.info("收到文本消息:{}",content);
                Message msg = null;
                try{
                    msg = JSONObject.parseObject(content, Message.class);
                }catch (Exception e){
                    JSONObject obj = new JSONObject();
                    obj.put("content","无法理解你发过来的消息内容,不予处理:"+content);
                    obj.put("msgType",Linkproto.LinkCmdType.LINK_CMD_ZERO_VALUE);
                    session.send(Flux.just(session.textMessage(obj.toJSONString()))).then().toProcessor();
                    log.error("解析消息内容出错");
                    return;
                }
                BaseDispatch dispatch = dispatchFactory.getCommand(msg.getMsgType());
                if(dispatch != null){
                    dispatch.executeMsg(session, msg);
                }
            }else if(message.getType().equals(WebSocketMessage.Type.PING)){
                session.send(Flux.just(session.pongMessage(s -> s.wrap(new byte[256]))));
                log.info("收到ping消息");
            }else if(message.getType().equals(WebSocketMessage.Type.PONG)){
                log.info("收到pong消息");
                if(systemconfig.getPingInterval() != 0){
                    executor.schedule(() ->  session.send(Flux.just(session.pingMessage(DataBufferFactory::allocateBuffer))).toProcessor(),systemconfig.getPingInterval(), TimeUnit.SECONDS);
                }
            }
        }).doOnError(e -> {
            e.printStackTrace();
            log.error("doOnError");
        }).doOnRequest(r -> {
            log.info("doOnRequest");
        }).then();
    }

这边显示的是整个从连接建立到连接断开的生命周期,可以区区分二进制消息还是文本消息,发送消息时,一定要加上toProcessor(),不然不会发送。

如果要发送消息到其它客户端,需要在后台将连接过来的session保存起来,根据用户名或者其它方式保存之后,获取到session进行发送:如下面这个SocketSessionRegistry类

import com.mti.enums.SocketCloseStatus;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import org.springframework.web.reactive.socket.WebSocketSession;

import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.CopyOnWriteArraySet;
import java.util.concurrent.CountDownLatch;

/**
 *用户session记录类
 *
 * @author zhaoyj
 * @date 2019/3/12
 */
@Service
@Slf4j
public class SocketSessionRegistry {

    /**
     * 这个集合存储session
     */
    private final ConcurrentMap> userSessionIds = new ConcurrentHashMap<>();

    private final ConcurrentMap clientInfoSessionIds = new ConcurrentHashMap<>();

    private final ConcurrentMap sessionIdUser = new ConcurrentHashMap<>();
    private ConcurrentMap cacheTimestamp = new ConcurrentHashMap<>();
    private final Object lock = new Object();


    /**
     * 获取sessionId
     *
     * @param user
     * @return
     */
    private Set getSessionIds(String user) {
        Set set = this.userSessionIds.get(user);
        return set != null ? set : Collections.emptySet();
    }

    /**
     * 获取用户session
     * @param user
     * @return
     */
    public Collection getSessionByUser(String user){
        Set sessionIds = Optional.ofNullable(getSessionIds(user)).orElse(new CopyOnWriteArraySet<>());
        List sessions = new ArrayList<>();
        for (String sessionId : sessionIds) {
            sessions.add(clientInfoSessionIds.get(sessionId));
        }
        return sessions;
    }

    /**
     * 获取用户session
     * @param users
     * @return
     */
    public Collection getSessionByUsers(Collection users){
        List sessions = new ArrayList<>();
        if(!CollectionUtils.isEmpty(users)){
            for (String user : users) {
                sessions.addAll(getSessionByUser(user));
            }
        }
        return sessions;
    }

    /**
     * 获取所有session
     * @return  Collection
     */
    public Collection getAllSessions(){
        return clientInfoSessionIds.values();
    }

    /**
     * 获取所有session
     *
     * @return
     */
    public ConcurrentMap> getAllSessionIds() {
        return this.userSessionIds;
    }
    /**
     * 获取所有session
     *
     * @return
     */
    public ConcurrentMap getAllSessionWebSocketInfos() {
        return this.clientInfoSessionIds;
    }
    /**
     * register session
     *
     * @param user
     * @param sessionId
     */
    private void registerSessionId(String user, String sessionId) {

        synchronized (this.lock) {
            Set set = this.userSessionIds.get(user);
            if (set == null) {
                set = new CopyOnWriteArraySet<>();
                this.userSessionIds.put(user, set);
            }
            set.add(sessionId);
        }
    }

    /**
     * 保存session
     * @param session WebSocketSession
     */
    public  void registerSession(WebSocketSession session,String user){
        if(StringUtils.isEmpty(user)){
            user = parseUserByURI(session).get("user");
        }
        if(!StringUtils.isEmpty(user)){
            String sessionId = session.getId();
            registerSessionId(user,sessionId);
            registerSessionId(session);
            sessionIdUser.putIfAbsent(sessionId,user);
        }
    }
    /**
     * 从session里面解析参数
     * @param session
     * @return
     */
    private Map parseUserByURI(WebSocketSession session){
        Map map = new HashMap<>();
        String[] params = Optional.ofNullable(session.getHandshakeInfo().getUri().getQuery()).orElse("").split("&");
        for (String param : params) {
            String[] temp = param.split("=");
            if(temp.length == 2){
                map.put(temp[0],temp[1]);
            }
        }
        return map;
    }
    public WebSocketSession getSessionBySessionId(String sessionId){
        return this.clientInfoSessionIds.get(sessionId);
    }
    private void registerSessionId(WebSocketSession websocketInfo) {
        String sessionId = websocketInfo.getId();
        CountDownLatch signal = cacheTimestamp.putIfAbsent(sessionId, new CountDownLatch(1));
        if (signal == null) {
            signal = cacheTimestamp.get(sessionId);
            try {
                if (!clientInfoSessionIds.containsKey(sessionId)) {
                    WebSocketSession set = this.clientInfoSessionIds.get(sessionId);
                    if (set == null) {
                        clientInfoSessionIds.putIfAbsent(sessionId, websocketInfo);
                    }
                }
            } finally {
                signal.countDown();
                cacheTimestamp.remove(sessionId);
            }
        } else {
            try {
                signal.await();
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        }
    }

    private void unregisterSessionId(String userName, String sessionId) {

        synchronized (this.lock) {
            Set set = this.userSessionIds.get(userName);
            if (set != null && set.remove(sessionId) && set.isEmpty()) {
                this.userSessionIds.remove(userName);
            }
        }
    }
    private void unregisterSessionId(String sessionId) {

        synchronized (this.lock) {
            WebSocketSession set = this.clientInfoSessionIds.get(sessionId);
            if (set != null) {
                this.clientInfoSessionIds.remove(sessionId);
            }
        }
    }

    public void unregisterSession(WebSocketSession session){
        String sessionId = session.getId();
        String user = sessionIdUser.get(sessionId);
        if(!StringUtils.isEmpty(user)){
            unregisterSessionId(sessionId);
            unregisterSessionId(user,sessionId);
            sessionIdUser.remove(sessionId);
        }
    }

    public void checkAndRemove(WebSocketSession session){
        String sessionId = session.getId();
        if(!this.clientInfoSessionIds.containsKey(sessionId)){
            log.info("sessionId:{} 10秒内没有登陆,关掉它",sessionId);
            session.close(SocketCloseStatus.UN_LOGIN.getCloseStatus()).toProcessor();
        }else{
            log.info("sessinId:{}已经登陆,是合法的",sessionId);
        }
    }
}
userSessionIds是保存用记所属的sessionId列表的,因为同一个用户可能会在不同地方登陆,会有多个session
clientInfoSessionIds这个是保存session的,可以根据sessionId对应到用户。

这几周慢慢摸索出来的结果,网上资料很少,官网上的也不是很全,可能有不对的地方,在此做个记录!

你可能感兴趣的:(webflux的websocket连接与生命周期)