Springboot项目集成websocket

一、为什么使用websocket?

假设客户端需要感知服务端状态发生的变化(例如股票的实时行情,火车票的剩余票数等等),按照传统http的思路,是需要每隔一段时间去服务端询问有没有最新的数据,服务端有变化的数据就在response返回。

最开始的做法,是Polling(短轮询)。意思就是开个定时器,每隔(比如3s)一段时间向服务端发送一个http请求。这样做存在一些缺陷,一是性能差,因为服务端状态变化可能仅在某一时刻,在之后一段时间都不会变化,那么这些时间段发的请求都基本是无效的(没有数据)。二是实时性不够好,假设http轮询的时间间隔是3s,则有可能服务端状态变化3s后才能感知到,而不是立刻(当然你轮询时间间隔可以设置更短,然而这也意味着消耗更多的性能)。

后来,有人对Polling进行了改进,变成了LongPolling(长轮询)。长轮询和短轮询的差别在于,短轮询给服务端发送了一个请求后,如果没有数据则立即返回null,而长轮询的做法是向服务端发了一个请求后,如果没有数据不立即返回,而是一直等待,如果在等待的过程中有数据,则返回,则会响应超时,从而结束这个请求。这样做相比于polling,请求的次数少了,但是仍然存在缺陷。缺陷一是,每次数据更新都要经过客户端发起请求,服务端回复响应这一来一会,实时性还是不够好。缺陷二是,如果每次更新数据量较小,那么网络利用率会很低,因为数据包包含的http头部(General Headers + Request Headers)占比太大。

由于上述存在的问题,所以,后来推出了websocket协议。此协议基于tcp协议,能够实现服务端和客户端的双向通信。握手部分采用了http协议,然后升级成websocket协议。下图为wireshark抓取的websocket连接传输断连的过程。


websocket包分析

Springboot项目集成websocket_第1张图片
websocket握手协议头

websocket握手时使用http协议,建立连接后数据交互使用数据帧,格式如下:


Springboot项目集成websocket_第2张图片
websocket数据帧格式,参考rfc6455
数据帧中协议控制位很小,因此传输数据比一般较大。

二、springboot项目中使用websocket

1.maven依赖

        
            org.springframework.boot
            spring-boot-starter-websocket
        

2.创建handshake

package com.leaf.app.user.service.websocket;

import java.util.HashMap;
import java.util.Map;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.HandshakeInterceptor;

import com.leaf.app.common.passport.dto.TokenDTO;
import com.leaf.app.user.service.passport.service.TokenService;

@Component
public class HandShake implements HandshakeInterceptor {

    @Autowired
    private TokenService tokenService;

    private static Logger logger = LoggerFactory.getLogger(HandShake.class);

    // 建立连接前要先对token鉴权
    public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler,
                                   Map attributes) throws Exception {
        logger.debug("begin handshake, url: " + request.getURI());
        Map paramterMap = parseParameterMap(request.getURI().getQuery());
        String token = paramterMap.get("token");
        // 对token进行鉴权
        TokenDTO tokenDTO = tokenService.resolveToken(token);
        if (tokenService.checkToken(tokenDTO)) {
            // 鉴权通过后,设置当前uid
            attributes.put("uid", tokenDTO.getUid());
            return true;
        } else {
            logger.debug("handshake auth failed, url :" + request.getURI());
            return  false;
        }
    }
    public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler,
                               Exception exception) {
        logger.debug("**********afterHandshake************");
    }

    private Map parseParameterMap(String queryString) {
        Map parameterMap = new HashMap<>();
        String[] parameters = queryString.split("&");
        for (String parameter : parameters) {
            String[] paramPair = parameter.split("=");
            if (paramPair.length == 2) {
                parameterMap.put(paramPair[0], paramPair[1]);
            }
        }
        return parameterMap;
    }
}

注意,握手过程发过来的是http请求,从url请求参数解析token进行鉴权

  1. websocket 配置类
package com.leaf.app.user.service.websocket.config;

import javax.annotation.Resource;

import com.leaf.app.user.service.websocket.HandShake;
import com.leaf.app.user.service.websocket.MyWebSocketHandler;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurerAdapter;
import org.springframework.web.socket.config.annotation.EnableWebSocket;
import org.springframework.web.socket.config.annotation.WebSocketConfigurer;
import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry;

