1.导入依赖
org.springframework.boot
spring-boot-starter-websocket
2.配置Bean
@Configuration
public class WebSocketConfig {
/**
* 注入一个ServerEndpointExporter,该Bean会自动注册使用@ServerEndpoint注解申明的websocket endpoint
*/
@Bean
public ServerEndpointExporter serverEndpointExporter() {
return new ServerEndpointExporter();
}
}
3.核心代码配置
@Slf4j
@ServerEndpoint(value = "/websocket/{token}")//通讯地址:ws://localhost:8080//websocket/token
@Component
public class WebSocket {
/**
* 存储session集合
*/
private static ConcurrentHashMap sessionMap = new ConcurrentHashMap<>();
/**
* 存储session集合
*/
private static ConcurrentHashMap userMap = new ConcurrentHashMap<>();
private final static Logger logger = LogManager.getLogger(WebSocket.class);
/**
* 静态变量,用来记录当前在线连接数。应该把它设计成线程安全的
*/
private static int onlineCount = 0;
/**
* concurrent包的线程安全Map,用来存放每个客户端对应的MyWebSocket对象
*/
private static ConcurrentHashMap webSocketMap = new ConcurrentHashMap<>();
/**
* 与某个客户端的连接会话,需要通过它来给客户端发送数据
*/
private Session session;
private String userId;
/**
* @description:连接建立成功调用的方法
* @param session
* @param token
* @author: Mr.Kai
* @time: 2023/4/3 15:48
*/
@OnOpen
public void onOpen(Session session, @PathParam("token") String token) {
System.out.println("★webSocket连接成功★,token为:"+token);
LoginUser loginUser = LoginHelper.getLoginUser(token);
if (ObjectUtil.isNull(loginUser)){
log.error("token失效或无法解析");
}
setMap(session,loginUser);
}
private void setMap(Session session, LoginUser loginUser) {
//获取用户id
Long userId = loginUser.getUserId();
//存储会话到会话集合
sessionMap.put(userId,session);
//存储用户信息到用户集合
userMap.put(userId,loginUser);
//获取会话长度(就是在线人数)
int size = sessionMap.size();
log.warn("用户连接:{},昵称:{},当前在线人数:{}",userId,loginUser.getUsername(),size);
}
/**
* 连接关闭调用的方法
*/
@OnClose
public void onClose() {
// //从map中删除
// webSocketMap.remove(userId);
// subOnlineCount(); //在线数减1
// logger.info("用户{}关闭连接!当前在线人数为{}", userId, getOnlineCount());
System.out.println("★webSocket退出成功★");
removeMap(session);
}
private void removeMap(Session session){
Long userId = getUserIdBySession(session);
if (ObjectUtil.isNull(userId)){
return;
}
sessionMap.remove(userId);
userMap.remove(userId);
}
//根据session拿到用户id
private Long getUserIdBySession(Session session) {
for (Long userId : sessionMap.keySet()){
if(sessionMap.get(userId).getId().equals(session.getId())){
return userId;
}
}
return null;
}
/**
* 收到客户端消息后调用的方法
*
* @param message 客户端发送过来的消息
*/
@OnMessage
public void onMessage(String message, Session session) {
//logger.info("来自客户端用户:{} 消息:{}",userId, message);
//群发消息
/*for (String item : webSocketMap.keySet()) {
try {
webSocketMap.get(item).sendMessage(message);
} catch (IOException e) {
e.printStackTrace();
}
}*/
System.out.println("★webSocket接收成功★内容为:"+message);
LoginUser loginUser = getUserBySession(session);
if (ObjectUtil.isNull(loginUser)){
return;
}
if (UserType.SYS_USER.getUserType().equals(loginUser.getUserType())){
//系统用户
handlePCMsg(loginUser,message);
}else {
//app用户
handleAPPMsg(loginUser,message);
}
}
private void handleAPPMsg(LoginUser loginUser, String message) {
log.info("APP用户:{},消息",loginUser.getUsername(),message);
}
private void handlePCMsg(LoginUser loginUser, String message) {
log.info("系统用户:{},消息",loginUser.getUsername(),message);
}
private LoginUser getUserBySession(Session session) {
Long userId = getUserIdBySession(session);
if (ObjectUtil.isNull(userId)){
return null;
}
return userMap.get(userId);
}
/**
* 发生错误时调用
*
* @OnError
*/
@OnError
public void onError(Session session, Throwable error) {
logger.error("用户错误:" + this.userId + ",原因:" + error.getMessage());
error.printStackTrace();
}
/**
* 向客户端发送消息
*/
public void sendMessage(String message) throws IOException {
this.session.getBasicRemote().sendText(message);
//this.session.getAsyncRemote().sendText(message);
}
/**
* 通过userId向客户端发送消息
*/
public void sendMessageByUserId(String userId, String message) throws IOException {
logger.info("服务端发送消息到{},消息:{}",userId,message);
if(StrUtil.isNotBlank(userId)&&webSocketMap.containsKey(userId)){
webSocketMap.get(userId).sendMessage(message);
}else{
logger.error("用户{}不在线",userId);
}
}
/**
* 群发自定义消息
*/
public static void sendInfo(String message) throws IOException {
for (String item : webSocketMap.keySet()) {
try {
webSocketMap.get(item).sendMessage(message);
} catch (IOException e) {
continue;
}
}
}
public static synchronized int getOnlineCount() {
return onlineCount;
}
public static synchronized void addOnlineCount() {
WebSocket.onlineCount++;
}
public static synchronized void subOnlineCount() {
WebSocket.onlineCount--;
}
/**
* @description:发送自定义消息方法
* @param toUserId
* @author: Mr.Kai
* @time: 2023/4/3 16:30
*/
public static void sendInfo(String message,Long toUserId){
log.info("发送消息到:{},消息内容:{}",toUserId,message);
if (ObjectUtil.isNull(toUserId) || StringUtils.isBlank(message)){
log.error("消息体不完整");
return;
}
// if (sessionMap.contains(toUserId)){
try {
sendMessage(sessionMap.get(toUserId),message);
} catch (Exception e) {
log.error("发送给{}的消息出错",toUserId);
}
// }else {
// //用户不在线
// log.error("用户:{}不在线",toUserId);
// //后续可以处理
// }
}
public static void sendMessage(Session session, String message) throws IOException {
session.getBasicRemote().sendText(message);
}
}
4.测试发送
@RequiredArgsConstructor
@Service
public class SystemTemplateServiceImpl implements ISystemTemplateService {
@Autowired
private WebSocket webSocket;
/**
* TODO 发送功能
*/
@Transactional(rollbackFor = Exception.class)//表示此方法有异常时触发Spring事务
@Override
public void sendMessage(Long userId,String messsge) {
try {
WebSocket.sendInfo(userId,message);//发送的id和发送的内容
} catch (IOException e) {
e.printStackTrace();
}
}