SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(5)

简介: SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(5)

背景


本文基于 SPARK 3.3.0

从一个unit test来探究SPARK Codegen的逻辑,

  test("SortAggregate should be included in WholeStageCodegen") {
    val df = spark.range(10).agg(max(col("id")), avg(col("id")))
    withSQLConf("spark.sql.test.forceApplySortAggregate" -> "true") {
      val plan = df.queryExecution.executedPlan
      assert(plan.exists(p =>
        p.isInstanceOf[WholeStageCodegenExec] &&
          p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SortAggregateExec]))
      assert(df.collect() === Array(Row(9, 4.5)))
    }
  }
该sql形成的执行计划第一部分的全代码生成部分如下:

WholeStageCodegen

± *(1) SortAggregate(key=[], functions=[partial_max(id#0L), partial_avg(id#0L)], output=[max#12L, sum#13, count#14L])

± *(1) Range (0, 10, step=1, splits=2)

分析


第一阶段wholeStageCodegen

第一阶段的代码生成涉及到SortAggregateExec和RangeExec的produce和consume方法,这里一一来分析:

第一阶段wholeStageCodegen数据流如下:

 WholeStageCodegenExec      SortAggregateExec(partial)     RangeExec        
  =========================================================================
  -> execute()
      |
   doExecute() --------->   inputRDDs() -----------------> inputRDDs() 
      |
   doCodeGen()
      |
      +----------------->   produce()
                              |
                           doProduce() 
                              |
                           doProduceWithoutKeys() -------> produce()
                                                              |
                                                          doProduce()
                                                              |
                           doConsume()<------------------- consume()
                              |
                           doConsumeWithoutKeys()
                              |并不是doConsumeWithoutKeys调用consume,而是由doProduceWithoutKeys调用
   doConsume()  <--------  consume()

SortAggregateExec(Partial)的doConsume方法

override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
    if (groupingExpressions.isEmpty) {
      doConsumeWithoutKeys(ctx, input)
    } else {
      doConsumeWithKeys(ctx, input)
    }
  }

注意这里虽然把ExprCode类型变量row传递进来了,但是在这个方法中却没有用到,因为对于大部分情况来说,该变量是对外部传递InteralRow的作用。

而input则是sortAgg_expr_0_0,由rang_value_0赋值而来.

doConsumeWithoutKeys对应的方法如下:


  private def doConsumeWithoutKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = {
    // only have DeclarativeAggregate
    val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
    val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ inputAttributes
    // To individually generate code for each aggregate function, an element in `updateExprs` holds
    // all the expressions for the buffer of an aggregation function.
    val updateExprs = aggregateExpressions.map { e =>
      e.mode match {
        case Partial | Complete =>
          e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions
        case PartialMerge | Final =>
          e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions
      }
    }
    ctx.currentVars = bufVars.flatten ++ input
    println(s"updateExprs: $updateExprs")
    val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc =>
      bindReferences(updateExprsForOneFunc, inputAttrs)
    }
    val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten)
    val effectiveCodes = ctx.evaluateSubExprEliminationState(subExprs.states.values)
    val bufferEvals = boundUpdateExprs.map { boundUpdateExprsForOneFunc =>
      ctx.withSubExprEliminationExprs(subExprs.states) {
        boundUpdateExprsForOneFunc.map(_.genCode(ctx))
      }
    }
    val aggNames = functions.map(_.prettyName)
    val aggCodeBlocks = bufferEvals.zipWithIndex.map { case (bufferEvalsForOneFunc, i) =>
      val bufVarsForOneFunc = bufVars(i)
      // All the update code for aggregation buffers should be placed in the end
      // of each aggregation function code.
      println(s"bufVarsForOneFunc: $bufVarsForOneFunc")
      val updates = bufferEvalsForOneFunc.zip(bufVarsForOneFunc).map { case (ev, bufVar) =>
        s"""
           |${bufVar.isNull} = ${ev.isNull};
           |${bufVar.value} = ${ev.value};
         """.stripMargin
      }
      code"""
            |${ctx.registerComment(s"do aggregate for ${aggNames(i)}")}
            |${ctx.registerComment("evaluate aggregate function")}
            |${evaluateVariables(bufferEvalsForOneFunc)}
            |${ctx.registerComment("update aggregation buffers")}
            |${updates.mkString("\n").trim}
       """.stripMargin
    }
    val codeToEvalAggFuncs = generateEvalCodeForAggFuncs(
      ctx, input, inputAttrs, boundUpdateExprs, aggNames, aggCodeBlocks, subExprs)
    s"""
       |// do aggregate
       |// common sub-expressions
       |$effectiveCodes
       |// evaluate aggregate functions and update aggregation buffers
       |$codeToEvalAggFuncs
     """.stripMargin
  }

