springboot防重注解

防重注解

使用形式

@NoRepeatSubmit(checkType = NoRepeatSumbitCheckType.PARAMS,checkRequestParams = {"info.name"})
@PostMapping("/testNoRepectSumbit")
public ResultDto testNoRepectSumbit(@RequestBody TestPerson testPerson){
    System.out.println("====testNoRepectSumbit");
    System.out.println("【map】 " + testPerson);
    return ResultDto.createSuccess("接口成功");
}

NoRepeatSubmit

注解配置信息

package cn.com.yusys.config;

import java.lang.annotation.*;

@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface NoRepeatSubmit {

    long timeOut() default 1 ;

    //如果type为params,需要在checkRequestParams中设置防重使用的参数名
    NoRepeatSumbitCheckType checkType() default NoRepeatSumbitCheckType.IP_PATH;

    //如果参数是在多级对象里,如 paramOne:{paramSecond:secondValue}
    //格式化:paramOne.ParamSecond
    String[] checkRequestParams() default {};

    RequestType requestType() default RequestType.POST;
}


NoRepeatSumbitCheckType

枚举配置信息

package cn.com.yusys.config;

public enum  NoRepeatSumbitCheckType {
    //防重检测会对 token#ip 进行分割,分割符为 “#”
    TOKEN_IP_PATH_PARAMS("ip#path#params#token"),

    IP_PATH_PARAMS("ip#path#params"),

    IP_PATH("ip#path"),

    PARAMS("params");

    private String name;

    NoRepeatSumbitCheckType(String name){
        this.name = name;
    }

    public String getName(){
        return this.name;
    }
}


RequestType

请求方式配置

package cn.com.yusys.config;

public enum  RequestType {
    POST,
    GET;
}

AspectConfig

切点配置信息

package com.tianqiauto.tis.api.xdd.noRepeatSubmit;

import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import com.tianqiauto.base.core.ErrorCode;
import com.tianqiauto.base.core.Result;
import com.tianqiauto.base.core.ResultGenerator;
import com.tianqiauto.base.model.AjaxResult;
import lombok.extern.slf4j.Slf4j;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.stereotype.Component;
import org.springframework.util.ObjectUtils;
import org.springframework.util.StringUtils;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;

import javax.servlet.http.HttpServletRequest;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.TimeUnit;


/**
 * @Description 防止重复提交
 * @Author wjx
 * @Date 2023/6/16 15:00
 **/
@Slf4j
@Aspect
@Component
public class AspectConfig {


    private static final String NO_REPEAT_HEADER = "NO_REPEAT:";

    @Autowired
    private RedisTemplate<String, String> redisTemplate;

    /**
     * @param point
     * @方法描述:通过检测注解 NoRepeatSubmit防止重复提交
     */
    @Around("@annotation(NoRepeatSubmit)")
    public Object NoRepeatSubmit(ProceedingJoinPoint point) throws Throwable {
        ServletRequestAttributes attributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
        HttpServletRequest request = attributes.getRequest();
        //获取注解中的防重间隔时间
        MethodSignature signature = (MethodSignature) point.getSignature();
        Method method = signature.getMethod();
        NoRepeatSubmit annotation = method.getAnnotation(NoRepeatSubmit.class);
        long timeOut = annotation.timeOut();
        String checkTypeName = annotation.requestType().name();
        log.debug("【防重间隔时间 单位:秒】" + timeOut);
        //非checkTypeName(例如:POST)请求提示错误
        String requestTypeName = request.getMethod();
        if (!checkTypeName.equals(requestTypeName)) {
            throw new RuntimeException("【请求方式不正确】【requestTypeName:" + requestTypeName + "】【checkTypeName:" + checkTypeName + "】");
        }
        Class<?> returnType = method.getReturnType();
        //获取key
        String key = null;
        try {
            key = getKey(point, annotation, request);
        } catch (Exception e) {
            log.error("【防重注解获取key解析失败】失败原因:{}",e);
            return getFailResult(returnType);
        }
        log.debug("【防重主键】" + key);
        //判断redis中是否存在key
        Boolean notRepeat = redisTemplate.opsForValue().setIfAbsent(key, "1", timeOut, TimeUnit.SECONDS);
        if (!notRepeat) {
            log.warn("【重复提交】key:{}", key);
            return getFailResult(returnType);
        }
        return point.proceed();
    }

