springboot中对实体类参数中属性进行校验一般都是使用javax.validation中提供的注解
我这次这个项目需要所有接口参数加密,我这里参数解密是使用自定义参数解析器实现HandlerMethodArgumentResolver接口来实现的,通过获取请求体中的加密字符串然后解密后封装到接口参数中。所以就不用@RequestBody注解了,并且那些参数校验的属性也不会起作用。
如果要是在接口里面写if校验就有点。。不优雅,然后就想到在参数解析的时候自己根据这些注解进行校验
package com.gt.gxjhpt.configuration; import cn.hutool.core.convert.Convert; import cn.hutool.core.util.ArrayUtil; import cn.hutool.core.util.CharsetUtil; import cn.hutool.core.util.ReUtil; import cn.hutool.core.util.StrUtil; import cn.hutool.crypto.symmetric.AES; import cn.hutool.json.JSONArray; import cn.hutool.json.JSONObject; import cn.hutool.json.JSONUtil; import com.gt.gxjhpt.annotation.ParamsAES; import com.gt.gxjhpt.entity.dto.BaseReq; import com.gt.gxjhpt.utils.AESUtil; import lombok.extern.log4j.Log4j2; import org.jetbrains.annotations.Nullable; import org.springframework.core.MethodParameter; import org.springframework.validation.annotation.Validated; import org.springframework.web.bind.support.WebDataBinderFactory; import org.springframework.web.context.request.NativeWebRequest; import org.springframework.web.method.support.HandlerMethodArgumentResolver; import org.springframework.web.method.support.ModelAndViewContainer; import javax.servlet.http.HttpServletRequest; import javax.validation.ConstraintViolationException; import javax.validation.constraints.*; import java.io.IOException; import java.lang.annotation.Annotation; import java.lang.reflect.Field; import java.lang.reflect.ParameterizedType; import java.lang.reflect.Type; import java.util.Collection; import java.util.stream.Collectors; /** * 解析加密注解 * * @author vhukze * @date 2021/9/8 11:14 */ @Log4j2 public class AESDecodeResolver implements HandlerMethodArgumentResolver { /* json参数的key */ private static final String NAME = "str"; /** * 如果接口或者接口参数有解密注解,就解析 */ @Override public boolean supportsParameter(MethodParameter parameter) { return parameter.hasMethodAnnotation(ParamsAES.class) || parameter.hasParameterAnnotation(ParamsAES.class); } @Override public Object resolveArgument(MethodParameter parameter, ModelAndViewContainer modelAndViewContainer, NativeWebRequest webRequest, WebDataBinderFactory webDataBinderFactory) throws IOException, InstantiationException, IllegalAccessException { AES aes = AESUtil.aes; // 获取post请求的json字符串 String postStr = getPostStr(webRequest); // 接口参数的字节码对象 Class<?> parameterType = parameter.getParameterType(); //如果是实体类参数,把请求参数封装 if (BaseReq.class.isAssignableFrom(parameterType)) { //获取加密的请求数据并解密 // String beforeParam = webRequest.getParameter(NAME); String afterParam = aes.decryptStr(JSONUtil.parseObj(postStr).get(NAME).toString(), CharsetUtil.CHARSET_UTF_8); // 校验参数 if (parameter.hasParameterAnnotation(Validated.class)) { Validated validated = parameter.getParameterAnnotation(Validated.class); this.verifyObjField(afterParam, parameterType, validated.value()); } //json转对象 // 这里的return就会把转化过的参数赋给控制器的方法参数 return JSONUtil.toBean(afterParam, parameterType); // 如果是非集合类,就直接解码返回 } else if (!Iterable.class.isAssignableFrom(parameterType)) { // String decryptStr = aes.decryptStr(webRequest.getParameter(parameter.getParameterName()), CharsetUtil.CHARSET_UTF_8); // return Integer.class.isAssignableFrom(parameter.getParameterType()) ? Integer.parseInt(decryptStr) : decryptStr; Object value = JSONUtil.parseObj(aes.decryptStr(JSONUtil.parseObj(postStr).get(NAME).toString(), CharsetUtil.CHARSET_UTF_8)).get(parameter.getParameterName()); this.verifyOneField(parameter, value); return value; //如果是集合类 } else if (Iterable.class.isAssignableFrom(parameterType)) { //获取加密的请求数据并解密 // String beforeParam = webRequest.getParameter(NAME); String afterParam = aes.decryptStr(JSONUtil.parseObj(postStr).get(NAME).toString(), CharsetUtil.CHARSET_UTF_8); //转成对象数组 JSONArray jsonArray = JSONUtil.parseArray(afterParam); this.verifyCollField(parameter, jsonArray); return jsonArray.toList(Object.class); } return null; } /** * 校验单个参数 */ private void verifyOneField(MethodParameter parameter, Object value) { for (Annotation annotation : parameter.getParameterAnnotations()) { if (annotation instanceof NotBlank) { if (value == null || StrUtil.isBlank(value.toString())) { log.info("参数为空"); throw new ConstraintViolationException(null); } } if (annotation instanceof NotNull) { if (value == null) { log.info("参数为空"); throw new ConstraintViolationException(null); } } // 只能是字符串类型 if (annotation instanceof Size) { Size size = (Size) annotation; if (value != null && (value.toString().length() < size.min() || value.toString().length() > size.max())) { log.info("参数长度不对"); throw new ConstraintViolationException(null); } } } } /** * 校验集合类型 */ private void verifyCollField(MethodParameter parameter, JSONArray jsonArray) { for (Annotation annotation : parameter.getParameterAnnotations()) { if (annotation instanceof NotEmpty) { if (jsonArray == null || jsonArray.size() == 0) { log.info("集合参数值为空"); throw new ConstraintViolationException(null); } } if (annotation instanceof Size) { Size size = (Size) annotation; if (jsonArray.size() < size.min() || jsonArray.size() > size.max()) { log.info("集合参数值大小不对"); throw new ConstraintViolationException(null); } } } } /** * 校验实体类参数 * * @param param 前端传的参数(json字符串) * @param clazz 接口实体类参数的字节码对象 * @param groups 校验那些组 */ private void verifyObjField(String param, Class<?> clazz, Class<?>[] groups) { // 前端传的参数 JSONObject jsonObj = JSONUtil.parseObj(param); // 实体类所有字段 Field[] fields = clazz.getDeclaredFields(); for (Field field : fields) { // 字段如果不可访问,设置可访问 if (!field.isAccessible()) { field.setAccessible(true); } Annotation[] annotations = field.getDeclaredAnnotations(); for (Annotation annotation : annotations) { if (annotation instanceof NotNull) { NotNull notNull = (NotNull) annotation; if ((ArrayUtil.isEmpty(groups) && ArrayUtil.isEmpty(notNull.groups())) || ArrayUtil.containsAny(groups, notNull.groups())) { if (jsonObj.get(field.getName()) == null) { log.info("字段>>>>>>" + field.getName() + "值有问题"); throw new ConstraintViolationException(null); } } } if (annotation instanceof NotBlank) { NotBlank notBlank = (NotBlank) annotation; if ((ArrayUtil.isEmpty(groups) && ArrayUtil.isEmpty(notBlank.groups())) || ArrayUtil.containsAny(groups, notBlank.groups())) { Object val = jsonObj.get(field.getName()); if (val == null || StrUtil.isBlank(val.toString())) { log.info("字段>>>>>>" + field.getName() + "值有问题"); throw new ConstraintViolationException(null); } } } if (annotation instanceof Size) { Size size = (Size) annotation; if ((ArrayUtil.isEmpty(groups) && ArrayUtil.isEmpty(size.groups())) || ArrayUtil.containsAny(groups, size.groups())) { Object val = jsonObj.get(field.getName()); if (val instanceof String) { if (val.toString().length() < size.min() || val.toString().length() > size.max()) { log.info("字段>>>>>>" + field.getName() + "值有问题"); throw new ConstraintViolationException(null); } } if (val instanceof Collection && Convert.toList(val).size() == 0) { log.info("字段>>>>>>" + field.getName() + "值有问题"); throw new ConstraintViolationException(null); } } } if (annotation instanceof Pattern) { Pattern pattern = (Pattern) annotation; if ((ArrayUtil.isEmpty(groups) && ArrayUtil.isEmpty(pattern.groups())) || ArrayUtil.containsAny(groups, pattern.groups())) { Object val = jsonObj.get(field.getName()); if (val != null && !ReUtil.isMatch(pattern.regexp(), val.toString())) { log.info("字段>>>>>>" + field.getName() + "值正则校验失败"); throw new ConstraintViolationException(null); } } } if (annotation instanceof Max) { Max max = (Max) annotation; if ((ArrayUtil.isEmpty(groups) && ArrayUtil.isEmpty(max.groups())) || ArrayUtil.containsAny(groups, max.groups())) { Object val = jsonObj.get(field.getName()); if (val != null && Convert.toInt(val) > max.value()) { log.info("字段>>>>>>" + field.getName() + "值太大"); throw new ConstraintViolationException(null); } } } } } } @Nullable private String getPostStr(NativeWebRequest webRequest) throws IOException { //获取post请求的json数据 HttpServletRequest request = (HttpServletRequest) webRequest.getNativeRequest(); int contentLength = request.getContentLength(); if (contentLength < 0) { return null; } byte[] buffer = new byte[contentLength]; for (int i = 0; i < contentLength; ) { int readlen = request.getInputStream().read(buffer, i, contentLength - i); if (readlen == -1) { break; } i += readlen; } String str = new String(buffer, CharsetUtil.CHARSET_UTF_8); StringBuilder sb = new StringBuilder(); for (char c : str.toCharArray()) { //去掉json中的空格 换行符 制表符 if (c != 32 && c != 13 && c != 10) { sb.append(c); } } return sb.toString(); } }