val functions =和val inputAttrs =

val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ inputAttributes,对于AVG聚合函数来说,聚合的缓冲属性(aggBufferAttributes)为AttributeReference("sum", sumDataType)()和AttributeReference("count", LongType)().

对于当前的计划来说,SortAggregateExec的inputAttributes 为AttributeReference("id", LongType, nullable = false)()


val updateExprs = aggregateExpressions.

对于目前的物理计划来说,当前的mode为Partial,所以该值为updateExpressions,也就是局部更新,即

    Add(
    sum,
    coalesce(child.cast(sumDataType), Literal.default(sumDataType)),
    failOnError = useAnsiAdd),
  /* count = */ If(child.isNull, count, count + 1L)

ctx.currentVars = bufVars.flatten ++ input

这里的bufVars是在SortAggregateExec的produce方法进行赋值的,也就是对应“SUM”和“COUNT”初始值的ExprCode

这里的input 是名为sortAgg_expr_0_0的ExprCode变量


val boundUpdateExprs =

把当前的输入变量绑定到updataExprs中去(很明显inputAttrs和currentVars是一一对应的)


val subExprs = 和val effectiveCodes =

进行公共子表达式的消除,并提前计算出在计算子表达式计算之前的自表达式。

对于当前的计划来说,该``effectiveCodes`为空字符串.


val bufferEvals =

产生进行update的ExprCode,这里具体为(这里分别为Add和IF表达式的codegen:


List(ExprCode(boolean sortAgg_isNull_7 = true;
     double sortAgg_value_7 = -1.0;
     if (!sortAgg_bufIsNull_1) {
       sortAgg_sortAgg_isNull_9_0 = true;
    double sortAgg_value_9 = -1.0;
    do {
    boolean sortAgg_isNull_10 = false;
    double sortAgg_value_10 = -1.0;
    if (!false) {
     sortAgg_value_10 = (double) sortAgg_expr_0_0;
    }
   if (!sortAgg_isNull_10) {
     sortAgg_sortAgg_isNull_9_0 = false;
     sortAgg_value_9 = sortAgg_value_10;
     continue;
   }
   if (!false) {
     sortAgg_sortAgg_isNull_9_0 = false;
     sortAgg_value_9 = 0.0D;
     continue;
   }
   } while (false);
   sortAgg_isNull_7 = false; // resultCode could change nullability.
   sortAgg_value_7 = sortAgg_bufValue_1 + sortAgg_value_9;
           },sortAgg_isNull_7,sortAgg_value_7), 
   ExprCode(boolean sortAgg_isNull_13 = false;
   long sortAgg_value_13 = -1L;
   if (!false && false) {
     sortAgg_isNull_13 = sortAgg_bufIsNull_2;
     sortAgg_value_13 = sortAgg_bufValue_2;
   } else {
     boolean sortAgg_isNull_17 = true;
     long sortAgg_value_17 = -1L;
     if (!sortAgg_bufIsNull_2) {
   sortAgg_isNull_17 = false; // resultCode could change nullability.
   sortAgg_value_17 = sortAgg_bufValue_2 + 1L;
           }
     sortAgg_isNull_13 = sortAgg_isNull_17;
     sortAgg_value_13 = sortAgg_value_17;
   },sortAgg_isNull_13,sortAgg_value_13))

val aggNames = functions.map(_.prettyName)

这里定义聚合函数的方法名字,最终会行成如下:sortAgg_doAggregate_avg_0类似这种名字的方法。


val aggCodeBlocks =

这个是对应各个聚合函数的代码块,并在进行了聚合以后,把聚合的结果赋值给全局变量,对应的sql为:

  sortAgg_bufIsNull_1 = sortAgg_isNull_7;
  sortAgg_bufValue_1 = sortAgg_value_7;
  sortAgg_bufIsNull_2 = sortAgg_isNull_13;
  sortAgg_bufValue_2 = sortAgg_value_13;
  • 其中sortAgg_bufValue_1代表了SUMsortAgg_bufValue_2代表COUNT
  • val codeToEvalAggFuncs = generateEvalCodeForAggFuncs
    生成各个聚合函数的代码,如下:
     sortAgg_doAggregate_max_0(sortAgg_expr_0_0);
     sortAgg_doAggregate_avg_0(sortAgg_expr_0_0);
  • $effectiveCodes
    组装代码
相关文章
|
5月前
|
传感器 数据可视化 知识图谱
计算轴向磁铁和环状磁铁的磁场(Matlab代码实现)
计算轴向磁铁和环状磁铁的磁场(Matlab代码实现)
196 2
|
7月前
|
安全 数据可视化 网络协议
千万别错过!这个国产开源项目彻底改变了你的域名资产管理方式,收藏它相当于多一个安全专家!
Domain Admin 是一款免费开源、专为个人与企业设计的高效域名生命周期管理工具。支持多域名集中管理、自动同步信息、过期提醒与续期预警,提供数据可视化面板及 Webhook 通知功能。采用现代化技术栈(Python+Flask、Vue3.js),界面清爽易用,特别适合中文用户。相比 CentralOps、NetBox 等工具,Domain Admin 功能更全面,安全性更高,是管理域名资产的理想选择。项目地址:https://github.com/dromara/domain-admin
686 3
|
存储 监控 负载均衡
MongoDB的水平扩展能力
MongoDB的水平扩展能力
256 3
|
存储 Java 数据库连接
Mybatis-plus@DS实现动态切换数据源应用
Mybatis-plus@DS实现动态切换数据源应用
1731 0
|
编解码 移动开发 前端开发
web canvas系列——快速入门上手绘制二维空间点、线、面
web canvas系列——快速入门上手绘制二维空间点、线、面
570 4
|
Java Maven C++
【Azure Developer】记录一次使用Java Azure Key Vault Secret示例代码生成的Jar包,单独运行出现 no main manifest attribute, in target/demo-1.0-SNAPSHOT.jar 错误消息
【Azure Developer】记录一次使用Java Azure Key Vault Secret示例代码生成的Jar包,单独运行出现 no main manifest attribute, in target/demo-1.0-SNAPSHOT.jar 错误消息
308 0
|
NoSQL 网络协议 安全
Lettuce的特性和内部实现问题之Lettuce天然地使用管道模式与Redis交互的问题如何解决
Lettuce的特性和内部实现问题之Lettuce天然地使用管道模式与Redis交互的问题如何解决
206 0
|
SQL 分布式计算 Spark
如何在Spark中实现Count Distinct重聚合
Count Distinct是SQL查询中经常使用的聚合统计方式,用于计算非重复结果的数目。由于需要去除重复结果,Count Distinct的计算通常非常耗时。本文主要介绍在Spark中如何基于重聚合实现交互式响应的COUNT DISTINCT支持。
|
SQL JSON 分布式计算
|
Java 网络安全 API
Java一分钟之-JavaMail:发送电子邮件
本文介绍了使用JavaMail API发送电子邮件的步骤,包括环境准备、依赖引入、基本配置和代码示例。通过添加Maven或Gradle依赖,设置SMTP服务器信息并实现Authenticator,可以创建和发送邮件。同时,文章列举了SMTP认证失败、连接超时等常见问题及其解决方案,并提出了安全与最佳实践建议,如启用SSL/TLS、避免硬编码密码和妥善处理异常。
3253 0