    private Object getFailResult(Class<?> returnType) {
        Object obj;
        if (returnType == Result.class) {
            obj = ResultGenerator.genFailResult(ErrorCode.NO_REPEAT_SUBMIT);
        } else {
            obj = AjaxResult.getInstance().setSuccessChain(false).setMessageChain(ErrorCode.NO_REPEAT_SUBMIT.getMessage());
        }
        return obj;
    }

    /**
     * @param point
     * @param annotation
     * @param request
     * @return
     * @方法描述:解析注解中的枚举配置信息
     */
    private String getKey(ProceedingJoinPoint point, NoRepeatSubmit annotation, HttpServletRequest request) {
        MethodSignature signature = (MethodSignature) point.getSignature();
        NoRepeatSumbitCheckType noRepeatSumbitCheckType = annotation.checkType();
        String name = noRepeatSumbitCheckType.getName();
        String[] splitStr = name.split("#");
        StringBuilder key = new StringBuilder();
        key.append(NO_REPEAT_HEADER);
        for (String type : splitStr) {
            switch (type.trim()) {
                case "ip":
                    String ipAddress = getIPAddress(request);
                    if (!StringUtils.isEmpty(ipAddress)) {
                        ipAddress = ipAddress.replaceAll(":","");
                        appendKey(key, ipAddress);
                        appendKey(key, ":");
                    }else {
                        log.warn("【防重注解】访问{}路径时,ip地址未获取到",request.getServletPath());
                        appendKey(key, "null:");
                    }
                    log.debug("【防重访问ip ipAddress】" + ipAddress);
                    break;
                case "path":
                    String servletPath = request.getServletPath();
                    log.debug("【防重访问路径 servletPath】" + servletPath);
                    appendKey(key, servletPath);
                    break;
                case "params":
                    Object[] args = point.getArgs();
                    String[] parameterNames = signature.getParameterNames();
                    pushParamsToKey(args, key, annotation, parameterNames);

                    break;
                case "token":
                    String accessToken = request.getHeader("token");
                    log.debug("【防重访问token currentUserToken】" + accessToken);
                    appendKey(key, accessToken);
                    break;
                default:
                    throw new RuntimeException("【未在 NoRepeatSumbitCheckType 中配置匹配类型】" + type.trim());
            }
        }
        if (NO_REPEAT_HEADER.equals(key)) {
            throw new RuntimeException("【防重主键为空】");
        }
        return key.substring(0, key.length() - 1);


    }

    /**
     * @param args
     * @param key
     * @param annotation
     * @param parameterNames
     * @方法描述:通过 a.b.c 的结构在请求参数中拿到 c的value值
     */
    private void pushParamsToKey(Object[] args, StringBuilder key, NoRepeatSubmit annotation, String[] parameterNames) {
        String debugStrNeed = "";

        //防重需要的请求参数names
        String[] strings = annotation.checkRequestParams();
        //判断传入参数
        if (args == null || args.length <= 0) {
            throw new RuntimeException("请求参数为空" + Arrays.toString(args));
        }
        //遍历checkRequestParams
        for (String string : strings) {
            Object value = null;
            for (int i = 0; i < args.length; i++) {
                //如果arg类型为基础数据类型,直接push到key
                Object argsOne = args[i];
                if (isBaseType(argsOne)) {
                    if (string.equals(parameterNames[i])) {
                        appendKey(key, argsOne);
                        break;
                    }
                    continue;
                }
                // 如果参数类型为数组,只取第一个对象的键值对信息
                Object transformObj = null;
                if (argsOne instanceof Object[]){
                    Object[] arg = (Object[]) argsOne;
                    if (arg.length <= 0) continue;
                    transformObj = arg[0];

                }else if (argsOne instanceof List){
                    List arg = (List) argsOne;
                    if (arg.size() <=0) continue;
                    transformObj = arg.get(0);
                }else {
                    transformObj = argsOne;
                }
                JSONObject argJson = JSON.parseObject(JSON.toJSONString(transformObj));
                //不存在多级
                if (string.indexOf(".") == -1) {
                    value = argJson.get(string);
                } else {
                    //存在多级
                    String[] split = string.split("\\.");
                    int length = split.length;
                    Object remainTemp = argJson.get(split[0]);
                    JSONObject transJson;
                    try {
                        for (int j = 1; j < length; j++) {
                            if (remainTemp instanceof JSONObject) {
                                transJson = (JSONObject) remainTemp;
                                remainTemp = transJson.get(split[j]);
                            }
                        }
                        value = remainTemp;
                    } catch (Exception e) {
                        log.error("【防重注解 params结构错误】");
                        log.error("【错误信息】" + e.getMessage());
                        return;
                    }

                }
            }
            appendKey(key, value);
            if (log.isDebugEnabled()){
                debugStrNeed += value + "#";
            }
        }
        if (log.isDebugEnabled()){
            log.debug("【防重访问params】{}",debugStrNeed);
        }
    }

