spring webscoket服务端使用记录

记录spring4中websocket的使用方式

pom jar包配置

<dependency>  
    <groupId>org.springframeworkgroupId>  
    <artifactId>spring-websocketartifactId>  
    <version>${spring.version}version>  
dependency> 
<dependency>  
    <groupId>org.springframeworkgroupId>  
    <artifactId>spring-messagingartifactId>  
    <version>${spring.version}version>  
dependency>  

其中spring.version的配置是:


   .build.sourceEncoding>UTF-8.build.sourceEncoding>
   .version>4.0.0.RELEASE.version>
   .version>1.8.version>
   .version>1.1.6.version>
 

涉及到json消息的支持jar用的是alibaba提供的:

   
 <dependency>  
     <groupId>com.alibabagroupId>  
     <artifactId>fastjsonartifactId>  
     <version>1.2.28version>  
 dependency>  

配置websocket服务

在spring webscoket中有两种方式配置webscoket服务,一种是xml中配置,一种是使用代码继承WebSocketConfigurer,这里使用第二种:

import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.servlet.config.annotation.EnableWebMvc;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.config.annotation.EnableWebSocket;
import org.springframework.web.socket.config.annotation.WebSocketConfigurer;
import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry;

/**
 * spring websocket配置
 * @author ThatWay
 * 2018-5-8
 */
@Configuration
@EnableWebMvc
@EnableWebSocket
public class WebSocketConfig implements WebSocketConfigurer {

    @Override
    public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
        //注册webscoket处理类、webscocket的访问地址、过滤处理类
        registry.addHandler(webSocketHandler(), "/ws").addInterceptors(webSocketInterceptor());
    }

    /**
     * websocket请求处理
     * @return
     */
    @Bean
    public WebSocketHandler webSocketHandler() {
        return new WebScoketHandler();
    }

    /**
     * websocket拦截器
     * @return
     */
    @Bean
    public WebSocketInterceptor webSocketInterceptor(){
        return new WebSocketInterceptor();
    }

}

webscoket请求过滤

在上一步的服务配置中,使用的webSocketInterceptor是实现了HandshakeInterceptor接口的过滤处理类,它将拦截所有到达服务端的websocket请求,可websocket消息处理前和处理后插入动作。
这里面主要做的事是,客户端创建连接时传递的参数可以取出来,放入到创建连接后产生的session中,在服务端下发消息时可以通过参数来区分session,下面代码中作为session标识的是pageFlag参数。客户端请求的地址是这样的:ws://localhost:8080/integrate_pipe/ws?pageFlag=p1&actionFlag=simple

import javax.servlet.http.HttpServletRequest;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.util.StringUtils;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.HandshakeInterceptor;

/**
 * websocket请求过滤器
 * @author ThatWay
 * 2018-5-8
 */
public class WebSocketInterceptor implements HandshakeInterceptor {

    private static Logger logger = LoggerFactory.getLogger(WebSocketInterceptor.class); 

    @Override
    public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Exception exception) {
        logger.info("webscoket处理后过滤回调触发");
    }

    @Override
    public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map attributes) throws Exception {
        logger.info("webscoket处理前过滤回调触发");

        boolean flag = true;

        //在调用handler前处理方法
        if (request instanceof ServletServerHttpRequest) {
            ServletServerHttpRequest serverHttpRequest = (ServletServerHttpRequest) request; 
            HttpServletRequest req = serverHttpRequest.getServletRequest();
            // 从请求中获取页面标志
            String pageFlag = req.getParameter("pageFlag");
            // 获取初始化需要的数据
            String actionFlag = req.getParameter("actionFlag");

            if(StringUtils.isEmpty(pageFlag) || StringUtils.isEmpty(actionFlag) ){
                flag = false;
                logger.info("webscoket连接请求,页面标志pageFlag:"+pageFlag+",动作标志:"+actionFlag+",参数不正确,请求拒绝");
            } else {
                logger.info("webscoket连接请求,页面标志pageFlag:"+pageFlag+",动作标志:"+actionFlag);
                // 将页面标识放入参数中,之后的session将根据这个值来区分
                attributes.put("pageFlag", pageFlag.trim());
                attributes.put("actionFlag", actionFlag.trim());
            }
        } else {
            flag = false;
        }

        return flag;
    }
}

消息处理

