背景
本文基于 SPARK 3.3.0
从一个unit test来探究SPARK Codegen的逻辑,
test("SortAggregate should be included in WholeStageCodegen") { val df = spark.range(10).agg(max(col("id")), avg(col("id"))) withSQLConf("spark.sql.test.forceApplySortAggregate" -> "true") { val plan = df.queryExecution.executedPlan assert(plan.exists(p => p.isInstanceOf[WholeStageCodegenExec] && p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SortAggregateExec])) assert(df.collect() === Array(Row(9, 4.5))) } }
该sql形成的执行计划第一部分的全代码生成部分如下:
WholeStageCodegen +- *(1) SortAggregate(key=[], functions=[partial_max(id#0L), partial_avg(id#0L)], output=[max#12L, sum#13, count#14L]) +- *(1) Range (0, 10, step=1, splits=2)
分析
第一阶段wholeStageCodegen
第一阶段的代码生成涉及到SortAggregateExec和RangeExec的produce和consume方法,这里一一来分析:
第一阶段wholeStageCodegen数据流如下:
WholeStageCodegenExec SortAggregateExec(partial) RangeExec ========================================================================= -> execute() | doExecute() ---------> inputRDDs() -----------------> inputRDDs() | doCodeGen() | +-----------------> produce() | doProduce() | doProduceWithoutKeys() -------> produce() | doProduce() | doConsume()<------------------- consume() | doConsumeWithoutKeys() |并不是doConsumeWithoutKeys调用consume,而是由doProduceWithoutKeys调用 doConsume() <-------- consume()
SortAggregateExec(Partial)的Consume方法
此方法是由doProduceWithoutKeys方法调用的,代码如下:
s""" |while (!$initAgg) { | $initAgg = true; | $doAggWithRecordMetric | | // output the result | ${genResult.trim} | | $numOutput.add(1); | ${consume(ctx, resultVars).trim} |} """.stripMargin
其中resultVars的值为flatBufVars,即全局的sortAgg_bufValue_1和sortAgg_bufValue_2变量
在SPARK中的wholeStageCodegen全代码生成–以aggregate代码生成为例说起(5)中我们提到在对应的函数计算完后,sortAgg_bufValue_1和sortAgg_bufValue_2会被赋值为计算的结果,如下:
sortAgg_bufIsNull_1 = sortAgg_isNull_7; sortAgg_bufValue_1 = sortAgg_value_7; sortAgg_bufIsNull_2 = sortAgg_isNull_13; sortAgg_bufValue_2 = sortAgg_value_13;
所以 resultVars
是已经计算处理的结果了。
这里的consume
的方法已经说过了,
不同的是:
- SortAggregateExec(Partial)的outout是
max
,sum
,count
val rowVar = prepareRowVar(ctx, row, outputVars)
返回的是包含了max
,sum
,count
的UnsafeRow,如下:
ExprCode(range_mutableStateArray_0[2].reset(); range_mutableStateArray_0[2].zeroOutNullBytes(); if (sortAgg_bufIsNull_0) { range_mutableStateArray_0[2].setNullAt(0); } else { range_mutableStateArray_0[2].write(0, sortAgg_bufValue_0); } if (sortAgg_bufIsNull_1) { range_mutableStateArray_0[2].setNullAt(1); } else { range_mutableStateArray_0[2].write(1, sortAgg_bufValue_1); } if (sortAgg_bufIsNull_2) { range_mutableStateArray_0[2].setNullAt(2); } else { range_mutableStateArray_0[2].write(2, sortAgg_bufValue_2); },false,(range_mutableStateArray_0[2].getRow()))
val requireAllOutput = output.forall(parent.usedInputs.contains(_)) 返回的是false
所以数据流直接到了parent.doConsume(ctx, inputVars, rowVar)
WholeStageCodegenExec的doConsume
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { val doCopy = if (needCopyResult) { ".copy()" } else { "" } s""" |${row.code} |append(${row.value}$doCopy); """.stripMargin.trim }
其中 input为 Seq(max,sum,count), row为包含了 max,sum,count的UnsafeRow
val doCopy =
因为 needCopyResult返回的是children.head.asInstanceOf[CodegenSupport].needCopyResult,对应的是SortAggregateExec的needCopyResult为false
${row.code}
代码组装,直接如下:
range_mutableStateArray_0[2].reset(); range_mutableStateArray_0[2].zeroOutNullBytes(); if (sortAgg_bufIsNull_0) { range_mutableStateArray_0[2].setNullAt(0); } else { range_mutableStateArray_0[2].write(0, sortAgg_bufValue_0); } if (sortAgg_bufIsNull_1) { range_mutableStateArray_0[2].setNullAt(1); } else { range_mutableStateArray_0[2].write(1, sortAgg_bufValue_1); } if (sortAgg_bufIsNull_2) { range_mutableStateArray_0[2].setNullAt(2); } else { range_mutableStateArray_0[2].write(2, sortAgg_bufValue_2); append((range_mutableStateArray_0[2].getRow()));
WholeStageCodegenExec的doCodeGen
具体的代码如下:
def doCodeGen(): (CodegenContext, CodeAndComment) = { val startTime = System.nanoTime() val ctx = new CodegenContext val code = child.asInstanceOf[CodegenSupport].produce(ctx, this) // main next function. ctx.addNewFunction("processNext", s""" protected void processNext() throws java.io.IOException { ${code.trim} } """, inlineToOuterClass = true) val className = generatedClassName() val source = s""" public Object generate(Object[] references) { return new $className(references); } ${ctx.registerComment( s"""Codegened pipeline for stage (id=$codegenStageId) |${this.treeString.trim}""".stripMargin, "wsc_codegenPipeline")} ${ctx.registerComment(s"codegenStageId=$codegenStageId", "wsc_codegenStageId", true)} final class $className extends ${classOf[BufferedRowIterator].getName} { private Object[] references; private scala.collection.Iterator[] inputs; ${ctx.declareMutableStates()} public $className(Object[] references) { this.references = references; } public void init(int index, scala.collection.Iterator[] inputs) { partitionIndex = index; this.inputs = inputs; ${ctx.initMutableStates()} ${ctx.initPartition()} } ${ctx.emitExtraCode()} ${ctx.declareAddedFunctions()} } """.trim // try to compile, helpful for debug val cleanedSource = CodeFormatter.stripOverlappingComments( new CodeAndComment(CodeFormatter.stripExtraNewLines(source), ctx.getPlaceHolderToComments())) val duration = System.nanoTime() - startTime WholeStageCodegenExec.increaseCodeGenTime(duration) logDebug(s"\n${CodeFormatter.format(cleanedSource)}") (ctx, cleanedSource) }
val code = child.asInstanceOf[CodegenSupport].produce(ctx, this)
code就是我们生成的代码逻辑,
ctx.addNewFunction
code的代码会被processNext包装起来
val className = generatedClassName()
对应的类名
val source =
这里面的ctx.declareMutableStates,ctx.initMutableStates()等,都是在代码生成过程中,引用到的变量,在这里进行声明或者初始化
(ctx, cleanedSource)
返回生成的代码
第一阶段wholeStageCodegen最终代码
/* 001 */ public Object generate(Object[] references) { /* 002 */ return new GeneratedIteratorForCodegenStage1(references); /* 003 */ } /* 004 */ /* 005 */ // codegenStageId=1 /* 006 */ final class GeneratedIteratorForCodegenStage1 extends org.apache.spark.sql.execution.BufferedRowIterator { /* 007 */ private Object[] references; /* 008 */ private scala.collection.Iterator[] inputs; /* 009 */ private boolean sortAgg_initAgg_0; /* 010 */ private boolean sortAgg_bufIsNull_0; /* 011 */ private long sortAgg_bufValue_0; /* 012 */ private boolean sortAgg_bufIsNull_1; /* 013 */ private double sortAgg_bufValue_1; /* 014 */ private boolean sortAgg_bufIsNull_2; /* 015 */ private long sortAgg_bufValue_2; /* 016 */ private boolean range_initRange_0; /* 017 */ private long range_nextIndex_0; /* 018 */ private TaskContext range_taskContext_0; /* 019 */ private InputMetrics range_inputMetrics_0; /* 020 */ private long range_batchEnd_0; /* 021 */ private long range_numElementsTodo_0; /* 022 */ private boolean sortAgg_sortAgg_isNull_4_0; /* 023 */ private boolean sortAgg_sortAgg_isNull_9_0; /* 024 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] range_mutableStateArray_0 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[3]; /* 025 */ /* 026 */ public GeneratedIteratorForCodegenStage1(Object[] references) { /* 027 */ this.references = references; /* 028 */ } /* 029 */ /* 030 */ public void init(int index, scala.collection.Iterator[] inputs) { /* 031 */ partitionIndex = index; /* 032 */ this.inputs = inputs; /* 033 */ /* 034 */ range_taskContext_0 = TaskContext.get(); /* 035 */ range_inputMetrics_0 = range_taskContext_0.taskMetrics().inputMetrics(); /* 036 */ range_mutableStateArray_0[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0); /* 037 */ range_mutableStateArray_0[1] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0); /* 038 */ range_mutableStateArray_0[2] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(3, 0); /* 039 */ /* 040 */ } /* 041 */ /* 042 */ private void sortAgg_doAggregate_max_0(long sortAgg_expr_0_0) throws java.io.IOException { /* 043 */ sortAgg_sortAgg_isNull_4_0 = true; /* 044 */ long sortAgg_value_4 = -1L; /* 045 */ /* 046 */ if (!sortAgg_bufIsNull_0 && (sortAgg_sortAgg_isNull_4_0 || /* 047 */ sortAgg_bufValue_0 > sortAgg_value_4)) { /* 048 */ sortAgg_sortAgg_isNull_4_0 = false; /* 049 */ sortAgg_value_4 = sortAgg_bufValue_0; /* 050 */ } /* 051 */ /* 052 */ if (!false && (sortAgg_sortAgg_isNull_4_0 || /* 053 */ sortAgg_expr_0_0 > sortAgg_value_4)) { /* 054 */ sortAgg_sortAgg_isNull_4_0 = false; /* 055 */ sortAgg_value_4 = sortAgg_expr_0_0; /* 056 */ } /* 057 */ /* 058 */ sortAgg_bufIsNull_0 = sortAgg_sortAgg_isNull_4_0; /* 059 */ sortAgg_bufValue_0 = sortAgg_value_4; /* 060 */ } /* 061 */ /* 062 */ private void sortAgg_doAggregateWithoutKey_0() throws java.io.IOException { /* 063 */ // initialize aggregation buffer /* 064 */ sortAgg_bufIsNull_0 = true; /* 065 */ sortAgg_bufValue_0 = -1L; /* 066 */ sortAgg_bufIsNull_1 = false; /* 067 */ sortAgg_bufValue_1 = 0.0D; /* 068 */ sortAgg_bufIsNull_2 = false; /* 069 */ sortAgg_bufValue_2 = 0L; /* 070 */ /* 071 */ // initialize Range /* 072 */ if (!range_initRange_0) { /* 073 */ range_initRange_0 = true; /* 074 */ initRange(partitionIndex); /* 075 */ } /* 076 */ /* 077 */ while (true) { /* 078 */ if (range_nextIndex_0 == range_batchEnd_0) { /* 079 */ long range_nextBatchTodo_0; /* 080 */ if (range_numElementsTodo_0 > 1000L) { /* 081 */ range_nextBatchTodo_0 = 1000L; /* 082 */ range_numElementsTodo_0 -= 1000L; /* 083 */ } else { /* 084 */ range_nextBatchTodo_0 = range_numElementsTodo_0; /* 085 */ range_numElementsTodo_0 = 0; /* 086 */ if (range_nextBatchTodo_0 == 0) break; /* 087 */ } /* 088 */ range_batchEnd_0 += range_nextBatchTodo_0 * 1L; /* 089 */ } /* 090 */ /* 091 */ int range_localEnd_0 = (int)((range_batchEnd_0 - range_nextIndex_0) / 1L); /* 092 */ for (int range_localIdx_0 = 0; range_localIdx_0 < range_localEnd_0; range_localIdx_0++) { /* 093 */ long range_value_0 = ((long)range_localIdx_0 * 1L) + range_nextIndex_0; /* 094 */ /* 095 */ sortAgg_doConsume_0(range_value_0); /* 096 */ /* 097 */ // shouldStop check is eliminated /* 098 */ } /* 099 */ range_nextIndex_0 = range_batchEnd_0; /* 100 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(range_localEnd_0); /* 101 */ range_inputMetrics_0.incRecordsRead(range_localEnd_0); /* 102 */ range_taskContext_0.killTaskIfInterrupted(); /* 103 */ } /* 104 */ /* 105 */ } /* 106 */ /* 107 */ private void initRange(int idx) { /* 108 */ java.math.BigInteger index = java.math.BigInteger.valueOf(idx); /* 109 */ java.math.BigInteger numSlice = java.math.BigInteger.valueOf(2L); /* 110 */ java.math.BigInteger numElement = java.math.BigInteger.valueOf(10L); /* 111 */ java.math.BigInteger step = java.math.BigInteger.valueOf(1L); /* 112 */ java.math.BigInteger start = java.math.BigInteger.valueOf(0L); /* 113 */ long partitionEnd; /* 114 */ /* 115 */ java.math.BigInteger st = index.multiply(numElement).divide(numSlice).multiply(step).add(start); /* 116 */ if (st.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) { /* 117 */ range_nextIndex_0 = Long.MAX_VALUE; /* 118 */ } else if (st.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) { /* 119 */ range_nextIndex_0 = Long.MIN_VALUE; /* 120 */ } else { /* 121 */ range_nextIndex_0 = st.longValue(); /* 122 */ } /* 123 */ range_batchEnd_0 = range_nextIndex_0; /* 124 */ /* 125 */ java.math.BigInteger end = index.add(java.math.BigInteger.ONE).multiply(numElement).divide(numSlice) /* 126 */ .multiply(step).add(start); /* 127 */ if (end.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) { /* 128 */ partitionEnd = Long.MAX_VALUE; /* 129 */ } else if (end.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) { /* 130 */ partitionEnd = Long.MIN_VALUE; /* 131 */ } else { /* 132 */ partitionEnd = end.longValue(); /* 133 */ } /* 134 */ /* 135 */ java.math.BigInteger startToEnd = java.math.BigInteger.valueOf(partitionEnd).subtract( /* 136 */ java.math.BigInteger.valueOf(range_nextIndex_0)); /* 137 */ range_numElementsTodo_0 = startToEnd.divide(step).longValue(); /* 138 */ if (range_numElementsTodo_0 < 0) { /* 139 */ range_numElementsTodo_0 = 0; /* 140 */ } else if (startToEnd.remainder(step).compareTo(java.math.BigInteger.valueOf(0L)) != 0) { /* 141 */ range_numElementsTodo_0++; /* 142 */ } /* 143 */ } /* 144 */ /* 145 */ protected void processNext() throws java.io.IOException { /* 146 */ while (!sortAgg_initAgg_0) { /* 147 */ sortAgg_initAgg_0 = true; /* 148 */ sortAgg_doAggregateWithoutKey_0(); /* 149 */ /* 150 */ // output the result /* 151 */ /* 152 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[1] /* numOutputRows */).add(1); /* 153 */ range_mutableStateArray_0[2].reset(); /* 154 */ /* 155 */ range_mutableStateArray_0[2].zeroOutNullBytes(); /* 156 */ /* 157 */ if (sortAgg_bufIsNull_0) { /* 158 */ range_mutableStateArray_0[2].setNullAt(0); /* 159 */ } else { /* 160 */ range_mutableStateArray_0[2].write(0, sortAgg_bufValue_0); /* 161 */ } /* 162 */ /* 163 */ if (sortAgg_bufIsNull_1) { /* 164 */ range_mutableStateArray_0[2].setNullAt(1); /* 165 */ } else { /* 166 */ range_mutableStateArray_0[2].write(1, sortAgg_bufValue_1); /* 167 */ } /* 168 */ /* 169 */ if (sortAgg_bufIsNull_2) { /* 170 */ range_mutableStateArray_0[2].setNullAt(2); /* 171 */ } else { /* 172 */ range_mutableStateArray_0[2].write(2, sortAgg_bufValue_2); /* 173 */ } /* 174 */ append((range_mutableStateArray_0[2].getRow())); /* 175 */ } /* 176 */ } /* 177 */ /* 178 */ private void sortAgg_doConsume_0(long sortAgg_expr_0_0) throws java.io.IOException { /* 179 */ // do aggregate /* 180 */ // common sub-expressions /* 181 */ /* 182 */ // evaluate aggregate functions and update aggregation buffers /* 183 */ sortAgg_doAggregate_max_0(sortAgg_expr_0_0); /* 184 */ sortAgg_doAggregate_avg_0(sortAgg_expr_0_0); /* 185 */ /* 186 */ } /* 187 */ /* 188 */ private void sortAgg_doAggregate_avg_0(long sortAgg_expr_0_0) throws java.io.IOException { /* 189 */ boolean sortAgg_isNull_7 = true; /* 190 */ double sortAgg_value_7 = -1.0; /* 191 */ /* 192 */ if (!sortAgg_bufIsNull_1) { /* 193 */ sortAgg_sortAgg_isNull_9_0 = true; /* 194 */ double sortAgg_value_9 = -1.0; /* 195 */ do { /* 196 */ boolean sortAgg_isNull_10 = false; /* 197 */ double sortAgg_value_10 = -1.0; /* 198 */ if (!false) { /* 199 */ sortAgg_value_10 = (double) sortAgg_expr_0_0; /* 200 */ } /* 201 */ if (!sortAgg_isNull_10) { /* 202 */ sortAgg_sortAgg_isNull_9_0 = false; /* 203 */ sortAgg_value_9 = sortAgg_value_10; /* 204 */ continue; /* 205 */ } /* 206 */ /* 207 */ if (!false) { /* 208 */ sortAgg_sortAgg_isNull_9_0 = false; /* 209 */ sortAgg_value_9 = 0.0D; /* 210 */ continue; /* 211 */ } /* 212 */ /* 213 */ } while (false); /* 214 */ /* 215 */ sortAgg_isNull_7 = false; // resultCode could change nullability. /* 216 */ /* 217 */ sortAgg_value_7 = sortAgg_bufValue_1 + sortAgg_value_9; /* 218 */ /* 219 */ } /* 220 */ boolean sortAgg_isNull_13 = false; /* 221 */ long sortAgg_value_13 = -1L; /* 222 */ if (!false && false) { /* 223 */ sortAgg_isNull_13 = sortAgg_bufIsNull_2; /* 224 */ sortAgg_value_13 = sortAgg_bufValue_2; /* 225 */ } else { /* 226 */ boolean sortAgg_isNull_17 = true; /* 227 */ long sortAgg_value_17 = -1L; /* 228 */ /* 229 */ if (!sortAgg_bufIsNull_2) { /* 230 */ sortAgg_isNull_17 = false; // resultCode could change nullability. /* 231 */ /* 232 */ sortAgg_value_17 = sortAgg_bufValue_2 + 1L; /* 233 */ /* 234 */ } /* 235 */ sortAgg_isNull_13 = sortAgg_isNull_17; /* 236 */ sortAgg_value_13 = sortAgg_value_17; /* 237 */ } /* 238 */ /* 239 */ sortAgg_bufIsNull_1 = sortAgg_isNull_7; /* 240 */ sortAgg_bufValue_1 = sortAgg_value_7; /* 241 */ /* 242 */ sortAgg_bufIsNull_2 = sortAgg_isNull_13; /* 243 */ sortAgg_bufValue_2 = sortAgg_value_13; /* 244 */ } /* 245 */ /* 246 */ }