背景
本文基于 SPARK 3.3.0
从一个unit test来探究SPARK Codegen的逻辑,
test("SortAggregate should be included in WholeStageCodegen") { val df = spark.range(10).agg(max(col("id")), avg(col("id"))) withSQLConf("spark.sql.test.forceApplySortAggregate" -> "true") { val plan = df.queryExecution.executedPlan assert(plan.exists(p => p.isInstanceOf[WholeStageCodegenExec] && p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SortAggregateExec])) assert(df.collect() === Array(Row(9, 4.5))) } }
该sql形成的执行计划第二部分的全代码生成部分如下:
WholeStageCodegen *(2) SortAggregate(key=[], functions=[max(id#0L), avg(id#0L)], output=[max(id)#5L, avg(id)#6]) InputAdapter +- Exchange SinglePartition, ENSURE_REQUIREMENTS, [id=#13]
分析
第二阶段wholeStageCodegen
第二阶段的代码生成涉及到SortAggregateExec和ShuffleExchangeExec以及InputAdapter的produce和consume方法,这里一一来分析:
第二阶段wholeStageCodegen数据流如下:
WholeStageCodegenExec SortAggregateExec(Final) InputAdapter ShuffleExchangeExec ==================================================================================== -> execute() | doExecute() ---------> inputRDDs() -----------------> inputRDDs() -------> execute() | | doCodeGen() doExecute() | | +-----------------> produce() ShuffledRowRDD | doProduce() | doProduceWithoutKeys() -------> produce() | doProduce() | doConsume() <------------------- consume() | doConsumeWithoutKeys() |并不是doConsumeWithoutKeys调用consume,而是由doProduceWithoutKeys调用 doConsume() <-------- consume()
SortAggregateExec(Final) 的doConsume
sortAggregateExec的doConsume方法最终会调用doConsumeWithoutKeys(ctx, input)方法,其中input为ArrayBuffer(ExprCode(,sortAgg_exprIsNull_0_0,sortAgg_expr_0_0), ExprCode(,sortAgg_exprIsNull_1_0,sortAgg_expr_1_0), ExprCode(,sortAgg_exprIsNull_2_0,sortAgg_expr_2_0))
和SortAggregateExec(Partial)不同点:
updateExprs的不同
val updateExprs = aggregateExpressions.map { e => e.mode match { case Partial | Complete => e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions case PartialMerge | Final => e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions } }
因为这里是Final
阶段,所以更新语句是mergeExpressions
,即
protected def getMergeExpressions = Seq( /* sum = */ Add(sum.left, sum.right, useAnsiAdd), /* count = */ count.left + count.right
生成对应的ExprCode类为:
List(ExprCode(boolean sortAgg_isNull_13 = true; double sortAgg_value_13 = -1.0; if (!sortAgg_bufIsNull_1) { if (!sortAgg_exprIsNull_1_0) { sortAgg_isNull_13 = false; // resultCode could change nullability. sortAgg_value_13 = sortAgg_bufValue_1 + sortAgg_expr_1_0; } },sortAgg_isNull_13,sortAgg_value_13), ExprCode(boolean sortAgg_isNull_16 = true; long sortAgg_value_16 = -1L; if (!sortAgg_bufIsNull_2) { if (!sortAgg_exprIsNull_2_0) { sortAgg_isNull_16 = false; // resultCode could change nullability. sortAgg_value_16 = sortAgg_bufValue_2 + sortAgg_expr_2_0; } },sortAgg_isNull_16,sortAgg_value_16)))
可以看到,对于sortAgg_value_16也就是COUNT的值,这里是+ sortAgg_expr_2_0,而在Partial部分则是+ 1,因为这是各个Task local操作完后的结果buffer。
sortAgg_isNull_13和sortAgg_value_16 的最终结果还是会赋值给全局变量:
sortAgg_bufIsNull_1 = sortAgg_isNull_13; sortAgg_bufValue_1 = sortAgg_value_13; sortAgg_bufIsNull_2 = sortAgg_isNull_16; sortAgg_bufValue_2 = sortAgg_value_16;
SortAggregateExec(Final)的consume
consume(ctx, resultVars)
其中resultVars为ExprCode(,sortAgg_isNull_4,sortAgg_value_4),这里包含了最终的AVG的结果值。
其他的数据流向和之前的一样,
val rowVar = prepareRowVar(ctx, row, outputVars)
该最终sql返回如下(和之前的一致):
sortAgg_mutableStateArray_0[0].reset(); sortAgg_mutableStateArray_0[0].zeroOutNullBytes(); if (sortAgg_bufIsNull_0) { sortAgg_mutableStateArray_0[0].setNullAt(0); } else { sortAgg_mutableStateArray_0[0].write(0, sortAgg_bufValue_0); } if (sortAgg_isNull_4) { sortAgg_mutableStateArray_0[0].setNullAt(1); } else { sortAgg_mutableStateArray_0[0].write(1, sortAgg_value_4); }
WholeStageCodegenExec的doConsume.
数据流和之前的一致,最终的生成的代码如下:
append((sortAgg_mutableStateArray_0[0].getRow()));
第二阶段wholeStageCodegen最终的代码如下:
/* 001 */ public Object generate(Object[] references) { /* 002 */ return new GeneratedIteratorForCodegenStage2(references); /* 003 */ } /* 004 */ /* 005 */ // codegenStageId=2 /* 006 */ final class GeneratedIteratorForCodegenStage2 extends org.apache.spark.sql.execution.BufferedRowIterator { /* 007 */ private Object[] references; /* 008 */ private scala.collection.Iterator[] inputs; /* 009 */ private boolean sortAgg_initAgg_0; /* 010 */ private boolean sortAgg_bufIsNull_0; /* 011 */ private long sortAgg_bufValue_0; /* 012 */ private boolean sortAgg_bufIsNull_1; /* 013 */ private double sortAgg_bufValue_1; /* 014 */ private boolean sortAgg_bufIsNull_2; /* 015 */ private long sortAgg_bufValue_2; /* 016 */ private scala.collection.Iterator inputadapter_input_0; /* 017 */ private boolean sortAgg_sortAgg_isNull_10_0; /* 018 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] sortAgg_mutableStateArray_0 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[1]; /* 019 */ /* 020 */ public GeneratedIteratorForCodegenStage2(Object[] references) { /* 021 */ this.references = references; /* 022 */ } /* 023 */ /* 024 */ public void init(int index, scala.collection.Iterator[] inputs) { /* 025 */ partitionIndex = index; /* 026 */ this.inputs = inputs; /* 027 */ /* 028 */ inputadapter_input_0 = inputs[0]; /* 029 */ sortAgg_mutableStateArray_0[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(2, 0); /* 030 */ /* 031 */ } /* 032 */ /* 033 */ private void sortAgg_doAggregate_max_0(long sortAgg_expr_0_0, boolean sortAgg_exprIsNull_0_0) throws java.io.IOException { /* 034 */ sortAgg_sortAgg_isNull_10_0 = true; /* 035 */ long sortAgg_value_10 = -1L; /* 036 */ /* 037 */ if (!sortAgg_bufIsNull_0 && (sortAgg_sortAgg_isNull_10_0 || /* 038 */ sortAgg_bufValue_0 > sortAgg_value_10)) { /* 039 */ sortAgg_sortAgg_isNull_10_0 = false; /* 040 */ sortAgg_value_10 = sortAgg_bufValue_0; /* 041 */ } /* 042 */ /* 043 */ if (!sortAgg_exprIsNull_0_0 && (sortAgg_sortAgg_isNull_10_0 || /* 044 */ sortAgg_expr_0_0 > sortAgg_value_10)) { /* 045 */ sortAgg_sortAgg_isNull_10_0 = false; /* 046 */ sortAgg_value_10 = sortAgg_expr_0_0; /* 047 */ } /* 048 */ /* 049 */ sortAgg_bufIsNull_0 = sortAgg_sortAgg_isNull_10_0; /* 050 */ sortAgg_bufValue_0 = sortAgg_value_10; /* 051 */ } /* 052 */ /* 053 */ private void sortAgg_doAggregateWithoutKey_0() throws java.io.IOException { /* 054 */ // initialize aggregation buffer /* 055 */ sortAgg_bufIsNull_0 = true; /* 056 */ sortAgg_bufValue_0 = -1L; /* 057 */ sortAgg_bufIsNull_1 = false; /* 058 */ sortAgg_bufValue_1 = 0.0D; /* 059 */ sortAgg_bufIsNull_2 = false; /* 060 */ sortAgg_bufValue_2 = 0L; /* 061 */ /* 062 */ while ( inputadapter_input_0.hasNext()) { /* 063 */ InternalRow inputadapter_row_0 = (InternalRow) inputadapter_input_0.next(); /* 064 */ /* 065 */ boolean inputadapter_isNull_0 = inputadapter_row_0.isNullAt(0); /* 066 */ long inputadapter_value_0 = inputadapter_isNull_0 ? /* 067 */ -1L : (inputadapter_row_0.getLong(0)); /* 068 */ boolean inputadapter_isNull_1 = inputadapter_row_0.isNullAt(1); /* 069 */ double inputadapter_value_1 = inputadapter_isNull_1 ? /* 070 */ -1.0 : (inputadapter_row_0.getDouble(1)); /* 071 */ boolean inputadapter_isNull_2 = inputadapter_row_0.isNullAt(2); /* 072 */ long inputadapter_value_2 = inputadapter_isNull_2 ? /* 073 */ -1L : (inputadapter_row_0.getLong(2)); /* 074 */ /* 075 */ sortAgg_doConsume_0(inputadapter_row_0, inputadapter_value_0, inputadapter_isNull_0, inputadapter_value_1, inputadapter_isNull_1, inputadapter_value_2, inputadapter_isNull_2); /* 076 */ // shouldStop check is eliminated /* 077 */ } /* 078 */ /* 079 */ } /* 080 */ /* 081 */ protected void processNext() throws java.io.IOException { /* 082 */ while (!sortAgg_initAgg_0) { /* 083 */ sortAgg_initAgg_0 = true; /* 084 */ sortAgg_doAggregateWithoutKey_0(); /* 085 */ /* 086 */ // output the result /* 087 */ boolean sortAgg_isNull_6 = sortAgg_bufIsNull_2; /* 088 */ double sortAgg_value_6 = -1.0; /* 089 */ if (!sortAgg_bufIsNull_2) { /* 090 */ sortAgg_value_6 = (double) sortAgg_bufValue_2; /* 091 */ } /* 092 */ boolean sortAgg_isNull_4 = false; /* 093 */ double sortAgg_value_4 = -1.0; /* 094 */ if (sortAgg_isNull_6 || sortAgg_value_6 == 0) { /* 095 */ sortAgg_isNull_4 = true; /* 096 */ } else { /* 097 */ if (sortAgg_bufIsNull_1) { /* 098 */ sortAgg_isNull_4 = true; /* 099 */ } else { /* 100 */ sortAgg_value_4 = (double)(sortAgg_bufValue_1 / sortAgg_value_6); /* 101 */ } /* 102 */ } /* 103 */ /* 104 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1); /* 105 */ sortAgg_mutableStateArray_0[0].reset(); /* 106 */ /* 107 */ sortAgg_mutableStateArray_0[0].zeroOutNullBytes(); /* 108 */ /* 109 */ if (sortAgg_bufIsNull_0) { /* 110 */ sortAgg_mutableStateArray_0[0].setNullAt(0); /* 111 */ } else { /* 112 */ sortAgg_mutableStateArray_0[0].write(0, sortAgg_bufValue_0); /* 113 */ } /* 114 */ /* 115 */ if (sortAgg_isNull_4) { /* 116 */ sortAgg_mutableStateArray_0[0].setNullAt(1); /* 117 */ } else { /* 118 */ sortAgg_mutableStateArray_0[0].write(1, sortAgg_value_4); /* 119 */ } /* 120 */ append((sortAgg_mutableStateArray_0[0].getRow())); /* 121 */ } /* 122 */ } /* 123 */ /* 124 */ private void sortAgg_doConsume_0(InternalRow inputadapter_row_0, long sortAgg_expr_0_0, boolean sortAgg_exprIsNull_0_0, double sortAgg_expr_1_0, boolean sortAgg_exprIsNull_1_0, long sortAgg_expr_2_0, boolean sortAgg_exprIsNull_2_0) throws java.io.IOException { /* 125 */ // do aggregate /* 126 */ // common sub-expressions /* 127 */ /* 128 */ // evaluate aggregate functions and update aggregation buffers /* 129 */ sortAgg_doAggregate_max_0(sortAgg_expr_0_0, sortAgg_exprIsNull_0_0); /* 130 */ sortAgg_doAggregate_avg_0(sortAgg_exprIsNull_1_0, sortAgg_expr_1_0, sortAgg_exprIsNull_2_0, sortAgg_expr_2_0); /* 131 */ /* 132 */ } /* 133 */ /* 134 */ private void sortAgg_doAggregate_avg_0(boolean sortAgg_exprIsNull_1_0, double sortAgg_expr_1_0, boolean sortAgg_exprIsNull_2_0, long sortAgg_expr_2_0) throws java.io.IOException { /* 135 */ boolean sortAgg_isNull_13 = true; /* 136 */ double sortAgg_value_13 = -1.0; /* 137 */ /* 138 */ if (!sortAgg_bufIsNull_1) { /* 139 */ if (!sortAgg_exprIsNull_1_0) { /* 140 */ sortAgg_isNull_13 = false; // resultCode could change nullability. /* 141 */ /* 142 */ sortAgg_value_13 = sortAgg_bufValue_1 + sortAgg_expr_1_0; /* 143 */ /* 144 */ } /* 145 */ /* 146 */ } /* 147 */ boolean sortAgg_isNull_16 = true; /* 148 */ long sortAgg_value_16 = -1L; /* 149 */ /* 150 */ if (!sortAgg_bufIsNull_2) { /* 151 */ if (!sortAgg_exprIsNull_2_0) { /* 152 */ sortAgg_isNull_16 = false; // resultCode could change nullability. /* 153 */ /* 154 */ sortAgg_value_16 = sortAgg_bufValue_2 + sortAgg_expr_2_0; /* 155 */ /* 156 */ } /* 157 */ /* 158 */ } /* 159 */ /* 160 */ sortAgg_bufIsNull_1 = sortAgg_isNull_13; /* 161 */ sortAgg_bufValue_1 = sortAgg_value_13; /* 162 */ /* 163 */ sortAgg_bufIsNull_2 = sortAgg_isNull_16; /* 164 */ sortAgg_bufValue_2 = sortAgg_value_16; /* 165 */ } /* 166 */ /* 167 */ }