在服务配置中,使用的WebSocketHandler是继承了TextWebSocketHandler的消息处理类,将由这个类来处理消息,spring中将webscoket相关的生命周期回调也封装到了这里。另外,通过@Service将此类注解为服务,在其他业务controller中就可以使用此类方法触发消息下发了。

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Service;
import org.springframework.util.StringUtils;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.TextWebSocketHandler;

import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;

import cn.qingk.entity.User;

/**
 * 消息处理类
 * @author ThatWay
 * 2018-5-5
 */
@Service
public class WebScoketHandler extends TextWebSocketHandler {

    private static Logger logger = LoggerFactory.getLogger(WebSocketHandler.class); 
    // 页面标识名称
    private final String CLIENT_ID = "pageFlag";
    // 初始化动作标识名称
    private final String ACTION_INIT = "actionFlag";
    // 页面集合
    private static Map clients = new ConcurrentHashMap(); 
    // 静态变量,用来记录当前在线连接数
    private static final AtomicInteger connectCount = new AtomicInteger(0);

    /***********
    /**
     * 连接建立成功后的回调
     */
    @Override    
    public void afterConnectionEstablished(WebSocketSession session) throws Exception { 
        logger.info("wescoket成功建立连接");  

        // 页面标识
        String pageFlag = getAttributeFlag(session,this.CLIENT_ID);
        // 初始化动作标识
        String reqAction = getAttributeFlag(session,this.ACTION_INIT);

        // 返回结果
        int code = WebSocketStatus.CODE_FAIL;
        String msg = WebSocketStatus.MSG_FAIL;
        String returnJson = "";

        if (!StringUtils.isEmpty(pageFlag)) {
          // 连接数加一,为了保证多个同页面标识的请求能被处理
          addOnlineCount();
          int onlineCount = getOnlineCount();
          String key = pageFlag+"_"+onlineCount;
          //管理已连接的session
          clients.put(key, session);
          logger.info("在线屏数:"+onlineCount);

          // 从数据库里查询需要信息返回
          code = WebSocketStatus.CODE_SUCCESS;
          msg = WebSocketStatus.MSG_SUCCESS;
          // 查询数据库得到type
          String type = WebSocketStatus.TYPE_BDXW;

          if (reqAction.toLowerCase().equals(WebSocketStatus.ACTION_SIMPLE)) {

            // DB基本数据
            logger.info("数据库查询【"+pageFlag+"】的基本数据");
            Map infoMap = new HashMap();
            infoMap.put("type", "qwzx");
            infoMap.put("title", "全网资讯");
            returnJson = this.makeInfoResponseJson(code, type,reqAction, msg, infoMap);

          } else if (reqAction.toLowerCase().equals(WebSocketStatus.ACTION_DETAIL)) {

            // DB数据列表
            logger.info("数据库查询【"+pageFlag+"】的列表数据");
            int totalCount = 1;

            List userList = new ArrayList();
            User user1 = new User();
            user1.setAddress("address 1");
            user1.setAge(18);
            user1.setId(1);
            user1.setName("name 1");
            userList.add(user1);

            returnJson = this.makeListResponseJson(code, type,reqAction, msg, totalCount,userList);

         } else {
            code = WebSocketStatus.CODE_FAIL;
            msg = WebSocketStatus.MSG_FAIL;
            logger.error("客户端请求的action为:"+reqAction);
         }

          // 返回信息
          TextMessage returnMessage = new TextMessage(returnJson); 
          session.sendMessage(returnMessage); 

        } else {
          session.sendMessage(new TextMessage("无页面标识,连接关闭!")); 
          session.close();
        }
    }    

