背景
本文基于 SPARK 3.3.0
从一个unit test来探究SPARK Codegen的逻辑,
test("SortAggregate should be included in WholeStageCodegen") { val df = spark.range(10).agg(max(col("id")), avg(col("id"))) withSQLConf("spark.sql.test.forceApplySortAggregate" -> "true") { val plan = df.queryExecution.executedPlan assert(plan.exists(p => p.isInstanceOf[WholeStageCodegenExec] && p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SortAggregateExec])) assert(df.collect() === Array(Row(9, 4.5))) } }
该sql形成的执行计划第一部分的全代码生成部分如下:
WholeStageCodegen +- *(1) SortAggregate(key=[], functions=[partial_max(id#0L), partial_avg(id#0L)], output=[max#12L, sum#13, count#14L]) +- *(1) Range (0, 10, step=1, splits=2)
分析
第一阶段wholeStageCodegen
第一阶段的代码生成涉及到SortAggregateExec和RangeExec的produce和consume方法,这里一一来分析:
第一阶段wholeStageCodegen数据流如下:
WholeStageCodegenExec SortAggregateExec(partial) RangeExec ========================================================================= -> execute() | doExecute() ---------> inputRDDs() -----------------> inputRDDs() | doCodeGen() | +-----------------> produce() | doProduce() | doProduceWithoutKeys() -------> produce() | doProduce() | doConsume()<------------------- consume() | doConsumeWithoutKeys() |并不是doConsumeWithoutKeys调用consume,而是由doProduceWithoutKeys调用 doConsume() <-------- consume()
WholeStageCodegenExec的doExecute
override def doExecute(): RDD[InternalRow] = { val (ctx, cleanedSource) = doCodeGen() // try to compile and fallback if it failed val (_, compiledCodeStats) = try { CodeGenerator.compile(cleanedSource) } catch { case NonFatal(_) if !Utils.isTesting && conf.codegenFallback => // We should already saw the error message logWarning(s"Whole-stage codegen disabled for plan (id=$codegenStageId):\n $treeString") return child.execute() } // Check if compiled code has a too large function if (compiledCodeStats.maxMethodCodeSize > conf.hugeMethodLimit) { logInfo(s"Found too long generated codes and JIT optimization might not work: " + s"the bytecode size (${compiledCodeStats.maxMethodCodeSize}) is above the limit " + s"${conf.hugeMethodLimit}, and the whole-stage codegen was disabled " + s"for this plan (id=$codegenStageId). To avoid this, you can raise the limit " + s"`${SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.key}`:\n$treeString") return child.execute() } val references = ctx.references.toArray val durationMs = longMetric("pipelineTime") // Even though rdds is an RDD[InternalRow] it may actually be an RDD[ColumnarBatch] with // type erasure hiding that. This allows for the input to a code gen stage to be columnar, // but the output must be rows. val rdds = child.asInstanceOf[CodegenSupport].inputRDDs() assert(rdds.size <= 2, "Up to two input RDDs can be supported") if (rdds.length == 1) { rdds.head.mapPartitionsWithIndex { (index, iter) => val (clazz, _) = CodeGenerator.compile(cleanedSource) val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator] buffer.init(index, Array(iter)) new Iterator[InternalRow] { override def hasNext: Boolean = { val v = buffer.hasNext if (!v) durationMs += buffer.durationMs() v } override def next: InternalRow = buffer.next() } } } else { // Right now, we support up to two input RDDs. rdds.head.zipPartitions(rdds(1)) { (leftIter, rightIter) => Iterator((leftIter, rightIter)) // a small hack to obtain the correct partition index }.mapPartitionsWithIndex { (index, zippedIter) => val (leftIter, rightIter) = zippedIter.next() val (clazz, _) = CodeGenerator.compile(cleanedSource) val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator] buffer.init(index, Array(leftIter, rightIter)) new Iterator[InternalRow] { override def hasNext: Boolean = { val v = buffer.hasNext if (!v) durationMs += buffer.durationMs() v } override def next: InternalRow = buffer.next() } } } }
这里主要分两部分:
一部分是获取输入的RDD
一部分是进行代码生成以及处理
val (ctx, cleanedSource) = doCodeGen()
全代码生成
val (_, compiledCodeStats) =
代码进行编译,如果编译报错,则回退到原始的执行child.execute(),这里会先在driver端进行编译,如果代码生成有误能够提前发现
if (compiledCodeStats.maxMethodCodeSize > conf.hugeMethodLimit) {
如果代码生成的长度大于65535(默认值),则回退到原始的执行child.execute()
val rdds = child.asInstanceOf[CodegenSupport].inputRDDs()
获取对应的RDD,便于进行迭代,因为这里是SortAggregateExec所以最终调用到RangeExec的inputRDDs:
override def inputRDDs(): Seq[RDD[InternalRow]] = { val rdd = if (isEmptyRange) { new EmptyRDD[InternalRow](sparkContext) } else { sparkContext.parallelize(0 until numSlices, numSlices).map(i => InternalRow(i)) } rdd :: Nil }
rdds.head.mapPartitionsWithIndex { (index, iter) =>
对rdd进行迭代,对于当前的第一阶段全代码生成来说,该rdds不会被用到,因为数据是由RangExec产生的
val (clazz, _) = CodeGenerator.compile(cleanedSource)
executor端代码生成
val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]
reference来自于val references = ctx.references.toArray,
对于当前来说,目前只是numOutputRows指标变量
val buffer = 和buffer.init(index, Array(iter))
代码初始化,对于当前第一阶段全代码生成来说,index会被用来进行产生数据,iter不会被用到,第二阶段中会把iter拿来进行数据处理