springboot搭建websocket环境

1.pom文件依赖

<!-- netty -->
<dependency>
   <groupId>io.netty</groupId>
   <artifactId>netty-codec-http</artifactId>
</dependency>
<!-- JSON工具类 -->
<dependency>
   <groupId>com.alibaba</groupId>
   <artifactId>fastjson</artifactId>
   <version>1.2.50</version>
</dependency>

2.工具类CommonUtil.java

package com.ruoyi.common.utils;

import org.apache.commons.lang3.time.DateFormatUtils;

import javax.servlet.http.HttpServletRequest;
import java.text.ParseException;
import java.text.SimpleDateFormat;
import java.time.LocalDateTime;
import java.time.ZoneId;
import java.util.*;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

public class CommonUtil {

    private static final String randChars = "0123456789abcdefghigklmnopqrstuvtxyzABCDEFGHIGKLMNOPQRSTUVWXYZ";

    private static final String randUpperChars = "0123456789ABCDEFGHIGKLMNOPQRSTUVWXYZ";
    private static Random random = new Random();
    // 邮箱正则表达式
    private static String emailRegex = "^([a-z0-9A-Z]+[-|\\.]?)+[a-z0-9A-Z]@([a-z0-9A-Z]+(-[a-z0-9A-Z]+)?\\.)+[a-zA-Z]{ueditor,}$";
    // 手机号码表达式
    private static String mobileRegex = "^((1[3-9][0-9])|(15[^4,\\D])|(18[0,5-9]))\\d{8}$";

    private static AtomicInteger orderIdCount = new AtomicInteger();

    /**
     * 判断是否为空
     * @param obj Object
     * @return 空 = true,不为空 = false
     */
    public static boolean isEmpty(Object obj) {
        if (obj == null) {
            return true;
        } else if (obj instanceof String && (obj.equals("") )) {
            return true;
        } else if (obj instanceof Boolean && !((Boolean) obj)) {
            return true;
        } else if (obj instanceof Collection && ((Collection<?>) obj).isEmpty()) {
            return true;
        } else if (obj instanceof Map && ((Map<?,?>) obj).isEmpty()) {
            return true;
        } else if (obj instanceof Object[] && ((Object[]) obj).length == 0) {
            return true;
        }
        if("null".equals(obj)){
            return true;
        }
        if("NULL".equals(obj)){
            return true;
        }
        if("Null".equals(obj)){
            return true;
        }
        if("undefined".equals(obj)){
            return true;
        }
        if("UNDEFINED".equals(obj)){
            return true;
        }
        return false;
    }

    /**
     * 手机短信的code
     * @author wangghua
     * @date 2018年9月12日
     * @version 1.0
     * @return
     */
    public static String getMsgCode() {
        int codeI = (int)((Math.random()*9+1)*100000);
        String code = Integer.toString(codeI);
        return code;
    }

    /**
     * 获取随机字符
     * @param length 随机字符长度
     * @param isOnlyNum 是否只是数字
     * @return 随机字符串
     */
    public static String getRandStr(int length, boolean isOnlyNum) {
        int size = isOnlyNum ? 10 : 62;
        StringBuffer hash = new StringBuffer(length);
        for (int i = 0; i < length; i++) {
            hash.append(randChars.charAt(random.nextInt(size)));
        }
        return hash.toString();
    }
    /**
     * 获取大写字母和数字的随机字符
     * @param length 随机字符长度
     * @param isOnlyNum 是否只是数字
     * @return 随机字符串
     */
    public static String getRandUpperStr(int length, boolean isOnlyNum) {
        int size = isOnlyNum ? 10 : 36;
        StringBuffer hash = new StringBuffer(length);
        for (int i = 0; i < length; i++) {
            hash.append(randUpperChars.charAt(random.nextInt(size)));
        }
        return hash.toString();
    }


