package org.springframework.cloud.gateway.handler; import java.util.ArrayList; import java.util.List; import java.util.stream.Collectors; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import reactor.core.publisher.Mono; import org.springframework.cloud.gateway.filter.GatewayFilter; import org.springframework.cloud.gateway.filter.GatewayFilterChain; import org.springframework.cloud.gateway.filter.GlobalFilter; import org.springframework.cloud.gateway.filter.OrderedGatewayFilter; import org.springframework.cloud.gateway.filter.factory.GatewayFilterFactory; import org.springframework.cloud.gateway.route.Route; import org.springframework.core.Ordered; import org.springframework.core.annotation.AnnotationAwareOrderComparator; import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.WebHandler; import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.GATEWAY_ROUTE_ATTR; /** * WebHandler that delegates to a chain of {@link GlobalFilter} instances and * {@link GatewayFilterFactory} instances then to the target {@link WebHandler}. * * @author Rossen Stoyanchev * @author Spencer Gibb * @since 0.1 */ public class FilteringWebHandler implements WebHandler { protected static final Log logger = LogFactory.getLog(FilteringWebHandler.class); private final List<GatewayFilter> globalFilters; public FilteringWebHandler(List<GlobalFilter> globalFilters) { this.globalFilters = loadFilters(globalFilters); } private static List<GatewayFilter> loadFilters(List<GlobalFilter> filters) { return filters.stream().map(filter -> { GatewayFilterAdapter gatewayFilter = new GatewayFilterAdapter(filter); if (filter instanceof Ordered) { int order = ((Ordered) filter).getOrder(); return new OrderedGatewayFilter(gatewayFilter, order); } return gatewayFilter; }).collect(Collectors.toList()); } /* * TODO: relocate @EventListener(RefreshRoutesEvent.class) void handleRefresh() { * this.combinedFiltersForRoute.clear(); */ @Override public Mono<Void> handle(ServerWebExchange exchange) { Route route = exchange.getRequiredAttribute(GATEWAY_ROUTE_ATTR); List<GatewayFilter> gatewayFilters = route.getFilters(); List<GatewayFilter> combined = new ArrayList<>(this.globalFilters); combined.addAll(gatewayFilters); // TODO: needed or cached? AnnotationAwareOrderComparator.sort(combined); if (logger.isDebugEnabled()) { logger.debug("Sorted gatewayFilterFactories: " + combined); } return new DefaultGatewayFilterChain(combined).filter(exchange); } private static class DefaultGatewayFilterChain implements GatewayFilterChain { private final int index; private final List<GatewayFilter> filters; DefaultGatewayFilterChain(List<GatewayFilter> filters) { this.filters = filters; this.index = 0; } private DefaultGatewayFilterChain(DefaultGatewayFilterChain parent, int index) { this.filters = parent.getFilters(); this.index = index; } public List<GatewayFilter> getFilters() { return filters; } @Override public Mono<Void> filter(ServerWebExchange exchange) { return Mono.defer(() -> { if (this.index < filters.size()) { GatewayFilter filter = filters.get(this.index); DefaultGatewayFilterChain chain = new DefaultGatewayFilterChain(this, this.index + 1); return filter.filter(exchange, chain); } else { return Mono.empty(); // complete } }); } } private static class GatewayFilterAdapter implements GatewayFilter { private final GlobalFilter delegate; GatewayFilterAdapter(GlobalFilter delegate) { this.delegate = delegate; } @Override public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) { return this.delegate.filter(exchange, chain); } @Override public String toString() { final StringBuilder sb = new StringBuilder("GatewayFilterAdapter{"); sb.append("delegate=").append(delegate); sb.append('}'); return sb.toString(); } } }