    /**
     * 接收消息处理
     * 客户端发送消息需遵循的格式:
        {
            "pageFlag": "p1",
            "actionFlag": "simple/detail"
        }
     */
    @Override    
    public void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {   

        long start = System.currentTimeMillis();

        // 返回结果
        int code = WebSocketStatus.CODE_FAIL;
        String msg = WebSocketStatus.MSG_FAIL;
        String returnJson = "";

        //接收终端发过来的消息
        String reqMsg = message.getPayload();

        // 根据页面标识进行逻辑处理,提取需要的数据
        if (!StringUtils.isEmpty(msg)) {
            JSONObject terminalMsg = JSONObject.parseObject(reqMsg);
            if (!terminalMsg.isEmpty()) {
                if (terminalMsg.containsKey("pageFlag") && terminalMsg.containsKey("actionFlag")) {

                    //pageFlag
                    String reqPageFlag = terminalMsg.getString("pageFlag");
                    String reqAction = terminalMsg.getString("actionFlag");

                    // 从数据库里查询需要信息返回
                    code = WebSocketStatus.CODE_SUCCESS;
                    msg = WebSocketStatus.MSG_SUCCESS;
                    // 查询数据库得到type
                    String type = WebSocketStatus.TYPE_BDXW;

                    if (reqAction.toLowerCase().equals(WebSocketStatus.ACTION_SIMPLE)) {

                        // DB基本数据
                        logger.info("数据库查询【"+reqPageFlag+"】的基本数据");
                        Map infoMap = new HashMap();
                        infoMap.put("type", "qwzx");
                        infoMap.put("title", "全网资讯");
                        returnJson = this.makeInfoResponseJson(code, type,reqAction, msg, infoMap);

                    } else if (reqAction.toLowerCase().equals(WebSocketStatus.ACTION_DETAIL)) {

                        // DB数据列表
                        logger.info("数据库查询【"+reqPageFlag+"】的列表数据");
                        int totalCount = 1;

                        List userList = new ArrayList();
                        User user1 = new User();
                        user1.setAddress("address 1");
                        user1.setAge(18);
                        user1.setId(1);
                        user1.setName("name 1");
                        userList.add(user1);

                        returnJson = this.makeListResponseJson(code, type,reqAction, msg, totalCount,userList);

                    } else {
                        code = WebSocketStatus.CODE_FAIL;
                        msg = WebSocketStatus.MSG_FAIL;
                        logger.error("客户端请求的action为:"+reqAction);
                    }
                }
            } else {
                logger.error("客户端请求的消息转换json为空");
            }
        } else {
            logger.error("客户端请求的消息为空");
        }

        // 返回信息
        TextMessage returnMessage = new TextMessage(returnJson); 

        long pass = System.currentTimeMillis() - start;
        logger.info("接收终端请求返回:" + returnMessage.toString()+",耗时:"+pass+"ms");

        // 向终端发送信息
        session.sendMessage(returnMessage);    
    }    

    /**
     * 出现异常时的回调
     */
    @Override  
    public void handleTransportError(WebSocketSession session, Throwable thrwbl) throws Exception {    
        if(session.isOpen()){    
            session.close();  
        }    
       logger.info("websocket 连接出现异常准备关闭");
    }    

    /**
     * 连接关闭后的回调
     */
    @Override    
    public void afterConnectionClosed(WebSocketSession session, CloseStatus cs) throws Exception {    
        // 连接数减1
        for (Entry entry : clients.entrySet()) {
            String clientKey = entry.getKey();
            WebSocketSession closeSession = entry.getValue();

            if(closeSession == session){
                logger.info("移除clientKey:"+clientKey);
                clients.remove(clientKey);
                decOnlineCount();

                int leftOnlineCount = getOnlineCount();
                logger.info("剩余在线屏数:"+leftOnlineCount);
            }
        }
        logger.info("websocket 连接关闭了");    
    }    

    @Override    
    public boolean supportsPartialMessages() {    
        return false;    
    }  

    /**
     * 发送信息给指定页面
     * @param clientId
     * @param message
     * @return
     */
    public boolean sendMessageToPage(String pageFlag, TextMessage message) {

        boolean flag = false;
        int all_counter = 0;
        int send_counter = 0;
        long start = System.currentTimeMillis();

        if(!StringUtils.isEmpty(pageFlag)){

            for (Entry entry : clients.entrySet()) {
                String clientKey = entry.getKey();
                // 给所有以此id标识开头的终端发送消息
                if(clientKey.startsWith(pageFlag)){
                    all_counter++;
                    WebSocketSession session = entry.getValue();
                    if (!session.isOpen()) {
                      flag = false;
                    } else {
                        try {
                            session.sendMessage(message);
                            send_counter++;
                            flag =  true;
                            logger.info("sendMessageToPage:[clientKey:"+clientKey+"],flag:"+flag);
                        } catch (IOException e) {
                            e.printStackTrace();
                            flag = false;
                        }
                    }
                }
            }
        }

        long pass = System.currentTimeMillis() - start;
        logger.info("sendMessageToPage:"+pageFlag+",flag:"+flag+",all_counter:"+all_counter+",send_counter:"+send_counter+",pass:"+pass+"ms");   

        return flag;
    }