    public static String listToString(List list, char separator){
        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < list.size(); i++) {
            sb.append(list.get(i));
            if (i < list.size() - 1) {
                sb.append(separator);
            }
        }
        return sb.toString();
    }

    public static   String getFeeNoByUUId() {
        int machineId = 1;//最大支持1-9个集群机器部署
        int hashCodeV = UUID.randomUUID().toString().hashCode();
        if(hashCodeV < 0) {//有可能是负数
            hashCodeV = - hashCodeV;
        }
        // 0 代表前面补充0
        // 4 代表长度为4
        // d 代表参数为正数型
        return machineId + String.format("%015d", hashCodeV);
    }

    /**
     * 获取时间差(小时)
     * @param time 需要比较的时间
     * @return 随机字符串
     */
    public static Long getHourByTimeAndNow(Date time) {
        Long now = (new Date().getTime()-time.getTime())/(1000*60*60); //小时
        return now;
    }
    public static Long getHourByTime(Date startTime,Date endTime) {
        Long now = (endTime.getTime()-startTime.getTime())/(1000*60*60); //小时
        return now;
    }
    /**
     * 14:33:00 是否在 09:30:00 - 12:00:00 内
     *
     * @param str1  14:33:00
     * @param round 09:30:00 - 12:00:00
     */
    public static boolean timeIsInRound(String str1, String round) {
        String[] roundTime = round.split(" - ");
        return timeIsInRound(str1, roundTime[0], roundTime[1]);
    }
    /**
     * 比较一个 HH:mm:ss 是否在一个时间段内
     * 如:14:33:00 是否在 09:30:00 和 12:00:00 内
     */
    public static boolean timeIsInRound(String str1, String start, String end) {
        SimpleDateFormat df = new SimpleDateFormat("HH:mm:ss");
        Date now = null;
        Date beginTime = null;
        Date endTime = null;

        try {
            now = df.parse(str1);
            beginTime = df.parse(start);
            endTime = df.parse(end);
        } catch (Exception e) {
            e.printStackTrace();
        }
        return belongCalendar(now, beginTime, endTime);
    }

    /**
     * 判断时间是否在时间段内
     */
    public static boolean belongCalendar(Date nowTime, Date beginTime, Date endTime) {
        Calendar date = Calendar.getInstance();
        date.setTime(nowTime);

        Calendar begin = Calendar.getInstance();
        begin.setTime(beginTime);

        Calendar end = Calendar.getInstance();
        end.setTime(endTime);

        return date.after(begin) && date.before(end);
    }
    /**
     * 获取随机字符
     * @param startTime 开始时间
     * @param endTime 结束时间
     * @return 随机字符串
     */
    public static String getDayHourMin(Date startTime,Date endTime) {
        Long totalTime = endTime.getTime()-startTime.getTime();
        long day = totalTime / (24 * 60 * 60 * 1000);
        long hour = (totalTime / (60 * 60 * 1000) - day * 24);
        long min = ((totalTime / (60 * 1000)) - day * 24 * 60 - hour * 60);
        String str = day + "天" + hour + "小时" + min + "分";
        return str;
    }


    /**
     * 获取时间差(分钟)
     * @param time 需要比较的时间
     * @return 随机字符串
     */
    public static Long getMinByTimeAndNow(Date time) {
        Long now = (new Date().getTime()-time.getTime())/(1000*60); //分钟
        return now;
    }

    /**
     * 验证邮箱格式
     * @param email 待验证的字符串
     * @return 如果是符合邮箱格式的字符串,返回true,否则为false
     */
    public static boolean isEmail(String email) {
        return match(emailRegex, email);
    }

    /**
     * 验证手机号码
     * @param mobiles 待验证的字符串
     * @return 如果是符合手机号码格式的字符串,返回true,否则为false
     */
    public static boolean isMobile(String mobiles) {
        return match(mobileRegex, mobiles);
    }

    /**
     * @param regex 正则表达式字符串
     * @param str 要匹配的字符串
     * @return 如果str 符合 regex的正则表达式格式,返回true, 否则返回 false;
     */
    private static boolean match(String regex, String str) {
        if(isEmpty(str)){
            return false;
        }
        Pattern pattern = Pattern.compile(regex);
        Matcher matcher = pattern.matcher(str);
        return matcher.matches();
    }

    /**
     * 获取IP4
     * @param request
     * @return IP地址
     */
    public static String getIP4(HttpServletRequest request){
        String ip = request.getHeader("x-forwarded-for");
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("Proxy-Client-IP");
        }
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("WL-Proxy-Client-IP");
        }
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("HTTP_CLIENT_IP");
        }
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("HTTP_X_FORWARDED_FOR");
        }
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getRemoteAddr();
        }
        if ("0:0:0:0:0:0:0:1".equals(ip)) {
            ip = "本地";
        }else
        if (ip.split(",").length > 1) {
            ip = ip.split(",")[0];
        }
        return ip;
    }

    /**
     * 数组根据表达式转换为String
     * @param list
     * @param exp
     * @return
     */
    public static String listToString(List<?> list , String exp) {
        StringBuffer resourceIdsSB = new StringBuffer();
        for (int i = 0; i < list.size(); i++) {
            resourceIdsSB.append(list.get(i));
            if(i < list.size() -1){
                resourceIdsSB.append(exp);
            }
        }
        return resourceIdsSB.toString();
    }

    public static synchronized String createOrderID(){
        Random random = new Random();
        Integer number = random.nextInt(900000) + 100000;
        return System.currentTimeMillis() + String.valueOf(number);
    }
    public static synchronized String createCode(String code){
        Random random = new Random();
        Integer number = random.nextInt(9000) + 1000;
        return code+ String.valueOf(number);
    }

    public final static Integer psInt(Object obj){
        if(obj==null||"".equals(obj.toString()))
            return null;
        return Integer.parseInt(obj.toString());
    }

    public final static String psString(Object obj){
        return obj==null ? null:obj.toString();
    }

    public final static Long psLong(Object obj){
        if(obj==null||"".equals(obj.toString()))
            return null;
        return Long.parseLong(obj.toString());
    }

    public final static Float psFloat(Object obj){
        if(obj==null||"".equals(obj.toString()))
            return null;
        return Float.parseFloat(obj.toString());
    }

    public final static Short psShort(Object obj){
        if(obj==null||"".equals(obj.toString()))
            return null;
        return Short.parseShort(obj.toString());
    }

    /**
     * 获取一定长度的随机字符串
     * @param length 指定字符串长度
     * @return 一定长度的字符串
     */
    public static String getRandomStringByLength(int length) {
        String base = "abcdefghijklmnopqrstuvwxyz0123456789";
        Random random = new Random();
        StringBuffer sb = new StringBuffer();
        for (int i = 0; i < length; i++) {
            int number = random.nextInt(base.length());
            sb.append(base.charAt(number));
        }
        return sb.toString();
    }

    /**
     * 根据单号生成下一个取单号
     * @param no
     * @return
     */
    public static String takeCode(String no){
        String zms[]={"A","B","C","D","E","F","G","H","I","J","K","L","M","N","O","P","Q","R","S","T","U","V","W","X","Y","Z"};
        String tackCode="A01";
        if(CommonUtil.isEmpty(no)){
            return tackCode;
        }
        String zm=no.substring(0,1);
        Integer num=Integer.parseInt(no.substring(1,no.length()));
        if(num<99){
            num++;
            if(num<10){
                tackCode=zm+"0"+num;
            }else {
                tackCode=zm+num;
            }
        }else {
            for (int i=0;i<zms.length;i++){
                if(zm.equals(zms[i])){
                    zm=zms[i+1];
                    return zm+"01";
                }
            }
        }
        return tackCode;
    }
    public static String dateToStr(Date datetime) {
        SimpleDateFormat formatter = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
        String dateString = formatter.format(datetime);
        return dateString;
    }
    public static String dateToCode(Date datetime) {
        SimpleDateFormat formatter = new SimpleDateFormat("yyyyMMddHHmmss");
        String dateString = formatter.format(datetime);
        return dateString;
    }


    public static Date strToDate(String strDate)throws Exception {
        SimpleDateFormat formatter = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
        return formatter.parse(strDate);
    }

    //生成16位随机数字码
    public static String getSixteenByUUId() {
        int random = (int) (Math.random()*9+1);
        String valueOf = String.valueOf(random);
        //生成uuid的hashCode值
        int hashCode = UUID.randomUUID().toString().hashCode();
        //可能为负数
        if(hashCode<0){
            hashCode = -hashCode;
        }
        String value = valueOf + String.format("%015d", hashCode);
        return value;
    }
    public static String genMerOrderId(String msgId) {
        String date = DateFormatUtils.format(new Date(), "yyyyMMddHHmmss");
        int i = orderIdCount.incrementAndGet()%1000;
        if(i<1000)
            i+=1000;
        return msgId + date + i;
    }

    /**
     * @Author wangghua
     * @Description 比较时间大小
     * @Date 17:55 2021/7/30
     * @Param [startTime, endTime]
     * @return boolean
     **/
    public static boolean compareTime(Date startTime, Date endTime) {
        if (isEmpty(startTime) || isEmpty(endTime)){
            return false;
        }
        if (startTime.getTime()>endTime.getTime()){
            return false;
        }
        return true;
    }

    /**
     * 获取startDate日期后month月的日期
     * @param startDate 开始日期
     * @param month  几个月后
     * @return
     */
    public static Date getMonthDate(Date startDate,Long month){
        LocalDateTime localDateTime = startDate.toInstant()
            .atZone(ZoneId.systemDefault() )
            .toLocalDateTime().plusMonths(month);
        Date date = Date.from(localDateTime.atZone( ZoneId.systemDefault()).toInstant());
        return date;
    }
    /**
     * 获取startDate日期后day天的日期
     * @param startDate 开始日期
     * @param day  几天后
     * @return
     */
    public static Date getDayDate(Date startDate,Long day){
        LocalDateTime localDateTime = startDate.toInstant()
            .atZone(ZoneId.systemDefault() )
            .toLocalDateTime().plusDays(day);
        Date date = Date.from(localDateTime.atZone( ZoneId.systemDefault()).toInstant());
        return date;
    }
    /**
     * 获取当前时间后minute分钟的时间
     * @param minute  多少分钟后
     * @return
     */
    public static Date getMinuteDate(Long minute){
        LocalDateTime localDateTime = new Date().toInstant()
            .atZone(ZoneId.systemDefault() )
            .toLocalDateTime().plusMinutes(minute);
        Date date = Date.from(localDateTime.atZone( ZoneId.systemDefault()).toInstant());
        return date;
    }

    /**
     * 计算两个日期的天数差值
     * @param beginDate
     * @param endDate
     * @return
     */
    public static Long getDaysByTwoTime(Date beginDate,Date endDate){
        long beginTime = beginDate.getTime();
        long endTime = endDate.getTime();
        long betweenDays = (long)((beginTime - endTime) / (1000 * 60 * 60 *24) );
        return betweenDays;
    }

    /**
     * 计算两个日期之间相差的天数
     * @param smdate 较小的时间
     * @param bdate  较大的时间
     * @return 相差天数
     * @throws ParseException
     */
    public static int daysBetween(Date smdate,Date bdate) throws ParseException{
        SimpleDateFormat sdf=new SimpleDateFormat("yyyy-MM-dd");
        smdate=sdf.parse(sdf.format(smdate));
        bdate=sdf.parse(sdf.format(bdate));
        Calendar cal = Calendar.getInstance();
        cal.setTime(smdate);
        long time1 = cal.getTimeInMillis();
        cal.setTime(bdate);
        long time2 = cal.getTimeInMillis();
        long between_days=(time1-time2)/(1000*3600*24);

        return Integer.parseInt(String.valueOf(between_days));
    }

    public static void main(String[] args) throws ParseException {
        SimpleDateFormat format = new SimpleDateFormat("yyyy-MM-dd");
        System.out.println(CommonUtil.daysBetween(format.parse("2022-02-25"),new Date()));
//        System.out.println(DateUtil.da);
    }

}

