1、首先Filter的实现基础
@Override
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse,
FilterChain filterChain) throws IOException, ServletException {
filterChain.doFilter(servletRequest, servletResponse);
}
用户请求到达,经过Filter到后台,后台处理完成,到Filter,返回给用户
doFilter(servletRequest, servletResponse)方法一直传递
servletRequest,servletResponse
Controller是怎么获取参数和返回参数的呢?
public interface ServletRequest {
String getParameter(String var1);
String[] getParameterValues(String var1);
Map<String, String[]> getParameterMap();
}
Controller主要通过这三个方法获取参数
public interface ServletResponse {
ServletOutputStream getOutputStream() throws IOException;
PrintWriter getWriter() throws IOException;
}
Controller主要通过这两个流输出结果到前端
因此可以重写 ServletRequest的方法让Controller在去参数时得到的是我们修改过的参数
重写ServletResponse的方法让Controller在往前端写结果时写到我们的重写类里面,然后处理这些数据,再重新写到前端
ServletRequest重写
package com.spring.demo.filter;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.UnsupportedEncodingException;
import java.net.URLDecoder;
import java.util.regex.Pattern;
public class RequestWrapper extends HttpServletRequestWrapper{
public RequestWrapper(HttpServletRequest request) {
super(request);
}
@Override
public String getParameter(String name) {
String value = super.getParameter(replaceXSS(name));
if (value != null) {
value = replaceXSS(value);
}
return value;
}
@Override
public String[] getParameterValues(String name) {
String[] values = super.getParameterValues(replaceXSS(name));
if(values != null && values.length > 0){
for(int i =0; i< values.length ;i++){
values[i] = replaceXSS(values[i]);
}
}
return values;
}
@Override
public String getHeader(String name) {
String value = super.getHeader(replaceXSS(name));
if (value != null) {
value = replaceXSS(value);
}
return value;
}
/**
* 去除待带script、src的语句,转义替换后的value值
*/
public static String replaceXSS(String value) {
if (value != null) {
try{
value = value.replace("+","%2B"); //'+' replace to '%2B'
value = URLDecoder.decode(value, "utf-8");
}catch(UnsupportedEncodingException e){
}catch(IllegalArgumentException e){
}
// Avoid null characters
value = value.replaceAll("\0", "");
// Avoid anything between script tags
Pattern scriptPattern = Pattern.compile("<script>(.*?)</script>", Pattern.CASE_INSENSITIVE);
value = scriptPattern.matcher(value).replaceAll("");
// Avoid anything in a src='...' type of expression
scriptPattern = Pattern.compile("src[\r\n]*=[\r\n]*\\\'(.*?)\\\'", Pattern.CASE_INSENSITIVE | Pattern.MULTILINE | Pattern.DOTALL);
value = scriptPattern.matcher(value).replaceAll("");
scriptPattern = Pattern.compile("src[\r\n]*=[\r\n]*\\\"(.*?)\\\"", Pattern.CASE_INSENSITIVE | Pattern.MULTILINE | Pattern.DOTALL);
value = scriptPattern.matcher(value).replaceAll("");
// Remove any lonesome </script> tag
scriptPattern = Pattern.compile("</script>", Pattern.CASE_INSENSITIVE);
value = scriptPattern.matcher(value).replaceAll("");
// Remove any lonesome <script ...> tag
scriptPattern = Pattern.compile("<script(.*?)>", Pattern.CASE_INSENSITIVE | Pattern.MULTILINE | Pattern.DOTALL);
value = scriptPattern.matcher(value).replaceAll("");
// Avoid eval(...) expressions
scriptPattern = Pattern.compile("eval\\((.*?)\\)", Pattern.CASE_INSENSITIVE | Pattern.MULTILINE | Pattern.DOTALL);
value = scriptPattern.matcher(value).replaceAll("");
// Avoid expression(...) expressions
scriptPattern = Pattern.compile("expression\\((.*?)\\)", Pattern.CASE_INSENSITIVE | Pattern.MULTILINE | Pattern.DOTALL);
value = scriptPattern.matcher(value).replaceAll("");
// Avoid javascript:... expressions
scriptPattern = Pattern.compile("javascript:", Pattern.CASE_INSENSITIVE);
value = scriptPattern.matcher(value).replaceAll("");
// Avoid alert:... expressions
scriptPattern = Pattern.compile("alert", Pattern.CASE_INSENSITIVE);
value = scriptPattern.matcher(value).replaceAll("");
// Avoid onload= expressions
scriptPattern = Pattern.compile("onload(.*?)=", Pattern.CASE_INSENSITIVE | Pattern.MULTILINE | Pattern.DOTALL);
value = scriptPattern.matcher(value).replaceAll("");
scriptPattern = Pattern.compile("vbscript[\r\n| | ]*:[\r\n| | ]*", Pattern.CASE_INSENSITIVE);
value = scriptPattern.matcher(value).replaceAll("");
}
return filter(value);
}
/**
* 过滤特殊字符
*/
public static String filter(String value) {
if (value == null) {
return null;
}
StringBuffer result = new StringBuffer(value.length());
for (int i=0; i<value.length(); ++i) {
switch (value.charAt(i)) {
case '<':
result.append("<");
break;
case '>':
result.append(">");
break;
case '"':
result.append("\"");
break;
case '\'':
result.append("'");
break;
case '%':
result.append("%");
break;
case ';':
result.append(";");
break;
case '(':
result.append("(");
break;
case ')':
result.append(")");
break;
case '&':
result.append("&");
break;
case '+':
result.append("+");
break;
default:
result.append(value.charAt(i));
break;
}
}
return result.toString();
}
}
ServletResponse重写
//重定向输出流写到
DataOutputStream
package com.spring.demo.filter;
import javax.servlet.ServletOutputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.OutputStream;
//重定向输出流写到
DataOutputStream
public class FilterServletOutputStream extends ServletOutputStream {
DataOutputStream output;
public FilterServletOutputStream(OutputStream output) {
this.output = new DataOutputStream(output);
}
@Override
public void write(int arg0) throws IOException {
output.write(arg0);
}
@Override
public void write(byte[] arg0, int arg1, int arg2) throws IOException {
output.write(arg0, arg1, arg2);
}
@Override
public void write(byte[] arg0) throws IOException {
output.write(arg0);
}
}
//重定向输出流写到
ByteArrayOutputStream
//ByteArrayOutputStream 接受Controller写入的数据,并以byte[]形式返回给Filter
package com.spring.demo.filter;
import io.netty.handler.codec.http.HttpResponseStatus;
import javax.servlet.ServletOutputStream;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpServletResponseWrapper;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
public class ResponseWrapper extends HttpServletResponseWrapper {
ByteArrayOutputStream output;
FilterServletOutputStream filterOutput;
HttpResponseStatus status = HttpResponseStatus.OK;
public ResponseWrapper(HttpServletResponse response) {
super(response);
output = new ByteArrayOutputStream();
}
@Override
public ServletOutputStream getOutputStream() throws IOException {
if (filterOutput == null) {
filterOutput = new FilterServletOutputStream(output);
}
return filterOutput;
}
public byte[] getDataStream() {
return output.toByteArray();
}
}
package com.spring.demo.filter;
import java.io.Serializable;
public class RestResponse implements Serializable {
private int status;
private String message;
private Object data;
public RestResponse(int status, String message, Object data) {
this.status = status;
this.message = message;
this.data = data;
}
public int getStatus() {
return status;
}
public void setStatus(int status) {
this.status = status;
}
public String getMessage() {
return message;
}
public void setMessage(String message) {
this.message = message;
}
public Object getData() {
return data;
}
public void setData(Object data) {
this.data = data;
}
}
package com.spring.demo.filter;
import com.alibaba.fastjson.JSONObject;
import org.apache.htrace.fasterxml.jackson.databind.ObjectMapper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.servlet.*;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
public class SessionFilter implements Filter {
protected final Logger logger = LoggerFactory.getLogger(SessionFilter.class);
@Override
public void init(FilterConfig filterConfig) throws ServletException {
logger.info("SessionFilter init" );
}
@Override
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse,
FilterChain filterChain) throws IOException, ServletException {
logger.info("doFilter start" );
// TODO Auto-generated method stub
RequestWrapper requestWrapper = new RequestWrapper((HttpServletRequest) servletRequest);
ResponseWrapper responseWrapper = new ResponseWrapper((HttpServletResponse) servletResponse);
try {
filterChain.doFilter(requestWrapper, responseWrapper);
} catch (ServletException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
String responseContent = new String(responseWrapper.getDataStream());
//此处可以处理
responseContent,然后封装成RestResponse 返回给前端
JSONObject jsonObject = JSONObject.parseObject(responseContent);
logger.info("responseContent({})",responseContent);
RestResponse fullResponse = new RestResponse(205, "OK-MESSAGE",jsonObject);
byte[] responseToSend = restResponseBytes(fullResponse);
servletResponse.getOutputStream().write(responseToSend);
logger.info("doFilter end" );
}
@Override
public void destroy() {
}
private byte[] restResponseBytes(RestResponse response) throws IOException {
String serialized = new ObjectMapper().writeValueAsString(response);
return serialized.getBytes("UTF-8");
}
}