@EnableWebSocket
@Configuration
public class WebSocketConfig extends WebMvcConfigurerAdapter implements WebSocketConfigurer {
    @Resource
    MyWebSocketHandler handler;
    
    @Autowired
    HandShake handShake;
    public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
        registry.addHandler(handler, "/v1/ws").addInterceptors(handShake).setAllowedOrigins("*");
        registry.addHandler(handler, "/v1/ws/sockjs").addInterceptors(handShake).setAllowedOrigins("*").withSockJS();
    }
}

注: setAllowedOrigins主要是解决 websocket连接 403的问题

  1. websocket处理类
package com.leaf.app.user.service.websocket;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.ConcurrentHashMap;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.PingMessage;
import org.springframework.web.socket.PongMessage;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.WebSocketSession;

import com.leaf.app.common.constant.VersionConstant;
import com.leaf.app.common.user.activity.UserActivityMessage;
import com.leaf.app.user.service.activity.service.ChatService;
import com.leaf.app.user.service.sync.dto.UnreadReminderDTO;
import com.leaf.app.user.service.sync.dto.UnreadReminderListDTO;
import com.leaf.app.user.service.sync.service.MessageService;
import com.leaf.app.user.service.utils.MixUtils;
import com.leaf.app.user.service.websocket.constant.MessageType;
import com.leaf.app.user.service.websocket.entity.ChatMessage;
import com.leaf.app.user.service.websocket.entity.Message;
import com.leaf.app.user.service.websocket.util.SpringContextUtil;
import com.leaf.shared.util.JSONUtils;

@Component
public class MyWebSocketHandler implements WebSocketHandler {

    @Autowired
    private SpringContextUtil springContextUtil;

    private static final Logger logger = LoggerFactory.getLogger(MyWebSocketHandler.class);

    public static final Map userSocketSessionMap;

    static {
        userSocketSessionMap = new ConcurrentHashMap();
    }

    public static Map getUsersocketsessionmap() {
        return userSocketSessionMap;
    }

    /**
     * 建立连接后
     */
    public void afterConnectionEstablished(WebSocketSession session) throws Exception {
        Long uid = (Long) session.getAttributes().get("uid");
        logger.debug("user id: "+ uid + " established websocket session.");

        if (userSocketSessionMap.get(uid) == null) {
            userSocketSessionMap.put(uid, session);
        }
        // 建立连接后推送未读消息
        MessageService messageService = springContextUtil.getBean(MessageService.class);
        UnreadReminderListDTO unreadReminderListDTO = messageService.getUserUnreadReminders(uid);
        for (UnreadReminderDTO unreadReminderDTO : unreadReminderListDTO.getData()) {
            Message message = new Message();
            message.setMessageId(unreadReminderDTO.getId());
            message.setData(JSONUtils.toJSONString(unreadReminderDTO));
            message.setTimestamp(unreadReminderDTO.getTimestamp());
            message.setType(MessageType.UNREAD_REMINDER.getCode() + "");
            message.setVersion(VersionConstant.V1);
            session.sendMessage(new TextMessage(JSONUtils.toJSONString(message).getBytes()));
        }
    }

