SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(4)

简介: SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(4)

背景

本文基于 SPARK 3.3.0

从一个unit test来探究SPARK Codegen的逻辑,


  test("SortAggregate should be included in WholeStageCodegen") {
    val df = spark.range(10).agg(max(col("id")), avg(col("id")))
    withSQLConf("spark.sql.test.forceApplySortAggregate" -> "true") {
      val plan = df.queryExecution.executedPlan
      assert(plan.exists(p =>
        p.isInstanceOf[WholeStageCodegenExec] &&
          p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SortAggregateExec]))
      assert(df.collect() === Array(Row(9, 4.5)))
    }
  }

该sql形成的执行计划第一部分的全代码生成部分如下:

   WholeStageCodegen     
   +- *(1) SortAggregate(key=[], functions=[partial_max(id#0L), partial_avg(id#0L)], output=[max#12L, sum#13, count#14L])
      +- *(1) Range (0, 10, step=1, splits=2)    


分析


第一阶段wholeStageCodegen

第一阶段的代码生成涉及到SortAggregateExec和RangeExec的produce和consume方法,这里一一来分析:

第一阶段wholeStageCodegen数据流如下:

WholeStageCodegenExec      SortAggregateExec(partial)     RangeExec        
  =========================================================================
  -> execute()
      |
   doExecute() --------->   inputRDDs() -----------------> inputRDDs() 
      |
   doCodeGen()
      |
      +----------------->   produce()
                              |
                           doProduce() 
                              |
                           doProduceWithoutKeys() -------> produce()
                                                              |
                                                          doProduce()
                                                              |
                           doConsume()<------------------- consume()
                              |
                           doConsumeWithoutKeys()
                              |并不是doConsumeWithoutKeys调用consume,而是由doProduceWithoutKeys调用
   doConsume()  <--------  consume()

RangeExec的consume方法

final def consume(ctx: CodegenContext, outputVars: Seq[ExprCode], row: String = null): String = {
    val inputVarsCandidate =
      if (outputVars != null) {
        assert(outputVars.length == output.length)
        // outputVars will be used to generate the code for UnsafeRow, so we should copy them
        outputVars.map(_.copy())
      } else {
        assert(row != null, "outputVars and row cannot both be null.")
        ctx.currentVars = null
        ctx.INPUT_ROW = row
        output.zipWithIndex.map { case (attr, i) =>
          BoundReference(i, attr.dataType, attr.nullable).genCode(ctx)
        }
      }
    val inputVars = inputVarsCandidate match {
      case stream: Stream[ExprCode] => stream.force
      case other => other
    }
    val rowVar = prepareRowVar(ctx, row, outputVars)
    // Set up the `currentVars` in the codegen context, as we generate the code of `inputVars`
    // before calling `parent.doConsume`. We can't set up `INPUT_ROW`, because parent needs to
    // generate code of `rowVar` manually.
    ctx.currentVars = inputVars
    ctx.INPUT_ROW = null
    ctx.freshNamePrefix = parent.variablePrefix
    val evaluated = evaluateRequiredVariables(output, inputVars, parent.usedInputs)
    // Under certain conditions, we can put the logic to consume the rows of this operator into
    // another function. So we can prevent a generated function too long to be optimized by JIT.
    // The conditions:
    // 1. The config "spark.sql.codegen.splitConsumeFuncByOperator" is enabled.
    // 2. `inputVars` are all materialized. That is guaranteed to be true if the parent plan uses
    //    all variables in output (see `requireAllOutput`).
    // 3. The number of output variables must less than maximum number of parameters in Java method
    //    declaration.
    val confEnabled = conf.wholeStageSplitConsumeFuncByOperator
    val requireAllOutput = output.forall(parent.usedInputs.contains(_))
    val paramLength = CodeGenerator.calculateParamLength(output) + (if (row != null) 1 else 0)
    val consumeFunc = if (confEnabled && requireAllOutput
        && CodeGenerator.isValidParamLength(paramLength)) {
      constructDoConsumeFunction(ctx, inputVars, row)
    } else {
      parent.doConsume(ctx, inputVars, rowVar)
    }
    s"""
       |${ctx.registerComment(s"CONSUME: ${parent.simpleString(conf.maxToStringFields)}")}
       |$evaluated
       |$consumeFunc
     """.stripMargin
  }

其中参数outputVars为传入的rangeExc产生的value


val inputVarsCandidate =和val inputVars =

对于outputVars 不为空的情况下,直接copy复制一份outputVars值作为输入的变量

如果outputVars为空,而row不为空的情况下,则说明传入的是InteralRow类型的变量,需要调用InteralRow对应的方法获取对应的值


val rowVar = prepareRowVar(ctx, row, outputVars)

这部分在RangeExec中不会用到,这里不讲解(因为rangExec这里数据流会走向constructDoConsumeFunction这里)


ctx.currentVars = inputVars ctx.INPUT_ROW = null ctx.freshNamePrefix = parent.variablePrefix

这里是为了对evaluateRequiredVariables方法做铺垫,因为


val evaluated = evaluateRequiredVariables(output, inputVars, parent.usedInputs)

其中这里的output 为 Range.getOutputAttrs,即StructType(StructField("id", LongType, nullable = false) :: Nil).toAttributes

inputVars为range_value_0

parent.usedInputs为AttributeSet(children.flatMap(_.asInstanceOf[QueryPlan[PlanType]].output)),和output一样,也就是Range.getOutputAttrs,即StructType(StructField("id", LongType, nullable = false) :: Nil).toAttributes

因为inputVars的code为空,所以 evaluated对于该inputVars计算也为空


val confEnabled val requireAllOutput

这里的两个条件都是 TRUE


val paramLength = CodeGenerator.calculateParamLength(output) + (if (row != null) 1 else 0)

计算表达式的长度,对于LONG和DOUBLE类型长度为2,其他的为1,因为range_value_0是LONG类型,所以总的长度为3


val consumeFunc =confEnabled && requireAllOutput&& CodeGenerator.isValidParamLength(paramLength)

这里的三个条件都满足,所以数据流向constructDoConsumeFunction方法,如下:

private def constructDoConsumeFunction(
    ctx: CodegenContext,
    inputVars: Seq[ExprCode],
    row: String): String = {
  val (args, params, inputVarsInFunc) = constructConsumeParameters(ctx, output, inputVars, row)
  val rowVar = prepareRowVar(ctx, row, inputVarsInFunc)
  val doConsume = ctx.freshName("doConsume")
  ctx.currentVars = inputVarsInFunc
  ctx.INPUT_ROW = null
  val doConsumeFuncName = ctx.addNewFunction(doConsume,
    s"""
       | private void $doConsume(${params.mkString(", ")}) throws java.io.IOException {
       |   ${parent.doConsume(ctx, inputVarsInFunc, rowVar)}
       | }
     """.stripMargin)
  s"""
     | $doConsumeFuncName(${args.mkString(", ")});
   """.stripMargin
}

其中inputVars为range_value_0

row为NULL


val (args, params, inputVarsInFunc) = constructConsumeParameters(ctx, output, inputVars, row)

构造 函数实参,形参,以及形参ExprCode变量,分别为range_value_0,long sortAgg_expr_0_0,sortAgg_expr_0_0

val rowVar = prepareRowVar(ctx, row, inputVarsInFunc)

这里是构造UnsafeRow类型的变量便于传给parent进行消费 ,其中 row为NULL,inputVarsInFunc为sortAgg_expr_0_0

private def prepareRowVar(ctx: CodegenContext, row: String, colVars: Seq[ExprCode]): ExprCode = {
  if (row != null) {
    ExprCode.forNonNullValue(JavaCode.variable(row, classOf[UnsafeRow]))
  } else {
    if (colVars.nonEmpty) {
      val colExprs = output.zipWithIndex.map { case (attr, i) =>
        BoundReference(i, attr.dataType, attr.nullable)
      }
      val evaluateInputs = evaluateVariables(colVars)
      // generate the code to create a UnsafeRow
      ctx.INPUT_ROW = row
      ctx.currentVars = colVars
      val ev = GenerateUnsafeProjection.createCode(ctx, colExprs, false)
      val code = code"""
        |$evaluateInputs
        |${ev.code}
       """.stripMargin
      ExprCode(code, FalseLiteral, ev.value)
    } else {
      // There are no columns
      ExprCode.forNonNullValue(JavaCode.variable("unsafeRow", classOf[UnsafeRow]))
    }
  }
}

对于val colExprs =

这块是针对当前物理计划的输出(output)与变量值进行绑定,对于RangeExec来说output的值为Range.getOutputAttrs,即StructType(StructField(“id”, LongType, nullable = false) :: Nil).toAttributes ,而当前rangexec的对应的变量为range_value_0


val evaluateInputs = evaluateVariables(colVars)

对于不是直接赋值的变量,而是通过计算得到的变量,则需要进行提前计算,在这里不需要计算。


val ev = GenerateUnsafeProjection.createCode(ctx, colExprs, false)

这部分是产生UnsafeRow类型的变量,这个UnsafeRow类型的变量里包含了rangExec的产生的变量rang_value_0

里面具体的细节,这里先忽略,以后会有具体的文章分析。


ExprCode(code, FalseLiteral, ev.value)

这里就返回ExprCode类型的数据结构,

其中code如下:range_mutableStateArray_0[0].reset();range_mutableStateArray_0[0].write(0, sortAgg_expr_0_0);

ev.value如下:range_mutableStateArray_0[0].getRow()


val doConsume = ctx.freshName(“doConsume”)

构建函数的名字,这里为sortAgg_doConsume_0


val doConsumeFuncName =

构造函数调用,其中主要调用的是parent.doConsume(ctx, inputVarsInFunc, rowVar)方法,

注意:这里的rowVar在SortAggregateExec中不会被用到,但是在WholeStageCodeGenExec中会被用到


最后的s"""${ctx.registerComment(s"CONSUME: ${parent.simpleString(conf.maxToStringFields)}")}$evaluated 则是组装代码

相关文章
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch并行与分布式(三)DataParallel原理、源码解析、举例实战
PyTorch并行与分布式(三)DataParallel原理、源码解析、举例实战
646 0
|
分布式计算 Hadoop 大数据
Spark 原理_总结介绍_案例编写 | 学习笔记
快速学习 Spark 原理_总结介绍_案例编写
103 0
Spark 原理_总结介绍_案例编写 | 学习笔记
|
SQL 分布式计算 Spark
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(10)
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(10)
164 0
|
分布式计算 Java Spark
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(3)
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(3)
218 0
|
分布式计算 Spark
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(1)
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(1)
184 0
|
SQL 分布式计算 Serverless
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(6)
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(6)
81 0
|
SQL 分布式计算 Spark
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(8)
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(8)
179 0
|
SQL 分布式计算 数据处理
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(7)
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(7)
108 0
|
缓存 分布式计算 Spark
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(2)
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(2)
128 0
|
SQL 分布式计算 Spark
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(9)
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(9)
106 0