package com.ph.sp.gateway.filter;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.lang.Pair;
import cn.hutool.core.util.StrUtil;
import lombok.extern.slf4j.Slf4j;
import org.reactivestreams.Publisher;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.core.annotation.Order;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferFactory;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.core.io.buffer.DefaultDataBufferFactory;
import org.springframework.data.redis.core.HashOperations;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.http.HttpHeaders;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpRequestDecorator;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.http.server.reactive.ServerHttpResponseDecorator;
import org.springframework.stereotype.Component;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilter;
import org.springframework.web.server.WebFilterChain;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import javax.annotation.Resource;
import java.nio.charset.StandardCharsets;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;
@Component
@Order(3)
@Slf4j
@SuppressWarnings("all")
public class MockFilter implements WebFilter {
@Value("${mockClose:true}")
private boolean mockClose;
@Value("${cacheClose:true}")
private boolean cacheClose;
@Value("${isPrd:true}")
private boolean prd;
private final String cache = "A_CACHE_";
private final String mock = "A_MOCK_";
@Resource
private RedisTemplate<String, Map<String, String>> hashTemplate;
@Override
public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
if (prd || mockClose && cacheClose) {
return chain.filter(exchange);
}
ServerHttpRequest request = exchange.getRequest();
String uri = request.getURI().getPath();
Map<String, String> mockConfig = getRedisConfig(exchange, mock);
Map<String, String> cacheConfig = getRedisConfig(exchange, cache);
if (CollUtil.isEmpty(mockConfig) && CollUtil.isEmpty(cacheConfig)) {
return chain.filter(exchange);
}
AtomicReference<String> requestBodyContent = new AtomicReference<>("");
Flux<DataBuffer> body = exchange.getRequest().getBody();
return body.doOnNext(buffer -> {
byte[] bytes = new byte[buffer.readableByteCount()];
buffer.read(bytes);
DataBufferUtils.release(buffer);
requestBodyContent.set(new String(bytes, StandardCharsets.UTF_8));
}).then(Mono.defer(() -> diyFilter(requestBodyContent.get(), exchange, chain, mockConfig, cacheConfig)
)).then();
}
private Mono<Void> diyFilter(String body, ServerWebExchange exchange, WebFilterChain chain,
Map<String, String> mockConfig, Map<String, String> cacheConfig) {
String uri = exchange.getRequest().getURI().getPath();
ServerHttpResponse response = exchange.getResponse();
Pair<String, String> mockResult = check(mockConfig, body);
if (StrUtil.isNotBlank(mockResult.getValue())) {
return hit(mock, mockResult.getKey(), Base64Utils.decode(mockResult.getValue()), uri, response);
}
Pair<String, String> cacheResult = check(cacheConfig, body);
if (StrUtil.isNotBlank(cacheResult.getValue())) {
return hit(cache, cacheResult.getKey(), cacheResult.getValue(), uri, response);
}
Flux<DataBuffer> cachedFlux = Flux.defer(() -> Mono.just(exchange.getResponse().bufferFactory().wrap(body.getBytes())));
ServerHttpRequest mutatedRequest = new ServerHttpRequestDecorator(exchange.getRequest()) {
@Override
public HttpHeaders getHeaders() {
HttpHeaders headers = new HttpHeaders();
headers.putAll(exchange.getRequest().getHeaders());
headers.remove(HttpHeaders.CONTENT_LENGTH);
headers.setContentLength(body.getBytes().length);
return headers;
}
@Override
public Flux<DataBuffer> getBody() {
return cachedFlux;
}
};
if (StrUtil.isNotBlank(cacheResult.getKey())) {
ServerHttpResponseDecorator mutatedResponse = decoratedResponse(exchange, cacheResult.getKey());
return chain.filter(exchange.mutate().request(mutatedRequest).response(mutatedResponse).build());
}
return chain.filter(exchange.mutate().request(mutatedRequest).build());
}
private Mono<Void> hit(String type, String key, String val, String uri, ServerHttpResponse resp) {
log.info("命中缓存数据:(key: {}, hashKey:{},用完请手动删除该条hash,谢谢)", type, type + uri, key);
resp.getHeaders().setContentType(HeaderConstants.APPLICATION_JSON_UTF8);
return resp.writeWith(Mono.fromSupplier(() -> resp.bufferFactory().wrap(val.getBytes(StandardCharsets.UTF_8))));
}
private ServerHttpResponseDecorator decoratedResponse(ServerWebExchange exchange, String key) {
String path = exchange.getRequest().getURI().getPath();
ServerHttpResponse originalResponse = exchange.getResponse();
DataBufferFactory bufferFactory = originalResponse.bufferFactory();
return new ServerHttpResponseDecorator(originalResponse) {
@Override
public Mono<Void> writeWith(Publisher<? extends DataBuffer> body) {
if (body instanceof Mono) {
Mono<? extends DataBuffer> mono = (Mono<? extends DataBuffer>) body;
body = mono.flux();
}
if (body instanceof Flux) {
Flux<? extends DataBuffer> fluxBody = (Flux<? extends DataBuffer>) body;
return super.writeWith(fluxBody.buffer().map(dataBuffer -> {
DataBufferFactory dataBufferFactory = new DefaultDataBufferFactory();
DataBuffer join = dataBufferFactory.join(dataBuffer);
byte[] content = new byte[join.readableByteCount()];
join.read(content);
DataBufferUtils.release(join);
HashOperations<String, String, String> operation = hashTemplate.opsForHash();
operation.put(cache + path, key, new String(content, StandardCharsets.UTF_8));
originalResponse.getHeaders().setContentLength(content.length);
return bufferFactory.wrap(content);
}));
}
return super.writeWith(body);
}
};
}
private Map<String, String> getRedisConfig(ServerWebExchange exchange, String pre) {
String uri = exchange.getRequest().getURI().getPath();
HashOperations<String, String, String> operation = hashTemplate.opsForHash();
return operation.entries(pre + uri);
}
private Pair<String, String> check(Map<String, String> config, String body) {
return config.entrySet().stream().filter(e -> StrUtil.contains(body, e.getKey())).findFirst()
.map(e -> Pair.of(e.getKey(), e.getValue())).orElse(Pair.of(null, null));
}
}