背景
本文基于 SPARK 3.3.0
从一个unit test来探究SPARK Codegen的逻辑,
test("SortAggregate should be included in WholeStageCodegen") { val df = spark.range(10).agg(max(col("id")), avg(col("id"))) withSQLConf("spark.sql.test.forceApplySortAggregate" -> "true") { val plan = df.queryExecution.executedPlan assert(plan.exists(p => p.isInstanceOf[WholeStageCodegenExec] && p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SortAggregateExec])) assert(df.collect() === Array(Row(9, 4.5))) } }
该sql形成的执行计划第一部分的全代码生成部分如下:
WholeStageCodegen +- *(1) SortAggregate(key=[], functions=[partial_max(id#0L), partial_avg(id#0L)], output=[max#12L, sum#13, count#14L]) +- *(1) Range (0, 10, step=1, splits=2)
分析
第一阶段wholeStageCodegen
第一阶段的代码生成涉及到SortAggregateExec和RangeExec的produce和consume方法,这里一一来分析:
第一阶段wholeStageCodegen数据流如下:
WholeStageCodegenExec SortAggregateExec(partial) RangeExec ========================================================================= -> execute() | doExecute() ---------> inputRDDs() -----------------> inputRDDs() | doCodeGen() | +-----------------> produce() | doProduce() | doProduceWithoutKeys() -------> produce() | doProduce() | doConsume()<------------------- consume() | doConsumeWithoutKeys() |并不是doConsumeWithoutKeys调用consume,而是由doProduceWithoutKeys调用 doConsume() <-------- consume()
RangeExec的consume方法
final def consume(ctx: CodegenContext, outputVars: Seq[ExprCode], row: String = null): String = { val inputVarsCandidate = if (outputVars != null) { assert(outputVars.length == output.length) // outputVars will be used to generate the code for UnsafeRow, so we should copy them outputVars.map(_.copy()) } else { assert(row != null, "outputVars and row cannot both be null.") ctx.currentVars = null ctx.INPUT_ROW = row output.zipWithIndex.map { case (attr, i) => BoundReference(i, attr.dataType, attr.nullable).genCode(ctx) } } val inputVars = inputVarsCandidate match { case stream: Stream[ExprCode] => stream.force case other => other } val rowVar = prepareRowVar(ctx, row, outputVars) // Set up the `currentVars` in the codegen context, as we generate the code of `inputVars` // before calling `parent.doConsume`. We can't set up `INPUT_ROW`, because parent needs to // generate code of `rowVar` manually. ctx.currentVars = inputVars ctx.INPUT_ROW = null ctx.freshNamePrefix = parent.variablePrefix val evaluated = evaluateRequiredVariables(output, inputVars, parent.usedInputs) // Under certain conditions, we can put the logic to consume the rows of this operator into // another function. So we can prevent a generated function too long to be optimized by JIT. // The conditions: // 1. The config "spark.sql.codegen.splitConsumeFuncByOperator" is enabled. // 2. `inputVars` are all materialized. That is guaranteed to be true if the parent plan uses // all variables in output (see `requireAllOutput`). // 3. The number of output variables must less than maximum number of parameters in Java method // declaration. val confEnabled = conf.wholeStageSplitConsumeFuncByOperator val requireAllOutput = output.forall(parent.usedInputs.contains(_)) val paramLength = CodeGenerator.calculateParamLength(output) + (if (row != null) 1 else 0) val consumeFunc = if (confEnabled && requireAllOutput && CodeGenerator.isValidParamLength(paramLength)) { constructDoConsumeFunction(ctx, inputVars, row) } else { parent.doConsume(ctx, inputVars, rowVar) } s""" |${ctx.registerComment(s"CONSUME: ${parent.simpleString(conf.maxToStringFields)}")} |$evaluated |$consumeFunc """.stripMargin }
其中参数outputVars为传入的rangeExc产生的value
val inputVarsCandidate =和val inputVars =
对于outputVars 不为空的情况下,直接copy复制一份outputVars值作为输入的变量
如果outputVars为空,而row不为空的情况下,则说明传入的是InteralRow类型的变量,需要调用InteralRow对应的方法获取对应的值
val rowVar = prepareRowVar(ctx, row, outputVars)
这部分在RangeExec中不会用到,这里不讲解(因为rangExec这里数据流会走向constructDoConsumeFunction这里)
ctx.currentVars = inputVars ctx.INPUT_ROW = null ctx.freshNamePrefix = parent.variablePrefix
这里是为了对evaluateRequiredVariables方法做铺垫,因为
val evaluated = evaluateRequiredVariables(output, inputVars, parent.usedInputs)
其中这里的output 为 Range.getOutputAttrs,即StructType(StructField("id", LongType, nullable = false) :: Nil).toAttributes
inputVars为range_value_0
parent.usedInputs为AttributeSet(children.flatMap(_.asInstanceOf[QueryPlan[PlanType]].output)),和output一样,也就是Range.getOutputAttrs,即StructType(StructField("id", LongType, nullable = false) :: Nil).toAttributes
因为inputVars的code为空,所以 evaluated对于该inputVars计算也为空
val confEnabled val requireAllOutput
这里的两个条件都是 TRUE
val paramLength = CodeGenerator.calculateParamLength(output) + (if (row != null) 1 else 0)
计算表达式的长度,对于LONG和DOUBLE类型长度为2,其他的为1,因为range_value_0是LONG类型,所以总的长度为3
val consumeFunc =confEnabled && requireAllOutput&& CodeGenerator.isValidParamLength(paramLength)
这里的三个条件都满足,所以数据流向constructDoConsumeFunction方法,如下:
private def constructDoConsumeFunction( ctx: CodegenContext, inputVars: Seq[ExprCode], row: String): String = { val (args, params, inputVarsInFunc) = constructConsumeParameters(ctx, output, inputVars, row) val rowVar = prepareRowVar(ctx, row, inputVarsInFunc) val doConsume = ctx.freshName("doConsume") ctx.currentVars = inputVarsInFunc ctx.INPUT_ROW = null val doConsumeFuncName = ctx.addNewFunction(doConsume, s""" | private void $doConsume(${params.mkString(", ")}) throws java.io.IOException { | ${parent.doConsume(ctx, inputVarsInFunc, rowVar)} | } """.stripMargin) s""" | $doConsumeFuncName(${args.mkString(", ")}); """.stripMargin }
其中inputVars为range_value_0
row为NULL
val (args, params, inputVarsInFunc) = constructConsumeParameters(ctx, output, inputVars, row)
构造 函数实参,形参,以及形参ExprCode变量,分别为range_value_0,long sortAgg_expr_0_0,sortAgg_expr_0_0
val rowVar = prepareRowVar(ctx, row, inputVarsInFunc)
这里是构造UnsafeRow类型的变量便于传给parent进行消费 ,其中 row为NULL,inputVarsInFunc为sortAgg_expr_0_0
private def prepareRowVar(ctx: CodegenContext, row: String, colVars: Seq[ExprCode]): ExprCode = { if (row != null) { ExprCode.forNonNullValue(JavaCode.variable(row, classOf[UnsafeRow])) } else { if (colVars.nonEmpty) { val colExprs = output.zipWithIndex.map { case (attr, i) => BoundReference(i, attr.dataType, attr.nullable) } val evaluateInputs = evaluateVariables(colVars) // generate the code to create a UnsafeRow ctx.INPUT_ROW = row ctx.currentVars = colVars val ev = GenerateUnsafeProjection.createCode(ctx, colExprs, false) val code = code""" |$evaluateInputs |${ev.code} """.stripMargin ExprCode(code, FalseLiteral, ev.value) } else { // There are no columns ExprCode.forNonNullValue(JavaCode.variable("unsafeRow", classOf[UnsafeRow])) } } }
对于val colExprs =
这块是针对当前物理计划的输出(output)与变量值进行绑定,对于RangeExec来说output的值为Range.getOutputAttrs,即StructType(StructField(“id”, LongType, nullable = false) :: Nil).toAttributes ,而当前rangexec的对应的变量为range_value_0
val evaluateInputs = evaluateVariables(colVars)
对于不是直接赋值的变量,而是通过计算得到的变量,则需要进行提前计算,在这里不需要计算。
val ev = GenerateUnsafeProjection.createCode(ctx, colExprs, false)
这部分是产生UnsafeRow类型的变量,这个UnsafeRow类型的变量里包含了rangExec的产生的变量rang_value_0
里面具体的细节,这里先忽略,以后会有具体的文章分析。
ExprCode(code, FalseLiteral, ev.value)
这里就返回ExprCode类型的数据结构,
其中code如下:range_mutableStateArray_0[0].reset();range_mutableStateArray_0[0].write(0, sortAgg_expr_0_0);
ev.value如下:range_mutableStateArray_0[0].getRow()
val doConsume = ctx.freshName(“doConsume”)
构建函数的名字,这里为sortAgg_doConsume_0
val doConsumeFuncName =
构造函数调用,其中主要调用的是parent.doConsume(ctx, inputVarsInFunc, rowVar)方法,
注意:这里的rowVar在SortAggregateExec中不会被用到,但是在WholeStageCodeGenExec中会被用到
最后的s"""${ctx.registerComment(s"CONSUME: ${parent.simpleString(conf.maxToStringFields)}")}$evaluated 则是组装代码