    /**
     * 消息处理,在客户端通过Websocket API发送的消息会经过这里,然后进行相应的处理
     */
    public void handleMessage(WebSocketSession session, WebSocketMessage message) throws Exception {
        // 如果是ping帧, 回复 pong帧
        Long uid = (Long) session.getAttributes().get("uid");
        if (message instanceof PingMessage) {
            logger.debug("user :" + uid + " is keep alive");
            ByteBuffer byteBuffer = ByteBuffer.wrap("OK".getBytes());
            session.sendMessage(new PongMessage(byteBuffer));
        } else {
            String clientMessage = message.getPayload().toString();
            if (message.getPayload().toString().length() == 0) {
                return;
            } else {
                Message bizMessage = JSONUtils.parseObject(clientMessage, Message.class);
                logger.debug("receive client data:" + clientMessage + ", message type:" + MessageType.getCHName(Integer.parseInt(bizMessage.getType())));

                // TODO 根据message 类型处理业务
                if (bizMessage.getType().equals(MessageType.CLIENT_ACK.getCode() + "")) {
                    MessageService messageService = springContextUtil.getBean(MessageService.class);
                    messageService.clearUnreadReminder(bizMessage.getMessageId());
                }
                if (bizMessage.getType().equals(MessageType.CHAT_MESSAGE.getCode() + "")) {
                    ChatMessage chatMessage = JSONUtils.parseObject(bizMessage.getData(), ChatMessage.class);
                    ChatService chatService = springContextUtil.getBean(ChatService.class);
                    UserActivityMessage userActivityMessage = new UserActivityMessage();
                    userActivityMessage.setFromUserId(Long.parseLong(chatMessage.getFromUserId()));
                    userActivityMessage.setToUserId(Long.parseLong(chatMessage.getToUserId()));
                    userActivityMessage.setOpType(10);
                    chatService.sendChatAPNSMessage(userActivityMessage);
                    // 返回服务端确认包
                    Message ackMessage = new Message();
                    ackMessage.setMessageId(bizMessage.getMessageId());
                    ackMessage.setType(MessageType.SERVER_ACK.getCode() + "");
                    Map dataMap = new HashMap<>();
                    dataMap.put("originType", bizMessage.getType());
                    ackMessage.setData(JSONUtils.toJSONString(dataMap));
                    session.sendMessage(new TextMessage(JSONUtils.toJSONString(ackMessage).getBytes()));
                }
                // 发送接收ACK确认包
//               session.sendMessage(new TextMessage(("server received data:" + clientMessage).getBytes()));
            }
        }
    }

    /**
     * 消息传输错误处理
     */
    public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
        long userId = (long) session.getAttributes().get("uid");
        if (session.isOpen()) {
            session.close();
        }
        Iterator> it = userSocketSessionMap.entrySet().iterator();
        // 移除Socket会话
        while (it.hasNext()) {
            Entry entry = it.next();
            if (entry.getValue().getId().equals(session.getId())) {
                userSocketSessionMap.remove(entry.getKey());
                logger.debug("user : " + userId + " has close websocket!");
                break;
            }
        }
    }

    /**
     * 关闭连接后
     */
    public void afterConnectionClosed(WebSocketSession session, CloseStatus closeStatus) throws Exception {
        Iterator> it = userSocketSessionMap.entrySet().iterator();
        // 移除Socket会话
        while (it.hasNext()) {
            Entry entry = it.next();
            if (entry.getValue().getId().equals(session.getId())) {
                userSocketSessionMap.remove(entry.getKey());
                break;
            }
        }
    }

    public boolean supportsPartialMessages() {
        return false;
    }

    /**
     * 给所有在线用户发送消息
     * 
     * @param message
     * @throws IOException
     */
    public void broadcast(final TextMessage message) throws IOException {
        Iterator> it = userSocketSessionMap.entrySet().iterator();
        // 多线程群发
        while (it.hasNext()) {
            final Entry entry = it.next();
            if (entry.getValue().isOpen()) {
                // entry.getValue().sendMessage(message);
                new Thread(new Runnable() {
                    public void run() {
                        try {
                            if (entry.getValue().isOpen()) {
                                entry.getValue().sendMessage(message);
                            }
                        } catch (IOException e) {
                            e.printStackTrace();
                        }
                    }
                }).start();
            }
        }
    }

    /**
     * 给某个用户发送消息
     * 
     * @param uid
     * @param message
     * @throws IOException
     */
    public void sendMessageToUser(Long uid, Message message) throws IOException {
        WebSocketSession session = userSocketSessionMap.get(uid);
        String hostName = MixUtils.getServerInfo().getHostName();
        if (session != null && session.isOpen()) {
            logger.debug("Found user :" + uid + " websocket session, server info:" + hostName + ", message:" + message.toString());
            TextMessage textMessage = new TextMessage(JSONUtils.toJSONString(message));
            session.sendMessage(textMessage);
        } else {
            logger.debug("Not found user :" + uid + " websocket session, server info:" + hostName + ", message:" + message.toString());
        }
    }

}
package com.leaf.app.user.service.websocket.util;

import org.springframework.beans.BeansException;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.stereotype.Component;