    /**
     * 发送信息给所有页面
     * @param clientId
     * @param message
     * @return
     */
    public boolean sendMessageToAll(TextMessage message) {

        boolean flag = false;
        int all_counter = 0;
        int send_counter = 0;
        long start = System.currentTimeMillis();

        for (Entry entry : clients.entrySet()) {  

            all_counter++;
            String clientKey = entry.getKey();
            WebSocketSession session = entry.getValue();
            if (!session.isOpen()) {
              flag =  false;
            } else {
                try {
                    session.sendMessage(message);
                    flag = true;
                    send_counter++;
                    logger.info("sendMessageToAll:[clientKey:"+clientKey+"],flag:"+flag);
                } catch (IOException e) {
                    e.printStackTrace();
                    flag = false;
                }
            } 
        }  
        long pass = System.currentTimeMillis() - start;
        logger.info("sendMessageToAll,flag:"+flag+",all_counter:"+all_counter+",send_counter:"+send_counter+",pass:"+pass+"ms"); 
        return flag;
    }

    /**
     * 给指定的精准发送消息
     * @param message
     * @param toUser
     * @throws IOException
     */
    public boolean sendMessageToId(String clientId,TextMessage message) throws IOException {  

        boolean flag = false;
        int all_counter = 0;
        int send_counter = 0;
        long start = System.currentTimeMillis();

        if(!StringUtils.isEmpty(clientId)){
            all_counter++;
            WebSocketSession session = clients.get(clientId);
            if (!session.isOpen()) {
                flag = false;
            } else {
                try {
                    session.sendMessage(message);
                    flag = true;
                    send_counter++;
                } catch (IOException e) {
                    e.printStackTrace();
                    flag = false;
                }
            } 
        }

        long pass = System.currentTimeMillis() - start;
        logger.info("sendMessageToId:"+clientId+",flag:"+flag+",all_counter:"+all_counter+",send_counter:"+send_counter+",pass:"+pass+"ms");
        return flag;
    }  

    /**
     * 获取参数标识
     * @param session
     * @return
     */
    private String getAttributeFlag(WebSocketSession session,String flagName) {

        String flag = null;
        try {
            flag = (String) session.getHandshakeAttributes().get(flagName);
        } catch (Exception e) {
            logger.error(e.getMessage());
        }

        return flag;
    }

    /**
     * 当前连接数
     * @return
     */
    private synchronized int getOnlineCount() {  
        return connectCount.get();  
    }  

    /**
     * 新增连接数
     */
    private synchronized void addOnlineCount() {  
        connectCount.getAndIncrement();
    }  

    /**
     * 减连接数
     */
    private synchronized void decOnlineCount() {  
        connectCount.getAndDecrement();
    }  

   /**
    * 生成列表响应json
    * @param code 状态码
    * @param type 数据类型
    * @param action 操作类选
    * @param msg 提示信息
    * @param totalCount 总数量
    * @param dataList 数据列表
    * @return json
    */
    public synchronized String makeListResponseJson(int code,String type,String action,String msg,int totalCount,List dataList){

        JSONObject jsonObj = new JSONObject();
        jsonObj.put("code", code);
        jsonObj.put("type", type);
        jsonObj.put("action", action);
        jsonObj.put("msg", msg);

        JSONObject contentObj = new JSONObject();
        contentObj.put("totalCount", totalCount);

        JSONArray listArray = new JSONArray(dataList);
        contentObj.put("list", listArray);

        jsonObj.put("body", contentObj);
        logger.info("生成list json:" + jsonObj.toString());
        return jsonObj.toString();
    }


    /**
     *  生成详情响应json
     * @param code 状态
     * @param type 数据类型
     * @param action 操作类型
     * @param msg 提示消息
     * @param info 数据详情
     * @return json
     */
    public synchronized String makeInfoResponseJson(int code,String type,String action,String msg,Object info){

        JSONObject jsonObj = new JSONObject();
        jsonObj.put("code", code);
        jsonObj.put("type", type);
        jsonObj.put("action", action);
        jsonObj.put("msg", msg);
        jsonObj.put("body", info);

        logger.info("生成info json:" + jsonObj.toString());
        return jsonObj.toString();
    }

} 
  

状态辅助类

在消息处理类中用到了一些状态码、下发消息等静态变量主要是为了和客户端交互时定义好消息格式的。这个类不一定需要。

