1、配置入口:
import com.mti.handler.MessageHandler;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.Ordered;
import org.springframework.web.reactive.HandlerMapping;
import org.springframework.web.reactive.handler.SimpleUrlHandlerMapping;
import org.springframework.web.reactive.socket.WebSocketHandler;
import org.springframework.web.reactive.socket.server.support.WebSocketHandlerAdapter;
import java.util.HashMap;
import java.util.Map;
/**
* ThreadConfig class
*
* @author zhaoyj
* @date 2019/3/12
*/
@Configuration
public class WebSocketConfiguration {
@Autowired
@Bean
public HandlerMapping webSocketMapping(final MessageHandler echoHandler) {
final Map map = new HashMap<>();
map.put("/echo", echoHandler);
final SimpleUrlHandlerMapping mapping = new SimpleUrlHandlerMapping();
mapping.setOrder(Ordered.HIGHEST_PRECEDENCE);
mapping.setUrlMap(map);
return mapping;
}
@Bean
public WebSocketHandlerAdapter handlerAdapter() {
return new WebSocketHandlerAdapter();
}
}
2、配置Handler
import com.alibaba.fastjson.JSONObject;
import com.google.protobuf.InvalidProtocolBufferException;
import com.mti.configuration.Systemconfig;
import com.mti.enums.ReferenceMsgType;
import com.mti.exception.BusinessException;
import com.mti.handler.up.StreamReferenceReq;
import com.mti.proto.Linkproto;
import com.mti.vo.Message;
import com.mti.websocket.SocketSessionRegistry;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.core.io.buffer.DataBufferFactory;
import org.springframework.core.task.TaskExecutor;
import org.springframework.stereotype.Component;
import org.springframework.web.reactive.socket.WebSocketHandler;
import org.springframework.web.reactive.socket.WebSocketMessage;
import org.springframework.web.reactive.socket.WebSocketSession;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import java.util.Optional;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
/**
* MessageHandler class
*
* @author zhaoyj
* @date 2019/3/12
*/
@Component
@Slf4j
public class MessageHandler implements WebSocketHandler {
@Autowired
private SocketSessionRegistry sessionRegistry;
@Autowired
private ScheduledThreadPoolExecutor executor;
@Autowired
private Systemconfig systemconfig;
@Autowired
DispatchFactory dispatchFactory;
@Autowired
TaskExecutor taskExecutor;
@Override
public Mono handle(WebSocketSession session) {
return session.receive().doOnSubscribe(s -> {
log.info("发起连接:{}",s);
/**
* 你有10秒时间登陆,不登陆就关掉连接;并且不给任何错误信息
*/
if(systemconfig.getLoginInterval() != 0){
executor.schedule(() -> sessionRegistry.checkAndRemove(session),systemconfig.getLoginInterval(),TimeUnit.SECONDS);
}
if(systemconfig.getPingInterval() != 0){
executor.schedule(() -> session.send(Flux.just(session.pingMessage(DataBufferFactory::allocateBuffer))).toProcessor(),systemconfig.getPingInterval(), TimeUnit.SECONDS);
}
}).doOnTerminate(() -> {
sessionRegistry.unregisterSession(session);
StreamReferenceReq req = (StreamReferenceReq) dispatchFactory.getCommand(ReferenceMsgType.SEND_VALUE);
taskExecutor.execute(() -> Optional.ofNullable(req.removeSession(session)).ifPresent(list -> list.forEach(req::sendStopStreamConfig)));
log.info("doOnTerminate");
}).doOnComplete(() -> {
log.info("doOnComplete");
}).doOnCancel(() -> {
log.info("doOnCancel");
}).doOnNext(message -> {
if(message.getType().equals(WebSocketMessage.Type.BINARY)){
log.info("收到二进制消息");
Linkproto.LinkCmd linkCmd = null;
try {
linkCmd = Optional.ofNullable(Linkproto.LinkCmd.parseFrom(message.getPayload().asByteBuffer())).orElseThrow(() -> new BusinessException(500,"解析出错了"));
BaseDispatch dispatch = dispatchFactory.getCommand(linkCmd.getTypeValue());
log.info("处理session,{},消息实体,{},类型,{},dispatch:{}",session,linkCmd,linkCmd.getTypeValue(),dispatch);
dispatch.excuted(session, linkCmd);
} catch (InvalidProtocolBufferException e) {
e.printStackTrace();
}
}else if(message.getType().equals(WebSocketMessage.Type.TEXT)){
String content = message.getPayloadAsText();
log.info("收到文本消息:{}",content);
Message msg = null;
try{
msg = JSONObject.parseObject(content, Message.class);
}catch (Exception e){
JSONObject obj = new JSONObject();
obj.put("content","无法理解你发过来的消息内容,不予处理:"+content);
obj.put("msgType",Linkproto.LinkCmdType.LINK_CMD_ZERO_VALUE);
session.send(Flux.just(session.textMessage(obj.toJSONString()))).then().toProcessor();
log.error("解析消息内容出错");
return;
}
BaseDispatch dispatch = dispatchFactory.getCommand(msg.getMsgType());
if(dispatch != null){
dispatch.executeMsg(session, msg);
}
}else if(message.getType().equals(WebSocketMessage.Type.PING)){
session.send(Flux.just(session.pongMessage(s -> s.wrap(new byte[256]))));
log.info("收到ping消息");
}else if(message.getType().equals(WebSocketMessage.Type.PONG)){
log.info("收到pong消息");
if(systemconfig.getPingInterval() != 0){
executor.schedule(() -> session.send(Flux.just(session.pingMessage(DataBufferFactory::allocateBuffer))).toProcessor(),systemconfig.getPingInterval(), TimeUnit.SECONDS);
}
}
}).doOnError(e -> {
e.printStackTrace();
log.error("doOnError");
}).doOnRequest(r -> {
log.info("doOnRequest");
}).then();
}
这边显示的是整个从连接建立到连接断开的生命周期,可以区区分二进制消息还是文本消息,发送消息时,一定要加上toProcessor(),不然不会发送。
如果要发送消息到其它客户端,需要在后台将连接过来的session保存起来,根据用户名或者其它方式保存之后,获取到session进行发送:如下面这个SocketSessionRegistry类
import com.mti.enums.SocketCloseStatus;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import org.springframework.web.reactive.socket.WebSocketSession;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.CopyOnWriteArraySet;
import java.util.concurrent.CountDownLatch;
/**
*用户session记录类
*
* @author zhaoyj
* @date 2019/3/12
*/
@Service
@Slf4j
public class SocketSessionRegistry {
/**
* 这个集合存储session
*/
private final ConcurrentMap> userSessionIds = new ConcurrentHashMap<>();
private final ConcurrentMap clientInfoSessionIds = new ConcurrentHashMap<>();
private final ConcurrentMap sessionIdUser = new ConcurrentHashMap<>();
private ConcurrentMap cacheTimestamp = new ConcurrentHashMap<>();
private final Object lock = new Object();
/**
* 获取sessionId
*
* @param user
* @return
*/
private Set getSessionIds(String user) {
Set set = this.userSessionIds.get(user);
return set != null ? set : Collections.emptySet();
}
/**
* 获取用户session
* @param user
* @return
*/
public Collection getSessionByUser(String user){
Set sessionIds = Optional.ofNullable(getSessionIds(user)).orElse(new CopyOnWriteArraySet<>());
List sessions = new ArrayList<>();
for (String sessionId : sessionIds) {
sessions.add(clientInfoSessionIds.get(sessionId));
}
return sessions;
}
/**
* 获取用户session
* @param users
* @return
*/
public Collection getSessionByUsers(Collection users){
List sessions = new ArrayList<>();
if(!CollectionUtils.isEmpty(users)){
for (String user : users) {
sessions.addAll(getSessionByUser(user));
}
}
return sessions;
}
/**
* 获取所有session
* @return Collection
*/
public Collection getAllSessions(){
return clientInfoSessionIds.values();
}
/**
* 获取所有session
*
* @return
*/
public ConcurrentMap> getAllSessionIds() {
return this.userSessionIds;
}
/**
* 获取所有session
*
* @return
*/
public ConcurrentMap getAllSessionWebSocketInfos() {
return this.clientInfoSessionIds;
}
/**
* register session
*
* @param user
* @param sessionId
*/
private void registerSessionId(String user, String sessionId) {
synchronized (this.lock) {
Set set = this.userSessionIds.get(user);
if (set == null) {
set = new CopyOnWriteArraySet<>();
this.userSessionIds.put(user, set);
}
set.add(sessionId);
}
}
/**
* 保存session
* @param session WebSocketSession
*/
public void registerSession(WebSocketSession session,String user){
if(StringUtils.isEmpty(user)){
user = parseUserByURI(session).get("user");
}
if(!StringUtils.isEmpty(user)){
String sessionId = session.getId();
registerSessionId(user,sessionId);
registerSessionId(session);
sessionIdUser.putIfAbsent(sessionId,user);
}
}
/**
* 从session里面解析参数
* @param session
* @return
*/
private Map parseUserByURI(WebSocketSession session){
Map map = new HashMap<>();
String[] params = Optional.ofNullable(session.getHandshakeInfo().getUri().getQuery()).orElse("").split("&");
for (String param : params) {
String[] temp = param.split("=");
if(temp.length == 2){
map.put(temp[0],temp[1]);
}
}
return map;
}
public WebSocketSession getSessionBySessionId(String sessionId){
return this.clientInfoSessionIds.get(sessionId);
}
private void registerSessionId(WebSocketSession websocketInfo) {
String sessionId = websocketInfo.getId();
CountDownLatch signal = cacheTimestamp.putIfAbsent(sessionId, new CountDownLatch(1));
if (signal == null) {
signal = cacheTimestamp.get(sessionId);
try {
if (!clientInfoSessionIds.containsKey(sessionId)) {
WebSocketSession set = this.clientInfoSessionIds.get(sessionId);
if (set == null) {
clientInfoSessionIds.putIfAbsent(sessionId, websocketInfo);
}
}
} finally {
signal.countDown();
cacheTimestamp.remove(sessionId);
}
} else {
try {
signal.await();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}
private void unregisterSessionId(String userName, String sessionId) {
synchronized (this.lock) {
Set set = this.userSessionIds.get(userName);
if (set != null && set.remove(sessionId) && set.isEmpty()) {
this.userSessionIds.remove(userName);
}
}
}
private void unregisterSessionId(String sessionId) {
synchronized (this.lock) {
WebSocketSession set = this.clientInfoSessionIds.get(sessionId);
if (set != null) {
this.clientInfoSessionIds.remove(sessionId);
}
}
}
public void unregisterSession(WebSocketSession session){
String sessionId = session.getId();
String user = sessionIdUser.get(sessionId);
if(!StringUtils.isEmpty(user)){
unregisterSessionId(sessionId);
unregisterSessionId(user,sessionId);
sessionIdUser.remove(sessionId);
}
}
public void checkAndRemove(WebSocketSession session){
String sessionId = session.getId();
if(!this.clientInfoSessionIds.containsKey(sessionId)){
log.info("sessionId:{} 10秒内没有登陆,关掉它",sessionId);
session.close(SocketCloseStatus.UN_LOGIN.getCloseStatus()).toProcessor();
}else{
log.info("sessinId:{}已经登陆,是合法的",sessionId);
}
}
}
userSessionIds是保存用记所属的sessionId列表的,因为同一个用户可能会在不同地方登陆,会有多个session
clientInfoSessionIds这个是保存session的,可以根据sessionId对应到用户。
这几周慢慢摸索出来的结果,网上资料很少,官网上的也不是很全,可能有不对的地方,在此做个记录!