/**
 * 获取spring容器,以访问容器中定义的其他bean
 */
@Component
public class SpringContextUtil implements ApplicationContextAware {

    // Spring应用上下文环境
    @Autowired
    private ApplicationContext applicationContext;

    /**
     * 实现ApplicationContextAware接口的回调方法,设置上下文环境
     *
     * @param applicationContext
     */
    public void setApplicationContext(ApplicationContext applicationContext) {
        this.applicationContext = applicationContext;
    }

    /**
     * @return ApplicationContext
     */
    public ApplicationContext getApplicationContext() {
        return this.applicationContext;
    }

    /**
     * 获取对象 这里重写了bean方法,起主要作用
     *
     * @param name
     * @return Object 一个以所给名字注册的bean的实例
     * @throws BeansException
     */
    public Object getBean(String name) throws BeansException {
        return applicationContext.getBean(name);
    }

    public  T getBean(Class clazz) {
        return applicationContext.getBean(clazz);
    }
}

注:SpringContextUtil类获取bean主要是解决循环依赖的问题

  1. 消息推送消费
// Copyright 2017 www.chinaleaf.net All rights reserved.
package com.leaf.app.user.service.websocket.consumer;

import java.io.IOException;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;

import com.leaf.app.api.common.Constant;
import com.leaf.app.common.constant.VersionConstant;
import com.leaf.app.common.mq.MQReceiveMessage;
import com.leaf.app.common.mq.consumer.AbstractONSConsumer;
import com.leaf.app.common.mq.consumer.ConsumerContext;
import com.leaf.app.common.user.activity.UserActivityMessage;
import com.leaf.app.user.service.sync.dto.UnreadReminderDTO;
import com.leaf.app.user.service.sync.service.MessageService;
import com.leaf.app.user.service.websocket.MyWebSocketHandler;
import com.leaf.app.user.service.websocket.constant.MessageType;
import com.leaf.app.user.service.websocket.entity.Message;
import com.leaf.shared.util.JSONUtils;

/**
 * @author xiao.xianming
 * 2017年11月14日下午4:11:55
 */
@Component
public class WebSocketMessagePushConsumer extends AbstractONSConsumer {

    @Value("${ons.consumer.cid.websocket.push}")
    private String consumerId;
    @Value("${ons.consumer.subscribe.topic.websocket.push}")
    private String consumeTopic;

    @Value("${ons.consumer.subscribe.topic.tags.websocket.push}")
    private String consumeTopicTags;

    @Autowired
    private MyWebSocketHandler webSocketHandler;

    @Autowired
    private MessageService messageService;

    @Override
    protected boolean consumeMessage(MQReceiveMessage msg) {
        UserActivityMessage userActivityMessage = JSONUtils.parseObject(msg.getMessageBody(),
                UserActivityMessage.class);
        // 如果是自己的活动(自己点赞),则不推送
        if (userActivityMessage.getFromUserId() == userActivityMessage.getToUserId()) {
            return true;
        }
        // 发websocket推送消息
        UnreadReminderDTO unreadReminderDTO = messageService.convertUserActivityMessage2UnreadReminder(userActivityMessage);
        if (unreadReminderDTO != null) {
            Message message = new Message();
            message.setMessageId(unreadReminderDTO.getId());
            message.setType(MessageType.UNREAD_REMINDER.getCode() + "");
            message.setTimestamp(unreadReminderDTO.getTimestamp());
            message.setVersion(VersionConstant.V1);
            message.setData(JSONUtils.toJSONString(unreadReminderDTO));
            try {
                this.webSocketHandler.sendMessageToUser(new Long(userActivityMessage.getToUserId()), message);
            } catch (IOException e) {
                e.printStackTrace();
                logger.error(e.getMessage());
            }
        }
        return true;
    }

    @Override
    protected void initConsumerContext() {
        boolean flBroadcastConsume = true;
        this.consumeContext = new ConsumerContext(consumerId, consumeTopic, consumeTopicTags, flBroadcastConsume);
    }

}

6.nginx配置

#user  nobody;
worker_processes  1;

#error_log  logs/error.log;
#error_log  logs/error.log  notice;
#error_log  logs/error.log  info;

#pid        logs/nginx.pid;