3.工具类SpringContextUtil.java

package com.ruoyi.common.utils;

import org.springframework.beans.BeansException;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.stereotype.Component;

@Component
public class SpringContextUtil implements ApplicationContextAware {

// Spring应用上下文环境

	private static ApplicationContext applicationContext;

	/** * 实现ApplicationContextAware接口的回调方法。设置上下文环境

	 *

	 * @param applicationContext

	 */

	public void setApplicationContext(ApplicationContext applicationContext) {

		SpringContextUtil.applicationContext = applicationContext;

	}

	/** *

	 @return ApplicationContext

	 */

	public static ApplicationContext getApplicationContext() {

		return applicationContext;

	}

	/** * 通过name获取对象

	 *  @param name

	 * @return Object

	 * @throws BeansException

	 */

	public static Object getBeanObj(String name) throws BeansException {

		return applicationContext.getBean(name);

	}

	/** * 通过class获取对象

	 *@param

	 * @return T

	 * @throws BeansException

	 */

	public  <T> T getBean(Class<T> clazz) throws BeansException {

		return applicationContext.getBean(clazz);

	}

}

4.添加websocket包,包下面文件结构为:

springboot搭建websocket环境_第1张图片

①.ChannelSupervise类
package com.ruoyi.common.websocket;

import io.netty.channel.Channel;
import io.netty.channel.ChannelId;
import io.netty.channel.group.ChannelGroup;
import io.netty.channel.group.DefaultChannelGroup;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import io.netty.util.concurrent.GlobalEventExecutor;
import org.springframework.stereotype.Component;

