springboot中对实体类参数中属性进行校验一般都是使用javax.validation中提供的注解
我这次这个项目需要所有接口参数加密,我这里参数解密是使用自定义参数解析器实现HandlerMethodArgumentResolver接口来实现的,通过获取请求体中的加密字符串然后解密后封装到接口参数中。所以就不用@RequestBody注解了,并且那些参数校验的属性也不会起作用。
如果要是在接口里面写if校验就有点。。不优雅,然后就想到在参数解析的时候自己根据这些注解进行校验
package com.gt.gxjhpt.configuration;
import cn.hutool.core.convert.Convert;
import cn.hutool.core.util.ArrayUtil;
import cn.hutool.core.util.CharsetUtil;
import cn.hutool.core.util.ReUtil;
import cn.hutool.core.util.StrUtil;
import cn.hutool.crypto.symmetric.AES;
import cn.hutool.json.JSONArray;
import cn.hutool.json.JSONObject;
import cn.hutool.json.JSONUtil;
import com.gt.gxjhpt.annotation.ParamsAES;
import com.gt.gxjhpt.entity.dto.BaseReq;
import com.gt.gxjhpt.utils.AESUtil;
import lombok.extern.log4j.Log4j2;
import org.jetbrains.annotations.Nullable;
import org.springframework.core.MethodParameter;
import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.support.WebDataBinderFactory;
import org.springframework.web.context.request.NativeWebRequest;
import org.springframework.web.method.support.HandlerMethodArgumentResolver;
import org.springframework.web.method.support.ModelAndViewContainer;
import javax.servlet.http.HttpServletRequest;
import javax.validation.ConstraintViolationException;
import javax.validation.constraints.*;
import java.io.IOException;
import java.lang.annotation.Annotation;
import java.lang.reflect.Field;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.util.Collection;
import java.util.stream.Collectors;
/**
* 解析加密注解
*
* @author vhukze
* @date 2021/9/8 11:14
*/
@Log4j2
public class AESDecodeResolver implements HandlerMethodArgumentResolver {
/*
json参数的key
*/
private static final String NAME = "str";
/**
* 如果接口或者接口参数有解密注解,就解析
*/
@Override
public boolean supportsParameter(MethodParameter parameter) {
return parameter.hasMethodAnnotation(ParamsAES.class) || parameter.hasParameterAnnotation(ParamsAES.class);
}
@Override
public Object resolveArgument(MethodParameter parameter, ModelAndViewContainer modelAndViewContainer,
NativeWebRequest webRequest, WebDataBinderFactory webDataBinderFactory) throws IOException, InstantiationException, IllegalAccessException {
AES aes = AESUtil.aes;
// 获取post请求的json字符串
String postStr = getPostStr(webRequest);
// 接口参数的字节码对象
Class> parameterType = parameter.getParameterType();
//如果是实体类参数,把请求参数封装
if (BaseReq.class.isAssignableFrom(parameterType)) {
//获取加密的请求数据并解密
// String beforeParam = webRequest.getParameter(NAME);
String afterParam = aes.decryptStr(JSONUtil.parseObj(postStr).get(NAME).toString(),
CharsetUtil.CHARSET_UTF_8);
// 校验参数
if (parameter.hasParameterAnnotation(Validated.class)) {
Validated validated = parameter.getParameterAnnotation(Validated.class);
this.verifyObjField(afterParam, parameterType, validated.value());
}
//json转对象 // 这里的return就会把转化过的参数赋给控制器的方法参数
return JSONUtil.toBean(afterParam, parameterType);
// 如果是非集合类,就直接解码返回
} else if (!Iterable.class.isAssignableFrom(parameterType)) {
// String decryptStr = aes.decryptStr(webRequest.getParameter(parameter.getParameterName()), CharsetUtil.CHARSET_UTF_8);
// return Integer.class.isAssignableFrom(parameter.getParameterType()) ? Integer.parseInt(decryptStr) : decryptStr;
Object value = JSONUtil.parseObj(aes.decryptStr(JSONUtil.parseObj(postStr).get(NAME).toString(),
CharsetUtil.CHARSET_UTF_8)).get(parameter.getParameterName());
this.verifyOneField(parameter, value);
return value;
//如果是集合类
} else if (Iterable.class.isAssignableFrom(parameterType)) {
//获取加密的请求数据并解密
// String beforeParam = webRequest.getParameter(NAME);
String afterParam = aes.decryptStr(JSONUtil.parseObj(postStr).get(NAME).toString(),
CharsetUtil.CHARSET_UTF_8);
//转成对象数组
JSONArray jsonArray = JSONUtil.parseArray(afterParam);
this.verifyCollField(parameter, jsonArray);
return jsonArray.toList(Object.class);
}
return null;
}
/**
* 校验单个参数
*/
private void verifyOneField(MethodParameter parameter, Object value) {
for (Annotation annotation : parameter.getParameterAnnotations()) {
if (annotation instanceof NotBlank) {
if (value == null || StrUtil.isBlank(value.toString())) {
log.info("参数为空");
throw new ConstraintViolationException(null);
}
}
if (annotation instanceof NotNull) {
if (value == null) {
log.info("参数为空");
throw new ConstraintViolationException(null);
}
}
// 只能是字符串类型
if (annotation instanceof Size) {
Size size = (Size) annotation;
if (value != null && (value.toString().length() < size.min() || value.toString().length() > size.max())) {
log.info("参数长度不对");
throw new ConstraintViolationException(null);
}
}
}
}
/**
* 校验集合类型
*/
private void verifyCollField(MethodParameter parameter, JSONArray jsonArray) {
for (Annotation annotation : parameter.getParameterAnnotations()) {
if (annotation instanceof NotEmpty) {
if (jsonArray == null || jsonArray.size() == 0) {
log.info("集合参数值为空");
throw new ConstraintViolationException(null);
}
}
if (annotation instanceof Size) {
Size size = (Size) annotation;
if (jsonArray.size() < size.min() || jsonArray.size() > size.max()) {
log.info("集合参数值大小不对");
throw new ConstraintViolationException(null);
}
}
}
}
/**
* 校验实体类参数
*
* @param param 前端传的参数(json字符串)
* @param clazz 接口实体类参数的字节码对象
* @param groups 校验那些组
*/
private void verifyObjField(String param, Class> clazz, Class>[] groups) {
// 前端传的参数
JSONObject jsonObj = JSONUtil.parseObj(param);
// 实体类所有字段
Field[] fields = clazz.getDeclaredFields();
for (Field field : fields) {
// 字段如果不可访问,设置可访问
if (!field.isAccessible()) {
field.setAccessible(true);
}
Annotation[] annotations = field.getDeclaredAnnotations();
for (Annotation annotation : annotations) {
if (annotation instanceof NotNull) {
NotNull notNull = (NotNull) annotation;
if ((ArrayUtil.isEmpty(groups) && ArrayUtil.isEmpty(notNull.groups()))
|| ArrayUtil.containsAny(groups, notNull.groups())) {
if (jsonObj.get(field.getName()) == null) {
log.info("字段>>>>>>" + field.getName() + "值有问题");
throw new ConstraintViolationException(null);
}
}
}
if (annotation instanceof NotBlank) {
NotBlank notBlank = (NotBlank) annotation;
if ((ArrayUtil.isEmpty(groups) && ArrayUtil.isEmpty(notBlank.groups()))
|| ArrayUtil.containsAny(groups, notBlank.groups())) {
Object val = jsonObj.get(field.getName());
if (val == null || StrUtil.isBlank(val.toString())) {
log.info("字段>>>>>>" + field.getName() + "值有问题");
throw new ConstraintViolationException(null);
}
}
}
if (annotation instanceof Size) {
Size size = (Size) annotation;
if ((ArrayUtil.isEmpty(groups) && ArrayUtil.isEmpty(size.groups()))
|| ArrayUtil.containsAny(groups, size.groups())) {
Object val = jsonObj.get(field.getName());
if (val instanceof String) {
if (val.toString().length() < size.min() || val.toString().length() > size.max()) {
log.info("字段>>>>>>" + field.getName() + "值有问题");
throw new ConstraintViolationException(null);
}
}
if (val instanceof Collection && Convert.toList(val).size() == 0) {
log.info("字段>>>>>>" + field.getName() + "值有问题");
throw new ConstraintViolationException(null);
}
}
}
if (annotation instanceof Pattern) {
Pattern pattern = (Pattern) annotation;
if ((ArrayUtil.isEmpty(groups) && ArrayUtil.isEmpty(pattern.groups()))
|| ArrayUtil.containsAny(groups, pattern.groups())) {
Object val = jsonObj.get(field.getName());
if (val != null && !ReUtil.isMatch(pattern.regexp(), val.toString())) {
log.info("字段>>>>>>" + field.getName() + "值正则校验失败");
throw new ConstraintViolationException(null);
}
}
}
if (annotation instanceof Max) {
Max max = (Max) annotation;
if ((ArrayUtil.isEmpty(groups) && ArrayUtil.isEmpty(max.groups()))
|| ArrayUtil.containsAny(groups, max.groups())) {
Object val = jsonObj.get(field.getName());
if (val != null && Convert.toInt(val) > max.value()) {
log.info("字段>>>>>>" + field.getName() + "值太大");
throw new ConstraintViolationException(null);
}
}
}
}
}
}
@Nullable
private String getPostStr(NativeWebRequest webRequest) throws IOException {
//获取post请求的json数据
HttpServletRequest request = (HttpServletRequest) webRequest.getNativeRequest();
int contentLength = request.getContentLength();
if (contentLength < 0) {
return null;
}
byte[] buffer = new byte[contentLength];
for (int i = 0; i < contentLength; ) {
int readlen = request.getInputStream().read(buffer, i,
contentLength - i);
if (readlen == -1) {
break;
}
i += readlen;
}
String str = new String(buffer, CharsetUtil.CHARSET_UTF_8);
StringBuilder sb = new StringBuilder();
for (char c : str.toCharArray()) {
//去掉json中的空格 换行符 制表符
if (c != 32 && c != 13 && c != 10) {
sb.append(c);
}
}
return sb.toString();
}
}
校验注解就写了几个常用的,需要其他的还可以自己加。不需要分组校验的可以把判断分组的删掉