在springboot中实现WebSocket协议通信

前面介绍了使用netty实现websocket通信,有些时候,如果我们的服务并不复杂或者连接数并不高,单独搭建一个websocket服务端有些浪费资源,这时候我们就可以在web服务内提供简单的websocket连接支持。其实springboot已经支持了websocket通信协议,只需要几步简单的配置就可以实现。
老规矩,首先需要引入相关的依赖:

<dependency>
    <groupId>org.springframework.bootgroupId>
    <artifactId>spring-boot-starter-webartifactId>
dependency>
<dependency>
    <groupId>org.springframework.bootgroupId>
    <artifactId>spring-boot-starter-websocketartifactId>
dependency>
<dependency>
    <groupId>org.projectlombokgroupId>
    <artifactId>lombokartifactId>
    <version>1.18.12version>
    <scope>providedscope>
dependency>
<dependency>
    <groupId>org.apache.commonsgroupId>
    <artifactId>commons-lang3artifactId>
    <version>3.12.0version>
dependency>

springboot的配置文件application.yaml不需要额外内容,简单指定一下端口号和服务名称就可以了:

server:
  port: 8081
  shutdown: graceful

spring:
  application:
    name: test-ws

由于我这里使用了日志,简单配置一下日志文件logback-spring.xml输出内容:


<configuration scan="true" scanPeriod="60 seconds" debug="false">
    <contextName>api-logger-servercontextName>

    
    <appender name="console" class="ch.qos.logback.core.ConsoleAppender">
        <encoder>
            <pattern>%d{yyyy-MM-dd HH:mm:ss.SSS}|%thread|[%-5level]|%logger{36}.%method|%msg%npattern>
            <charset>UTF-8charset>
        encoder>
    appender>

    
    <appender name="msg" class="ch.qos.logback.core.rolling.RollingFileAppender">
        <file>${user.dir}/logs/msg.logfile>
        <encoder>
            <pattern>%d{yyyy-MM-dd HH:mm:ss.SSS}|%thread|[%-5level]|%logger{36}.%method|%msg%npattern>
            <charset>UTF-8charset>
        encoder>
        <rollingPolicy class="ch.qos.logback.core.rolling.TimeBasedRollingPolicy">
            <FileNamePattern>${user.dir}/logs/msg.log.%d{yyyy-MM-dd}FileNamePattern>
        rollingPolicy>
    appender>

    <logger name="msg" level="ERROR" additivity="false">
        <appender-ref ref="msg"/>
    logger>

    
    <appender name="INFO" class="ch.qos.logback.core.rolling.RollingFileAppender">
        <filter class="ch.qos.logback.classic.filter.LevelFilter">
            <level>ERRORlevel>
            <onMatch>DENYonMatch>
            <onMismatch>ACCEPTonMismatch>
        filter>
        <encoder>
            <pattern>%d|%t|%-5p|%c|%m%npattern>
            <charset>UTF-8charset>
        encoder>
        <rollingPolicy class="ch.qos.logback.core.rolling.SizeAndTimeBasedRollingPolicy">
            
            <fileNamePattern>${user.dir}/logs/info/%d.%i.logfileNamePattern>
            <maxFileSize>100MBmaxFileSize>
            
            <maxHistory>15maxHistory>
            
            <totalSizeCap>10GBtotalSizeCap>
        rollingPolicy>
    appender>

    <appender name="ERROR" class="ch.qos.logback.core.rolling.RollingFileAppender">
        <filter class="ch.qos.logback.classic.filter.ThresholdFilter">
            <level>ERRORlevel>
        filter>
        <encoder>
            <pattern>%d|%t|%-5p|%c|%m%npattern>
            <charset>UTF-8charset>
        encoder>
        
        <rollingPolicy class="ch.qos.logback.core.rolling.SizeAndTimeBasedRollingPolicy">
            
            <fileNamePattern>${user.dir}/logs/error/%d.%i.logfileNamePattern>
            <maxFileSize>100MBmaxFileSize>
            
            <maxHistory>15maxHistory>
            
            <totalSizeCap>10GBtotalSizeCap>
        rollingPolicy>
    appender>

    <root level="INFO">
        <appender-ref ref="console"/>
        <appender-ref ref="INFO"/>
        <appender-ref ref="ERROR"/>
    root>
