SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(3)

简介: SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(3)

背景


本文基于 SPARK 3.3.0

从一个unit test来探究SPARK Codegen的逻辑,

  test("SortAggregate should be included in WholeStageCodegen") {
    val df = spark.range(10).agg(max(col("id")), avg(col("id")))
    withSQLConf("spark.sql.test.forceApplySortAggregate" -> "true") {
      val plan = df.queryExecution.executedPlan
      assert(plan.exists(p =>
        p.isInstanceOf[WholeStageCodegenExec] &&
          p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SortAggregateExec]))
      assert(df.collect() === Array(Row(9, 4.5)))
    }
  }
该sql形成的执行计划第一部分的全代码生成部分如下:
WholeStageCodegen

± *(1) SortAggregate(key=[], functions=[partial_max(id#0L), partial_avg(id#0L)], output=[max#12L, sum#13, count#14L])

± *(1) Range (0, 10, step=1, splits=2)

分析


第一阶段wholeStageCodegen


第一阶段的代码生成涉及到SortAggregateExec和RangeExec的produce和consume方法,这里一一来分析:

第一阶段wholeStageCodegen数据流如下:


 WholeStageCodegenExec      SortAggregateExec(partial)     RangeExec        
  =========================================================================
  -> execute()
      |
   doExecute() --------->   inputRDDs() -----------------> inputRDDs() 
      |
   doCodeGen()
      |
      +----------------->   produce()
                              |
                           doProduce() 
                              |
                           doProduceWithoutKeys() -------> produce()
                                                              |
                                                          doProduce()
                                                              |
                           doConsume()<------------------- consume()
                              |
                           doConsumeWithoutKeys()
                              |并不是doConsumeWithoutKeys调用consume,而是由doProduceWithoutKeys调用
   doConsume()  <--------  consume()

RangeExec的produce

final def produce(ctx: CodegenContext, parent: CodegenSupport): String = executeQuery {
    this.parent = parent
    ctx.freshNamePrefix = variablePrefix
    s"""
       |${ctx.registerComment(s"PRODUCE: ${this.simpleString(conf.maxToStringFields)}")}
       |${doProduce(ctx)}
     """.stripMargin
  }

this.parent = parent以及ctx.freshNamePrefix = variablePrefix

设置parent 以便在做consume方法的时候能够获取到父节点的引用,这样才能调用到父节点的consume方法以便代码生成。

freshNamePrefix的设置是为了在生成对应的方法的时候,区分不同物理计划的方法,这样能防止方法名重复,避免编译代码时出错。

ctx.registerComment

这块是给java代码加上对应的注释,默认情况下是不会加上的,因为默认spark.sql.codegen.comments 是False

protected override def doProduce(ctx: CodegenContext): String = {
    val numOutput = metricTerm(ctx, "numOutputRows")
    val initTerm = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initRange")
    val nextIndex = ctx.addMutableState(CodeGenerator.JAVA_LONG, "nextIndex")
    val value = ctx.freshName("value")
    val ev = ExprCode.forNonNullValue(JavaCode.variable(value, LongType))
    val BigInt = classOf[java.math.BigInteger].getName
    // Inline mutable state since not many Range operations in a task
    val taskContext = ctx.addMutableState("TaskContext", "taskContext",
      v => s"$v = TaskContext.get();", forceInline = true)
    val inputMetrics = ctx.addMutableState("InputMetrics", "inputMetrics",
      v => s"$v = $taskContext.taskMetrics().inputMetrics();", forceInline = true)
    // In order to periodically update the metrics without inflicting performance penalty, this
    // operator produces elements in batches. After a batch is complete, the metrics are updated
    // and a new batch is started.
    // In the implementation below, the code in the inner loop is producing all the values
    // within a batch, while the code in the outer loop is setting batch parameters and updating
    // the metrics.
    // Once nextIndex == batchEnd, it's time to progress to the next batch.
    val batchEnd = ctx.addMutableState(CodeGenerator.JAVA_LONG, "batchEnd")
    // How many values should still be generated by this range operator.
    val numElementsTodo = ctx.addMutableState(CodeGenerator.JAVA_LONG, "numElementsTodo")
    // How many values should be generated in the next batch.
    val nextBatchTodo = ctx.freshName("nextBatchTodo")
    // The default size of a batch, which must be positive integer
    val batchSize = 1000
    val initRangeFuncName = ctx.addNewFunction("initRange",
      s"""
        | private void initRange(int idx) {
        |   $BigInt index = $BigInt.valueOf(idx);
        |   $BigInt numSlice = $BigInt.valueOf(${numSlices}L);
        |   $BigInt numElement = $BigInt.valueOf(${numElements.toLong}L);
        |   $BigInt step = $BigInt.valueOf(${step}L);
        |   $BigInt start = $BigInt.valueOf(${start}L);
        |   long partitionEnd;
        |
        |   $BigInt st = index.multiply(numElement).divide(numSlice).multiply(step).add(start);
        |   if (st.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) {
        |     $nextIndex = Long.MAX_VALUE;
        |   } else if (st.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) {
        |     $nextIndex = Long.MIN_VALUE;
        |   } else {
        |     $nextIndex = st.longValue();
        |   }
        |   $batchEnd = $nextIndex;
        |
        |   $BigInt end = index.add($BigInt.ONE).multiply(numElement).divide(numSlice)
        |     .multiply(step).add(start);
        |   if (end.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) {
        |     partitionEnd = Long.MAX_VALUE;
        |   } else if (end.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) {
        |     partitionEnd = Long.MIN_VALUE;
        |   } else {
        |     partitionEnd = end.longValue();
        |   }
        |
        |   $BigInt startToEnd = $BigInt.valueOf(partitionEnd).subtract(
        |     $BigInt.valueOf($nextIndex));
        |   $numElementsTodo  = startToEnd.divide(step).longValue();
        |   if ($numElementsTodo < 0) {
        |     $numElementsTodo = 0;
        |   } else if (startToEnd.remainder(step).compareTo($BigInt.valueOf(0L)) != 0) {
        |     $numElementsTodo++;
        |   }
        | }
       """.stripMargin)
    val localIdx = ctx.freshName("localIdx")
    val localEnd = ctx.freshName("localEnd")
    val stopCheck = if (parent.needStopCheck) {
      s"""
         |if (shouldStop()) {
         |  $nextIndex = $value + ${step}L;
         |  $numOutput.add($localIdx + 1);
         |  $inputMetrics.incRecordsRead($localIdx + 1);
         |  return;
         |}
       """.stripMargin
    } else {
      "// shouldStop check is eliminated"
    }
    val loopCondition = if (limitNotReachedChecks.isEmpty) {
      "true"
    } else {
      limitNotReachedChecks.mkString(" && ")
    }
    s"""
      | // initialize Range
      | if (!$initTerm) {
      |   $initTerm = true;
      |   $initRangeFuncName(partitionIndex);
      | }
      |
      | while ($loopCondition) {
      |   if ($nextIndex == $batchEnd) {
      |     long $nextBatchTodo;
      |     if ($numElementsTodo > ${batchSize}L) {
      |       $nextBatchTodo = ${batchSize}L;
      |       $numElementsTodo -= ${batchSize}L;
      |     } else {
      |       $nextBatchTodo = $numElementsTodo;
      |       $numElementsTodo = 0;
      |       if ($nextBatchTodo == 0) break;
      |     }
      |     $batchEnd += $nextBatchTodo * ${step}L;
      |   }
      |
      |   int $localEnd = (int)(($batchEnd - $nextIndex) / ${step}L);
      |   for (int $localIdx = 0; $localIdx < $localEnd; $localIdx++) {
      |     long $value = ((long)$localIdx * ${step}L) + $nextIndex;
      |     ${consume(ctx, Seq(ev))}
      |     $stopCheck
      |   }
      |   $nextIndex = $batchEnd;
      |   $numOutput.add($localEnd);
      |   $inputMetrics.incRecordsRead($localEnd);
      |   $taskContext.killTaskIfInterrupted();
      | }
     """.stripMargin
  }

val numOutput = metricTerm(ctx, “numOutputRows”)

numOutput指标,用于记录输出的记录条数


val initTerm =以及val nextIndex =

initTerm用于标识该物理计划是够已经生成了代码,

nextIndex是用来产生rangeExec数据的逻辑索引,遍历数据

这两个参数也是类的成员变量,即全局变量


val value =和val ev =

这个ev值是用来表示rangExec生成的数据的,最终会被consume(ctx, Seq(ev))方法所调用

而其中的value变量则是会在long $value = ((long)$localIdx * ${step}L) + $nextIndex;被赋值,这样父节点才能进行消费


val taskContext =和val inputMetrics =

taskContext和inputMetrics也是全部变量,而且还有初始化变量,这种初始化方法将会在生成的类方法init中进行初始化,会形成一下代码:

range_taskContext_0 = TaskContext.get();
range_inputMetrics_0 = range_taskContext_0.taskMetrics().inputMetrics();

之所以会在init方法进行初始化是因为该初始化方法会被放入到mutableStateInitCodeArray类型的变量中,而mutableStateInitCode里的

数据,将会在WholeStageCodegenExec的ctx.initMutableStates()会被组装调用,被调用的代码如下:

 public void init(int index, scala.collection.Iterator[] inputs) {
       partitionIndex = index;
       this.inputs = inputs;
       ${ctx.initMutableStates()}
       ${ctx.initPartition()}
     }

val batchEnd =和val numElementsTodo

这两个变量也是生成类的成员变量,即全局变量


val nextBatchTodo =

这个变量是临时变量,会在遍历生成数据的时候用到


val initRangeFuncName =

就是RangeExec生成数据的逻辑了,每个物理计划都是不一样。这里忽略


最后的while ($loopCondition)

这部分就是根据每个分区的index不一样,生成不同的数据。

值得一提的是initRangeFuncName(partitionIndex)这部分中的partitionIndex变量,这个变量是生成的类的父类BufferedRowIterator中,

而partitionIndex变量的赋值也在init方法中,具体代码如下:

public void init(int index, scala.collection.Iterator[] inputs) {
        partitionIndex = index;
        this.inputs = inputs; 

consume(ctx, Seq(ev))

父节点进行消费rangeExec产生的数据,接下来会继续讲解


numOutput和inputMetrics和taskContext

numOutput 进行输出数据的增加

inputMetrics 在taskMetrics级别数据的增加

taskContext.killTaskIfInterrupted 用来判断当前任务是不是被kill了,如果被kill了直接抛出异常

相关文章
|
SQL 分布式计算 Spark
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(10)
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(10)
145 0
|
分布式计算 Spark
SPARK中的wholeStageCodegen全代码生成--GenerateUnsafeProjection.createCode说明
SPARK中的wholeStageCodegen全代码生成--GenerateUnsafeProjection.createCode说明
101 0
|
SQL 分布式计算 Spark
Spark中的WholeStageCodegenExec(全代码生成)
Spark中的WholeStageCodegenExec(全代码生成)
430 0
Spark中的WholeStageCodegenExec(全代码生成)
|
4月前
|
机器学习/深度学习 SQL 分布式计算
Apache Spark 的基本概念和在大数据分析中的应用
介绍 Apache Spark 的基本概念和在大数据分析中的应用
161 0
|
19天前
|
分布式计算 Hadoop 大数据
大数据技术与Python:结合Spark和Hadoop进行分布式计算
【4月更文挑战第12天】本文介绍了大数据技术及其4V特性,阐述了Hadoop和Spark在大数据处理中的作用。Hadoop提供分布式文件系统和MapReduce,Spark则为内存计算提供快速处理能力。通过Python结合Spark和Hadoop,可在分布式环境中进行数据处理和分析。文章详细讲解了如何配置Python环境、安装Spark和Hadoop,以及使用Python编写和提交代码到集群进行计算。掌握这些技能有助于应对大数据挑战。
|
4月前
|
机器学习/深度学习 SQL 分布式计算
介绍 Apache Spark 的基本概念和在大数据分析中的应用。
介绍 Apache Spark 的基本概念和在大数据分析中的应用。
|
7天前
|
分布式计算 大数据 数据处理
[AIGC大数据基础] Spark 入门
[AIGC大数据基础] Spark 入门
|
3月前
|
分布式计算 大数据 Java
Spark 大数据实战:基于 RDD 的大数据处理分析
Spark 大数据实战:基于 RDD 的大数据处理分析
129 0
|
4月前
|
分布式计算 监控 大数据
Spark RDD分区和数据分布:优化大数据处理
Spark RDD分区和数据分布:优化大数据处理