一、 引言:从相关性到因果性的认知飞跃
传统机器学习模型擅长发现数据中的相关性模式,但无法回答"为什么"和"如果...会怎样"这类因果问题。因果推断技术使AI系统能够:
理解机制:揭示变量间的真实因果关系而非虚假相关
干预预测:预测政策干预、治疗方案等行动的效果
反事实分析:估计未发生情况下的潜在结果
可解释决策:提供基于因果机制的透明解释
Java在科学计算、企业级系统和数据工程中的优势,使其成为构建生产级因果推断系统的理想平台。本文将基于Apache Commons Math、Weka和自定义因果算法库,演示如何构建可靠、可解释的因果AI系统。
二、 因果推断技术架构设计
- 系统架构概览
text
数据层 → 因果发现 → 因果建模 → 效应估计 → 决策支持
↓ ↓ ↓ ↓ ↓
观测数据 → PC算法 → 因果图 → 双重差分 → 策略推荐
↓ ↓ ↓ ↓ ↓
实验数据 → FCI算法 → 结构方程 → 倾向得分 → 效果评估
- 核心组件选型
数值计算:Apache Commons Math、Colt
统计学习:Weka、Smile
图计算:JGraphT、GraphStream
分布式计算:Apache Spark MLlib
可视化:JFreeChart、JavaFX
- 项目依赖配置
xml
3.6.1
3.8.6
3.0.1
1.5.2
3.2.0
org.springframework.boot
spring-boot-starter-web
${spring-boot.version}
<!-- 数学计算 -->
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-math3</artifactId>
<version>${commons-math.version}</version>
</dependency>
<!-- 机器学习 -->
<dependency>
<groupId>nz.ac.waikato.cms.weka</groupId>
<artifactId>weka-stable</artifactId>
<version>${weka.version}</version>
</dependency>
<dependency>
<groupId>com.github.haifengl</groupId>
<artifactId>smile-core</artifactId>
<version>${smile.version}</version>
</dependency>
<!-- 图计算 -->
<dependency>
<groupId>org.jgrapht</groupId>
<artifactId>jgrapht-core</artifactId>
<version>${jgrapht.version}</version>
</dependency>
<!-- 分布式计算 -->
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-mllib_2.13</artifactId>
<version>3.4.0</version>
</dependency>
<!-- 数据可视化 -->
<dependency>
<groupId>org.jfree</groupId>
<artifactId>jfreechart</artifactId>
<version>1.5.4</version>
</dependency>
<!-- 配置管理 -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-configuration-processor</artifactId>
<version>${spring-boot.version}</version>
</dependency>
三、 因果发现与图模型构建
- 因果图数据结构
java
// CausalGraph.java
@Data
public class CausalGraph {
private final Graph graph;
private final Map variables;
private final List causalPaths;
public CausalGraph() {
this.graph = new DefaultDirectedGraph<>(DefaultEdge.class);
this.variables = new HashMap<>();
this.causalPaths = new ArrayList<>();
}
/**
* 添加变量到因果图
*/
public void addVariable(String variableName, VariableType type, DataType dataType) {
if (!graph.containsVertex(variableName)) {
graph.addVertex(variableName);
variables.put(variableName, new Variable(variableName, type, dataType));
}
}
/**
* 添加因果关系边
*/
public void addCausalEdge(String cause, String effect, double strength) {
if (!graph.containsVertex(cause) || !graph.containsVertex(effect)) {
throw new IllegalArgumentException("变量不存在于图中");
}
DefaultEdge edge = graph.addEdge(cause, effect);
if (edge != null) {
graph.setEdgeWeight(edge, strength);
}
}
/**
* 检查因果关系是否存在
*/
public boolean hasCausalEffect(String cause, String effect) {
return graph.containsEdge(cause, effect);
}
/**
* 获取变量的直接原因
*/
public Set<String> getDirectCauses(String variable) {
Set<String> causes = new HashSet<>();
for (DefaultEdge edge : graph.incomingEdgesOf(variable)) {
causes.add(graph.getEdgeSource(edge));
}
return causes;
}
/**
* 获取变量的直接结果
*/
public Set<String> getDirectEffects(String variable) {
Set<String> effects = new HashSet<>();
for (DefaultEdge edge : graph.outgoingEdgesOf(variable)) {
effects.add(graph.getEdgeTarget(edge));
}
return effects;
}
/**
* 发现所有因果路径
*/
public List<CausalPath> discoverCausalPaths(String treatment, String outcome) {
List<CausalPath> paths = new ArrayList<>();
findAllPaths(treatment, outcome, new LinkedList<>(), paths, new HashSet<>());
return paths;
}
/**
* 深度优先搜索发现所有路径
*/
private void findAllPaths(String current, String target,
LinkedList<String> currentPath,
List<CausalPath> paths, Set<String> visited) {
visited.add(current);
currentPath.add(current);
if (current.equals(target)) {
// 找到一条路径
CausalPath path = new CausalPath(new ArrayList<>(currentPath));
paths.add(path);
} else {
// 继续搜索
for (String neighbor : getDirectEffects(current)) {
if (!visited.contains(neighbor)) {
findAllPaths(neighbor, target, currentPath, paths, visited);
}
}
}
currentPath.removeLast();
visited.remove(current);
}
/**
* 识别混淆变量
*/
public Set<String> identifyConfounders(String treatment, String outcome) {
Set<String> confounders = new HashSet<>();
// 寻找treatment和outcome的共同原因
Set<String> treatmentCauses = getAllCauses(treatment);
Set<String> outcomeCauses = getAllCauses(outcome);
for (String cause : treatmentCauses) {
if (outcomeCauses.contains(cause) && !cause.equals(treatment)) {
confounders.add(cause);
}
}
return confounders;
}
/**
* 识别中介变量
*/
public Set<String> identifyMediators(String treatment, String outcome) {
Set<String> mediators = new HashSet<>();
List<CausalPath> paths = discoverCausalPaths(treatment, outcome);
for (CausalPath path : paths) {
if (path.getPath().size() > 2) {
// 路径中间的所有变量都是中介变量
for (int i = 1; i < path.getPath().size() - 1; i++) {
mediators.add(path.getPath().get(i));
}
}
}
return mediators;
}
/**
* 获取变量的所有原因(直接和间接)
*/
private Set<String> getAllCauses(String variable) {
Set<String> allCauses = new HashSet<>();
collectAllCauses(variable, allCauses, new HashSet<>());
return allCauses;
}
private void collectAllCauses(String variable, Set<String> causes, Set<String> visited) {
visited.add(variable);
for (String cause : getDirectCauses(variable)) {
if (!visited.contains(cause)) {
causes.add(cause);
collectAllCauses(cause, causes, visited);
}
}
}
/**
* 验证因果图的无环性
*/
public boolean isAcyclic() {
CycleDetector<String, DefaultEdge> cycleDetector = new CycleDetector<>(graph);
return !cycleDetector.detectCycles();
}
/**
* 计算因果距离
*/
public int getCausalDistance(String cause, String effect) {
DijkstraShortestPath<String, DefaultEdge> shortestPath =
new DijkstraShortestPath<>(graph);
return (int) shortestPath.getPath(cause, effect).getLength();
}
// 数据类
@Data
@AllArgsConstructor
public static class Variable {
private String name;
private VariableType type;
private DataType dataType;
}
@Data
@AllArgsConstructor
public static class CausalPath {
private List<String> path;
private double pathStrength;
public CausalPath(List<String> path) {
this.path = path;
this.pathStrength = calculatePathStrength();
}
private double calculatePathStrength() {
double strength = 1.0;
for (int i = 0; i < path.size() - 1; i++) {
DefaultEdge edge = graph.getEdge(path.get(i), path.get(i + 1));
strength *= graph.getEdgeWeight(edge);
}
return strength;
}
}
public enum VariableType {
TREATMENT, OUTCOME, CONFOUNDER, MEDIATOR, INSTRUMENT, COVARIATE
}
public enum DataType {
CONTINUOUS, CATEGORICAL, BINARY, ORDINAL
}
}
- PC算法因果发现
java
// CausalDiscoveryService.java
@Service
@Slf4j
public class CausalDiscoveryService {
private final StatisticalTestService statisticalTest;
private final ConditionalIndependenceTest ciTest;
public CausalDiscoveryService(StatisticalTestService statisticalTest) {
this.statisticalTest = statisticalTest;
this.ciTest = new ConditionalIndependenceTest(statisticalTest);
}
/**
* PC算法 - 从数据中发现因果结构
*/
public CausalGraph discoverCausalStructure(Dataset dataset, double alpha) {
log.info("开始PC算法因果发现,数据集大小: {}, 变量数: {}",
dataset.size(), dataset.getVariableNames().size());
CausalGraph graph = initializeFullyConnectedGraph(dataset);
int depth = 0;
// PC算法主循环
while (true) {
boolean changed = false;
for (String x : dataset.getVariableNames()) {
for (String y : dataset.getVariableNames()) {
if (x.equals(y) || !graph.hasCausalEdge(x, y)) continue;
Set<String> adjX = graph.getDirectEffects(x);
adjX.addAll(graph.getDirectCauses(x));
adjX.remove(y);
// 检查所有大小为depth的子集
if (adjX.size() >= depth) {
for (Set<String> subset : generateSubsets(adjX, depth)) {
if (ciTest.isConditionallyIndependent(dataset, x, y, subset, alpha)) {
graph.removeEdge(x, y);
changed = true;
log.debug("移除边: {} -> {} | {}", x, y, subset);
break;
}
}
}
}
}
if (!changed) break;
depth++;
}
// 方向确定阶段
orientEdges(graph, dataset, alpha);
log.info("PC算法完成,发现 {} 条边", graph.edgeSet().size());
return graph;
}
/**
* FCI算法 - 处理潜在混淆变量
*/
public CausalGraph discoverCausalStructureWithLatents(Dataset dataset, double alpha) {
log.info("开始FCI算法因果发现(处理潜在变量)");
CausalGraph graph = initializeFullyConnectedGraph(dataset);
// FCI算法实现(简化版)
// 1. PC算法的骨架发现阶段
graph = discoverCausalStructure(dataset, alpha);
// 2. 处理潜在混淆变量的额外规则
applyFCIRules(graph, dataset, alpha);
return graph;
}
/**
* 基于约束的因果发现
*/
public CausalGraph constraintBasedDiscovery(Dataset dataset,
DiscoveryMethod method,
double alpha) {
switch (method) {
case PC:
return discoverCausalStructure(dataset, alpha);
case FCI:
return discoverCausalStructureWithLatents(dataset, alpha);
case GES:
return gesDiscovery(dataset);
case LINGAM:
return lingamDiscovery(dataset);
default:
throw new IllegalArgumentException("不支持的发现方法: " + method);
}
}
/**
* GES算法 - 基于分数的因果发现
*/
private CausalGraph gesDiscovery(Dataset dataset) {
log.info("执行GES算法");
// 实现Greedy Equivalence Search算法
CausalGraph currentGraph = initializeEmptyGraph(dataset);
double currentScore = calculateBICScore(currentGraph, dataset);
boolean improved;
do {
improved = false;
// 前向阶段:添加边
for (String x : dataset.getVariableNames()) {
for (String y : dataset.getVariableNames()) {
if (x.equals(y) || currentGraph.hasCausalEdge(x, y)) continue;
CausalGraph candidate = currentGraph.copy();
candidate.addCausalEdge(x, y, 0.5); // 初始强度
double candidateScore = calculateBICScore(candidate, dataset);
if (candidateScore > currentScore) {
currentGraph = candidate;
currentScore = candidateScore;
improved = true;
}
}
}
// 后向阶段:移除边
for (DefaultEdge edge : currentGraph.edgeSet()) {
CausalGraph candidate = currentGraph.copy();
candidate.removeEdge(edge);
double candidateScore = calculateBICScore(candidate, dataset);
if (candidateScore > currentScore) {
currentGraph = candidate;
currentScore = candidateScore;
improved = true;
}
}
} while (improved);
return currentGraph;
}
/**
* LiNGAM算法 - 线性非高斯模型
*/
private CausalGraph lingamDiscovery(Dataset dataset) {
log.info("执行LiNGAM算法");
// 实现Linear Non-Gaussian Acyclic Model
double[][] data = dataset.toDoubleMatrix();
int n = data.length;
int p = data[0].length;
// 中心化数据
double[][] centered = centerData(data);
// ICA分解
ICAResult icaResult = fastICA(centered);
// 估计因果顺序
int[] causalOrder = estimateCausalOrder(icaResult, centered);
// 构建因果图
CausalGraph graph = initializeEmptyGraph(dataset);
double[][] B = estimateCausalEffects(icaResult, causalOrder, centered);
for (int i = 0; i < p; i++) {
for (int j = 0; j < p; j++) {
if (Math.abs(B[i][j]) > 1e-6 && i != j) {
String cause = dataset.getVariableNames().get(causalOrder[j]);
String effect = dataset.getVariableNames().get(causalOrder[i]);
graph.addCausalEdge(cause, effect, Math.abs(B[i][j]));
}
}
}
return graph;
}
/**
* 初始化全连接图
*/
private CausalGraph initializeFullyConnectedGraph(Dataset dataset) {
CausalGraph graph = new CausalGraph();
// 添加所有变量
for (String variable : dataset.getVariableNames()) {
graph.addVariable(variable, VariableType.COVARIATE, DataType.CONTINUOUS);
}
// 添加所有可能的边
for (String x : dataset.getVariableNames()) {
for (String y : dataset.getVariableNames()) {
if (!x.equals(y)) {
graph.addCausalEdge(x, y, 1.0);
}
}
}
return graph;
}
/**
* 初始化空图
*/
private CausalGraph initializeEmptyGraph(Dataset dataset) {
CausalGraph graph = new CausalGraph();
for (String variable : dataset.getVariableNames()) {
graph.addVariable(variable, VariableType.COVARIATE, DataType.CONTINUOUS);
}
return graph;
}
/**
* 边定向规则
*/
private void orientEdges(CausalGraph graph, Dataset dataset, double alpha) {
// 实现PC算法的边定向规则
// 规则1: 碰撞点识别
for (String b : dataset.getVariableNames()) {
Set<String> adjB = new HashSet<>();
adjB.addAll(graph.getDirectCauses(b));
adjB.addAll(graph.getDirectEffects(b));
for (String a : adjB) {
for (String c : adjB) {
if (!a.equals(c) && !graph.hasCausalEdge(a, c) && !graph.hasCausalEdge(c, a)) {
// 检查a-b-c是否形成碰撞点
if (isCollider(graph, dataset, a, b, c, alpha)) {
// 定向为 a -> b <- c
if (graph.hasCausalEdge(b, a)) {
graph.removeEdge(b, a);
graph.addCausalEdge(a, b, 0.5);
}
if (graph.hasCausalEdge(b, c)) {
graph.removeEdge(b, c);
graph.addCausalEdge(c, b, 0.5);
}
}
}
}
}
}
}
/**
* 检查碰撞点
*/
private boolean isCollider(CausalGraph graph, Dataset dataset,
String a, String b, String c, double alpha) {
Set<String> conditioningSet = new HashSet<>();
conditioningSet.addAll(graph.getDirectCauses(b));
conditioningSet.remove(a);
conditioningSet.remove(c);
return !ciTest.isConditionallyIndependent(dataset, a, c, conditioningSet, alpha);
}
/**
* 应用FCI规则
*/
private void applyFCIRules(CausalGraph graph, Dataset dataset, double alpha) {
// 实现FCI算法的额外规则
// 这里简化实现,实际需要更复杂的规则应用
log.info("应用FCI规则处理潜在变量");
}
/**
* 计算BIC分数
*/
private double calculateBICScore(CausalGraph graph, Dataset dataset) {
// 贝叶斯信息准则评分
double logLikelihood = calculateLogLikelihood(graph, dataset);
int parameters = countParameters(graph);
int sampleSize = dataset.size();
return logLikelihood - 0.5 * parameters * Math.log(sampleSize);
}
private double calculateLogLikelihood(CausalGraph graph, Dataset dataset) {
// 计算图结构的对数似然
// 简化实现
return -1000.0; // 占位值
}
private int countParameters(CausalGraph graph) {
// 计算模型参数数量
return graph.edgeSet().size() * 2; // 简化估计
}
/**
* 生成所有大小为k的子集
*/
private List<Set<String>> generateSubsets(Set<String> set, int k) {
List<Set<String>> subsets = new ArrayList<>();
generateSubsetsHelper(new ArrayList<>(set), k, 0, new HashSet<>(), subsets);
return subsets;
}
private void generateSubsetsHelper(List<String> elements, int k, int start,
Set<String> current, List<Set<String>> subsets) {
if (current.size() == k) {
subsets.add(new HashSet<>(current));
return;
}
for (int i = start; i < elements.size(); i++) {
current.add(elements.get(i));
generateSubsetsHelper(elements, k, i + 1, current, subsets);
current.remove(elements.get(i));
}
}
// 数据预处理方法
private double[][] centerData(double[][] data) {
int n = data.length;
int p = data[0].length;
double[][] centered = new double[n][p];
for (int j = 0; j < p; j++) {
double mean = 0.0;
for (int i = 0; i < n; i++) {
mean += data[i][j];
}
mean /= n;
for (int i = 0; i < n; i++) {
centered[i][j] = data[i][j] - mean;
}
}
return centered;
}
// ICA和相关方法(简化实现)
private ICAResult fastICA(double[][] data) {
// 简化ICA实现
return new ICAResult();
}
private int[] estimateCausalOrder(ICAResult icaResult, double[][] data) {
// 估计因果顺序
int p = data[0].length;
int[] order = new int[p];
for (int i = 0; i < p; i++) order[i] = i;
return order;
}
private double[][] estimateCausalEffects(ICAResult icaResult, int[] causalOrder, double[][] data) {
// 估计因果效应矩阵
int p = data[0].length;
return new double[p][p];
}
// 内部类
private static class ICAResult {
// ICA结果占位
}
public enum DiscoveryMethod {
PC, FCI, GES, LINGAM
}
}
四、 因果效应估计方法
- 双重差分法(DID)
java
// DifferenceInDifferences.java
@Service
@Slf4j
public class DifferenceInDifferences {
private final StatisticalTestService statisticalTest;
public DifferenceInDifferences(StatisticalTestService statisticalTest) {
this.statisticalTest = statisticalTest;
}
/**
* 标准双重差分估计
*/
public DIDResult estimate(DIDDataset dataset) {
log.info("开始双重差分分析,处理组: {}, 对照组: {}",
dataset.getTreatmentGroup().size(), dataset.getControlGroup().size());
// 计算处理组前后差异
double treatmentBefore = calculateMeanOutcome(dataset.getTreatmentGroup(), true);
double treatmentAfter = calculateMeanOutcome(dataset.getTreatmentGroup(), false);
double treatmentDiff = treatmentAfter - treatmentBefore;
// 计算对照组前后差异
double controlBefore = calculateMeanOutcome(dataset.getControlGroup(), true);
double controlAfter = calculateMeanOutcome(dataset.getControlGroup(), false);
double controlDiff = controlAfter - controlBefore;
// 双重差分估计
double didEstimate = treatmentDiff - controlDiff;
// 计算标准误和置信区间
double se = calculateStandardError(dataset, treatmentDiff, controlDiff);
double zScore = 1.96; // 95%置信水平
double ciLower = didEstimate - zScore * se;
double ciUpper = didEstimate + zScore * se;
// 假设检验
double tStatistic = didEstimate / se;
double pValue = calculatePValue(tStatistic, dataset.getTotalObservations());
return new DIDResult(didEstimate, se, ciLower, ciUpper, tStatistic, pValue);
}
/**
* 多期DID估计
*/
public MultiPeriodDIDResult estimateMultiPeriod(MultiPeriodDIDDataset dataset) {
log.info("开始多期双重差分分析");
List<Double> didEstimates = new ArrayList<>();
List<Double> timePeriods = new ArrayList<>();
// 对每个时间点计算DID
for (int period : dataset.getTimePeriods()) {
DIDDataset periodData = dataset.getDataForPeriod(period);
DIDResult periodResult = estimate(periodData);
didEstimates.add(periodResult.getEstimate());
timePeriods.add((double) period);
}
// 平行趋势检验
boolean parallelTrends = testParallelTrends(didEstimates);
// 动态效应分析
DynamicEffects dynamicEffects = analyzeDynamicEffects(didEstimates, timePeriods);
return new MultiPeriodDIDResult(didEstimates, parallelTrends, dynamicEffects);
}
/**
* 事件研究法
*/
public EventStudyResult eventStudyAnalysis(EventStudyDataset dataset) {
log.info("开始事件研究分析");
Map<Integer, Double> effectsByTime = new HashMap<>();
Map<Integer, Double> pValuesByTime = new HashMap<>();
// 分析事件前后各期的效应
for (int relativeTime = dataset.getMinTime(); relativeTime <= dataset.getMaxTime(); relativeTime++) {
if (relativeTime == 0) continue; // 跳过事件当期
DIDDataset timeData = dataset.getDataForRelativeTime(relativeTime);
DIDResult result = estimate(timeData);
effectsByTime.put(relativeTime, result.getEstimate());
pValuesByTime.put(relativeTime, result.getPValue());
}
// 检验前置趋势
boolean preTrendSignificant = testPreTrend(effectsByTime, pValuesByTime);
return new EventStudyResult(effectsByTime, pValuesByTime, preTrendSignificant);
}
/**
* 计算平均结果
*/
private double calculateMeanOutcome(List<Observation> observations, boolean isBefore) {
return observations.stream()
.filter(obs -> obs.isBefore() == isBefore)
.mapToDouble(Observation::getOutcome)
.average()
.orElse(0.0);
}
/**
* 计算标准误
*/
private double calculateStandardError(DIDDataset dataset, double treatmentDiff, double controlDiff) {
int nTreatment = dataset.getTreatmentGroup().size() / 2; // 前后各一半
int nControl = dataset.getControlGroup().size() / 2;
double varTreatment = calculateVariance(dataset.getTreatmentGroup(), true) +
calculateVariance(dataset.getTreatmentGroup(), false);
double varControl = calculateVariance(dataset.getControlGroup(), true) +
calculateVariance(dataset.getControlGroup(), false);
return Math.sqrt(varTreatment / nTreatment + varControl / nControl);
}
/**
* 计算方差
*/
private double calculateVariance(List<Observation> observations, boolean isBefore) {
double mean = calculateMeanOutcome(observations, isBefore);
return observations.stream()
.filter(obs -> obs.isBefore() == isBefore)
.mapToDouble(obs -> Math.pow(obs.getOutcome() - mean, 2))
.average()
.orElse(0.0);
}
/**
* 计算p值
*/
private double calculatePValue(double tStatistic, int df) {
// 使用t分布计算p值
TDistribution tDistribution = new TDistribution(df);
return 2 * (1 - tDistribution.cumulativeProbability(Math.abs(tStatistic)));
}
/**
* 平行趋势检验
*/
private boolean testParallelTrends(List<Double> didEstimates) {
if (didEstimates.size() < 3) return true;
// 检验前置期效应是否显著不为零
List<Double> prePeriodEffects = didEstimates.subList(0, didEstimates.size() / 2);
double prePeriodMean = prePeriodEffects.stream().mapToDouble(Double::doubleValue).average().orElse(0.0);
// 简化检验:如果前置期平均效应接近0,则认为满足平行趋势
return Math.abs(prePeriodMean) < 0.1; // 阈值可根据实际情况调整
}
/**
* 分析动态效应
*/
private DynamicEffects analyzeDynamicEffects(List<Double> effects, List<Double> timePeriods) {
// 分析效应的动态模式
boolean isPersistent = testEffectPersistence(effects);
boolean isIncreasing = testEffectTrend(effects);
return new DynamicEffects(isPersistent, isIncreasing, effects);
}
/**
* 检验效应持续性
*/
private boolean testEffectPersistence(List<Double> effects) {
if (effects.size() < 2) return false;
// 检验后期效应是否仍然显著
List<Double> lateEffects = effects.subList(effects.size() / 2, effects.size());
double lateMean = lateEffects.stream().mapToDouble(Double::doubleValue).average().orElse(0.0);
return Math.abs(lateMean) > 0.05; // 阈值可根据实际情况调整
}
/**
* 检验效应趋势
*/
private boolean testEffectTrend(List<Double> effects) {
if (effects.size() < 2) return false;
// 简单线性趋势检验
double[] x = new double[effects.size()];
double[] y = effects.stream().mapToDouble(Double::doubleValue).toArray();
for (int i = 0; i < x.length; i++) {
x[i] = i;
}
SimpleRegression regression = new SimpleRegression();
for (int i = 0; i < x.length; i++) {
regression.addData(x[i], y[i]);
}
return regression.getSlope() > 0; // 正趋势
}
/**
* 检验前置趋势
*/
private boolean testPreTrend(Map<Integer, Double> effects, Map<Integer, Double> pValues) {
// 检验事件发生前的效应是否显著
double preTrendSignificance = effects.entrySet().stream()
.filter(entry -> entry.getKey() < 0) // 前置期
.mapToDouble(entry -> pValues.get(entry.getKey()))
.average()
.orElse(1.0);
return preTrendSignificance > 0.1; // 前置期不显著
}
// 数据类
@Data
@AllArgsConstructor
public static class DIDResult {
private double estimate;
private double standardError;
private double confidenceIntervalLower;
private double confidenceIntervalUpper;
private double tStatistic;
private double pValue;
public boolean isStatisticallySignificant() {
return pValue < 0.05;
}
}
@Data
@AllArgsConstructor
public static class MultiPeriodDIDResult {
private List<Double> estimatesByPeriod;
private boolean parallelTrendsHolds;
private DynamicEffects dynamicEffects;
}
@Data
@AllArgsConstructor
public static class EventStudyResult {
private Map<Integer, Double> effectsByRelativeTime;
private Map<Integer, Double> pValuesByRelativeTime;
private boolean preTrendInsignificant;
}
@Data
@AllArgsConstructor
public static class DynamicEffects {
private boolean persistent;
private boolean increasing;
private List<Double> effectPattern;
}
}
- 倾向得分匹配(PSM)
java
// PropensityScoreMatching.java
@Service
@Slf4j
public class PropensityScoreMatching {
private final WekaMLService wekaMLService;
private final DistanceMetric distanceMetric;
public PropensityScoreMatching(WekaMLService wekaMLService) {
this.wekaMLService = wekaMLService;
this.distanceMetric = new EuclideanDistance();
}
/**
* 估计倾向得分
*/
public PropensityScoreResult estimatePropensityScores(PSMDataset dataset) {
log.info("估计倾向得分,样本数: {}, 协变量数: {}",
dataset.size(), dataset.getCovariateNames().size());
// 使用逻辑回归估计倾向得分
LogisticRegressionModel model = wekaMLService.trainLogisticRegression(
dataset.getFeatures(), dataset.getTreatments());
double[] propensityScores = wekaMLService.predictProbabilities(model, dataset.getFeatures());
// 检查倾向得分重叠
OverlapResult overlap = checkOverlap(propensityScores, dataset.getTreatments());
return new PropensityScoreResult(propensityScores, model, overlap);
}
/**
* 执行倾向得分匹配
*/
public PSMResult performMatching(PSMDataset dataset, MatchingMethod method, double caliper) {
PropensityScoreResult psResult = estimatePropensityScores(dataset);
double[] propensityScores = psResult.getPropensityScores();
log.info("执行倾向得分匹配,方法: {}, 卡钳值: {}", method, caliper);
List<Match> matches = new ArrayList<>();
switch (method) {
case NEAREST_NEIGHBOR:
matches = nearestNeighborMatching(propensityScores, dataset.getTreatments(), caliper);
break;
case KERNEL:
matches = kernelMatching(propensityScores, dataset.getTreatments(), dataset.getOutcomes());
break;
case STRATIFICATION:
matches = stratificationMatching(propensityScores, dataset.getTreatments(), dataset.getOutcomes());
break;
case MAHALANOBIS:
matches = mahalanobisMatching(dataset.getFeatures(), dataset.getTreatments(), caliper);
break;
}
// 计算平均处理效应
double ate = calculateAverageTreatmentEffect(matches, dataset.getOutcomes());
// 匹配质量评估
MatchingQuality quality = assessMatchingQuality(matches, dataset, propensityScores);
return new PSMResult(matches, ate, quality, psResult);
}
/**
* 最近邻匹配
*/
private List<Match> nearestNeighborMatching(double[] propensityScores, int[] treatments, double caliper) {
List<Match> matches = new ArrayList<>();
List<Integer> treatedIndices = getTreatedIndices(treatments);
List<Integer> controlIndices = getControlIndices(treatments);
for (int treatedIdx : treatedIndices) {
double treatedPS = propensityScores[treatedIdx];
int bestMatch = -1;
double bestDistance = Double.MAX_VALUE;
for (int controlIdx : controlIndices) {
double controlPS = propensityScores[controlIdx];
double distance = Math.abs(treatedPS - controlPS);
if (distance <= caliper && distance < bestDistance) {
bestDistance = distance;
bestMatch = controlIdx;
}
}
if (bestMatch != -1) {
matches.add(new Match(treatedIdx, bestMatch, bestDistance));
controlIndices.remove((Integer) bestMatch); // 不放回匹配
}
}
return matches;
}
/**
* 核匹配
*/
private List<Match> kernelMatching(double[] propensityScores, int[] treatments, double[] outcomes) {
List<Match> matches = new ArrayList<>();
List<Integer> treatedIndices = getTreatedIndices(treatments);
List<Integer> controlIndices = getControlIndices(treatments);
double bandwidth = calculateBandwidth(propensityScores);
for (int treatedIdx : treatedIndices) {
double treatedPS = propensityScores[treatedIdx];
double totalWeight = 0.0;
double weightedOutcome = 0.0;
for (int controlIdx : controlIndices) {
double controlPS = propensityScores[controlIdx];
double distance = Math.abs(treatedPS - controlPS);
double weight = kernelFunction(distance / bandwidth);
totalWeight += weight;
weightedOutcome += weight * outcomes[controlIdx];
}
if (totalWeight > 0) {
double counterfactual = weightedOutcome / totalWeight;
matches.add(new Match(treatedIdx, -1, counterfactual)); // 虚拟匹配
}
}
return matches;
}
/**
* 分层匹配
*/
private List<Match> stratificationMatching(double[] propensityScores, int[] treatments, double[] outcomes) {
List<Match> matches = new ArrayList<>();
// 基于倾向得分创建分层
int numStrata = 5;
double[] strataBounds = calculateStrataBounds(propensityScores, numStrata);
for (int stratum = 0; stratum < numStrata; stratum++) {
List<Integer> treatedInStratum = new ArrayList<>();
List<Integer> controlInStratum = new ArrayList<>();
for (int i = 0; i < propensityScores.length; i++) {
if (propensityScores[i] >= strataBounds[stratum] &&
propensityScores[i] < strataBounds[stratum + 1]) {
if (treatments[i] == 1) {
treatedInStratum.add(i);
} else {
controlInStratum.add(i);
}
}
}
// 计算层内平均效应
if (!treatedInStratum.isEmpty() && !controlInStratum.isEmpty()) {
double treatedMean = calculateMean(outcomes, treatedInStratum);
double controlMean = calculateMean(outcomes, controlInStratum);
double stratumEffect = treatedMean - controlMean;
for (int treatedIdx : treatedInStratum) {
matches.add(new Match(treatedIdx, -1, stratumEffect));
}
}
}
return matches;
}
/**
* 马氏距离匹配
*/
private List<Match> mahalanobisMatching(double[][] features, int[] treatments, double caliper) {
List<Match> matches = new ArrayList<>();
List<Integer> treatedIndices = getTreatedIndices(treatments);
List<Integer> controlIndices = getControlIndices(treatments);
// 计算协方差矩阵的逆
RealMatrix covariance = calculateCovarianceMatrix(features);
RealMatrix covarianceInverse = new LUDecomposition(covariance).getSolver().getInverse();
for (int treatedIdx : treatedIndices) {
double[] treatedFeatures = features[treatedIdx];
int bestMatch = -1;
double bestDistance = Double.MAX_VALUE;
for (int controlIdx : controlIndices) {
double[] controlFeatures = features[controlIdx];
double distance = mahalanobisDistance(treatedFeatures, controlFeatures, covarianceInverse);
if (distance <= caliper && distance < bestDistance) {
bestDistance = distance;
bestMatch = controlIdx;
}
}
if (bestMatch != -1) {
matches.add(new Match(treatedIdx, bestMatch, bestDistance));
controlIndices.remove((Integer) bestMatch);
}
}
return matches;
}
/**
* 计算平均处理效应
*/
private double calculateAverageTreatmentEffect(List<Match> matches, double[] outcomes) {
if (matches.isEmpty()) return 0.0;
double totalEffect = 0.0;
for (Match match : matches) {
double treatedOutcome = outcomes[match.getTreatedIndex()];
double controlOutcome = match.getControlIndex() >= 0 ?
outcomes[match.getControlIndex()] : match.getCounterfactual();
totalEffect += (treatedOutcome - controlOutcome);
}
return totalEffect / matches.size();
}
/**
* 评估匹配质量
*/
private MatchingQuality assessMatchingQuality(List<Match> matches, PSMDataset dataset, double[] propensityScores) {
// 计算标准化差异
double standardizedBias = calculateStandardizedBias(matches, dataset);
// 计算匹配后的倾向得分分布相似性
double distributionSimilarity = calculateDistributionSimilarity(matches, propensityScores, dataset.getTreatments());
// 计算匹配率
double matchingRate = (double) matches.size() / getTreatedIndices(dataset.getTreatments()).size();
return new MatchingQuality(standardizedBias, distributionSimilarity, matchingRate);
}
/**
* 检查倾向得分重叠
*/
private OverlapResult checkOverlap(double[] propensityScores, int[] treatments) {
List<Double> treatedScores = new ArrayList<>();
List<Double> controlScores = new ArrayList<>();
for (int i = 0; i < propensityScores.length; i++) {
if (treatments[i] == 1) {
treatedScores.add(propensityScores[i]);
} else {
controlScores.add(propensityScores[i]);
}
}
double treatedMin = Collections.min(treatedScores);
double treatedMax = Collections.max(treatedScores);
double controlMin = Collections.min(controlScores);
double controlMax = Collections.max(controlScores);
boolean goodOverlap = (treatedMin <= controlMax) && (controlMin <= treatedMax);
double overlapArea = calculateOverlapArea(treatedScores, controlScores);
return new OverlapResult(goodOverlap, overlapArea, treatedMin, treatedMax, controlMin, controlMax);
}
// 辅助方法
private List<Integer> getTreatedIndices(int[] treatments) {
List<Integer> indices = new ArrayList<>();
for (int i = 0; i < treatments.length; i++) {
if (treatments[i] == 1) indices.add(i);
}
return indices;
}
private List<Integer> getControlIndices(int[] treatments) {
List<Integer> indices = new ArrayList<>();
for (int i = 0; i < treatments.length; i++) {
if (treatments[i] == 0) indices.add(i);
}
return indices;
}
private double calculateBandwidth(double[] scores) {
// 使用Silverman规则计算带宽
double std = new StandardDeviation().evaluate(scores);
double iqr = calculateIQR(scores);
double n = scores.length;
double A = Math.min(std, iqr / 1.34);
return 0.9 * A * Math.pow(n, -0.2);
}
private double kernelFunction(double u) {
// Epanechnikov核函数
return Math.abs(u) <= 1 ? 0.75 * (1 - u * u) : 0;
}
private double[] calculateStrataBounds(double[] scores, int numStrata) {
double[] bounds = new double[numStrata + 1];
for (int i = 0; i <= numStrata; i++) {
bounds[i] = (double) i / numStrata;
}
return bounds;
}
private double calculateMean(double[] values, List<Integer> indices) {
return indices.stream().mapToDouble(i -> values[i]).average().orElse(0.0);
}
private RealMatrix calculateCovarianceMatrix(double[][] features) {
// 计算协方差矩阵
Covariance covariance = new Covariance(features);
return covariance.getCovarianceMatrix();
}
private double mahalanobisDistance(double[] x, double[] y, RealMatrix covarianceInverse) {
RealVector diff = new ArrayRealVector(x).subtract(new ArrayRealVector(y));
return Math.sqrt(diff.dotProduct(covarianceInverse.operate(diff)));
}
private double calculateStandardizedBias(List<Match> matches, PSMDataset dataset) {
// 计算标准化差异
return 0.0; // 简化实现
}
private double calculateDistributionSimilarity(List<Match> matches, double[] propensityScores, int[] treatments) {
// 计算分布相似性
return 1.0; // 简化实现
}
private double calculateOverlapArea(List<Double> treated, List<Double> control) {
// 计算重叠区域面积
return 0.8; // 简化实现
}
private double calculateIQR(double[] values) {
// 计算四分位距
DescriptiveStatistics stats = new DescriptiveStatistics(values);
return stats.getPercentile(75) - stats.getPercentile(25);
}
// 数据类
@Data
@AllArgsConstructor
public static class PropensityScoreResult {
private double[] propensityScores;
private LogisticRegressionModel model;
private OverlapResult overlap;
}
@Data
@AllArgsConstructor
public static class PSMResult {
private List<Match> matches;
private double averageTreatmentEffect;
private MatchingQuality quality;
private PropensityScoreResult propensityScoreResult;
}
@Data
@AllArgsConstructor
public static class Match {
private int treatedIndex;
private int controlIndex; // -1表示虚拟匹配
private double distanceOrCounterfactual;
}
@Data
@AllArgsConstructor
public static class OverlapResult {
private boolean goodOverlap;
private double overlapArea;
private double treatedMin;
private double treatedMax;
private double controlMin;
private double controlMax;
}
@Data
@AllArgsConstructor
public static class MatchingQuality {
private double standardizedBias; // 标准化差异
private double distributionSimilarity; // 分布相似性
private double matchingRate; // 匹配率
}
public enum MatchingMethod {
NEAREST_NEIGHBOR, KERNEL, STRATIFICATION, MAHALANOBIS
}
}
五、 反事实推理与因果森林
- 因果森林实现
java
// CausalForest.java
@Service
@Slf4j
public class CausalForest {
private final RandomForestFactory forestFactory;
private final int numTrees;
private final int maxDepth;
private final int minSamplesSplit;
public CausalForest(int numTrees, int maxDepth, int minSamplesSplit) {
this.numTrees = numTrees;
this.maxDepth = maxDepth;
this.minSamplesSplit = minSamplesSplit;
this.forestFactory = new RandomForestFactory();
}
/**
* 训练因果森林
*/
public CausalForestModel train(CausalForestDataset dataset) {
log.info("训练因果森林,树数量: {}, 最大深度: {}, 最小分裂样本: {}",
numTrees, maxDepth, minSamplesSplit);
List<CausalTree> trees = new ArrayList<>();
// 并行训练多棵树
trees = IntStream.range(0, numTrees)
.parallel()
.mapToObj(i -> trainTree(dataset, i))
.collect(Collectors.toList());
// 计算变量重要性
Map<String, Double> variableImportance = calculateVariableImportance(trees, dataset);
return new CausalForestModel(trees, variableImportance);
}
/**
* 训练单棵因果树
*/
private CausalTree trainTree(CausalForestDataset dataset, int treeId) {
log.debug("训练因果树 {}", treeId);
// 自助采样
int[] bootstrapIndices = bootstrapSample(dataset.size());
CausalForestDataset bootstrapData = dataset.subset(bootstrapIndices);
// 递归构建树
TreeNode root = buildTree(bootstrapData, 0);
return new CausalTree(root, treeId);
}
/**
* 递归构建树节点
*/
private TreeNode buildTree(CausalForestDataset data, int depth) {
// 终止条件
if (data.size() < minSamplesSplit || depth >= maxDepth) {
return createLeafNode(data);
}
// 寻找最佳分裂
SplitResult bestSplit = findBestSplit(data);
if (bestSplit == null || bestSplit.getImprovement() < 1e-6) {
return createLeafNode(data);
}
// 分裂数据
CausalForestDataset leftData = data.subset(bestSplit.getLeftIndices());
CausalForestDataset rightData = data.subset(bestSplit.getRightIndices());
// 递归构建子树
TreeNode leftChild = buildTree(leftData, depth + 1);
TreeNode rightChild = buildTree(rightData, depth + 1);
return new TreeNode(bestSplit, leftChild, rightChild, depth);
}
/**
* 寻找最佳分裂
*/
private SplitResult findBestSplit(CausalForestDataset data) {
SplitResult bestSplit = null;
double bestImprovement = -Double.MAX_VALUE;
// 随机选择特征子集
List<String> featureSubset = selectFeatureSubset(data.getFeatureNames());
for (String feature : featureSubset) {
// 对连续特征,尝试多个分裂点
if (data.isFeatureContinuous(feature)) {
double[] featureValues = data.getFeatureValues(feature);
double[] splitCandidates = generateSplitCandidates(featureValues);
for (double splitValue : splitCandidates) {
SplitResult split = evaluateSplit(data, feature, splitValue);
if (split != null && split.getImprovement() > bestImprovement) {
bestImprovement = split.getImprovement();
bestSplit = split;
}
}
} else {
// 对分类特征,尝试每个类别作为分裂
Set<Object> categories = data.getFeatureCategories(feature);
for (Object category : categories) {
SplitResult split = evaluateSplit(data, feature, category);
if (split != null && split.getImprovement() > bestImprovement) {
bestImprovement = split.getImprovement();
bestSplit = split;
}
}
}
}
return bestSplit;
}
/**
* 评估分裂质量
*/
private SplitResult evaluateSplit(CausalForestDataset data, String feature, Object splitValue) {
// 分裂数据
int[] leftIndices = data.getIndicesWhere(feature, splitValue, true);
int[] rightIndices = data.getIndicesWhere(feature, splitValue, false);
if (leftIndices.length < minSamplesSplit || rightIndices.length < minSamplesSplit) {
return null;
}
// 计算分裂前后的因果效应异质性
double parentHeterogeneity = calculateHeterogeneity(data);
double leftHeterogeneity = calculateHeterogeneity(data.subset(leftIndices));
double rightHeterogeneity = calculateHeterogeneity(data.subset(rightIndices));
// 计算信息增益
double improvement = parentHeterogeneity -
(leftHeterogeneity * leftIndices.length / data.size() +
rightHeterogeneity * rightIndices.length / data.size());
return new SplitResult(feature, splitValue, leftIndices, rightIndices, improvement);
}
/**
* 计算因果效应异质性
*/
private double calculateHeterogeneity(CausalForestDataset data) {
// 使用条件平均处理效应的方差作为异质性度量
double[] treatmentEffects = data.getConditionalTreatmentEffects();
return new Variance().evaluate(treatmentEffects);
}
/**
* 创建叶节点
*/
private TreeNode createLeafNode(CausalForestDataset data) {
double treatmentEffect = data.getAverageTreatmentEffect();
double[] featureVector = data.getAverageFeatures();
return new TreeNode(treatmentEffect, featureVector, data.size());
}
/**
* 预测个体处理效应
*/
public double predictIndividualTreatmentEffect(CausalForestModel model, double[] features) {
// 所有树的预测平均
return model.getTrees().stream()
.mapToDouble(tree -> predictTree(tree, features))
.average()
.orElse(0.0);
}
/**
* 单棵树预测
*/
private double predictTree(CausalTree tree, double[] features) {
TreeNode node = tree.getRoot();
while (!node.isLeaf()) {
SplitResult split = node.getSplit();
double featureValue = features[getFeatureIndex(split.getFeatureName())];
if (split.isLeft(featureValue)) {
node = node.getLeftChild();
} else {
node = node.getRightChild();
}
}
return node.getTreatmentEffect();
}
/**
* 计算变量重要性
*/
private Map<String, Double> calculateVariableImportance(List<CausalTree> trees, CausalForestDataset dataset) {
Map<String, Double> importance = new HashMap<>();
for (String feature : dataset.getFeatureNames()) {
double totalImprovement = trees.stream()
.mapToDouble(tree -> getFeatureImprovement(tree, feature))
.sum();
importance.put(feature, totalImprovement / trees.size());
}
return importance;
}
/**
* 获取特征在树中的总改进度
*/
private double getFeatureImprovement(CausalTree tree, String feature) {
return tree.getRoot().getTotalImprovementForFeature(feature);
}
// 辅助方法
private int[] bootstrapSample(int size) {
Random random = new Random();
int[] indices = new int[size];
for (int i = 0; i < size; i++) {
indices[i] = random.nextInt(size);
}
return indices;
}
private List<String> selectFeatureSubset(List<String> allFeatures) {
// 随机选择特征子集(通常为sqrt(p))
int subsetSize = (int) Math.sqrt(allFeatures.size());
Collections.shuffle(allFeatures);
return allFeatures.subList(0, Math.min(subsetSize, allFeatures.size()));
}
private double[] generateSplitCandidates(double[] values) {
// 生成分裂候选点(使用分位数)
int numCandidates = 10;
double[] candidates = new double[numCandidates];
for (int i = 0; i < numCandidates; i++) {
double quantile = (i + 1.0) / (numCandidates + 1);
candidates[i] = calculateQuantile(values, quantile);
}
return candidates;
}
private double calculateQuantile(double[] values, double quantile) {
Arrays.sort(values);
int index = (int) (quantile * values.length);
return values[Math.min(index, values.length - 1)];
}
private int getFeatureIndex(String featureName) {
// 根据特征名获取索引(简化实现)
return Integer.parseInt(featureName.replace("feature_", ""));
}
// 数据类
@Data
@AllArgsConstructor
public static class CausalForestModel {
private List<CausalTree> trees;
private Map<String, Double> variableImportance;
public List<String> getTopFeatures(int k) {
return variableImportance.entrySet().stream()
.sorted(Map.Entry.<String, Double>comparingByValue().reversed())
.limit(k)
.map(Map.Entry::getKey)
.collect(Collectors.toList());
}
}
@Data
@AllArgsConstructor
public static class CausalTree {
private TreeNode root;
private int treeId;
}
@Data
public static class TreeNode {
private SplitResult split;
private TreeNode leftChild;
private TreeNode rightChild;
private double treatmentEffect; // 叶节点
private double[] averageFeatures;
private int sampleSize;
private int depth;
private boolean isLeaf;
public TreeNode(SplitResult split, TreeNode leftChild, TreeNode rightChild, int depth) {
this.split = split;
this.leftChild = leftChild;
this.rightChild = rightChild;
this.depth = depth;
this.isLeaf = false;
}
public TreeNode(double treatmentEffect, double[] averageFeatures, int sampleSize) {
this.treatmentEffect = treatmentEffect;
this.averageFeatures = averageFeatures;
this.sampleSize = sampleSize;
this.isLeaf = true;
}
public double getTotalImprovementForFeature(String feature) {
if (isLeaf) return 0.0;
double improvement = split.getFeatureName().equals(feature) ? split.getImprovement() : 0.0;
improvement += leftChild.getTotalImprovementForFeature(feature);
improvement += rightChild.getTotalImprovementForFeature(feature);
return improvement;
}
}
@Data
@AllArgsConstructor
public static class SplitResult {
private String featureName;
private Object splitValue;
private int[] leftIndices;
private int[] rightIndices;
private double improvement;
public boolean isLeft(double featureValue) {
if (splitValue instanceof Double) {
return featureValue <= (Double) splitValue;
}
return featureValue.equals(splitValue);
}
}
}
六、 应用场景与决策支持
- 医疗治疗效果评估
java
// MedicalTreatmentService.java
@Service
@Slf4j
public class MedicalTreatmentService {
private final CausalDiscoveryService causalDiscovery;
private final DifferenceInDifferences did;
private final PropensityScoreMatching psm;
private final CausalForest causalForest;
public MedicalTreatmentService(CausalDiscoveryService causalDiscovery,
DifferenceInDifferences did,
PropensityScoreMatching psm,
CausalForest causalForest) {
this.causalDiscovery = causalDiscovery;
this.did = did;
this.psm = psm;
this.causalForest = causalForest;
}
/**
* 评估药物治疗效果
*/
public TreatmentEvaluationResult evaluateDrugEffectiveness(MedicalDataset dataset) {
log.info("评估药物治疗效果,患者数: {}", dataset.size());
// 因果发现:识别影响治疗效果的因素
CausalGraph causalGraph = causalDiscovery.discoverCausalStructure(dataset, 0.05);
// 双重差分分析(如果有前后数据)
DifferenceInDifferences.DIDResult didResult = null;
if (dataset.hasPrePostData()) {
didResult = did.estimate(dataset.toDIDDataset());
}
// 倾向得分匹配
PropensityScoreMatching.PSMResult psmResult = psm.performMatching(
dataset.toPSMDataset(),
PropensityScoreMatching.MatchingMethod.NEAREST_NEIGHBOR,
0.1
);
// 因果森林 - 个体化治疗效果
CausalForest.CausalForestModel forestModel = causalForest.train(dataset.toCausalForestDataset());
// 综合评估
return new TreatmentEvaluationResult(
causalGraph,
didResult,
psmResult,
forestModel,
calculateOverallEffectiveness(didResult, psmResult)
);
}
/**
* 预测个体治疗效果
*/
public IndividualTreatmentPrediction predictIndividualEffect(MedicalDataset dataset,
String patientId,
CausalForest.CausalForestModel model) {
double[] patientFeatures = dataset.getPatientFeatures(patientId);
double predictedEffect = causalForest.predictIndividualTreatmentEffect(model, patientFeatures);
// 计算置信区间
double confidence = calculatePredictionConfidence(patientFeatures, model);
double lowerBound = predictedEffect - 1.96 * confidence;
double upperBound = predictedEffect + 1.96 * confidence;
// 识别重要特征
List<String> importantFeatures = model.getTopFeatures(5);
Map<String, Double> featureContributions = calculateFeatureContributions(patientFeatures, model);
return new IndividualTreatmentPrediction(
patientId,
predictedEffect,
lowerBound,
upperBound,
confidence,
importantFeatures,
featureContributions
);
}
/**
* 识别治疗效果异质性
*/
public TreatmentHeterogeneityResult analyzeHeterogeneity(MedicalDataset dataset,
CausalForest.CausalForestModel model) {
// 分析不同亚组的治疗效果
Map<String, Double> subgroupEffects = new HashMap<>();
// 按年龄分组
subgroupEffects.put("young", calculateSubgroupEffect(dataset, model, "age", 0, 40));
subgroupEffects.put("middle", calculateSubgroupEffect(dataset, model, "age", 40, 60));
subgroupEffects.put("old", calculateSubgroupEffect(dataset, model, "age", 60, 100));
// 按性别分组
subgroupEffects.put("male", calculateSubgroupEffect(dataset, model, "gender", "M"));
subgroupEffects.put("female", calculateSubgroupEffect(dataset, model, "gender", "F"));
// 识别最优受益群体
String bestResponderGroup = identifyBestResponders(subgroupEffects);
return new TreatmentHeterogeneityResult(subgroupEffects, bestResponderGroup);
}
/**
* 反事实分析:如果改变治疗方案会怎样
*/
public CounterfactualAnalysisResult counterfactualAnalysis(MedicalDataset dataset,
String patientId,
String alternativeTreatment) {
double[] patientFeatures = dataset.getPatientFeatures(patientId);
double currentOutcome = dataset.getPatientOutcome(patientId);
// 估计反事实结果
double counterfactualOutcome = estimateCounterfactualOutcome(
patientFeatures, alternativeTreatment, dataset);
// 计算治疗改变的效果
double treatmentEffect = counterfactualOutcome - currentOutcome;
// 识别关键影响因素
List<String> keyFactors = identifyKeyFactors(patientFeatures, alternativeTreatment, dataset);
return new CounterfactualAnalysisResult(
patientId,
currentOutcome,
counterfactualOutcome,
treatmentEffect,
keyFactors
);
}
// 辅助方法
private double calculateOverallEffectiveness(DifferenceInDifferences.DIDResult didResult,
PropensityScoreMatching.PSMResult psmResult) {
// 综合DID和PSM的结果
double didEffect = didResult != null ? didResult.getEstimate() : 0.0;
double psmEffect = psmResult.getAverageTreatmentEffect();
return (didEffect + psmEffect) / 2.0;
}
private double calculatePredictionConfidence(double[] features, CausalForest.CausalForestModel model) {
// 计算预测置信度(基于树的方差)
double[] predictions = model.getTrees().stream()
.mapToDouble(tree -> predictTree(tree, features))
.toArray();
return new StandardDeviation().evaluate(predictions);
}
private double predictTree(CausalForest.CausalTree tree, double[] features) {
// 单棵树预测(简化实现)
return 0.0;
}
private Map<String, Double> calculateFeatureContributions(double[] features,
CausalForest.CausalForestModel model) {
// 计算特征贡献度
Map<String, Double> contributions = new HashMap<>();
// 简化实现
return contributions;
}
private double calculateSubgroupEffect(MedicalDataset dataset, CausalForest.CausalForestModel model,
String feature, Object value) {
// 计算特定亚组的平均治疗效果
List<String> subgroupPatients = dataset.getPatientsByFeature(feature, value);
return subgroupPatients.stream()
.mapToDouble(patientId -> {
double[] patientFeatures = dataset.getPatientFeatures(patientId);
return causalForest.predictIndividualTreatmentEffect(model, patientFeatures);
})
.average()
.orElse(0.0);
}
private double calculateSubgroupEffect(MedicalDataset dataset, CausalForest.CausalForestModel model,
String feature, double min, double max) {
// 计算数值特征亚组的平均治疗效果
List<String> subgroupPatients = dataset.getPatientsByFeatureRange(feature, min, max);
return subgroupPatients.stream()
.mapToDouble(patientId -> {
double[] patientFeatures = dataset.getPatientFeatures(patientId);
return causalForest.predictIndividualTreatmentEffect(model, patientFeatures);
})
.average()
.orElse(0.0);
}
private String identifyBestResponders(Map<String, Double> subgroupEffects) {
return subgroupEffects.entrySet().stream()
.max(Map.Entry.comparingByValue())
.map(Map.Entry::getKey)
.orElse("unknown");
}
private double estimateCounterfactualOutcome(double[] features, String treatment, MedicalDataset dataset) {
// 使用因果模型估计反事实结果
// 简化实现
return 0.0;
}
private List<String> identifyKeyFactors(double[] features, String treatment, MedicalDataset dataset) {
// 识别影响治疗效果的关键因素
// 简化实现
return new ArrayList<>();
}
// 数据类
@Data
@AllArgsConstructor
public static class TreatmentEvaluationResult {
private CausalGraph causalGraph;
private DifferenceInDifferences.DIDResult didResult;
private PropensityScoreMatching.PSMResult psmResult;
private CausalForest.CausalForestModel forestModel;
private double overallEffectiveness;
}
@Data
@AllArgsConstructor
public static class IndividualTreatmentPrediction {
private String patientId;
private double predictedEffect;
private double confidenceLower;
private double confidenceUpper;
private double confidenceLevel;
private List<String> importantFeatures;
private Map<String, Double> featureContributions;
}
@Data
@AllArgsConstructor
public static class TreatmentHeterogeneityResult {
private Map<String, Double> subgroupEffects;
private String bestResponderGroup;
}
@Data
@AllArgsConstructor
public static class CounterfactualAnalysisResult {
private String patientId;
private double currentOutcome;
private double counterfactualOutcome;
private double treatmentEffect;
private List<String> keyFactors;
}
}
七、 生产配置与API服务
- Spring Boot配置
yaml
application.yml
spring:
application:
name: causal-inference-service
causal:
discovery:
alpha: 0.05
method: PC
max-depth: 10
did:
confidence-level: 0.95
parallel-trend-test: true
psm:
matching-method: NEAREST_NEIGHBOR
caliper: 0.1
with-replacement: false
forest:
num-trees: 100
max-depth: 10
min-samples-split: 20
server:
port: 8080
logging:
level:
com.example.causal: INFO
file:
name: /var/log/causal-service.log
management:
endpoints:
web:
exposure:
include: health,metrics,info
endpoint:
health:
show-details: always
- REST API控制器
java
// CausalInferenceController.java
@RestController
@RequestMapping("/api/causal")
@Slf4j
public class CausalInferenceController {
private final CausalDiscoveryService causalDiscovery;
private final DifferenceInDifferences did;
private final PropensityScoreMatching psm;
private final MedicalTreatmentService medicalService;
public CausalInferenceController(CausalDiscoveryService causalDiscovery,
DifferenceInDifferences did,
PropensityScoreMatching psm,
MedicalTreatmentService medicalService) {
this.causalDiscovery = causalDiscovery;
this.did = did;
this.psm = psm;
this.medicalService = medicalService;
}
@PostMapping("/discover")
public ResponseEntity<CausalDiscoveryResponse> discoverCausalStructure(
@RequestBody CausalDiscoveryRequest request) {
try {
CausalGraph graph = causalDiscovery.discoverCausalStructure(
request.getDataset(), request.getAlpha());
return ResponseEntity.ok(CausalDiscoveryResponse.success(graph));
} catch (Exception e) {
log.error("因果发现失败", e);
return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR)
.body(CausalDiscoveryResponse.error(e.getMessage()));
}
}
@PostMapping("/did/estimate")
public ResponseEntity<DIDResponse> estimateDID(@RequestBody DIDRequest request) {
try {
DifferenceInDifferences.DIDResult result = did.estimate(request.getDataset());
return ResponseEntity.ok(DIDResponse.success(result));
} catch (Exception e) {
log.error("DID估计失败", e);
return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR)
.body(DIDResponse.error(e.getMessage()));
}
}
@PostMapping("/psm/match")
public ResponseEntity<PSMResponse> performPSM(@RequestBody PSMRequest request) {
try {
PropensityScoreMatching.PSMResult result = psm.performMatching(
request.getDataset(), request.getMethod(), request.getCaliper());
return ResponseEntity.ok(PSMResponse.success(result));
} catch (Exception e) {
log.error("PSM匹配失败", e);
return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR)
.body(PSMResponse.error(e.getMessage()));
}
}
@PostMapping("/medical/evaluate")
public ResponseEntity<MedicalEvaluationResponse> evaluateMedicalTreatment(
@RequestBody MedicalEvaluationRequest request) {
try {
MedicalTreatmentService.TreatmentEvaluationResult result =
medicalService.evaluateDrugEffectiveness(request.getDataset());
return ResponseEntity.ok(MedicalEvaluationResponse.success(result));
} catch (Exception e) {
log.error("医疗效果评估失败", e);
return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR)
.body(MedicalEvaluationResponse.error(e.getMessage()));
}
}
@PostMapping("/medical/predict")
public ResponseEntity<IndividualPredictionResponse> predictIndividualEffect(
@RequestBody IndividualPredictionRequest request) {
try {
MedicalTreatmentService.IndividualTreatmentPrediction prediction =
medicalService.predictIndividualEffect(
request.getDataset(), request.getPatientId(), request.getModel());
return ResponseEntity.ok(IndividualPredictionResponse.success(prediction));
} catch (Exception e) {
log.error("个体效果预测失败", e);
return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR)
.body(IndividualPredictionResponse.error(e.getMessage()));
}
}
@GetMapping("/health")
public ResponseEntity<HealthResponse> healthCheck() {
Map<String, Object> details = new HashMap<>();
details.put("status", "healthy");
details.put("timestamp", System.currentTimeMillis());
details.put("services", List.of("causal-discovery", "did", "psm", "medical"));
return ResponseEntity.ok(new HealthResponse("success", "服务运行正常", details));
}
// DTO类
@Data
public static class CausalDiscoveryRequest {
private Dataset dataset;
private double alpha = 0.05;
private CausalDiscoveryService.DiscoveryMethod method =
CausalDiscoveryService.DiscoveryMethod.PC;
}
@Data
@AllArgsConstructor
public static class CausalDiscoveryResponse {
private String status;
private String message;
private CausalGraph graph;
public static CausalDiscoveryResponse success(CausalGraph graph) {
return new CausalDiscoveryResponse("success", "因果发现完成", graph);
}
public static CausalDiscoveryResponse error(String message) {
return new CausalDiscoveryResponse("error", message, null);
}
}
@Data
public static class DIDRequest {
private DIDDataset dataset;
}
@Data
@AllArgsConstructor
public static class DIDResponse {
private String status;
private String message;
private DifferenceInDifferences.DIDResult result;
public static DIDResponse success(DifferenceInDifferences.DIDResult result) {
return new DIDResponse("success", "DID估计完成", result);
}
public static DIDResponse error(String message) {
return new DIDResponse("error", message, null);
}
}
@Data
public static class PSMRequest {
private PSMDataset dataset;
private PropensityScoreMatching.MatchingMethod method;
private double caliper;
}
@Data
@AllArgsConstructor
public static class PSMResponse {
private String status;
private String message;
private PropensityScoreMatching.PSMResult result;
public static PSMResponse success(PropensityScoreMatching.PSMResult result) {
return new PSMResponse("success", "PSM匹配完成", result);
}
public static PSMResponse error(String message) {
return new PSMResponse("error", message, null);
}
}
@Data
public static class MedicalEvaluationRequest {
private MedicalDataset dataset;
}
@Data
@AllArgsConstructor
public static class MedicalEvaluationResponse {
private String status;
private String message;
private MedicalTreatmentService.TreatmentEvaluationResult result;
public static MedicalEvaluationResponse success(MedicalTreatmentService.TreatmentEvaluationResult result) {
return new MedicalEvaluationResponse("success", "医疗效果评估完成", result);
}
public static MedicalEvaluationResponse error(String message) {
return new MedicalEvaluationResponse("error", message, null);
}
}
@Data
public static class IndividualPredictionRequest {
private MedicalDataset dataset;
private String patientId;
private CausalForest.CausalForestModel model;
}
@Data
@AllArgsConstructor
public static class IndividualPredictionResponse {
private String status;
private String message;
private MedicalTreatmentService.IndividualTreatmentPrediction prediction;
public static IndividualPredictionResponse success(
MedicalTreatmentService.IndividualTreatmentPrediction prediction) {
return new IndividualPredictionResponse("success", "个体预测完成", prediction);
}
public static IndividualPredictionResponse error(String message) {
return new IndividualPredictionResponse("error", message, null);
}
}
@Data
@AllArgsConstructor
public static class HealthResponse {
private String status;
private String message;
private Map<String, Object> details;
}
}
八、 应用场景与总结
- 典型应用场景
医疗决策:评估药物治疗效果,识别最佳受益人群
政策评估:分析政策干预的经济社会影响
市场营销:衡量广告投放、促销活动的真实效果
金融风控:识别风险因素的真实因果关系
工业优化:分析工艺参数对产品质量的因果影响
- 系统优势总结
因果而非相关:揭示变量间的真实因果关系
可解释性:提供基于因果机制的解释和洞察
反事实推理:估计未发生情况下的潜在结果
异质性分析:识别不同群体的差异化效果
政策模拟:预测干预措施的可能后果
- 技术挑战与解决方案
混淆变量:通过因果发现和倾向得分匹配控制
选择偏差:使用双重差分和匹配方法纠正
异质性处理:通过因果森林识别亚组效应
可识别性:借助工具变量和自然实验设计
- 总结
通过本文的实践,我们成功构建了一个完整的Java因果推断系统,具备以下核心能力:
因果发现:从观测数据自动发现因果结构
效应估计:多种方法估计因果效应大小
反事实推理:预测干预的潜在结果
异质性分析:识别差异化治疗效果
决策支持:为政策制定和个性化决策提供依据
因果推断代表了AI系统从模式识别到机制理解的重大进步。Java在企业级系统中的可靠性、可扩展性和丰富的科学生态,与因果推断的理论严谨性相结合,为构建真正可信、可解释的AI决策系统提供了坚实的技术基础。随着因果科学的发展和应用需求的增长,这种基于Java的因果推断架构将在医疗健康、公共政策、商业决策等领域发挥越来越重要的作用。