configuration>

本项目只是简单演示在springboot中使用websocket功能,所以没有涉及到复杂的业务逻辑,但还是需要定义一个用户服务类,用来存储用户身份信息和登录时的身份校验。

import lombok.Builder;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.RandomStringUtils;
import org.springframework.stereotype.Service;

import javax.annotation.PostConstruct;
import java.util.concurrent.ConcurrentHashMap;

/**
 * 用户服务类
 *
 * @Author xingo
 * @Date 2023/11/22
 */
@Slf4j
@Service
public class UserService {

    static final ConcurrentHashMap<String, User> USER_MAP = new ConcurrentHashMap<>();
    static final ConcurrentHashMap<String, String> TOKEN_MAP = new ConcurrentHashMap<>();

    /**
     * 启动时存入信息
     */
    @PostConstruct
    public void run() {
        User user1 = User.builder().userName("zhangsan").nickName("张三").build();
        User user2 = User.builder().userName("lisi").nickName("李四").build();

        // 用户信息集合
        USER_MAP.put(user1.getUserName(), user1);
        USER_MAP.put(user2.getUserName(), user2);

        // 模拟用户登录成功,将身份认证的token放入集合
        String random1 = "token_" + RandomStringUtils.random(18, "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890");
        String random2 = "token_" + RandomStringUtils.random(18, "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890");

        log.info("用户身份信息|{}|{}", user1.getUserName(), random1);
        log.info("用户身份信息|{}|{}", user2.getUserName(), random2);
        TOKEN_MAP.put(random1, user1.getUserName());
        TOKEN_MAP.put(random2, user2.getUserName());
    }

    /**
     * 根据用户名获取用户信息
     */
    public User getUserByUserName(String userName) {
        return USER_MAP.get(userName);
    }

    /**
     * 校验token和用户是否匹配
     */
    public boolean checkToken(String token, String userName) {
        return userName.equals(TOKEN_MAP.get(token));
    }

    /**
     * 用户信息实体类
     */
    @Data
    @Builder
    public static final class User {
        private String userName;
        private String nickName;
    }
}

接下来就是websocket相关注入到容器中,首先需要注入的是ServerEndpointExporter,这个类用来扫描ServerEndpoint相关内容:

import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.socket.server.standard.ServerEndpointExporter;

/**
 * 注入ServerEndpointExporter,用来扫描ServerEndpoint相关注解
 *
 * @author xingo
 * @Date 2023/11/22
 */
@Configuration
public class WebsocketConfig {

	@Bean
	public ServerEndpointExporter serverEndpointExporter() {
		return new ServerEndpointExporter();
	}
}

接下来需要再注入一个Bean,这个Bean需要添加ServerEndpoint注解,主要用来处理websocket连接。注意这个Bean是多例的,每个websocket连接都会新建一个实例。

import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.example.service.UserService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import javax.annotation.PostConstruct;
import javax.websocket.*;
import javax.websocket.server.PathParam;
import javax.websocket.server.ServerEndpoint;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;

/**
 * websocket服务类
 * 连接ws服务这里要两个参数:userName 和 token
 * userName 用于用户身份标识
 * token    用于用户身份认证,用户每次登录进入系统都有可能不同
 *
 * @author xingo
 * @Date 2023/11/22
 */
@Slf4j
@Component
@ServerEndpoint("/{userName}/{token}")
public class WebSocketEndpoint {

    /**
     * 存放所有在线的客户端:键为用户名,值为用户的所有连接
     */
    public static final Map<String, List<Session>> USER_SESSIONS = new ConcurrentHashMap<>();
    /**
     * 存放连接最近一次写数据的时间戳
     */
    public static final Map<Session, Long> LAST_REQUEST_TIME = new ConcurrentHashMap<>();

    // ServerEndpoint 是多例的,需要设置为静态的类成员,否则程序运行会出错
    private static UserService userService;

    // 只能通过属性的set方法注入
    @Autowired
    public void setUserService(UserService userService) {
        WebSocketEndpoint.userService = userService;
    }

