集成WebSocket+SpringMVC+dubbo的过程中遇到的问题

集成WebSocket的过程中遇到的问题

项目中使用到了dubbo和SpringMVC,并且使用SpringSession管理session,但在集成WebSocket时遇到了一些问题:
1、无法获取到登录用户
2、无法获取到dubbo的提供者
下面是解决这两个问题的思路,并在文章最后附上最终代码

无法获取到登录用户的问题

解决方法

  1. 写一个类GetHttpSessionConfigurator继承ServerEndpointConfig.Configurator
    重写modifyHandshake方法。
    可以在modifyHandshake方法中获取到登录用户,然后放入ServerEndpointConfig中。

//import com.test.WebUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.websocket.HandshakeResponse;
import javax.websocket.server.HandshakeRequest;
import javax.websocket.server.ServerEndpointConfig;

/**
 * 配置类将登录用户传入WebSocket中
 *
 * @author frank
 * @date 2018/11/14 12:36
 */
public class GetHttpSessionConfigurator extends ServerEndpointConfig.Configurator {
    protected final Logger logger = LoggerFactory.getLogger(this.getClass());
    @Override
    public void modifyHandshake(ServerEndpointConfig config, HandshakeRequest request, HandshakeResponse response) {
        String userId = "";
        try{
            // TODO 这里自己获取session中的userId,
            // userId = WebUtil.getCurrentUser();
            logger.info("GetHttpSessionConfigurator:userId"+userId);
        }catch (Exception e){
            logger.warn("没有权限");
            return ;
        }
        config.getUserProperties().put("userId", userId);
    }
}
  1. 在@ServerEndpoint中配置configurator = GetHttpSessionConfigurator.class
@ServerEndpoint(value="/web-socket/get-message-count", configurator = GetHttpSessionConfigurator.class)
  1. 在WebSocketServer的@OnOpen增加中 EndpointConfig config,使用 config.getUserProperties().get() 获取config中的登录用户信息
@OnOpen
    public void onOpen(Session session, EndpointConfig config) throws IOException {
        System.out.println("onOpen======================");
        this.currentSession = session;
        // TODO 获取config中的信息 获取userId 
        String userId = (String) config.getUserProperties().get("userId");
        System.out.println("userId======================"+userId);
        //建立链接时,缓存对象
        serverMap.put(userId, this);
        Map<String,Object> map = new HashMap<>();
        map.put("count",countUnreadMessageByUserId(userId));
        session.getBasicRemote().sendText(JSON.toJSONString(map));
    }

无法获取到dubbo的提供者

经过上面的配置后,发现WebSocketServer类中无法注入dubbo的Provider,
参考一些其他博客,发现configurator = SpringConfigurator.class时能够正常注入provider

鱼和熊掌不可兼得?去SpringConfigurator一看,SpringConfigurator继承的也是ServerEndpointConfig.Configurator。

public class SpringConfigurator extends Configurator {
	...
	...
}

这下好了,我们将GetHttpSessionConfigurator继承SpringConfigurator就可以了
解决方法

public class GetHttpSessionConfigurator extends SpringConfigurator {
	...
	...
}

最终代码

最终代码如下
GetHttpSessionConfigurator.java


//import com.test.WebUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.web.socket.server.standard.SpringConfigurator;
import javax.websocket.HandshakeResponse;
import javax.websocket.server.HandshakeRequest;
import javax.websocket.server.ServerEndpointConfig;

/**
 * 配置类将登录用户传入WebSocket中
 *
 * @author frank
 * @date 2018/11/14 12:36
 */
public class GetHttpSessionConfigurator extends SpringConfigurator {
    protected final Logger logger = LoggerFactory.getLogger(this.getClass());
    @Override
    public void modifyHandshake(ServerEndpointConfig config, HandshakeRequest request, HandshakeResponse response) {
        String userId = "";
        try{
            // TODO 这里自己获取session中的userId,
            // userId = WebUtil.getCurrentUser();
            logger.info("GetHttpSessionConfigurator:userId"+userId);
        }catch (Exception e){
            logger.warn("没有权限");
            return ;
        }
        config.getUserProperties().put("userId", userId);
    }
}