    private boolean isBaseType(Object arg) {
        if (arg instanceof String) {
            return true;
        } else if (arg instanceof Integer) {
            return true;
        } else if (arg instanceof Boolean) {
            return true;
        } else if (arg instanceof Float) {
            return true;
        } else if (arg instanceof Long) {
            return true;
        } else if (arg instanceof Double) {
            return true;
        } else if (arg instanceof Short) {
            return true;
        } else if (arg instanceof Byte) {
            return true;
        } else {
            return false;
        }
    }


    private void appendKey(StringBuilder key, Object value) {
        if (!ObjectUtils.isEmpty(value)) {
            key.append(value + "#");

        }
    }


    /**
     * @param request
     * @return
     * @方法描述:获取用户的ip地址
     */
    private String getIPAddress(HttpServletRequest request) {
        String ip = null;

        //X-Forwarded-For:Squid 服务代理
        String ipAddresses = request.getHeader("X-Forwarded-For");

        if (ipAddresses == null || ipAddresses.length() == 0 || "unknown".equalsIgnoreCase(ipAddresses)) {
            //Proxy-Client-IP:apache 服务代理
//            logger.info("【Proxy-Client-IP】");
            ipAddresses = request.getHeader("Proxy-Client-IP");
        }

        if (ipAddresses == null || ipAddresses.length() == 0 || "unknown".equalsIgnoreCase(ipAddresses)) {
            //WL-Proxy-Client-IP:weblogic 服务代理
//            logger.info("【WL-Proxy-Client-IP】");
            ipAddresses = request.getHeader("WL-Proxy-Client-IP");
        }

        if (ipAddresses == null || ipAddresses.length() == 0 || "unknown".equalsIgnoreCase(ipAddresses)) {
            //HTTP_CLIENT_IP:有些代理服务器
//            logger.info("【HTTP_CLIENT_IP】");
            ipAddresses = request.getHeader("HTTP_CLIENT_IP");
        }

        if (ipAddresses == null || ipAddresses.length() == 0 || "unknown".equalsIgnoreCase(ipAddresses)) {
            //X-Real-IP:nginx服务代理
//            logger.info("【X-Real-IP】");
            ipAddresses = request.getHeader("X-Real-IP");
        }

        //有些网络通过多层代理,那么获取到的ip就会有多个,一般都是通过逗号(,)分割开来,并且第一个ip为客户端的真实IP
        if (ipAddresses != null && ipAddresses.length() != 0) {
            ip = ipAddresses.split(",")[0];
        }

        //还是不能获取到,最后再通过request.getRemoteAddr();获取
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ipAddresses)) {
//            logger.info("【request.getRemoteAddr】");
            ip = request.getRemoteAddr();
        }
        return ip;
    }
}



有用的话,别忘记点赞呦!!!

你可能感兴趣的:(spring,boot,java,spring)