import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;

@Component
public class ChannelSupervise {
    private   static ChannelGroup GlobalGroup=new DefaultChannelGroup(GlobalEventExecutor.INSTANCE);
    private  static ConcurrentMap<String, ChannelId> ChannelMap=new ConcurrentHashMap();
    private  static Map<String, Set<Channel>> map = new HashMap<>();
    private  static Map<String, String> mapUri = new HashMap<>();
    FullHttpRequest req ;
    public  static void addChannel(Channel channel){
        GlobalGroup.add(channel);
        ChannelMap.put(channel.id().asShortText(),channel.id());
    }
    public static void removeChannel(Channel channel){
        GlobalGroup.remove(channel);
        ChannelMap.remove(channel.id().asShortText());
        remoUri(channel);
    }
    public static  Channel findChannel(String id){
        return GlobalGroup.find(ChannelMap.get(id));
    }
    public static void sendChannelAll(String msg,Channel channel){
        String uri = mapUri.get(channel.id().asShortText());
        Set<Channel> list = map.get(uri);
        if(list!=null){
            for (Channel chSend:list
            ) {
                chSend.writeAndFlush(new TextWebSocketFrame(msg));

            }
        }
//        GlobalGroup.writeAndFlush(tws);
    }
    public static void sysSendAll(String msg){
        GlobalGroup.writeAndFlush(new TextWebSocketFrame(msg));
    }
    //按路径发消息
    public static void uriSend(String uri,String msg){
        Set<Channel> list = map.get(uri);
        if(list!=null){
            for (Channel chSend:list
            ) {
                chSend.writeAndFlush(new TextWebSocketFrame(msg));

            }
        }
    }
    public  static int addUri(Channel channel,
                              String uri){
        Set<Channel> list = map.get(uri);
        if(list==null)
            list = new HashSet<>();
        list.add(channel);
        map.put(uri,list);
        mapUri.put(channel.id().asShortText(),uri);
        Set<Channel> finalList = list;
//        Thread t1 = new Thread(new Runnable() {
//            @SneakyThrows
//            @Override
//            public void run() {
//                Thread.sleep(100);
//                Map msg = new HashMap();
//                Map ms= new HashMap();
//                ms.put("watchNum", Math.round(finalList.size()*2.3f+10));
//                msg.put("type",3);
//                msg.put("data",ms);
//                uriSend(uri, JSON.toJSONString(msg));
//            }
//        });
//        t1.start();
        return finalList.size();
    }
    public  static void remoUri(Channel channel){
        String uri = mapUri.get(channel.id().asShortText());
        Set<Channel> list = map.get(uri);
        if(list!=null)
            list.remove(channel);
        mapUri.remove(channel.id().asShortText());
    }
}

