背景
本文基于 SPARK 3.3.0
从一个unit test来探究SPARK Codegen的逻辑,
test("SortAggregate should be included in WholeStageCodegen") { val df = spark.range(10).agg(max(col("id")), avg(col("id"))) withSQLConf("spark.sql.test.forceApplySortAggregate" -> "true") { val plan = df.queryExecution.executedPlan assert(plan.exists(p => p.isInstanceOf[WholeStageCodegenExec] && p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SortAggregateExec])) assert(df.collect() === Array(Row(9, 4.5))) } } 该sql形成的执行计划第一部分的全代码生成部分如下:
WholeStageCodegen
± *(1) SortAggregate(key=[], functions=[partial_max(id#0L), partial_avg(id#0L)], output=[max#12L, sum#13, count#14L])
± *(1) Range (0, 10, step=1, splits=2)
分析
第一阶段wholeStageCodegen
第一阶段的代码生成涉及到SortAggregateExec和RangeExec的produce和consume方法,这里一一来分析:
第一阶段wholeStageCodegen数据流如下:
WholeStageCodegenExec SortAggregateExec(partial) RangeExec ========================================================================= -> execute() | doExecute() ---------> inputRDDs() -----------------> inputRDDs() | doCodeGen() | +-----------------> produce() | doProduce() | doProduceWithoutKeys() -------> produce() | doProduce() | doConsume()<------------------- consume() | doConsumeWithoutKeys() |并不是doConsumeWithoutKeys调用consume,而是由doProduceWithoutKeys调用 doConsume() <-------- consume()
RangeExec的produce
final def produce(ctx: CodegenContext, parent: CodegenSupport): String = executeQuery { this.parent = parent ctx.freshNamePrefix = variablePrefix s""" |${ctx.registerComment(s"PRODUCE: ${this.simpleString(conf.maxToStringFields)}")} |${doProduce(ctx)} """.stripMargin }
this.parent = parent以及ctx.freshNamePrefix = variablePrefix
设置parent 以便在做consume方法的时候能够获取到父节点的引用,这样才能调用到父节点的consume方法以便代码生成。
freshNamePrefix的设置是为了在生成对应的方法的时候,区分不同物理计划的方法,这样能防止方法名重复,避免编译代码时出错。
ctx.registerComment
这块是给java代码加上对应的注释,默认情况下是不会加上的,因为默认spark.sql.codegen.comments 是False
protected override def doProduce(ctx: CodegenContext): String = { val numOutput = metricTerm(ctx, "numOutputRows") val initTerm = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initRange") val nextIndex = ctx.addMutableState(CodeGenerator.JAVA_LONG, "nextIndex") val value = ctx.freshName("value") val ev = ExprCode.forNonNullValue(JavaCode.variable(value, LongType)) val BigInt = classOf[java.math.BigInteger].getName // Inline mutable state since not many Range operations in a task val taskContext = ctx.addMutableState("TaskContext", "taskContext", v => s"$v = TaskContext.get();", forceInline = true) val inputMetrics = ctx.addMutableState("InputMetrics", "inputMetrics", v => s"$v = $taskContext.taskMetrics().inputMetrics();", forceInline = true) // In order to periodically update the metrics without inflicting performance penalty, this // operator produces elements in batches. After a batch is complete, the metrics are updated // and a new batch is started. // In the implementation below, the code in the inner loop is producing all the values // within a batch, while the code in the outer loop is setting batch parameters and updating // the metrics. // Once nextIndex == batchEnd, it's time to progress to the next batch. val batchEnd = ctx.addMutableState(CodeGenerator.JAVA_LONG, "batchEnd") // How many values should still be generated by this range operator. val numElementsTodo = ctx.addMutableState(CodeGenerator.JAVA_LONG, "numElementsTodo") // How many values should be generated in the next batch. val nextBatchTodo = ctx.freshName("nextBatchTodo") // The default size of a batch, which must be positive integer val batchSize = 1000 val initRangeFuncName = ctx.addNewFunction("initRange", s""" | private void initRange(int idx) { | $BigInt index = $BigInt.valueOf(idx); | $BigInt numSlice = $BigInt.valueOf(${numSlices}L); | $BigInt numElement = $BigInt.valueOf(${numElements.toLong}L); | $BigInt step = $BigInt.valueOf(${step}L); | $BigInt start = $BigInt.valueOf(${start}L); | long partitionEnd; | | $BigInt st = index.multiply(numElement).divide(numSlice).multiply(step).add(start); | if (st.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) { | $nextIndex = Long.MAX_VALUE; | } else if (st.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) { | $nextIndex = Long.MIN_VALUE; | } else { | $nextIndex = st.longValue(); | } | $batchEnd = $nextIndex; | | $BigInt end = index.add($BigInt.ONE).multiply(numElement).divide(numSlice) | .multiply(step).add(start); | if (end.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) { | partitionEnd = Long.MAX_VALUE; | } else if (end.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) { | partitionEnd = Long.MIN_VALUE; | } else { | partitionEnd = end.longValue(); | } | | $BigInt startToEnd = $BigInt.valueOf(partitionEnd).subtract( | $BigInt.valueOf($nextIndex)); | $numElementsTodo = startToEnd.divide(step).longValue(); | if ($numElementsTodo < 0) { | $numElementsTodo = 0; | } else if (startToEnd.remainder(step).compareTo($BigInt.valueOf(0L)) != 0) { | $numElementsTodo++; | } | } """.stripMargin) val localIdx = ctx.freshName("localIdx") val localEnd = ctx.freshName("localEnd") val stopCheck = if (parent.needStopCheck) { s""" |if (shouldStop()) { | $nextIndex = $value + ${step}L; | $numOutput.add($localIdx + 1); | $inputMetrics.incRecordsRead($localIdx + 1); | return; |} """.stripMargin } else { "// shouldStop check is eliminated" } val loopCondition = if (limitNotReachedChecks.isEmpty) { "true" } else { limitNotReachedChecks.mkString(" && ") } s""" | // initialize Range | if (!$initTerm) { | $initTerm = true; | $initRangeFuncName(partitionIndex); | } | | while ($loopCondition) { | if ($nextIndex == $batchEnd) { | long $nextBatchTodo; | if ($numElementsTodo > ${batchSize}L) { | $nextBatchTodo = ${batchSize}L; | $numElementsTodo -= ${batchSize}L; | } else { | $nextBatchTodo = $numElementsTodo; | $numElementsTodo = 0; | if ($nextBatchTodo == 0) break; | } | $batchEnd += $nextBatchTodo * ${step}L; | } | | int $localEnd = (int)(($batchEnd - $nextIndex) / ${step}L); | for (int $localIdx = 0; $localIdx < $localEnd; $localIdx++) { | long $value = ((long)$localIdx * ${step}L) + $nextIndex; | ${consume(ctx, Seq(ev))} | $stopCheck | } | $nextIndex = $batchEnd; | $numOutput.add($localEnd); | $inputMetrics.incRecordsRead($localEnd); | $taskContext.killTaskIfInterrupted(); | } """.stripMargin }
val numOutput = metricTerm(ctx, “numOutputRows”)
numOutput指标,用于记录输出的记录条数
val initTerm =以及val nextIndex =
initTerm用于标识该物理计划是够已经生成了代码,
nextIndex是用来产生rangeExec数据的逻辑索引,遍历数据
这两个参数也是类的成员变量,即全局变量
val value =和val ev =
这个ev值是用来表示rangExec生成的数据的,最终会被consume(ctx, Seq(ev))方法所调用
而其中的value变量则是会在long $value = ((long)$localIdx * ${step}L) + $nextIndex;被赋值,这样父节点才能进行消费
val taskContext =和val inputMetrics =
taskContext和inputMetrics也是全部变量,而且还有初始化变量,这种初始化方法将会在生成的类方法init中进行初始化,会形成一下代码:
range_taskContext_0 = TaskContext.get(); range_inputMetrics_0 = range_taskContext_0.taskMetrics().inputMetrics();
之所以会在init方法进行初始化是因为该初始化方法会被放入到mutableStateInitCodeArray类型的变量中,而mutableStateInitCode里的
数据,将会在WholeStageCodegenExec的ctx.initMutableStates()会被组装调用,被调用的代码如下:
public void init(int index, scala.collection.Iterator[] inputs) { partitionIndex = index; this.inputs = inputs; ${ctx.initMutableStates()} ${ctx.initPartition()} }
val batchEnd =和val numElementsTodo
这两个变量也是生成类的成员变量,即全局变量
val nextBatchTodo =
这个变量是临时变量,会在遍历生成数据的时候用到
val initRangeFuncName =
就是RangeExec生成数据的逻辑了,每个物理计划都是不一样。这里忽略
最后的while ($loopCondition)
这部分就是根据每个分区的index不一样,生成不同的数据。
值得一提的是initRangeFuncName(partitionIndex)这部分中的partitionIndex变量,这个变量是生成的类的父类BufferedRowIterator中,
而partitionIndex变量的赋值也在init方法中,具体代码如下:
public void init(int index, scala.collection.Iterator[] inputs) { partitionIndex = index; this.inputs = inputs;
consume(ctx, Seq(ev))
父节点进行消费rangeExec产生的数据,接下来会继续讲解
numOutput和inputMetrics和taskContext
numOutput 进行输出数据的增加
inputMetrics 在taskMetrics级别数据的增加
taskContext.killTaskIfInterrupted 用来判断当前任务是不是被kill了,如果被kill了直接抛出异常