events {
    worker_connections  1024;
}


http {
    include       mime.types;
    default_type  application/octet-stream,video/quicktime;

    #log_format  main  '$remote_addr - $remote_user [$time_local] "$request" '
    #                  '$status $body_bytes_sent "$http_referer" '
    #                  '"$http_user_agent" "$http_x_forwarded_for"';

    #access_log  logs/access.log  main;

    sendfile        on;
    #tcp_nopush     on;

    #keepalive_timeout  0;
    keepalive_timeout  65;

    upstream app_servers {
        server 116.62.206.5:8083;
        server 118.31.15.121:8083;
    }
    upstream websocket_servers {
        server 116.62.206.5:5555;
        server 118.31.15.121:5555;
    }
    map $http_upgrade $connection_upgrade {
        default upgrade;
        '' close;
    }
    #gzip  on;
    server {
        listen       80;

        listen              443 ssl;
        server_name  www.chinaleaf.net;
        ssl_certificate     /usr/local/nginx/ssl/s.crt;
        ssl_certificate_key /usr/local/nginx/ssl/nginx.key;

        #access_log  logs/host.access.log  main;

        location / {
                         root   html;
                 index  index.html index.htm;
        }
        location /v1 {
          proxy_pass http://app_servers;
          #proxy_http_version 1.1;
          #proxy_set_header Upgrade $http_upgrade;
          #proxy_set_header Connection $connection_upgrade;
        }
        location /v1/ws {
          proxy_pass http://websocket_servers;
          proxy_http_version 1.1;
          proxy_set_header Upgrade $http_upgrade;
          proxy_set_header Connection $connection_upgrade;
        }
        #error_page  404              /404.html;

        # redirect server error pages to the static page /50x.html
        #
        error_page   500 502 503 504  /50x.html;
        location = /50x.html {
            root   html;
        }

        # proxy the PHP scripts to Apache listening on 127.0.0.1:80
        #
        #location ~ \.php$ {
        #    proxy_pass   http://127.0.0.1;
        #}

        # pass the PHP scripts to FastCGI server listening on 127.0.0.1:9000
        #
        #location ~ \.php$ {
        #    root           html;
        #    fastcgi_pass   127.0.0.1:9000;
        #    fastcgi_index  index.php;
        #    fastcgi_param  SCRIPT_FILENAME  /scripts$fastcgi_script_name;
        #    include        fastcgi_params;
        #}

        # deny access to .htaccess files, if Apache's document root
        # concurs with nginx's one
        #
                                                                       #
        #location ~ /\.ht {
        #    deny  all;
        #}
    }


    # another virtual host using mix of IP-, name-, and port-based configuration
    #
    #server {
    #    listen       8000;
    #    listen       somename:8080;
    #    server_name  somename  alias  another.alias;

    #    location / {
    #        root   html;
    #        index  index.html index.htm;
    #    }
    #}


    # HTTPS server
    #
    #server {
    #    listen       443 ssl;
    #    server_name  localhost;

    #    ssl_certificate      cert.pem;
    #    ssl_certificate_key  cert.key;

    #    ssl_session_cache    shared:SSL:1m;
    #    ssl_session_timeout  5m;

    #    ssl_ciphers  HIGH:!aNULL:!MD5;
    #    ssl_prefer_server_ciphers  on;

    #    location / {
    #        root   html;
    #        index  index.html index.htm;
    #    }
    #}

}

附:前后端交互注意事项
1.消息确认和去重,服务端推送给客户端的消息,客户端需要回复ACK,对于没有ACK的消息,在每次websocket建立连接的时候会再次推送到客户端,客户端根据messageId对消息去重。
2.重接保活,客户端每隔20s会发一个ping帧,服务端回复一个pang帧,客户端那边根据结果判断是否需要重连(每次断网后再次联网也会重新连接)
3.websocket认证, 服务端在handshake时候对websocket携带的token进行鉴权
4.websocket的集群,使用nginx作为客户端代理入口
5.消息主动推送(点赞,关注等事件),通过ONS发送到各应用服务器,应用服务器广播消费这条消息,如果user的websocketsession在本机,则推送此消息,否则不做任何动作。

你可能感兴趣的:(Springboot项目集成websocket)