②.NioWebSocketChannelInitializer类
package com.ruoyi.common.websocket;

import io.netty.channel.ChannelInitializer;
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpServerCodec;
import io.netty.handler.logging.LoggingHandler;
import io.netty.handler.stream.ChunkedWriteHandler;

public class NioWebSocketChannelInitializer extends ChannelInitializer<SocketChannel> {
    @Override
    protected void initChannel(SocketChannel ch) {
        ch.pipeline().addLast("logging",new LoggingHandler("DEBUG"));//设置log监听器,并且日志级别为debug,方便观察运行流程
        ch.pipeline().addLast("http-codec",new HttpServerCodec());//设置解码器
        ch.pipeline().addLast("aggregator",new HttpObjectAggregator(65536));//聚合器,使用websocket会用到
        ch.pipeline().addLast("http-chunked",new ChunkedWriteHandler());//用于大数据的分区传输
        ch.pipeline().addLast("handler",new NioWebSocketHandler());//自定义的业务handler
    }
}

③.NioWebSocketHandler类
package com.ruoyi.common.websocket;

import com.alibaba.fastjson.JSON;
import com.ruoyi.common.cache.CacheComponent;
import com.ruoyi.common.utils.CommonUtil;
import com.ruoyi.common.utils.SpringContextUtil;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.HttpVersion;
import io.netty.handler.codec.http.websocketx.*;
import io.netty.util.CharsetUtil;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.springframework.util.StringUtils;

import java.lang.reflect.Method;
import java.util.Date;

import static io.netty.handler.codec.http.HttpUtil.isKeepAlive;

public class NioWebSocketHandler extends SimpleChannelInboundHandler<Object> {

    private final Logger logger= LogManager.getLogger();
    private WebSocketServerHandshaker handshaker;
    private final SpringContextUtil springContextUtil = new SpringContextUtil();
    private CacheComponent cacheComponent = springContextUtil.getBean(CacheComponent.class);  //Redis

    @Override
    protected void channelRead0(ChannelHandlerContext ctx, Object msg) throws Exception {

//        User user = userServiceImp.selectByuser(17l);
        if (msg instanceof FullHttpRequest){
            //以http请求形式接入,但是走的是websocket
            handleHttpRequest(ctx, (FullHttpRequest) msg);
        }else if (msg instanceof  WebSocketFrame){
            //处理websocket客户端的消息
            handlerWebSocketFrame(ctx, (WebSocketFrame) msg);
        }
    }

    @Override
    public void channelActive(ChannelHandlerContext ctx) throws Exception {
        //添加连接
        logger.info("客户端加入连接:"+ctx.channel());
        ChannelSupervise.addChannel(ctx.channel());
    }

    @Override
    public void channelInactive(ChannelHandlerContext ctx) throws Exception {
        //断开连接
        logger.info("客户端断开连接:"+ctx.channel());
        ChannelSupervise.removeChannel(ctx.channel());
    }

