源码下载
https://github.com/cbeann/Demooo/tree/master/springboot-ratelimiter
部分代码
pom
<!-- https://mvnrepository.com/artifact/com.google.guava/guava --> <dependency> <groupId>com.google.guava</groupId> <artifactId>guava</artifactId> <version>28.1-jre</version> </dependency>
自定义注解ExtRateLimiter
package com.example.annotation; import java.lang.annotation.*; /** * @author chaird * @create 2021-03-20 17:57 */ @Target(ElementType.METHOD) @Retention(RetentionPolicy.RUNTIME) @Documented public @interface ExtRateLimiter { // 以每秒为单位固定的速率值往令牌桶中添加令牌 double permitsPerSecond(); // 在规定的毫秒数中,如果没有获取到令牌的话,则直接走服务器降级处理 long timeout(); }
拦截器
package com.example.Interceptor; import com.example.annotation.ExtRateLimiter; import com.google.common.util.concurrent.RateLimiter; import org.springframework.stereotype.Component; import org.springframework.web.method.HandlerMethod; import org.springframework.web.servlet.ModelAndView; import org.springframework.web.servlet.handler.HandlerInterceptorAdapter; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import java.io.PrintWriter; import java.lang.reflect.Method; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.TimeUnit; /** * @author CBeann * @create 2020-07-04 18:06 */ @Component public class RateLimiterInceptor extends HandlerInterceptorAdapter { private Map<String, RateLimiter> rateHashMap = new ConcurrentHashMap<>(); @Override public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception { if (!(handler instanceof HandlerMethod)) { return true; } final HandlerMethod handlerMethod = (HandlerMethod) handler; final Method method = handlerMethod.getMethod(); // 有这个注解 boolean methodAnn = method.isAnnotationPresent(ExtRateLimiter.class); if (methodAnn) { // 获取注解 ExtRateLimiter extRateLimiter = method.getDeclaredAnnotation(ExtRateLimiter.class); //获取注解属性 double permitsPerSecond = extRateLimiter.permitsPerSecond(); long timeout = extRateLimiter.timeout(); String key = method.getDeclaringClass().getName() + method.getName(); RateLimiter rateLimiter = null; if (rateHashMap.get(key) == null) { rateLimiter = RateLimiter.create(permitsPerSecond); rateHashMap.put(key, rateLimiter); } else { rateLimiter = rateHashMap.get(key); } boolean tryAcquire = rateLimiter.tryAcquire(timeout, TimeUnit.MILLISECONDS); if (!tryAcquire) { response.setContentType("application/json; charset=utf-8"); PrintWriter writer = response.getWriter(); writer.print("限流"); writer.close(); response.flushBuffer(); return false; } return super.preHandle(request, response, handler); } return true; } @Override public void postHandle( HttpServletRequest request, HttpServletResponse response, Object handler, ModelAndView modelAndView) throws Exception { super.postHandle(request, response, handler, modelAndView); } }
拦截器配置
package com.example.config; import com.example.Interceptor.RateLimiterInceptor; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Configuration; import org.springframework.web.servlet.config.annotation.InterceptorRegistry; import org.springframework.web.servlet.config.annotation.WebMvcConfigurerAdapter; /** * @author chaird * @create 2020-09-23 16:13 */ @Configuration public class MVCConfig extends WebMvcConfigurerAdapter { @Autowired private RateLimiterInceptor rateLimiterInceptor; @Override public void addInterceptors(InterceptorRegistry registry) { // 获取http请求拦截器 registry.addInterceptor(rateLimiterInceptor).addPathPatterns("/*"); } }
controller
package com.example.controller; import com.example.annotation.ExtRateLimiter; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.RestController; @RestController public class StockController { @GetMapping("/getStock") @ExtRateLimiter(permitsPerSecond = 2, timeout = 500) public Object getStock() { String s = "ok"; return s; } }
测试
大约每秒允许2个请求
http://localhost:8080/getStock