基于SpringMVC架构对Request进行参数重写

环境:SpringBoot 2.0 以上的版本

假设目前有一种场景,我需要把所有请求都统一用一种方法来接收参数,同时还能加塞Request的参数。

例如:不管请求的参数是否是JSON,所有的参数接收处不添加 @RequestBody 注解。

LoginController.java

    @PostMapping("/login")
    public Result login(LoginDTO dto) {
        log.info("");
        return loginService.login(dto);
    }

这里,我们想要做到上述的效果,需要在过滤器上针对Request对象进行处理。

这里简单讲解下思路:

        首先,我是使用了SpringSecurity框架来维护系统的用户权限的,但用户登陆后的信息是统一存放Redis中。每次用户请求时如果Token校验通过,就会从Redis获取用户登陆的信息,加塞进Security的会话中。

针对用户校验及权限这一块,如果不需要的,可以删掉,从48-60行。

主要是通过继承了HttpServletRequestWrapper类解决重写Reqeust参数的问题。

因为Request内置的params参数是final声明,所以我们无法通过直接更改该Map,所以重新定义了一个Map,然后是为了兼容JSON格式的数据,

核心是构造方法 :

初始化Map,然后根据对应的请求类型进行处理:

        1.如果是JSON或XML。我们就通过读取数据流的saveInputStreamData方法,然后解析到的数据加塞到Map中。

        2.如果是普通的表单类型数据,我们就直接加塞到Map中。

关于saveInputStreamData方法:JSON请求只能通过流来获取其请求参数。而请求数据流只能读取一次。所以我们在读取了数据流后将读取到的数据保存下来,通过重写getInputStream方法来保证后续也能获取到请求数据流。

然后额外添加了addAllParameters、addParameters方法,方便加塞参数。

ParameterRequestWrapper方法:从Redis中解析出来的用户数据,我们加塞到请求参数中。

package com.mrlv.rua.auth.filter;

import cn.hutool.core.util.StrUtil;
import cn.hutool.http.ContentType;
import cn.hutool.json.JSONObject;
import cn.hutool.json.JSONUtil;
import com.auth0.jwt.exceptions.JWTVerificationException;
import com.mrlv.rua.auth.consts.RedisPreConst;
import com.mrlv.rua.auth.entity.LoginUser;
import com.mrlv.rua.auth.utils.JwtUtil;
import com.mrlv.rua.common.redis.utils.RedisUtil;
import lombok.extern.slf4j.Slf4j;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.stereotype.Component;
import org.springframework.web.filter.OncePerRequestFilter;

import javax.servlet.*;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import javax.servlet.http.HttpServletResponse;
import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.util.*;

/**
 * @author lvshiyu
 * @description: 动态权限过滤器
 * @date 2022年07月05日 17:32
 */
@Component
@Slf4j
public class AuthTokenFilter extends OncePerRequestFilter {


    /**
     * 过滤
     * @param request
     * @param response
     * @param filterChain
     */
    @Override
    protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException {
        //获取请求头,判断是否已经登陆
        String token = request.getHeader(JwtUtil.HEADER_STRING);
        LoginUser userDetails = null;
        try {
            if (StrUtil.isNotBlank(token)) {
                String userId = JwtUtil.getUserId(token);
                String key = RedisPreConst.AUTH_ONLINE_USER + userId;
                if (RedisUtil.hHasKey(key, "PC")) {
                    //这里使用Redis来存储用户登陆信息,同时通过SpringScurity来对用户登陆状态进行维护,如果不是使用SpringScurity的话,可以无视这一部分。
                    userDetails = (LoginUser)RedisUtil.hget(key, "PC");
                    UsernamePasswordAuthenticationToken user = new UsernamePasswordAuthenticationToken(userDetails, null, userDetails.getAuthorities());
                    SecurityContextHolder.getContext().setAuthentication(user);
                }
            }
            request = new ParameterRequestWrapper(request, userDetails);
        } catch (JWTVerificationException e) {
            log.info("Token异常 error:{}", e.getMessage(), e);
        } catch (Exception e) {
            e.printStackTrace();
        }
        filterChain.doFilter(request, response);
    }

    /**
     * @author lvshiyu
     * @description: 重写 HttpServletRequestWrapper
     * @date 2022年8月24日 14:50
     */
    class ParameterRequestWrapper extends HttpServletRequestWrapper {
        /**
         * 请求参数
         */
        private Map params;

        /**
         * 用于保存读取body中数据
         */
        private byte[] body;

        /**
         * 用于保存读取body中数据
         */
        private String bodyMessage;

