SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(9)

简介: SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(9)

背景


本文基于 SPARK 3.3.0

从一个unit test来探究SPARK Codegen的逻辑,

  test("SortAggregate should be included in WholeStageCodegen") {
    val df = spark.range(10).agg(max(col("id")), avg(col("id")))
    withSQLConf("spark.sql.test.forceApplySortAggregate" -> "true") {
      val plan = df.queryExecution.executedPlan
      assert(plan.exists(p =>
        p.isInstanceOf[WholeStageCodegenExec] &&
          p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SortAggregateExec]))
      assert(df.collect() === Array(Row(9, 4.5)))
    }
  }

该sql形成的执行计划第二部分的全代码生成部分如下:


WholeStageCodegen
*(2) SortAggregate(key=[], functions=[max(id#0L), avg(id#0L)], output=[max(id)#5L, avg(id)#6])
   InputAdapter
+- Exchange SinglePartition, ENSURE_REQUIREMENTS, [id=#13]


分析


第二阶段wholeStageCodegen


第二阶段的代码生成涉及到SortAggregateExec和ShuffleExchangeExec以及InputAdapter的produce和consume方法,这里一一来分析:

第二阶段wholeStageCodegen数据流如下:

 WholeStageCodegenExec      SortAggregateExec(Final)      InputAdapter       ShuffleExchangeExec        
  ====================================================================================
  -> execute()
      |
   doExecute() --------->   inputRDDs() -----------------> inputRDDs() -------> execute()
      |                                                                            |
   doCodeGen()                                                                  doExecute()     
      |                                                                            |
      +----------------->   produce()                                           ShuffledRowRDD
                              |
                           doProduce() 
                              |
                           doProduceWithoutKeys() -------> produce()
                                                              |
                                                          doProduce()
                                                              |
                           doConsume() <------------------- consume()
                              |
                           doConsumeWithoutKeys()
                              |并不是doConsumeWithoutKeys调用consume,而是由doProduceWithoutKeys调用
   doConsume()  <--------  consume()

SortAggregateExec(Final) 的doConsume


sortAggregateExec的doConsume方法最终会调用doConsumeWithoutKeys(ctx, input)方法,其中input为ArrayBuffer(ExprCode(,sortAgg_exprIsNull_0_0,sortAgg_expr_0_0), ExprCode(,sortAgg_exprIsNull_1_0,sortAgg_expr_1_0), ExprCode(,sortAgg_exprIsNull_2_0,sortAgg_expr_2_0))

和SortAggregateExec(Partial)不同点:


updateExprs的不同

    val updateExprs = aggregateExpressions.map { e =>
    e.mode match {
      case Partial | Complete =>
        e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions
      case PartialMerge | Final =>
        e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions
    }
  }

因为这里是Final阶段,所以更新语句是mergeExpressions,即


protected def getMergeExpressions = Seq(
 /* sum = */ Add(sum.left, sum.right, useAnsiAdd),
 /* count = */ count.left + count.right


生成对应的ExprCode类为:

List(ExprCode(boolean sortAgg_isNull_13 = true;
     double sortAgg_value_13 = -1.0;
     if (!sortAgg_bufIsNull_1) {
     if (!sortAgg_exprIsNull_1_0) {
           sortAgg_isNull_13 = false; // resultCode could change nullability.
           sortAgg_value_13 = sortAgg_bufValue_1 + sortAgg_expr_1_0;
     }
     },sortAgg_isNull_13,sortAgg_value_13),
 ExprCode(boolean sortAgg_isNull_16 = true;
     long sortAgg_value_16 = -1L;
     if (!sortAgg_bufIsNull_2) {
     if (!sortAgg_exprIsNull_2_0) {
           sortAgg_isNull_16 = false; // resultCode could change nullability.
           sortAgg_value_16 = sortAgg_bufValue_2 + sortAgg_expr_2_0;
     }
     },sortAgg_isNull_16,sortAgg_value_16)))


可以看到,对于sortAgg_value_16也就是COUNT的值,这里是+ sortAgg_expr_2_0,而在Partial部分则是+ 1,因为这是各个Task local操作完后的结果buffer。

sortAgg_isNull_13和sortAgg_value_16 的最终结果还是会赋值给全局变量:


  sortAgg_bufIsNull_1 = sortAgg_isNull_13;
  sortAgg_bufValue_1 = sortAgg_value_13;
  sortAgg_bufIsNull_2 = sortAgg_isNull_16;
  sortAgg_bufValue_2 = sortAgg_value_16;

SortAggregateExec(Final)的consume

consume(ctx, resultVars)

其中resultVars为ExprCode(,sortAgg_isNull_4,sortAgg_value_4),这里包含了最终的AVG的结果值。

其他的数据流向和之前的一样,


val rowVar = prepareRowVar(ctx, row, outputVars)

该最终sql返回如下(和之前的一致):

     sortAgg_mutableStateArray_0[0].reset();
     sortAgg_mutableStateArray_0[0].zeroOutNullBytes();
     if (sortAgg_bufIsNull_0) {
       sortAgg_mutableStateArray_0[0].setNullAt(0);
     } else {
       sortAgg_mutableStateArray_0[0].write(0, sortAgg_bufValue_0);
     }
     if (sortAgg_isNull_4) {
       sortAgg_mutableStateArray_0[0].setNullAt(1);
     } else {
       sortAgg_mutableStateArray_0[0].write(1, sortAgg_value_4);
     }

WholeStageCodegenExec的doConsume.

数据流和之前的一致,最终的生成的代码如下:

append((sortAgg_mutableStateArray_0[0].getRow()));


第二阶段wholeStageCodegen最终的代码如下:

/* 001 */ public Object generate(Object[] references) {
/* 002 */   return new GeneratedIteratorForCodegenStage2(references);
/* 003 */ }
/* 004 */
/* 005 */ // codegenStageId=2
/* 006 */ final class GeneratedIteratorForCodegenStage2 extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 007 */   private Object[] references;
/* 008 */   private scala.collection.Iterator[] inputs;
/* 009 */   private boolean sortAgg_initAgg_0;
/* 010 */   private boolean sortAgg_bufIsNull_0;
/* 011 */   private long sortAgg_bufValue_0;
/* 012 */   private boolean sortAgg_bufIsNull_1;
/* 013 */   private double sortAgg_bufValue_1;
/* 014 */   private boolean sortAgg_bufIsNull_2;
/* 015 */   private long sortAgg_bufValue_2;
/* 016 */   private scala.collection.Iterator inputadapter_input_0;
/* 017 */   private boolean sortAgg_sortAgg_isNull_10_0;
/* 018 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] sortAgg_mutableStateArray_0 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[1];
/* 019 */
/* 020 */   public GeneratedIteratorForCodegenStage2(Object[] references) {
/* 021 */     this.references = references;
/* 022 */   }
/* 023 */
/* 024 */   public void init(int index, scala.collection.Iterator[] inputs) {
/* 025 */     partitionIndex = index;
/* 026 */     this.inputs = inputs;
/* 027 */
/* 028 */     inputadapter_input_0 = inputs[0];
/* 029 */     sortAgg_mutableStateArray_0[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(2, 0);
/* 030 */
/* 031 */   }
/* 032 */
/* 033 */   private void sortAgg_doAggregate_max_0(long sortAgg_expr_0_0, boolean sortAgg_exprIsNull_0_0) throws java.io.IOException {
/* 034 */     sortAgg_sortAgg_isNull_10_0 = true;
/* 035 */     long sortAgg_value_10 = -1L;
/* 036 */
/* 037 */     if (!sortAgg_bufIsNull_0 && (sortAgg_sortAgg_isNull_10_0 ||
/* 038 */         sortAgg_bufValue_0 > sortAgg_value_10)) {
/* 039 */       sortAgg_sortAgg_isNull_10_0 = false;
/* 040 */       sortAgg_value_10 = sortAgg_bufValue_0;
/* 041 */     }
/* 042 */
/* 043 */     if (!sortAgg_exprIsNull_0_0 && (sortAgg_sortAgg_isNull_10_0 ||
/* 044 */         sortAgg_expr_0_0 > sortAgg_value_10)) {
/* 045 */       sortAgg_sortAgg_isNull_10_0 = false;
/* 046 */       sortAgg_value_10 = sortAgg_expr_0_0;
/* 047 */     }
/* 048 */
/* 049 */     sortAgg_bufIsNull_0 = sortAgg_sortAgg_isNull_10_0;
/* 050 */     sortAgg_bufValue_0 = sortAgg_value_10;
/* 051 */   }
/* 052 */
/* 053 */   private void sortAgg_doAggregateWithoutKey_0() throws java.io.IOException {
/* 054 */     // initialize aggregation buffer
/* 055 */     sortAgg_bufIsNull_0 = true;
/* 056 */     sortAgg_bufValue_0 = -1L;
/* 057 */     sortAgg_bufIsNull_1 = false;
/* 058 */     sortAgg_bufValue_1 = 0.0D;
/* 059 */     sortAgg_bufIsNull_2 = false;
/* 060 */     sortAgg_bufValue_2 = 0L;
/* 061 */
/* 062 */     while ( inputadapter_input_0.hasNext()) {
/* 063 */       InternalRow inputadapter_row_0 = (InternalRow) inputadapter_input_0.next();
/* 064 */
/* 065 */       boolean inputadapter_isNull_0 = inputadapter_row_0.isNullAt(0);
/* 066 */       long inputadapter_value_0 = inputadapter_isNull_0 ?
/* 067 */       -1L : (inputadapter_row_0.getLong(0));
/* 068 */       boolean inputadapter_isNull_1 = inputadapter_row_0.isNullAt(1);
/* 069 */       double inputadapter_value_1 = inputadapter_isNull_1 ?
/* 070 */       -1.0 : (inputadapter_row_0.getDouble(1));
/* 071 */       boolean inputadapter_isNull_2 = inputadapter_row_0.isNullAt(2);
/* 072 */       long inputadapter_value_2 = inputadapter_isNull_2 ?
/* 073 */       -1L : (inputadapter_row_0.getLong(2));
/* 074 */
/* 075 */       sortAgg_doConsume_0(inputadapter_row_0, inputadapter_value_0, inputadapter_isNull_0, inputadapter_value_1, inputadapter_isNull_1, inputadapter_value_2, inputadapter_isNull_2);
/* 076 */       // shouldStop check is eliminated
/* 077 */     }
/* 078 */
/* 079 */   }
/* 080 */
/* 081 */   protected void processNext() throws java.io.IOException {
/* 082 */     while (!sortAgg_initAgg_0) {
/* 083 */       sortAgg_initAgg_0 = true;
/* 084 */       sortAgg_doAggregateWithoutKey_0();
/* 085 */
/* 086 */       // output the result
/* 087 */       boolean sortAgg_isNull_6 = sortAgg_bufIsNull_2;
/* 088 */       double sortAgg_value_6 = -1.0;
/* 089 */       if (!sortAgg_bufIsNull_2) {
/* 090 */         sortAgg_value_6 = (double) sortAgg_bufValue_2;
/* 091 */       }
/* 092 */       boolean sortAgg_isNull_4 = false;
/* 093 */       double sortAgg_value_4 = -1.0;
/* 094 */       if (sortAgg_isNull_6 || sortAgg_value_6 == 0) {
/* 095 */         sortAgg_isNull_4 = true;
/* 096 */       } else {
/* 097 */         if (sortAgg_bufIsNull_1) {
/* 098 */           sortAgg_isNull_4 = true;
/* 099 */         } else {
/* 100 */           sortAgg_value_4 = (double)(sortAgg_bufValue_1 / sortAgg_value_6);
/* 101 */         }
/* 102 */       }
/* 103 */
/* 104 */       ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1);
/* 105 */       sortAgg_mutableStateArray_0[0].reset();
/* 106 */
/* 107 */       sortAgg_mutableStateArray_0[0].zeroOutNullBytes();
/* 108 */
/* 109 */       if (sortAgg_bufIsNull_0) {
/* 110 */         sortAgg_mutableStateArray_0[0].setNullAt(0);
/* 111 */       } else {
/* 112 */         sortAgg_mutableStateArray_0[0].write(0, sortAgg_bufValue_0);
/* 113 */       }
/* 114 */
/* 115 */       if (sortAgg_isNull_4) {
/* 116 */         sortAgg_mutableStateArray_0[0].setNullAt(1);
/* 117 */       } else {
/* 118 */         sortAgg_mutableStateArray_0[0].write(1, sortAgg_value_4);
/* 119 */       }
/* 120 */       append((sortAgg_mutableStateArray_0[0].getRow()));
/* 121 */     }
/* 122 */   }
/* 123 */
/* 124 */   private void sortAgg_doConsume_0(InternalRow inputadapter_row_0, long sortAgg_expr_0_0, boolean sortAgg_exprIsNull_0_0, double sortAgg_expr_1_0, boolean sortAgg_exprIsNull_1_0, long sortAgg_expr_2_0, boolean sortAgg_exprIsNull_2_0) throws java.io.IOException {
/* 125 */     // do aggregate
/* 126 */     // common sub-expressions
/* 127 */
/* 128 */     // evaluate aggregate functions and update aggregation buffers
/* 129 */     sortAgg_doAggregate_max_0(sortAgg_expr_0_0, sortAgg_exprIsNull_0_0);
/* 130 */     sortAgg_doAggregate_avg_0(sortAgg_exprIsNull_1_0, sortAgg_expr_1_0, sortAgg_exprIsNull_2_0, sortAgg_expr_2_0);
/* 131 */
/* 132 */   }
/* 133 */
/* 134 */   private void sortAgg_doAggregate_avg_0(boolean sortAgg_exprIsNull_1_0, double sortAgg_expr_1_0, boolean sortAgg_exprIsNull_2_0, long sortAgg_expr_2_0) throws java.io.IOException {
/* 135 */     boolean sortAgg_isNull_13 = true;
/* 136 */     double sortAgg_value_13 = -1.0;
/* 137 */
/* 138 */     if (!sortAgg_bufIsNull_1) {
/* 139 */       if (!sortAgg_exprIsNull_1_0) {
/* 140 */         sortAgg_isNull_13 = false; // resultCode could change nullability.
/* 141 */
/* 142 */         sortAgg_value_13 = sortAgg_bufValue_1 + sortAgg_expr_1_0;
/* 143 */
/* 144 */       }
/* 145 */
/* 146 */     }
/* 147 */     boolean sortAgg_isNull_16 = true;
/* 148 */     long sortAgg_value_16 = -1L;
/* 149 */
/* 150 */     if (!sortAgg_bufIsNull_2) {
/* 151 */       if (!sortAgg_exprIsNull_2_0) {
/* 152 */         sortAgg_isNull_16 = false; // resultCode could change nullability.
/* 153 */
/* 154 */         sortAgg_value_16 = sortAgg_bufValue_2 + sortAgg_expr_2_0;
/* 155 */
/* 156 */       }
/* 157 */
/* 158 */     }
/* 159 */
/* 160 */     sortAgg_bufIsNull_1 = sortAgg_isNull_13;
/* 161 */     sortAgg_bufValue_1 = sortAgg_value_13;
/* 162 */
/* 163 */     sortAgg_bufIsNull_2 = sortAgg_isNull_16;
/* 164 */     sortAgg_bufValue_2 = sortAgg_value_16;
/* 165 */   }
/* 166 */
/* 167 */ }


相关文章
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch并行与分布式(三)DataParallel原理、源码解析、举例实战
PyTorch并行与分布式(三)DataParallel原理、源码解析、举例实战
646 0
|
SQL 分布式计算 Spark
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(10)
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(10)
164 0
|
缓存 分布式计算 Spark
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(2)
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(2)
128 0
|
SQL 分布式计算 数据处理
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(7)
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(7)
108 0
|
分布式计算 Spark
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(1)
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(1)
184 0
|
分布式计算 Java Spark
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(3)
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(3)
218 0
|
SQL 分布式计算 Serverless
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(6)
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(6)
81 0
|
SQL 分布式计算 Spark
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(5)
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(5)
149 0
|
SQL 分布式计算 Spark
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(8)
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(8)
179 0
|
SQL 分布式计算 Spark
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(4)
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(4)
400 0