项目中使用到了dubbo和SpringMVC,并且使用SpringSession管理session,但在集成WebSocket时遇到了一些问题:
1、无法获取到登录用户
2、无法获取到dubbo的提供者
下面是解决这两个问题的思路,并在文章最后附上最终代码
解决方法
//import com.test.WebUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.websocket.HandshakeResponse;
import javax.websocket.server.HandshakeRequest;
import javax.websocket.server.ServerEndpointConfig;
/**
* 配置类将登录用户传入WebSocket中
*
* @author frank
* @date 2018/11/14 12:36
*/
public class GetHttpSessionConfigurator extends ServerEndpointConfig.Configurator {
protected final Logger logger = LoggerFactory.getLogger(this.getClass());
@Override
public void modifyHandshake(ServerEndpointConfig config, HandshakeRequest request, HandshakeResponse response) {
String userId = "";
try{
// TODO 这里自己获取session中的userId,
// userId = WebUtil.getCurrentUser();
logger.info("GetHttpSessionConfigurator:userId"+userId);
}catch (Exception e){
logger.warn("没有权限");
return ;
}
config.getUserProperties().put("userId", userId);
}
}
@ServerEndpoint(value="/web-socket/get-message-count", configurator = GetHttpSessionConfigurator.class)
@OnOpen
public void onOpen(Session session, EndpointConfig config) throws IOException {
System.out.println("onOpen======================");
this.currentSession = session;
// TODO 获取config中的信息 获取userId
String userId = (String) config.getUserProperties().get("userId");
System.out.println("userId======================"+userId);
//建立链接时,缓存对象
serverMap.put(userId, this);
Map<String,Object> map = new HashMap<>();
map.put("count",countUnreadMessageByUserId(userId));
session.getBasicRemote().sendText(JSON.toJSONString(map));
}
经过上面的配置后,发现WebSocketServer类中无法注入dubbo的Provider,
参考一些其他博客,发现configurator = SpringConfigurator.class时能够正常注入provider
鱼和熊掌不可兼得?去SpringConfigurator一看,SpringConfigurator继承的也是ServerEndpointConfig.Configurator。
public class SpringConfigurator extends Configurator {
...
...
}
这下好了,我们将GetHttpSessionConfigurator继承SpringConfigurator就可以了
解决方法
public class GetHttpSessionConfigurator extends SpringConfigurator {
...
...
}
最终代码如下
GetHttpSessionConfigurator.java
//import com.test.WebUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.web.socket.server.standard.SpringConfigurator;
import javax.websocket.HandshakeResponse;
import javax.websocket.server.HandshakeRequest;
import javax.websocket.server.ServerEndpointConfig;
/**
* 配置类将登录用户传入WebSocket中
*
* @author frank
* @date 2018/11/14 12:36
*/
public class GetHttpSessionConfigurator extends SpringConfigurator {
protected final Logger logger = LoggerFactory.getLogger(this.getClass());
@Override
public void modifyHandshake(ServerEndpointConfig config, HandshakeRequest request, HandshakeResponse response) {
String userId = "";
try{
// TODO 这里自己获取session中的userId,
// userId = WebUtil.getCurrentUser();
logger.info("GetHttpSessionConfigurator:userId"+userId);
}catch (Exception e){
logger.warn("没有权限");
return ;
}
config.getUserProperties().put("userId", userId);
}
}
WebSocketServer.java
import javax.websocket.*;
import javax.websocket.server.ServerEndpoint;
import com.alibaba.fastjson.JSON;
import com.google.common.collect.Maps;
import com.test.GetHttpSessionConfigurator;
import com.test.ISysProvider;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.stereotype.Controller;
import java.io.IOException;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
/**
* WebSocket服务
*
* @author frank
* @date 2018/8/6 18:18
*/
@Component
@Scope("prototype")
@ServerEndpoint(value="/web-socket/get-message-count", configurator = GetHttpSessionConfigurator.class)
public class WebSocketServer {
// dubbo的Provider
@Autowired
@Qualifier("sysProvider")
public ISysProvider sysProvider;
private static final Logger logger = LoggerFactory.getLogger(WebSocketServer.class);
/**已经建立链接的对象缓存起来*/
private static ConcurrentMap<String, WebSocketServer> serverMap = new ConcurrentHashMap<>();
/**当前session*/
private Session currentSession;
@OnOpen
public void onOpen(Session session, EndpointConfig config) throws IOException {
System.out.println("onOpen======================");
this.currentSession = session;
String userId = (String) config.getUserProperties().get("userId");
// 能够获取到userId
System.out.println("userId======================"+userId);
//建立链接时,缓存对象
serverMap.put(userId, this);
Map<String,Object> map = new HashMap<>();
map.put("count",countUnreadMessageByUserId(userId));
session.getBasicRemote().sendText(JSON.toJSONString(map));
}
@OnClose
public void onClose(Session session) {
System.out.println("onClose======================");
// 通过getLoginUserId()获取userId
System.out.println("userId======================"+getLoginUserId());
String userId = getLoginUserId();
serverMap.remove(userId, this);
this.currentSession = null;
try {
session.close();
} catch (IOException e) {
logger.error(e.getMessage());
}
}
@OnMessage()
public void onMessage(Session session,String msg) throws IOException {
System.out.println("onMessage======================"+msg);
// 可以获取到sysProvider和userId
System.out.println("sysProvider======================"+(null == sysProvider));
System.out.println("userId======================"+getLoginUserId());
Map<String,Object> map = new HashMap<>();
map.put("count",countUnreadMessageByUserId(getLoginUserId()));
this.currentSession.getBasicRemote().sendText(JSON.toJSONString(map));
}
@OnError
public void onError(Throwable t) {
logger.error(t.getMessage());
}
/**
* 根据传入的userId得到最新的未读消息数量
* @param userId
* @author frank
*/
public void sendUnreadMessageCountByUserId(String userId){
try {
//如果连接开启,则发送最新数据
if ( serverMap.containsKey(userId) && serverMap.get(userId).currentSession.isOpen() ) {
Map<String,Object> map = new HashMap<>();
map.put("count",countUnreadMessageByUserId(userId));
serverMap.get(userId).currentSession.getBasicRemote().sendText(JSON.toJSONString(map));
}
} catch (IOException e) {
logger.error(e.getMessage());
}
}
/**
* 根据websocket的session获取userId
* @author frank
* @param
* @date 2018/11/14 14:01
* @return java.lang.String
*/
private String getLoginUserId(){
if (serverMap.containsValue(this)) {
Iterator<String> keys = serverMap.keySet().iterator();
String userId = "";
while(keys.hasNext()) {
userId = keys.next();
if (serverMap.get(userId) == this) {
return userId;
}
}
}
return null;
}
/**
* 根据用户id统计未读消息数量
* @author frank
* @param userId
* @date 2018/8/6 21:40
* @return int
*/
public int countUnreadMessageByUserId(String userId){
/*
* TODO 获取最新的数据
*/
return RandomUtils.nextInt()/100000000;
}
}