    @Override
    public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
        ctx.flush();
    }
    private void handlerWebSocketFrame(ChannelHandlerContext ctx, WebSocketFrame frame){
        // 判断是否关闭链路的指令
        if (frame instanceof CloseWebSocketFrame) {
            handshaker.close(ctx.channel(), (CloseWebSocketFrame) frame.retain());
            return;
        }
        // 判断是否ping消息
        if (frame instanceof PingWebSocketFrame) {
            ctx.channel().write(
                    new PongWebSocketFrame(frame.content().retain()));
            return;
        }
        // 本例程仅支持文本消息,不支持二进制消息
        if (!(frame instanceof TextWebSocketFrame)) {
            logger.info("本例程仅支持文本消息,不支持二进制消息");
            throw new UnsupportedOperationException(String.format(
                    "%s frame types not supported", frame.getClass().getName()));
        }
        // 返回应答消息
        String request = ((TextWebSocketFrame) frame).text();
        if (StringUtils.isEmpty(request)) {
            return;
        }
        if(request.indexOf("hostCode")>0){
            logger.info("收到host消息:"+request);
            WebsocketMsg msg = JSON.parseObject(request, WebsocketMsg.class);

                    try{
                        Class<?> clazz = null;
                        if(msg.getMsgType()==0){
                            //在线推送
                            return;
                        }else if(msg.getMsgType()==1){
                            //门禁
                            clazz=Class.forName("com.admin.host.ServerDoor");
                        }else if(msg.getMsgType()==2){
                            //车禁
                            clazz = Class.forName("com.admin.host.ServerPark");
                        }
                        if(clazz!=null){
                            if(msg.getQueryType()==1){
                                //推送
                                Method method = clazz.getMethod( msg.getCmd(), WebsocketMsg.class);
                                method.invoke(springContextUtil.getBean(clazz),msg);
                                msg.setData(null);
                                msg.setQueryType(2);
                                request = JSON.toJSONString(msg);
                            }else if(msg.getQueryType()==2){
                                //返回
                                cacheComponent.putRaw(msg.getUuid(), request,30);
                            }else if(msg.getQueryType()==3){
                                //请求
                                Method method = clazz.getMethod( msg.getCmd(), WebsocketMsg.class);
                                Object reMsg = method.invoke(springContextUtil.getBean(clazz),msg);
                                request = JSON.toJSONString(reMsg);
                            }

                        }

                    }catch (Exception e){
                        e.printStackTrace();
                    }


        }else {
            logger.info("收到其他消息:"+request);
        }
        String msg = CommonUtil.dateToStr(new Date())
			+ " "+ ctx.channel().id() + ":" + request;
        TextWebSocketFrame tws = new TextWebSocketFrame(msg);
        // 群发
//        ChannelSupervise.sendChannelAll(request,ctx.channel());
        // 返回【谁发的发给谁】
         ctx.channel().writeAndFlush(tws);
    }
    /**
     * 唯一的一次http请求,用于创建websocket
     * */
    private void handleHttpRequest(ChannelHandlerContext ctx,
                                   FullHttpRequest req) {
        //要求Upgrade为websocket,过滤掉get/Post
        if (!req.decoderResult().isSuccess()
                || (!"websocket".equals(req.headers().get("Upgrade")))) {
            //若不是websocket方式,则创建BAD_REQUEST的req,返回给客户端
            sendHttpResponse(ctx, req, new DefaultFullHttpResponse(
                    HttpVersion.HTTP_1_1, HttpResponseStatus.BAD_REQUEST));
            return;
        }
        String uri = req.getUri();
        int num = ChannelSupervise.addUri(ctx.channel(),uri);
        WebSocketServerHandshakerFactory wsFactory = new WebSocketServerHandshakerFactory(
                "ws://localhost:8081"+uri, null, false);
        handshaker = wsFactory.newHandshaker(req);
        if (handshaker == null) {
            WebSocketServerHandshakerFactory
                    .sendUnsupportedVersionResponse(ctx.channel());
        } else {
            handshaker.handshake(ctx.channel(), req);
        }

    }
    /**
     * 拒绝不合法的请求,并返回错误信息
     * */
    private static void sendHttpResponse(ChannelHandlerContext ctx,
                                         FullHttpRequest req, DefaultFullHttpResponse res) {
        // 返回应答给客户端
        if (res.status().code() != 200) {
            ByteBuf buf = Unpooled.copiedBuffer(res.status().toString(),
                    CharsetUtil.UTF_8);
            res.content().writeBytes(buf);
            buf.release();
        }
        ChannelFuture f = ctx.channel().writeAndFlush(res);
        // 如果是非Keep-Alive,关闭连接
        if (!isKeepAlive(req) || res.status().code() != 200) {
            f.addListener(ChannelFutureListener.CLOSE);
        }
    }
}

