SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(7)

简介: SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(7)

背景


本文基于 SPARK 3.3.0

从一个unit test来探究SPARK Codegen的逻辑,

  test("SortAggregate should be included in WholeStageCodegen") {
    val df = spark.range(10).agg(max(col("id")), avg(col("id")))
    withSQLConf("spark.sql.test.forceApplySortAggregate" -> "true") {
      val plan = df.queryExecution.executedPlan
      assert(plan.exists(p =>
        p.isInstanceOf[WholeStageCodegenExec] &&
          p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SortAggregateExec]))
      assert(df.collect() === Array(Row(9, 4.5)))
    }
  }

该sql形成的执行计划第一部分的全代码生成部分如下:

   WholeStageCodegen     
   +- *(1) SortAggregate(key=[], functions=[partial_max(id#0L), partial_avg(id#0L)], output=[max#12L, sum#13, count#14L])
      +- *(1) Range (0, 10, step=1, splits=2)    

分析


第一阶段wholeStageCodegen

第一阶段的代码生成涉及到SortAggregateExec和RangeExec的produce和consume方法,这里一一来分析:

第一阶段wholeStageCodegen数据流如下:


 WholeStageCodegenExec      SortAggregateExec(partial)     RangeExec        
  =========================================================================
  -> execute()
      |
   doExecute() --------->   inputRDDs() -----------------> inputRDDs() 
      |
   doCodeGen()
      |
      +----------------->   produce()
                              |
                           doProduce() 
                              |
                           doProduceWithoutKeys() -------> produce()
                                                              |
                                                          doProduce()
                                                              |
                           doConsume()<------------------- consume()
                              |
                           doConsumeWithoutKeys()
                              |并不是doConsumeWithoutKeys调用consume,而是由doProduceWithoutKeys调用
   doConsume()  <--------  consume()


WholeStageCodegenExec的doExecute

override def doExecute(): RDD[InternalRow] = {
    val (ctx, cleanedSource) = doCodeGen()
    // try to compile and fallback if it failed
    val (_, compiledCodeStats) = try {
      CodeGenerator.compile(cleanedSource)
    } catch {
      case NonFatal(_) if !Utils.isTesting && conf.codegenFallback =>
        // We should already saw the error message
        logWarning(s"Whole-stage codegen disabled for plan (id=$codegenStageId):\n $treeString")
        return child.execute()
    }
    // Check if compiled code has a too large function
    if (compiledCodeStats.maxMethodCodeSize > conf.hugeMethodLimit) {
      logInfo(s"Found too long generated codes and JIT optimization might not work: " +
        s"the bytecode size (${compiledCodeStats.maxMethodCodeSize}) is above the limit " +
        s"${conf.hugeMethodLimit}, and the whole-stage codegen was disabled " +
        s"for this plan (id=$codegenStageId). To avoid this, you can raise the limit " +
        s"`${SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.key}`:\n$treeString")
      return child.execute()
    }
    val references = ctx.references.toArray
    val durationMs = longMetric("pipelineTime")
    // Even though rdds is an RDD[InternalRow] it may actually be an RDD[ColumnarBatch] with
    // type erasure hiding that. This allows for the input to a code gen stage to be columnar,
    // but the output must be rows.
    val rdds = child.asInstanceOf[CodegenSupport].inputRDDs()
    assert(rdds.size <= 2, "Up to two input RDDs can be supported")
    if (rdds.length == 1) {
      rdds.head.mapPartitionsWithIndex { (index, iter) =>
        val (clazz, _) = CodeGenerator.compile(cleanedSource)
        val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]
        buffer.init(index, Array(iter))
        new Iterator[InternalRow] {
          override def hasNext: Boolean = {
            val v = buffer.hasNext
            if (!v) durationMs += buffer.durationMs()
            v
          }
          override def next: InternalRow = buffer.next()
        }
      }
    } else {
      // Right now, we support up to two input RDDs.
      rdds.head.zipPartitions(rdds(1)) { (leftIter, rightIter) =>
        Iterator((leftIter, rightIter))
        // a small hack to obtain the correct partition index
      }.mapPartitionsWithIndex { (index, zippedIter) =>
        val (leftIter, rightIter) = zippedIter.next()
        val (clazz, _) = CodeGenerator.compile(cleanedSource)
        val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]
        buffer.init(index, Array(leftIter, rightIter))
        new Iterator[InternalRow] {
          override def hasNext: Boolean = {
            val v = buffer.hasNext
            if (!v) durationMs += buffer.durationMs()
            v
          }
          override def next: InternalRow = buffer.next()
        }
      }
    }
  }

这里主要分两部分:


一部分是获取输入的RDD

一部分是进行代码生成以及处理

val (ctx, cleanedSource) = doCodeGen()

全代码生成


val (_, compiledCodeStats) =

代码进行编译,如果编译报错,则回退到原始的执行child.execute(),这里会先在driver端进行编译,如果代码生成有误能够提前发现


if (compiledCodeStats.maxMethodCodeSize > conf.hugeMethodLimit) {

如果代码生成的长度大于65535(默认值),则回退到原始的执行child.execute()


val rdds = child.asInstanceOf[CodegenSupport].inputRDDs()

获取对应的RDD,便于进行迭代,因为这里是SortAggregateExec所以最终调用到RangeExec的inputRDDs:


override def inputRDDs(): Seq[RDD[InternalRow]] = {
  val rdd = if (isEmptyRange) {
    new EmptyRDD[InternalRow](sparkContext)
  } else {
    sparkContext.parallelize(0 until numSlices, numSlices).map(i => InternalRow(i))
  }
  rdd :: Nil
}

rdds.head.mapPartitionsWithIndex { (index, iter) =>

对rdd进行迭代,对于当前的第一阶段全代码生成来说,该rdds不会被用到,因为数据是由RangExec产生的


val (clazz, _) = CodeGenerator.compile(cleanedSource)

executor端代码生成


val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]

reference来自于val references = ctx.references.toArray,

对于当前来说,目前只是numOutputRows指标变量


val buffer = 和buffer.init(index, Array(iter))

代码初始化,对于当前第一阶段全代码生成来说,index会被用来进行产生数据,iter不会被用到,第二阶段中会把iter拿来进行数据处理

相关文章
|
3月前
|
存储 缓存 监控
GitlabCI学习笔记之五:GitLabRunner pipeline语法之cache
GitlabCI学习笔记之五:GitLabRunner pipeline语法之cache
|
3月前
LangChain 构建问题之定义zmng_query工具的具体实现函数如何解决
LangChain 构建问题之定义zmng_query工具的具体实现函数如何解决
29 0
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch并行与分布式(三)DataParallel原理、源码解析、举例实战
PyTorch并行与分布式(三)DataParallel原理、源码解析、举例实战
774 0
|
分布式计算 Java Spark
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(3)
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(3)
245 0
|
SQL 分布式计算 Spark
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(5)
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(5)
178 0
|
SQL 分布式计算 Serverless
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(6)
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(6)
92 0
|
SQL 分布式计算 Spark
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(4)
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(4)
437 0
|
SQL 分布式计算 Spark
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(10)
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(10)
181 0
|
SQL 分布式计算 Spark
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(9)
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(9)
130 0
|
SQL 分布式计算 Spark
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(8)
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(8)
197 0