WebSocketServer.java


import javax.websocket.*;
import javax.websocket.server.ServerEndpoint;
import com.alibaba.fastjson.JSON;
import com.google.common.collect.Maps;
import com.test.GetHttpSessionConfigurator;
import com.test.ISysProvider;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.stereotype.Controller;
import java.io.IOException;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
/**
 * WebSocket服务
 *
 * @author frank
 * @date 2018/8/6 18:18
 */
@Component
@Scope("prototype")
@ServerEndpoint(value="/web-socket/get-message-count", configurator = GetHttpSessionConfigurator.class)
public class WebSocketServer  {

    // dubbo的Provider
    @Autowired
    @Qualifier("sysProvider")
    public ISysProvider sysProvider;

    private static final Logger logger = LoggerFactory.getLogger(WebSocketServer.class);
    /**已经建立链接的对象缓存起来*/
    private static ConcurrentMap<String, WebSocketServer> serverMap = new ConcurrentHashMap<>();
    /**当前session*/
    private Session currentSession;

    @OnOpen
    public void onOpen(Session session, EndpointConfig config) throws IOException {
        System.out.println("onOpen======================");
        this.currentSession = session;
        String userId = (String) config.getUserProperties().get("userId");
        // 能够获取到userId
        System.out.println("userId======================"+userId);
        //建立链接时,缓存对象
        serverMap.put(userId, this);
        Map<String,Object> map = new HashMap<>();
        map.put("count",countUnreadMessageByUserId(userId));
        session.getBasicRemote().sendText(JSON.toJSONString(map));
    }

    @OnClose
    public void onClose(Session session) {
        System.out.println("onClose======================");
        // 通过getLoginUserId()获取userId
        System.out.println("userId======================"+getLoginUserId());
        String userId = getLoginUserId();
        serverMap.remove(userId, this);
        this.currentSession = null;
        try {
            session.close();
        } catch (IOException e) {
            logger.error(e.getMessage());
        }
    }

    @OnMessage()
    public void onMessage(Session session,String msg) throws IOException {
        System.out.println("onMessage======================"+msg);
        // 可以获取到sysProvider和userId
        System.out.println("sysProvider======================"+(null == sysProvider));
        System.out.println("userId======================"+getLoginUserId());
        Map<String,Object> map = new HashMap<>();
        map.put("count",countUnreadMessageByUserId(getLoginUserId()));
        this.currentSession.getBasicRemote().sendText(JSON.toJSONString(map));
    }
    @OnError
    public void onError(Throwable t) {
        logger.error(t.getMessage());
    }

    /**
     * 根据传入的userId得到最新的未读消息数量
     * @param userId
     * @author frank
     */
    public void sendUnreadMessageCountByUserId(String userId){
        try {
            //如果连接开启,则发送最新数据
            if ( serverMap.containsKey(userId) && serverMap.get(userId).currentSession.isOpen() ) {
                Map<String,Object> map = new HashMap<>();
                map.put("count",countUnreadMessageByUserId(userId));
                serverMap.get(userId).currentSession.getBasicRemote().sendText(JSON.toJSONString(map));
            }
        } catch (IOException e) {
            logger.error(e.getMessage());
        }
    }

    /**
     * 根据websocket的session获取userId
     * @author frank
     * @param
     * @date 2018/11/14 14:01
     * @return java.lang.String
     */
    private String getLoginUserId(){
        if (serverMap.containsValue(this)) {
            Iterator<String> keys = serverMap.keySet().iterator();
            String userId = "";
            while(keys.hasNext()) {
                userId = keys.next();
                if (serverMap.get(userId) == this) {
                    return userId;
                }
            }
        }
        return null;
    }

    /**
     * 根据用户id统计未读消息数量
     * @author frank
     * @param userId
     * @date 2018/8/6 21:40
     * @return int
     */
    public int countUnreadMessageByUserId(String userId){
    	/*
    	 * TODO 获取最新的数据
    	 */
        return RandomUtils.nextInt()/100000000;
    }


}

你可能感兴趣的:(WebSocket,java,IT人生)