        /**
         * 自定义构造方法
         * @param request
         * @throws IOException
         */
        public ParameterRequestWrapper(HttpServletRequest request) throws IOException {
            super(request);
            //参数保存
            this.params = new HashMap<>();
            //初始化参数
            String contentType = request.getContentType().toLowerCase();
            //如果是application/json
            if (Objects.equals(ContentType.JSON.toString(), contentType)) {
                //解析数据流数据
                saveInputStreamData(request);
                JSONObject parameter = JSONUtil.parseObj(this.getBodyMessage());
                this.addAllParameters(parameter);
            } else if (Objects.equals(ContentType.XML.toString(), contentType)) {
                saveInputStreamData(request);
                JSONObject parameter = JSONUtil.parseFromXml(this.getBodyMessage()).getJSONObject("request");
                this.addAllParameters(parameter);
            } else {
                Enumeration headerNames = request.getParameterNames();
                while (headerNames.hasMoreElements()) {
                    String key = headerNames.nextElement();
                    this.addParameter(key, request.getParameter(key));
                }
            }
        }

        /**
         * 自定义构造方法
         * @param request
         * @throws IOException
         */
        public ParameterRequestWrapper(HttpServletRequest request, LoginUser user) throws IOException {
            this(request);
            if (user != null) {
                this.addParameter("userId", String.valueOf(user.getId()));
                this.addParameter("username", user.getUsername());
                this.addParameter("nickname", user.getNickname());
            }
        }

        /**
         * 覆盖(重写)父类的方法
         * @return
         * @throws IOException
         */
        @Override
        public BufferedReader getReader() throws IOException {
            return new BufferedReader(new InputStreamReader(getInputStream()));
        }

        /**
         * 覆盖(重写)父类的方法
         * @return
         * @throws IOException
         */
        @Override
        public ServletInputStream getInputStream() throws IOException {
            final ByteArrayInputStream inputStream = new ByteArrayInputStream(this.body);
            return new ServletInputStream() {
                @Override
                public boolean isFinished() {
                    return false;
                }

                @Override
                public boolean isReady() {
                    return false;
                }

                @Override
                public void setReadListener(ReadListener readListener) {
                }

                @Override
                public int read() throws IOException {
                    return inputStream.read();
                }
            };
        }

        @Override
        public Enumeration getParameterNames() {
            return new Vector(params.keySet()).elements();
        }

        @Override
        public String getParameter(String name) {
            String[] values = this.params.get(name);
            if (values == null || values.length == 0) {
                return null;
            }
            return values[0];
        }

        @Override
        public String[] getParameterValues(String name) {
            String[] values = this.params.get(name);
            if (values == null || values.length == 0) {
                return null;
            }
            return values;
        }


        /**
         * 获取body中的数据
         * @return
         */
        public byte[] getBody() {
            return this.body;
        }

        /**
         * 把处理后的参数放到body里面
         * @param body
         */
        public void setBody(byte[] body) {
            this.body = body;
        }

        /**
         * 获取处理过的参数数据
         * @return
         */
        public String getBodyMessage() {
            return this.bodyMessage;
        }

        /**
         * 设置参数
         * @param otherParams
         */
        private void addAllParameters(Map otherParams) {
            for (Map.Entry entry : otherParams.entrySet()) {
                addParameter(entry.getKey(), entry.getValue());
            }
        }

        /**
         * 设置参数
         * @param name
         * @param value
         */
        private void addParameter(String name, Object value) {
            if (this.params == null) {
                this.params = new HashMap<>();
            }
            if (value != null) {
                if (value instanceof String[]) {
                    this.params.put(name, (String[]) value);
                } else if (value instanceof String) {
                    this.params.put(name, new String[]{(String) value});
                } else {
                    this.params.put(name, new String[]{String.valueOf(value)});
                }
            }
        }

        /**
         * 保存请求的InputSteam的数据
         * @param request
         * @throws IOException
         */
        private void saveInputStreamData(HttpServletRequest request) throws IOException {
            int contentLength = request.getContentLength();
            ServletInputStream inputStream = request.getInputStream();
            this.body = new byte[contentLength];
            inputStream.read(this.body, 0, contentLength);
            this.bodyMessage =   new String(this.body, StandardCharsets.UTF_8);
        }
    }
}

在调试途中遇到了一个问题:

关于Request一旦使用了getInputStream后就无法通过getParameterMap获取参数的问题,通过源码我们可以发现,Request 中的 usingInputStream,标识是否使用过流来读取数据,如果使用过了,则会改为True。从而导致 getParameterMap 获取不到数据

org.apache.catalina.connector.Request

基于SpringMVC架构对Request进行参数重写_第1张图片

org.apache.catalina.connector.Request

基于SpringMVC架构对Request进行参数重写_第2张图片

org.apache.catalina.connector.Request

基于SpringMVC架构对Request进行参数重写_第3张图片

总结下来,一个过滤器就可以解决问题,如果能给诸位带来帮助,麻烦点个赞。有什么不理解的,欢迎留言。

你可能感兴趣的:(springboot,servlet)