前面介绍了使用netty实现websocket通信,有些时候,如果我们的服务并不复杂或者连接数并不高,单独搭建一个websocket服务端有些浪费资源,这时候我们就可以在web服务内提供简单的websocket连接支持。其实springboot已经支持了websocket通信协议,只需要几步简单的配置就可以实现。
老规矩,首先需要引入相关的依赖:
<dependency>
<groupId>org.springframework.bootgroupId>
<artifactId>spring-boot-starter-webartifactId>
dependency>
<dependency>
<groupId>org.springframework.bootgroupId>
<artifactId>spring-boot-starter-websocketartifactId>
dependency>
<dependency>
<groupId>org.projectlombokgroupId>
<artifactId>lombokartifactId>
<version>1.18.12version>
<scope>providedscope>
dependency>
<dependency>
<groupId>org.apache.commonsgroupId>
<artifactId>commons-lang3artifactId>
<version>3.12.0version>
dependency>
springboot的配置文件application.yaml
不需要额外内容,简单指定一下端口号和服务名称就可以了:
server:
port: 8081
shutdown: graceful
spring:
application:
name: test-ws
由于我这里使用了日志,简单配置一下日志文件logback-spring.xml
输出内容:
<configuration scan="true" scanPeriod="60 seconds" debug="false">
<contextName>api-logger-servercontextName>
<appender name="console" class="ch.qos.logback.core.ConsoleAppender">
<encoder>
<pattern>%d{yyyy-MM-dd HH:mm:ss.SSS}|%thread|[%-5level]|%logger{36}.%method|%msg%npattern>
<charset>UTF-8charset>
encoder>
appender>
<appender name="msg" class="ch.qos.logback.core.rolling.RollingFileAppender">
<file>${user.dir}/logs/msg.logfile>
<encoder>
<pattern>%d{yyyy-MM-dd HH:mm:ss.SSS}|%thread|[%-5level]|%logger{36}.%method|%msg%npattern>
<charset>UTF-8charset>
encoder>
<rollingPolicy class="ch.qos.logback.core.rolling.TimeBasedRollingPolicy">
<FileNamePattern>${user.dir}/logs/msg.log.%d{yyyy-MM-dd}FileNamePattern>
rollingPolicy>
appender>
<logger name="msg" level="ERROR" additivity="false">
<appender-ref ref="msg"/>
logger>
<appender name="INFO" class="ch.qos.logback.core.rolling.RollingFileAppender">
<filter class="ch.qos.logback.classic.filter.LevelFilter">
<level>ERRORlevel>
<onMatch>DENYonMatch>
<onMismatch>ACCEPTonMismatch>
filter>
<encoder>
<pattern>%d|%t|%-5p|%c|%m%npattern>
<charset>UTF-8charset>
encoder>
<rollingPolicy class="ch.qos.logback.core.rolling.SizeAndTimeBasedRollingPolicy">
<fileNamePattern>${user.dir}/logs/info/%d.%i.logfileNamePattern>
<maxFileSize>100MBmaxFileSize>
<maxHistory>15maxHistory>
<totalSizeCap>10GBtotalSizeCap>
rollingPolicy>
appender>
<appender name="ERROR" class="ch.qos.logback.core.rolling.RollingFileAppender">
<filter class="ch.qos.logback.classic.filter.ThresholdFilter">
<level>ERRORlevel>
filter>
<encoder>
<pattern>%d|%t|%-5p|%c|%m%npattern>
<charset>UTF-8charset>
encoder>
<rollingPolicy class="ch.qos.logback.core.rolling.SizeAndTimeBasedRollingPolicy">
<fileNamePattern>${user.dir}/logs/error/%d.%i.logfileNamePattern>
<maxFileSize>100MBmaxFileSize>
<maxHistory>15maxHistory>
<totalSizeCap>10GBtotalSizeCap>
rollingPolicy>
appender>
<root level="INFO">
<appender-ref ref="console"/>
<appender-ref ref="INFO"/>
<appender-ref ref="ERROR"/>
root>
configuration>
本项目只是简单演示在springboot中使用websocket功能,所以没有涉及到复杂的业务逻辑,但还是需要定义一个用户服务类,用来存储用户身份信息和登录时的身份校验。
import lombok.Builder;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.RandomStringUtils;
import org.springframework.stereotype.Service;
import javax.annotation.PostConstruct;
import java.util.concurrent.ConcurrentHashMap;
/**
* 用户服务类
*
* @Author xingo
* @Date 2023/11/22
*/
@Slf4j
@Service
public class UserService {
static final ConcurrentHashMap<String, User> USER_MAP = new ConcurrentHashMap<>();
static final ConcurrentHashMap<String, String> TOKEN_MAP = new ConcurrentHashMap<>();
/**
* 启动时存入信息
*/
@PostConstruct
public void run() {
User user1 = User.builder().userName("zhangsan").nickName("张三").build();
User user2 = User.builder().userName("lisi").nickName("李四").build();
// 用户信息集合
USER_MAP.put(user1.getUserName(), user1);
USER_MAP.put(user2.getUserName(), user2);
// 模拟用户登录成功,将身份认证的token放入集合
String random1 = "token_" + RandomStringUtils.random(18, "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890");
String random2 = "token_" + RandomStringUtils.random(18, "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890");
log.info("用户身份信息|{}|{}", user1.getUserName(), random1);
log.info("用户身份信息|{}|{}", user2.getUserName(), random2);
TOKEN_MAP.put(random1, user1.getUserName());
TOKEN_MAP.put(random2, user2.getUserName());
}
/**
* 根据用户名获取用户信息
*/
public User getUserByUserName(String userName) {
return USER_MAP.get(userName);
}
/**
* 校验token和用户是否匹配
*/
public boolean checkToken(String token, String userName) {
return userName.equals(TOKEN_MAP.get(token));
}
/**
* 用户信息实体类
*/
@Data
@Builder
public static final class User {
private String userName;
private String nickName;
}
}
接下来就是websocket相关注入到容器中,首先需要注入的是ServerEndpointExporter
,这个类用来扫描ServerEndpoint
相关内容:
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.socket.server.standard.ServerEndpointExporter;
/**
* 注入ServerEndpointExporter,用来扫描ServerEndpoint相关注解
*
* @author xingo
* @Date 2023/11/22
*/
@Configuration
public class WebsocketConfig {
@Bean
public ServerEndpointExporter serverEndpointExporter() {
return new ServerEndpointExporter();
}
}
接下来需要再注入一个Bean,这个Bean需要添加ServerEndpoint
注解,主要用来处理websocket连接。注意这个Bean是多例的,每个websocket连接都会新建一个实例。
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.example.service.UserService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import javax.annotation.PostConstruct;
import javax.websocket.*;
import javax.websocket.server.PathParam;
import javax.websocket.server.ServerEndpoint;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
/**
* websocket服务类
* 连接ws服务这里要两个参数:userName 和 token
* userName 用于用户身份标识
* token 用于用户身份认证,用户每次登录进入系统都有可能不同
*
* @author xingo
* @Date 2023/11/22
*/
@Slf4j
@Component
@ServerEndpoint("/{userName}/{token}")
public class WebSocketEndpoint {
/**
* 存放所有在线的客户端:键为用户名,值为用户的所有连接
*/
public static final Map<String, List<Session>> USER_SESSIONS = new ConcurrentHashMap<>();
/**
* 存放连接最近一次写数据的时间戳
*/
public static final Map<Session, Long> LAST_REQUEST_TIME = new ConcurrentHashMap<>();
// ServerEndpoint 是多例的,需要设置为静态的类成员,否则程序运行会出错
private static UserService userService;
// 只能通过属性的set方法注入
@Autowired
public void setUserService(UserService userService) {
WebSocketEndpoint.userService = userService;
}
/**
* 客户端连接
* @param session
*/
@OnOpen
public void onOpen(Session session, EndpointConfig config, @PathParam("userName") String userName, @PathParam("token") String token) {
System.out.println("客户端连接|" + userName + "|" + token + "|" + session);
System.out.println(this);
System.out.println(userService);
LAST_REQUEST_TIME.put(session, System.currentTimeMillis());
if(StringUtils.isNotBlank(userName) && StringUtils.isNotBlank(token)) {
boolean flag = false;
boolean check = userService.checkToken(token, userName);
if(check) {
UserService.User user = userService.getUserByUserName(userName);
if(user != null) {
if(!USER_SESSIONS.containsKey(userName)) {
USER_SESSIONS.put(userName, new ArrayList<>());
}
USER_SESSIONS.get(userName).add(session);
flag = true;
}
}
if(flag) {
session.getAsyncRemote().sendText("连接服务端成功");
} else {
session.getAsyncRemote().sendText("用户信息认证失败,连接服务端失败");
}
} else {
session.getAsyncRemote().sendText("未获取到用户身份验证信息");
}
}
/**
* 客户端关闭
* @param session session
*/
@OnClose
public void onClose(Session session, CloseReason closeReason, @PathParam("userName") String userName, @PathParam("token") String token) {
System.out.println("客户端断开|" + userName + "|" + token + "|" + session);
if(StringUtils.isNotBlank(userName)) {
USER_SESSIONS.get(userName).remove(session);
LAST_REQUEST_TIME.remove(session);
}
LAST_REQUEST_TIME.remove(session);
}
/**
* 发生错误
* @param throwable e
*/
@OnError
public void onError(Session session, Throwable throwable) {
throwable.printStackTrace();
}
/**
* 收到客户端发来消息
* @param message 消息对象
*/
@OnMessage
public void onMessage(Session session, String message, @PathParam("userName") String userName, @PathParam("token") String token) {
log.info("接收到客户端消息|{}|{}|{}|{}", userName, token, session.getId(), message);
LAST_REQUEST_TIME.put(session, System.currentTimeMillis());
String resp = null;
try {
if("PING".equals(message)) {
resp = "PONG";
} else if("PONG".equals(message)) {
log.info("客户端响应心跳|{}", session.getId());
} else {
resp = "服务端收到信息 : " + message;
}
} catch (Exception e) {
e.printStackTrace();
}
if(resp != null) {
sendMessage(userName, resp);
}
}
/**
* 发送消息
* @param userName 用户名
* @param data 数据体
*/
public static void sendMessage(String userName, String data) {
List<Session> sessions = USER_SESSIONS.get(userName);
if(sessions != null && !sessions.isEmpty()) {
sessions.forEach(session -> session.getAsyncRemote().sendText(data));
} else {
log.error("客户端未连接|{}", userName);
}
}
/**
* 初始化方法执行标识
*/
public static final AtomicBoolean INIT_RUN = new AtomicBoolean(false);
/**
* 处理长时间没有与服务器进行通信的连接
*/
@PostConstruct
public void run() {
if(INIT_RUN.compareAndSet(false, true)) {
log.info("检查连接定时任务启动");
ScheduledExecutorService service = Executors.newScheduledThreadPool(1);
service.scheduleAtFixedRate(() -> {
// 超时关闭时间:超过5分钟未更新时间
long closeTimeout = System.currentTimeMillis() - TimeUnit.MILLISECONDS.convert(5, TimeUnit.MINUTES);
// 心跳包时间:超过2分钟未更新时间
long heartbeatTimeout = System.currentTimeMillis() - TimeUnit.MICROSECONDS.convert(2, TimeUnit.MINUTES);
Iterator<Map.Entry<Session, Long>> iterator = LAST_REQUEST_TIME.entrySet().iterator();
while (iterator.hasNext()) {
Map.Entry<Session, Long> next = iterator.next();
Session session = next.getKey();
long lastTimestamp = next.getValue();
if(lastTimestamp < closeTimeout) { // 超时链接关闭
log.error("关闭超时连接|{}", session.getId());
try {
session.close();
iterator.remove();
USER_SESSIONS.entrySet().forEach(entry -> entry.getValue().remove(session));
} catch (IOException e) {
e.printStackTrace();
}
} else if(lastTimestamp < heartbeatTimeout) { // 发送心跳包
log.info("发送心跳包|{}", session.getId());
session.getAsyncRemote().sendText("PING");
}
}
}, 5, 10, TimeUnit.SECONDS);
}
}
}
对于上面的Bean需要几点说明:
客户端连接|zhangsan|token_JTrFGlBW01gHxFZHFG|org.apache.tomcat.websocket.WsSession@7ef1b79f
org.example.websocket.WebSocketEndpoint@33141901
org.example.service.UserService@46db8a12
客户端断开|zhangsan|token_JTrFGlBW01gHxFZHFG|org.apache.tomcat.websocket.WsSession@7ef1b79f
客户端连接|zhangsan|token_JTrFGlBW01gHxFZHFG|org.apache.tomcat.websocket.WsSession@7116a4f3
org.example.websocket.WebSocketEndpoint@341424b5
org.example.service.UserService@46db8a12
客户端断开|zhangsan|token_JTrFGlBW01gHxFZHFG|org.apache.tomcat.websocket.WsSession@7116a4f3
客户端连接|zhangsan|token_JTrFGlBW01gHxFZHFG|org.apache.tomcat.websocket.WsSession@737a3e9b
org.example.websocket.WebSocketEndpoint@3678be90
org.example.service.UserService@46db8a12
private static UserService userService;
@Autowired
public void setUserService(UserService userService) {
WebSocketEndpoint.userService = userService;
}
上面几个类定义好后就实现了在springboot中使用websocket,添加启动类就可以进行前后通信:
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
/**
* 应用启动类
*
* @Author xingo
* @Date 2023/11/22
*/
@SpringBootApplication
public class WsApplication {
public static void main(String[] args) {
SpringApplication.run(WsApplication.class, args);
}
}
为了方便测试,再添加一个controller用于接收消息并将消息转发到客户端:
import org.example.websocket.WebSocketEndpoint;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RestController;
/**
* @Author xingo
* @Date 2023/11/22
*/
@RestController
public class MessageController {
/**
* 发送信息
*/
@GetMapping("/sendmessage")
public String sendMessage(String userName, String message) {
WebSocketEndpoint.sendMessage(userName, message);
return "ok";
}
}
测试服务是否正常。我这里选择使用postman进行测试,通过postman建立连接并发送消息。
连接建立成功,并且正常的发送和接收到数据。
下面测试一下通过http发送数据到服务端,服务端根据用户名查找到对应连接将消息转发到客户端。
这种模拟了服务端主动推送数据给客户端场景,实现了双向通信。
以上就是使用springboot搭建websocket的全部内容,发现还是非常简单,最主要的是可以与现有的项目实行完全融合,不需要做太多的改变。
上面这种方式只是单体服务简单的实现,对于稍微有一点规模的应用都会采用集群化部署,用一个nginx做反向代理后端搭配几个应用服务器组成集群模式,对于集群服务就会涉及到服务间通信的问题,需要将消息转发到用户正在连接的服务上面发送给客户端。后面会讲一下如何通过redis作为中心服务实现服务发现和请求转发的功能。