④.WebsocketMsg类
package com.ruoyi.common.websocket;

import lombok.Data;

@Data
public class WebsocketMsg<T> {
	//本地编码
	private String hostCode;
	//0成功,1失败
	private Integer code =0;
	//说明
	private String msg ="成功";
	//消息唯一标识
	private String uuid;
	//接口路径
	private String cmd;
	//业务数据
	private T data;
	//操作类型:1推送,2返回,3请求
	private Integer queryType;
	
	//操作命令
	private String action;
}

⑤.NioWebSocketServer类
package com.ruoyi.common.websocket;

import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelOption;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

//@Component
public class NioWebSocketServer {
    private final Logger logger= LogManager.getLogger();

    public NioWebSocketServer(int port){
        init(port);
    }
    private void init(int port){
        logger.info("正在启动websocket服务器");
        NioEventLoopGroup boss=new NioEventLoopGroup();
        NioEventLoopGroup work=new NioEventLoopGroup();
        try {
            ServerBootstrap bootstrap=new ServerBootstrap();
            bootstrap.group(boss,work);
            bootstrap.channel(NioServerSocketChannel.class);
            bootstrap.childHandler(new NioWebSocketChannelInitializer());
            bootstrap.childOption(ChannelOption.SO_KEEPALIVE,true);

            Channel channel = bootstrap.bind(port).sync().channel();
            logger.info("webSocket服务器启动成功:"+channel);
            channel.closeFuture().sync();
        } catch (InterruptedException e) {
            e.printStackTrace();
            logger.info("运行出错:"+e);
        }finally {
            boss.shutdownGracefully();
            work.shutdownGracefully();
            logger.info("websocket服务器已关闭");
        }
    }

//    public static void main(String[] args) {
//        new NioWebSocketServer().init();
//    }
}

5.缓存相关CacheComponent.java

package com.ruoyi.common.cache;

import com.alibaba.fastjson.JSONObject;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;

import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.TimeUnit;

@Component
public class CacheComponent {

	@Autowired
	private StringRedisTemplate stringRedisTemplate;


	public void updateStr(String key, String obj) {
		stringRedisTemplate.opsForValue().set(key, obj,0);
	}



	public void putObj(String key, Object obj, Integer expireSec) {
		if (expireSec != null) {
			stringRedisTemplate.opsForValue().set(key, JSONObject.toJSONString(obj), expireSec, TimeUnit.DAYS);
		} else {
			stringRedisTemplate.opsForValue().set(key, JSONObject.toJSONString(obj));
		}
	}

//    public Long incRaw(String key) {
//        return stringRedisTemplate.opsForValue().increment(key);
//    }

	public  <T> T getObj(String key, Class<T> clazz) {
		String json = stringRedisTemplate.opsForValue().get(key);
		if (StringUtils.isEmpty(json)) {
			return null;
		}
		return JSONObject.parseObject(json, clazz);
	}

	public <T> List<T> getObjList(String key, Class<T> clazz) {
		String json = stringRedisTemplate.opsForValue().get(key);
		if (StringUtils.isEmpty(json)) {
			return null;
		}
		return JSONObject.parseArray(json, clazz);
	}

	public void putHashAll(String key, Map<String, String> map, Integer expireSec) {
		stringRedisTemplate.opsForHash().putAll(key, map);
		stringRedisTemplate.expire(key, expireSec, TimeUnit.SECONDS);
	}

	public Map<String,String> getHashAll(String key) {
		if (!stringRedisTemplate.hasKey(key)) {
			return null;
		}
		return (Map)stringRedisTemplate.opsForHash().entries(key);
	}

	public <T> T getHashObj(String hashName, String key, Class<T> clazz) {
		String o = (String) stringRedisTemplate.opsForHash().get(hashName, key);
		if (StringUtils.isEmpty(o)) {
			return null;
		}
		return JSONObject.parseObject(o, clazz);
	}

	public String getHashRaw(String hashName, String key) {
		String o = (String) stringRedisTemplate.opsForHash().get(hashName, key);
		if (StringUtils.isEmpty(o)) {
			return null;
		}
		return o;
	}

	public <T> List<T> getHashArray(String hashName, String key, Class<T> clazz) {
		String o = (String) stringRedisTemplate.opsForHash().get(hashName, key);
		if (StringUtils.isEmpty(o)) {
			return null;
		}
		return JSONObject.parseArray(o, clazz);
	}

	public Long incHashRaw(String hashName, String key, long delta) {
		return stringRedisTemplate.opsForHash().increment(hashName, key, delta);
	}