    /**
     * 客户端连接
     * @param session
     */
    @OnOpen
    public void onOpen(Session session, EndpointConfig config, @PathParam("userName") String userName, @PathParam("token") String token) {
        System.out.println("客户端连接|" + userName + "|" + token + "|" + session);
        System.out.println(this);
        System.out.println(userService);
        LAST_REQUEST_TIME.put(session, System.currentTimeMillis());
        if(StringUtils.isNotBlank(userName) && StringUtils.isNotBlank(token)) {
            boolean flag = false;
            boolean check = userService.checkToken(token, userName);
            if(check) {
                UserService.User user = userService.getUserByUserName(userName);
                if(user != null) {
                    if(!USER_SESSIONS.containsKey(userName)) {
                        USER_SESSIONS.put(userName, new ArrayList<>());
                    }
                    USER_SESSIONS.get(userName).add(session);
                    flag = true;
                }
            }

            if(flag) {
                session.getAsyncRemote().sendText("连接服务端成功");
            } else {
                session.getAsyncRemote().sendText("用户信息认证失败,连接服务端失败");
            }
        } else {
            session.getAsyncRemote().sendText("未获取到用户身份验证信息");
        }
    }

    /**
     * 客户端关闭
     * @param session session
     */
    @OnClose
    public void onClose(Session session, CloseReason closeReason, @PathParam("userName") String userName, @PathParam("token") String token) {
        System.out.println("客户端断开|" + userName + "|" + token + "|" + session);
        if(StringUtils.isNotBlank(userName)) {
            USER_SESSIONS.get(userName).remove(session);
            LAST_REQUEST_TIME.remove(session);
        }
        LAST_REQUEST_TIME.remove(session);
    }

    /**
     * 发生错误
     * @param throwable e
     */
    @OnError
    public void onError(Session session, Throwable throwable) {
        throwable.printStackTrace();
    }

