【校验请求参数】
我们在校验请求参数的实现中使用了策略模式,目前只支持GET
,POST
请求,代码如下:
import javax.servlet.http.HttpServletRequest; /** * @author zouwei * @className MethodSecurityStrategy * @date: 2020/11/25 上午11:45 * @description: */ public interface MethodSecurityStrategy { String JOIN_STR = "&"; boolean test(HttpServletRequest request, String securityKey); boolean isTest(String requestMethod); } import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.netflix.zuul.context.RequestContext; import com.zx.silverfox.common.config.api.ApiSecurityConst; import com.zx.silverfox.common.util.CastUtil; import com.zx.silverfox.common.util.JsonUtil; import com.zx.silverfox.common.util.MD5Util; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; import org.springframework.util.StreamUtils; import javax.servlet.ReadListener; import javax.servlet.ServletInputStream; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequestWrapper; import java.io.BufferedReader; import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStreamReader; import java.nio.charset.Charset; import java.util.*; /** * @author zouwei * @className AbstractMethodSecurityStrategy * @date: 2020/11/25 下午4:07 * @description: */ @Slf4j public abstract class AbstractMethodSecurityStrategy implements MethodSecurityStrategy { @Override public boolean test(HttpServletRequest request, String securityKey) { // 名称需要排序 String str = joinGetRequestParams(request); return checkSign(request, str, securityKey); } private String requestSign(HttpServletRequest request) { return request.getHeader(ApiSecurityConst.SIGN_KEY); } private String appKey(HttpServletRequest request) { return request.getHeader(ApiSecurityConst.API_KEY); } /** * 校验签名 * * @param request 请求 * @param str 需要加密的字符串 * @param securityKey 加密密钥 * @return */ protected boolean checkSign(HttpServletRequest request, String str, String securityKey) { if (StringUtils.isBlank(str)) { return false; } // 通过securityKey加密 String sign = MD5Util.macSha2Base64(str, securityKey); return StringUtils.equals(requestSign(request), sign); } /** * 拼接GET请求参数 * * @param request * @return */ protected String joinGetRequestParams(HttpServletRequest request) { Enumeration<String> enumeration = request.getParameterNames(); List<String> list = Lists.newArrayList(); while (enumeration.hasMoreElements()) { list.add(enumeration.nextElement()); } Collections.sort(list); StringJoiner sj = new StringJoiner(JOIN_STR); for (String name : list) { String value = request.getParameter(name); sj.add(name + "=" + value); } sj.add("appKey=" + appKey(request)); return sj.toString(); } /** * 拼接POST请求参数 * * @param request * @return */ protected String joinPostRequestParams(HttpServletRequest request) { String requestBody; try { BodyReaderHttpServletRequestWrapper requestWrapper = new BodyReaderHttpServletRequestWrapper(request); RequestContext currentContext = RequestContext.getCurrentContext(); requestBody = requestWrapper.getBody(); currentContext.setRequest(requestWrapper); } catch (Exception e) { e.printStackTrace(); return StringUtils.EMPTY; } Map<String, Object> map = JsonUtil.string2Obj(requestBody, Map.class); // 准备排序 TreeMap<String, Object> treeMap = Maps.newTreeMap(); if (Objects.nonNull(map)) { treeMap.putAll(map); } // 获取parameter Enumeration<String> enumeration = request.getParameterNames(); while (enumeration.hasMoreElements()) { String key = enumeration.nextElement(); treeMap.put(key, request.getParameter(key)); } // 拼接字符串 StringJoiner sj = new StringJoiner(JOIN_STR); for (Map.Entry<String, Object> e : treeMap.entrySet()) { String name = e.getKey(); String value = CastUtil.castString(e.getValue()); sj.add(name + "=" + value); } sj.add("appKey=" + appKey(request)); return sj.toString(); } public class BodyReaderHttpServletRequestWrapper extends HttpServletRequestWrapper { private final byte[] bodyBytes; private final String body; public BodyReaderHttpServletRequestWrapper(HttpServletRequest request) throws IOException { super(request); this.bodyBytes = StreamUtils.copyToByteArray(request.getInputStream()); body = new String(this.bodyBytes, Charset.forName("UTF-8")); } public String getBody() { return this.body; } @Override public BufferedReader getReader() throws IOException { return new BufferedReader(new InputStreamReader(getInputStream())); } @Override public int getContentLength() { return this.bodyBytes.length; } @Override public long getContentLengthLong() { return this.bodyBytes.length; } @Override public ServletInputStream getInputStream() throws IOException { final ByteArrayInputStream bais = new ByteArrayInputStream(bodyBytes); return new ServletInputStream() { @Override public boolean isFinished() { return false; } @Override public boolean isReady() { return true; } @Override public void setReadListener(ReadListener listener) {} @Override public int read() throws IOException { return bais.read(); } }; } } } import org.apache.commons.lang3.StringUtils; import org.springframework.http.HttpMethod; import org.springframework.stereotype.Component; /** * @author zouwei * @className GetSecurityStrategy * @date: 2020/11/25 上午11:45 * @description: */ @Component public class GetSecurityStrategy extends AbstractMethodSecurityStrategy implements MethodSecurityStrategy { @Override public boolean isTest(String requestMethod) { return StringUtils.equalsIgnoreCase(requestMethod, HttpMethod.GET.name()); } } import org.apache.commons.lang3.StringUtils; import org.springframework.http.HttpMethod; import org.springframework.stereotype.Component; import javax.servlet.http.HttpServletRequest; /** * @author zouwei * @className GetSecurityStrategy * @date: 2020/11/25 上午11:45 * @description: */ @Component public class PostSecurityStrategy extends AbstractMethodSecurityStrategy implements MethodSecurityStrategy { @Override public boolean test(HttpServletRequest request, String securityKey) { // 解析请求体 String str = joinPostRequestParams(request); // 校验签名 return checkSign(request, str, securityKey); } @Override public boolean isTest(String requestMethod) { return StringUtils.equalsIgnoreCase(requestMethod, HttpMethod.POST.name()); } } 复制代码
【组装请求】
在请求被网关解开后,是不能继续向后传递的,那么网关需要重新组装一个请求对象并把之前取出来的数据放进去,这样才能让用户的请求继续向后传递;
import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.zx.silverfox.common.util.CastUtil; import org.apache.commons.lang3.StringUtils; import org.springframework.util.CollectionUtils; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequestWrapper; import java.util.*; /** * @author zouwei * @className CustomHeaderServletRequest * @date: 2020/9/24 下午1:48 * @description: */ public class CustomHeaderServletRequest extends HttpServletRequestWrapper { private Map<String, String> headers = Maps.newHashMap(); public CustomHeaderServletRequest(HttpServletRequest request) { super(request); } public void setHeaders(String key, String value) { headers.put(key, value); } private HttpServletRequest _getHttpServletRequest() { return (HttpServletRequest) super.getRequest(); } @Override public String getHeader(String name) { String value = this._getHttpServletRequest().getHeader(name); if (StringUtils.isBlank(value)) { return this.headers.get(name); } return value; } @Override public Enumeration<String> getHeaders(String name) { Enumeration<String> values = this._getHttpServletRequest().getHeaders(name); String value = this.headers.get(name); if (StringUtils.isBlank(value)) { return values; } Collection<String> collection = Lists.newArrayList(); collection.add(value); while (values.hasMoreElements()) { collection.add(values.nextElement()); } return Collections.enumeration(collection); } @Override public Enumeration<String> getHeaderNames() { Enumeration<String> values = this._getHttpServletRequest().getHeaderNames(); Set<String> keys = this.headers.keySet(); if (CollectionUtils.isEmpty(keys)) { return values; } Collection<String> collection = Lists.newArrayList(); while (values.hasMoreElements()) { collection.add(values.nextElement()); } collection.addAll(keys); return Collections.enumeration(collection); } @Override public long getDateHeader(String name) { long value = this._getHttpServletRequest().getDateHeader(name); if (value <= -1) { return CastUtil.castInt(headers.get(name), -1); } return value; } /** * The default behavior of this method is to return getIntHeader(String name) on the wrapped * request object. */ @Override public int getIntHeader(String name) { int value = this._getHttpServletRequest().getIntHeader(name); if (value <= -1) { return CastUtil.castInt(headers.get(name), -1); } return value; } } 复制代码
我们通过继承HttpServletRequestWrapper
类来实现一个自定义的请求类,来根据元数据可以重新创建一个请求,这样的话,才能让用户请求继续传递下去。
小结
我们为了保证开放平台api的安全性,需要在网关针对特性请求进行拦截来校验,根据以上代码演示及概述,我们可以做出如下小结:
1.开放平台需要把对应的
appKey
和appSecret
给到调用方;2.调用方需要使用
appKey
和appSecret
针对请求参数进行加密,得到加密结果sign
并放到请求头中;3.开发平台接收到请求后,判断是否是特定请求,如果属于需要拦截的请求,那么取出
appKey
去数据表中查询是否确实存在对应的权限,并同时查询出数据表中的appSecret
;如果数据表中查询不到,那么说明没有访问权限;4.开放平台同时还要拿出请求参数,并通过
appKey
与appSecret
进行同样的加密操作得到sign
结果,并与请求头中拿到的sign
进行对比;如果两者一致,权限通过;否则没有访问权限;5.权限处理完毕后,需要重新组装请求继续向后传递,这里需要通过自定义请求来处理;