public class WebSocketStatus {

    /*********************状态码 开始**********************/
    //需要根据业务具体情况扩展状态码
    // 处理成功
    public static final int CODE_SUCCESS = 200;
    // 处理失败
    public static final int CODE_FAIL = 200;
    /*********************状态码 结束**********************/

    /*********************信息 开始**********************/
    //需要根据业务具体情况扩展信息
    // 处理成功
    public static final String MSG_SUCCESS = "OK";
    // 处理失败
    public static final String MSG_FAIL = "FAIL";
    /*********************信息 结束**********************/

    /*********************数据类型 开始**********************/
    // 全网热点
    public static final String TYPE_QWRD = "qwrd";
    // 本地新闻
    public static final String TYPE_BDXW = "bdxw";
    // 网络热搜
    public static final String TYPE_WLRS = "wlrs";
    // 地方舆论
    public static final String TYPE_DFYL = "dfyl";
    // 新闻选题
    public static final String TYPE_XWXT = "xwxt";
    // 外采调度
    public static final String TYPE_WCDD = "wcdd";
    // 生产力统计
    public static final String TYPE_SCLTJ = "scltj";
    // 影响力统计
    public static final String TYPE_YXLTJ = "yxltj";
    // 任务统计
    public static final String TYPE_RWTJ = "rwtj";
    // 资讯热榜
    public static final String TYPE_ZXRB = "zxrb";
    // 视频热榜
    public static final String TYPE_SPRB = "sprb";
    // 列表自定义
    public static final String TYPE_LBZDY = "lbzdy";
    // 图表自定义
    public static final String TYPE_TBZDY = "tbzdy";
    /*********************数据类型 结束**********************/

    /*********************动作类型 开始**********************/
    // 基本信息
    public static final String ACTION_SIMPLE = "simple";
    // 详情信息
    public static final String ACTION_DETAIL = "detail";
    /*********************动作类型 开始**********************/

}

控制器中调用

这里主要是模拟了控制器中由于某个动作需要触发给指定的session发送消息。

@Controller
@RequestMapping("/testController")
public class TestController {

    public static final Logger LOGGER = Logger.getLogger(TestController.class);

    @Autowired
    private TestService testService;
    @Autowired
    private WebScoketHandler handler;

    @RequestMapping("/test")
    public void test(HttpServletRequest request, HttpServletResponse response) {
        try {

            Map infoMap = new HashMap();
            infoMap.put("type", "qwzx");
            infoMap.put("title", "全网资讯");

            TextMessage infoMessage = new TextMessage(handler.makeInfoResponseJson(WebSocketStatus.CODE_SUCCESS, WebSocketStatus.TYPE_QWRD,WebSocketStatus.ACTION_SIMPLE, WebSocketStatus.MSG_SUCCESS, infoMap));

            int totalCount = 3;
            User user1 = new User();
            user1.setAddress("address 1");
            user1.setAge(18);
            user1.setId(1);
            user1.setName("name 1");

            User user2 = new User();
            user2.setAddress("address 2");
            user2.setAge(18);
            user2.setId(1);
            user2.setName("name 2");

            User user3 = new User();
            user3.setAddress("address 3");
            user3.setAge(18);
            user3.setId(1);
            user3.setName("name 3");

            List userList = new ArrayList();
            userList.add(user1);
            userList.add(user2);
            userList.add(user3);
            TextMessage listMessage = new TextMessage(handler.makeListResponseJson(WebSocketStatus.CODE_SUCCESS, WebSocketStatus.TYPE_QWRD,WebSocketStatus.ACTION_DETAIL, WebSocketStatus.MSG_SUCCESS, totalCount,userList));

            String pageFlag = "p1";

            //向所有打开P1的浏览器发送消息
            boolean sendFlag1 = this.handler.sendMessageToPage(pageFlag, infoMessage);
            System.out.println("sendFlag1:"+sendFlag1);
            response.getWriter().print(sendFlag1);


            boolean sendFlag2 = this.handler.sendMessageToPage(pageFlag, listMessage);
            System.out.println("sendFlag1:"+sendFlag2);
            response.getWriter().print(sendFlag2);
        } catch (IOException e) {
            e.printStackTrace();
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
} 
  

                            
                        
                    
                    
                    

你可能感兴趣的:(webscoket,spring,webscoket)