    /**
     * 收到客户端发来消息
     * @param message  消息对象
     */
    @OnMessage
    public void onMessage(Session session, String message, @PathParam("userName") String userName, @PathParam("token") String token) {
        log.info("接收到客户端消息|{}|{}|{}|{}", userName, token, session.getId(), message);

        LAST_REQUEST_TIME.put(session, System.currentTimeMillis());
        String resp = null;
        try {
            if("PING".equals(message)) {
                resp = "PONG";
            } else if("PONG".equals(message)) {
                log.info("客户端响应心跳|{}", session.getId());
            } else {
                resp = "服务端收到信息 : " + message;
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
        if(resp != null) {
            sendMessage(userName, resp);
        }
    }

    /**
     * 发送消息
     * @param userName      用户名
     * @param data          数据体
     */
    public static void sendMessage(String userName, String data) {
        List<Session> sessions = USER_SESSIONS.get(userName);
        if(sessions != null && !sessions.isEmpty()) {
            sessions.forEach(session -> session.getAsyncRemote().sendText(data));
        } else {
            log.error("客户端未连接|{}", userName);
        }
    }

    /**
     * 初始化方法执行标识
     */
    public static final AtomicBoolean INIT_RUN = new AtomicBoolean(false);

    /**
     * 处理长时间没有与服务器进行通信的连接
     */
    @PostConstruct
    public void run() {
        if(INIT_RUN.compareAndSet(false, true)) {
            log.info("检查连接定时任务启动");

            ScheduledExecutorService service = Executors.newScheduledThreadPool(1);
            service.scheduleAtFixedRate(() -> {
                // 超时关闭时间:超过5分钟未更新时间
                long closeTimeout = System.currentTimeMillis() - TimeUnit.MILLISECONDS.convert(5, TimeUnit.MINUTES);
                // 心跳包时间:超过2分钟未更新时间
                long heartbeatTimeout = System.currentTimeMillis() - TimeUnit.MICROSECONDS.convert(2, TimeUnit.MINUTES);

                Iterator<Map.Entry<Session, Long>> iterator = LAST_REQUEST_TIME.entrySet().iterator();
                while (iterator.hasNext()) {
                    Map.Entry<Session, Long> next = iterator.next();
                    Session session = next.getKey();
                    long lastTimestamp = next.getValue();
                    if(lastTimestamp < closeTimeout) {    // 超时链接关闭
                        log.error("关闭超时连接|{}", session.getId());
                        try {
                            session.close();
                            iterator.remove();
                            USER_SESSIONS.entrySet().forEach(entry -> entry.getValue().remove(session));
                        } catch (IOException e) {
                            e.printStackTrace();
                        }
                    } else if(lastTimestamp < heartbeatTimeout) {   // 发送心跳包
                        log.info("发送心跳包|{}", session.getId());
                        session.getAsyncRemote().sendText("PING");
                    }
                }
            }, 5, 10, TimeUnit.SECONDS);
        }
    }
}

对于上面的Bean需要几点说明:

  1. 该Bean是多例的,每个websocket连接都会创建一个实例。在上面连接建立的方法里面输出当前实例对象的内容每个连接输出的内容都不同:
客户端连接|zhangsan|token_JTrFGlBW01gHxFZHFG|org.apache.tomcat.websocket.WsSession@7ef1b79f
org.example.websocket.WebSocketEndpoint@33141901
org.example.service.UserService@46db8a12
客户端断开|zhangsan|token_JTrFGlBW01gHxFZHFG|org.apache.tomcat.websocket.WsSession@7ef1b79f
客户端连接|zhangsan|token_JTrFGlBW01gHxFZHFG|org.apache.tomcat.websocket.WsSession@7116a4f3
org.example.websocket.WebSocketEndpoint@341424b5
org.example.service.UserService@46db8a12
客户端断开|zhangsan|token_JTrFGlBW01gHxFZHFG|org.apache.tomcat.websocket.WsSession@7116a4f3
客户端连接|zhangsan|token_JTrFGlBW01gHxFZHFG|org.apache.tomcat.websocket.WsSession@737a3e9b
org.example.websocket.WebSocketEndpoint@3678be90
org.example.service.UserService@46db8a12
  1. 在该类中注入其他的Bean要设置为静态属性,并且注入要通过set方法,否则注入失败,之前在项目中使用时就出现过这种问题,将属性定义为成员变量并且直接在属性上面添加@Autowired注解,导致该属性一直是null。
    比如我的UserService服务就是通过这种方式注入的:
private static UserService userService;

@Autowired
public void setUserService(UserService userService) {
    WebSocketEndpoint.userService = userService;
}

上面几个类定义好后就实现了在springboot中使用websocket,添加启动类就可以进行前后通信:

import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;

/**
 * 应用启动类
 * 
 * @Author xingo
 * @Date 2023/11/22
 */
@SpringBootApplication
public class WsApplication {

    public static void main(String[] args) {
        SpringApplication.run(WsApplication.class, args);
    }
}

为了方便测试,再添加一个controller用于接收消息并将消息转发到客户端:

import org.example.websocket.WebSocketEndpoint;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RestController;

/**
 * @Author xingo
 * @Date 2023/11/22
 */
@RestController
public class MessageController {

    /**
     * 发送信息
     */
    @GetMapping("/sendmessage")
    public String sendMessage(String userName, String message) {
        WebSocketEndpoint.sendMessage(userName, message);
        return "ok";
    }
}

测试服务是否正常。我这里选择使用postman进行测试,通过postman建立连接并发送消息。
在springboot中实现WebSocket协议通信_第1张图片
连接建立成功,并且正常的发送和接收到数据。
下面测试一下通过http发送数据到服务端,服务端根据用户名查找到对应连接将消息转发到客户端。
在springboot中实现WebSocket协议通信_第2张图片
在springboot中实现WebSocket协议通信_第3张图片
这种模拟了服务端主动推送数据给客户端场景,实现了双向通信。

以上就是使用springboot搭建websocket的全部内容,发现还是非常简单,最主要的是可以与现有的项目实行完全融合,不需要做太多的改变。

上面这种方式只是单体服务简单的实现,对于稍微有一点规模的应用都会采用集群化部署,用一个nginx做反向代理后端搭配几个应用服务器组成集群模式,对于集群服务就会涉及到服务间通信的问题,需要将消息转发到用户正在连接的服务上面发送给客户端。后面会讲一下如何通过redis作为中心服务实现服务发现和请求转发的功能。

你可能感兴趣的:(spring,boot,websocket,后端)