SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(3)

简介: SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(3)

背景


本文基于 SPARK 3.3.0

从一个unit test来探究SPARK Codegen的逻辑,

  test("SortAggregate should be included in WholeStageCodegen") {
    val df = spark.range(10).agg(max(col("id")), avg(col("id")))
    withSQLConf("spark.sql.test.forceApplySortAggregate" -> "true") {
      val plan = df.queryExecution.executedPlan
      assert(plan.exists(p =>
        p.isInstanceOf[WholeStageCodegenExec] &&
          p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SortAggregateExec]))
      assert(df.collect() === Array(Row(9, 4.5)))
    }
  }
该sql形成的执行计划第一部分的全代码生成部分如下:
WholeStageCodegen

± *(1) SortAggregate(key=[], functions=[partial_max(id#0L), partial_avg(id#0L)], output=[max#12L, sum#13, count#14L])

± *(1) Range (0, 10, step=1, splits=2)

分析


第一阶段wholeStageCodegen


第一阶段的代码生成涉及到SortAggregateExec和RangeExec的produce和consume方法,这里一一来分析:

第一阶段wholeStageCodegen数据流如下:


 WholeStageCodegenExec      SortAggregateExec(partial)     RangeExec        
  =========================================================================
  -> execute()
      |
   doExecute() --------->   inputRDDs() -----------------> inputRDDs() 
      |
   doCodeGen()
      |
      +----------------->   produce()
                              |
                           doProduce() 
                              |
                           doProduceWithoutKeys() -------> produce()
                                                              |
                                                          doProduce()
                                                              |
                           doConsume()<------------------- consume()
                              |
                           doConsumeWithoutKeys()
                              |并不是doConsumeWithoutKeys调用consume,而是由doProduceWithoutKeys调用
   doConsume()  <--------  consume()

RangeExec的produce

final def produce(ctx: CodegenContext, parent: CodegenSupport): String = executeQuery {
    this.parent = parent
    ctx.freshNamePrefix = variablePrefix
    s"""
       |${ctx.registerComment(s"PRODUCE: ${this.simpleString(conf.maxToStringFields)}")}
       |${doProduce(ctx)}
     """.stripMargin
  }

this.parent = parent以及ctx.freshNamePrefix = variablePrefix

设置parent 以便在做consume方法的时候能够获取到父节点的引用,这样才能调用到父节点的consume方法以便代码生成。

freshNamePrefix的设置是为了在生成对应的方法的时候,区分不同物理计划的方法,这样能防止方法名重复,避免编译代码时出错。

ctx.registerComment

这块是给java代码加上对应的注释,默认情况下是不会加上的,因为默认spark.sql.codegen.comments 是False

protected override def doProduce(ctx: CodegenContext): String = {
    val numOutput = metricTerm(ctx, "numOutputRows")
    val initTerm = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initRange")
    val nextIndex = ctx.addMutableState(CodeGenerator.JAVA_LONG, "nextIndex")
    val value = ctx.freshName("value")
    val ev = ExprCode.forNonNullValue(JavaCode.variable(value, LongType))
    val BigInt = classOf[java.math.BigInteger].getName
    // Inline mutable state since not many Range operations in a task
    val taskContext = ctx.addMutableState("TaskContext", "taskContext",
      v => s"$v = TaskContext.get();", forceInline = true)
    val inputMetrics = ctx.addMutableState("InputMetrics", "inputMetrics",
      v => s"$v = $taskContext.taskMetrics().inputMetrics();", forceInline = true)
    // In order to periodically update the metrics without inflicting performance penalty, this
    // operator produces elements in batches. After a batch is complete, the metrics are updated
    // and a new batch is started.
    // In the implementation below, the code in the inner loop is producing all the values
    // within a batch, while the code in the outer loop is setting batch parameters and updating
    // the metrics.
    // Once nextIndex == batchEnd, it's time to progress to the next batch.
    val batchEnd = ctx.addMutableState(CodeGenerator.JAVA_LONG, "batchEnd")
    // How many values should still be generated by this range operator.
    val numElementsTodo = ctx.addMutableState(CodeGenerator.JAVA_LONG, "numElementsTodo")
    // How many values should be generated in the next batch.
    val nextBatchTodo = ctx.freshName("nextBatchTodo")
    // The default size of a batch, which must be positive integer
    val batchSize = 1000
    val initRangeFuncName = ctx.addNewFunction("initRange",
      s"""
        | private void initRange(int idx) {
        |   $BigInt index = $BigInt.valueOf(idx);
        |   $BigInt numSlice = $BigInt.valueOf(${numSlices}L);
        |   $BigInt numElement = $BigInt.valueOf(${numElements.toLong}L);
        |   $BigInt step = $BigInt.valueOf(${step}L);
        |   $BigInt start = $BigInt.valueOf(${start}L);
        |   long partitionEnd;
        |
        |   $BigInt st = index.multiply(numElement).divide(numSlice).multiply(step).add(start);
        |   if (st.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) {
        |     $nextIndex = Long.MAX_VALUE;
        |   } else if (st.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) {
        |     $nextIndex = Long.MIN_VALUE;
        |   } else {
        |     $nextIndex = st.longValue();
        |   }
        |   $batchEnd = $nextIndex;
        |
        |   $BigInt end = index.add($BigInt.ONE).multiply(numElement).divide(numSlice)
        |     .multiply(step).add(start);
        |   if (end.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) {
        |     partitionEnd = Long.MAX_VALUE;
        |   } else if (end.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) {
        |     partitionEnd = Long.MIN_VALUE;
        |   } else {
        |     partitionEnd = end.longValue();
        |   }
        |
        |   $BigInt startToEnd = $BigInt.valueOf(partitionEnd).subtract(
        |     $BigInt.valueOf($nextIndex));
        |   $numElementsTodo  = startToEnd.divide(step).longValue();
        |   if ($numElementsTodo < 0) {
        |     $numElementsTodo = 0;
        |   } else if (startToEnd.remainder(step).compareTo($BigInt.valueOf(0L)) != 0) {
        |     $numElementsTodo++;
        |   }
        | }
       """.stripMargin)
    val localIdx = ctx.freshName("localIdx")
    val localEnd = ctx.freshName("localEnd")
    val stopCheck = if (parent.needStopCheck) {
      s"""
         |if (shouldStop()) {
         |  $nextIndex = $value + ${step}L;
         |  $numOutput.add($localIdx + 1);
         |  $inputMetrics.incRecordsRead($localIdx + 1);
         |  return;
         |}
       """.stripMargin
    } else {
      "// shouldStop check is eliminated"
    }
    val loopCondition = if (limitNotReachedChecks.isEmpty) {
      "true"
    } else {
      limitNotReachedChecks.mkString(" && ")
    }
    s"""
      | // initialize Range
      | if (!$initTerm) {
      |   $initTerm = true;
      |   $initRangeFuncName(partitionIndex);
      | }
      |
      | while ($loopCondition) {
      |   if ($nextIndex == $batchEnd) {
      |     long $nextBatchTodo;
      |     if ($numElementsTodo > ${batchSize}L) {
      |       $nextBatchTodo = ${batchSize}L;
      |       $numElementsTodo -= ${batchSize}L;
      |     } else {
      |       $nextBatchTodo = $numElementsTodo;
      |       $numElementsTodo = 0;
      |       if ($nextBatchTodo == 0) break;
      |     }
      |     $batchEnd += $nextBatchTodo * ${step}L;
      |   }
      |
      |   int $localEnd = (int)(($batchEnd - $nextIndex) / ${step}L);
      |   for (int $localIdx = 0; $localIdx < $localEnd; $localIdx++) {
      |     long $value = ((long)$localIdx * ${step}L) + $nextIndex;
      |     ${consume(ctx, Seq(ev))}
      |     $stopCheck
      |   }
      |   $nextIndex = $batchEnd;
      |   $numOutput.add($localEnd);
      |   $inputMetrics.incRecordsRead($localEnd);
      |   $taskContext.killTaskIfInterrupted();
      | }
     """.stripMargin
  }

val numOutput = metricTerm(ctx, “numOutputRows”)

numOutput指标,用于记录输出的记录条数


val initTerm =以及val nextIndex =

initTerm用于标识该物理计划是够已经生成了代码,

nextIndex是用来产生rangeExec数据的逻辑索引,遍历数据

这两个参数也是类的成员变量,即全局变量


val value =和val ev =

这个ev值是用来表示rangExec生成的数据的,最终会被consume(ctx, Seq(ev))方法所调用

而其中的value变量则是会在long $value = ((long)$localIdx * ${step}L) + $nextIndex;被赋值,这样父节点才能进行消费


val taskContext =和val inputMetrics =

taskContext和inputMetrics也是全部变量,而且还有初始化变量,这种初始化方法将会在生成的类方法init中进行初始化,会形成一下代码:

range_taskContext_0 = TaskContext.get();
range_inputMetrics_0 = range_taskContext_0.taskMetrics().inputMetrics();

之所以会在init方法进行初始化是因为该初始化方法会被放入到mutableStateInitCodeArray类型的变量中,而mutableStateInitCode里的

数据,将会在WholeStageCodegenExec的ctx.initMutableStates()会被组装调用,被调用的代码如下:

 public void init(int index, scala.collection.Iterator[] inputs) {
       partitionIndex = index;
       this.inputs = inputs;
       ${ctx.initMutableStates()}
       ${ctx.initPartition()}
     }

val batchEnd =和val numElementsTodo

这两个变量也是生成类的成员变量,即全局变量


val nextBatchTodo =

这个变量是临时变量,会在遍历生成数据的时候用到


val initRangeFuncName =

就是RangeExec生成数据的逻辑了,每个物理计划都是不一样。这里忽略


最后的while ($loopCondition)

这部分就是根据每个分区的index不一样,生成不同的数据。

值得一提的是initRangeFuncName(partitionIndex)这部分中的partitionIndex变量,这个变量是生成的类的父类BufferedRowIterator中,

而partitionIndex变量的赋值也在init方法中,具体代码如下:

public void init(int index, scala.collection.Iterator[] inputs) {
        partitionIndex = index;
        this.inputs = inputs; 

consume(ctx, Seq(ev))

父节点进行消费rangeExec产生的数据,接下来会继续讲解


numOutput和inputMetrics和taskContext

numOutput 进行输出数据的增加

inputMetrics 在taskMetrics级别数据的增加

taskContext.killTaskIfInterrupted 用来判断当前任务是不是被kill了,如果被kill了直接抛出异常

相关文章
|
SQL 分布式计算 Spark
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(10)
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(10)
186 0
|
分布式计算 Spark
SPARK中的wholeStageCodegen全代码生成--GenerateUnsafeProjection.createCode说明
SPARK中的wholeStageCodegen全代码生成--GenerateUnsafeProjection.createCode说明
134 0
|
SQL 分布式计算 Spark
Spark中的WholeStageCodegenExec(全代码生成)
Spark中的WholeStageCodegenExec(全代码生成)
570 0
Spark中的WholeStageCodegenExec(全代码生成)
|
1月前
|
分布式计算 大数据 Apache
ClickHouse与大数据生态集成:Spark & Flink 实战
【10月更文挑战第26天】在当今这个数据爆炸的时代,能够高效地处理和分析海量数据成为了企业和组织提升竞争力的关键。作为一款高性能的列式数据库系统,ClickHouse 在大数据分析领域展现出了卓越的能力。然而,为了充分利用ClickHouse的优势,将其与现有的大数据处理框架(如Apache Spark和Apache Flink)进行集成变得尤为重要。本文将从我个人的角度出发,探讨如何通过这些技术的结合,实现对大规模数据的实时处理和分析。
124 2
ClickHouse与大数据生态集成:Spark & Flink 实战
|
2月前
|
存储 分布式计算 算法
大数据-106 Spark Graph X 计算学习 案例:1图的基本计算、2连通图算法、3寻找相同的用户
大数据-106 Spark Graph X 计算学习 案例:1图的基本计算、2连通图算法、3寻找相同的用户
70 0
|
2月前
|
消息中间件 分布式计算 NoSQL
大数据-104 Spark Streaming Kafka Offset Scala实现Redis管理Offset并更新
大数据-104 Spark Streaming Kafka Offset Scala实现Redis管理Offset并更新
44 0
|
2月前
|
消息中间件 存储 分布式计算
大数据-103 Spark Streaming Kafka Offset管理详解 Scala自定义Offset
大数据-103 Spark Streaming Kafka Offset管理详解 Scala自定义Offset
100 0
|
1月前
|
SQL 机器学习/深度学习 分布式计算
Spark快速上手:揭秘大数据处理的高效秘密,让你轻松应对海量数据
【10月更文挑战第25天】本文全面介绍了大数据处理框架 Spark,涵盖其基本概念、安装配置、编程模型及实际应用。Spark 是一个高效的分布式计算平台,支持批处理、实时流处理、SQL 查询和机器学习等任务。通过详细的技术综述和示例代码,帮助读者快速掌握 Spark 的核心技能。
77 6
|
1月前
|
存储 分布式计算 Hadoop
数据湖技术:Hadoop与Spark在大数据处理中的协同作用
【10月更文挑战第27天】在大数据时代,数据湖技术凭借其灵活性和成本效益成为企业存储和分析大规模异构数据的首选。Hadoop和Spark作为数据湖技术的核心组件,通过HDFS存储数据和Spark进行高效计算,实现了数据处理的优化。本文探讨了Hadoop与Spark的最佳实践,包括数据存储、处理、安全和可视化等方面,展示了它们在实际应用中的协同效应。
104 2