	public void putHashRaw(String hashName, String key, String str, Integer expireSec) {
		boolean hasKey = stringRedisTemplate.hasKey(key);
		stringRedisTemplate.opsForHash().put(hashName, key, str);
		if (!hasKey) {
			stringRedisTemplate.expire(key, expireSec, TimeUnit.SECONDS);
		}
	}

	public void putHashRaw(String hashName, String key, String str) {
		stringRedisTemplate.opsForHash().put(hashName, key, str);
	}

	public void putHashObj(String hashName, String key, Object obj, Integer expireSec) {
		boolean hasKey = stringRedisTemplate.hasKey(key);
		stringRedisTemplate.opsForHash().put(hashName, key, JSONObject.toJSONString(obj));
		if (!hasKey) {
			stringRedisTemplate.expire(key, expireSec, TimeUnit.SECONDS);
		}
	}

	public void delHashObj(String hashName, String key) {
		stringRedisTemplate.opsForHash().delete(hashName, key);
	}


	public void putRaw(String key, String value) {
		putRaw(key, value, null);
	}

	public void putRaw(String key, String value, Integer expireSec) {
		if (expireSec != null) {
			stringRedisTemplate.opsForValue().set(key, value, expireSec, TimeUnit.SECONDS);
		} else {
			stringRedisTemplate.opsForValue().set(key, value);
		}
	}

	public String getRaw(String key) {
		return stringRedisTemplate.opsForValue().get(key);
	}

	public void del(String key) {
		stringRedisTemplate.delete(key);
	}

	public boolean hasKey(String key) {
		return stringRedisTemplate.hasKey(key);
	}

	public void putSetRaw(String key, String member, Integer expireSec) {
		stringRedisTemplate.opsForSet().add(key, member);
		stringRedisTemplate.expire(key, expireSec, TimeUnit.SECONDS);
	}

	public void putSetRawAll(String key, String[] set, Integer expireSec) {
		stringRedisTemplate.opsForSet().add(key, set);
		stringRedisTemplate.expire(key, expireSec, TimeUnit.SECONDS);
	}

	public void removeSetRaw(String key, String member) {
		stringRedisTemplate.opsForSet().remove(key, member);
	}

	public boolean isSetMember(String key, String member) {
		return stringRedisTemplate.opsForSet().isMember(key, member);
	}

	/**
	 * 获取指定前缀的Key
	 * @param prefix
	 * @return
	 */
	public Set<String> getPrefixKeySet(String prefix) {
		return stringRedisTemplate.keys(prefix + "*");
	}

	public void delPrefixKey(String prefix) {
		Set<String> prefixKeySet = getPrefixKeySet(prefix);
		for (String key : prefixKeySet) {
			stringRedisTemplate.delete(key);
		}
	}

    /**
     * 获取并设置key,如果key存在就返回为false 不覆盖,如果不存在就设置key
     * 用于分布式锁的情况下true代表锁没人使用,false代表锁已被人使用
     * @param key
     * @param value
     * @param time 过期时间,单位秒
     * @return true ,key为空 设置成功, false key存在设置失败
     */
    public boolean setIfAbsent(String key,String value,long time){
        return stringRedisTemplate.opsForValue().setIfAbsent(key,value,time,TimeUnit.SECONDS);
    }
}

6.配置相关:StartSocketConfig(启动springboot项目时会自动启动websocket)

package com.ruoyi.web.controller.config;

import com.ruoyi.common.websocket.NioWebSocketServer;
import org.springframework.boot.CommandLineRunner;
import org.springframework.stereotype.Component;

@Component
public class StartSocketConfig implements CommandLineRunner {

	@Override
	public void run(String... args) throws Exception {
		new NioWebSocketServer(8031);
	}
}

接口测试NettyController(注:使用的springboot是RuoYi-Vue-Plus4.5.0版本,相关的注解有@SaIgnore)

package com.ruoyi.web.controller.MQTT;

import cn.dev33.satoken.annotation.SaIgnore;
import com.alibaba.fastjson.JSON;
import com.ruoyi.common.core.domain.R;
import com.ruoyi.common.websocket.ChannelSupervise;
import io.swagger.v3.oas.annotations.responses.ApiResponse;
import org.springframework.stereotype.Controller;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.ResponseBody;

import java.util.Map;

@SaIgnore
@Controller
@RequestMapping("/api/msg")
public class NettyController {

    /**
     * 发送消息
     * @param data
     * @returne
     */
	public @ResponseBody
	@RequestMapping("/sendWebsocket")
    R sendWebsockt(String data){
        ChannelSupervise.sysSendAll(data);
		return R.ok();
	}
}

你可能感兴趣的:(springboot,spring,boot,websocket,java)