Spark Codegen浅析

本文涉及的产品
EMR Serverless StarRocks,5000CU*H 48000GB*H
简介: Codegen是Spark Runtime优化性能的关键技术,核心在于动态生成java代码、即时compile和加载,把解释执行转化为编译执行。Spark Codegen分为Expression级别和WholeStage级别,分别针对表达式计算和全Stage计算做代码生成,都取得了数量级的性能提升。本文浅析Spark Codegen技术原理。

作者:周克勇,花名一锤,阿里巴巴计算平台事业部EMR团队技术专家,大数据领域技术爱好者,对Spark有浓厚兴趣和一定的了解,目前主要专注于EMR产品中开源计算引擎的优化工作。


背景介绍

SparkSQL的优越性能背后有两大技术支柱:Optimizer和Runtime。前者致力于寻找最优的执行计划,后者则致力于把既定的执行计划尽可能快地执行出来。Runtime的多种优化可概括为两个层面:
1. 全局优化。从提升全局资源利用率、消除数据倾斜、降低IO等角度做优化,包括自适应执行(Adaptive Execution), Shuffle Removal等。
2. 局部优化。优化具体的Task的执行效率,主要依赖Codegen技术,具体包括Expression级别和WholeStage级别的Codegen。
本文介绍Spark Codegen的技术原理。

Case Study

本节通过两个具体case介绍Codegen的做法。

Expression级别

考虑下面的表达式计算:x + (1 + 2),用scala代码表达如下:

Add(Attribute(x), Add(Literal(1), Literal(2)))

语法树如下:


Expression

递归求值这棵语法树的常规代码如下:

tree.transformUp {
  case Attribute(idx) => Literal(row.getValue(idx))
  case Add(Literal(c1),Literal(c2)) => Literal(c1+c2)
  case Literal(c) => Literal(c)
}

执行上述代码需要做很多类型匹配、虚函数调用、对象创建等额外逻辑,这些overhead远超对表达式求值本身。
为了消除这些overhead,Spark Codegen直接拼成求值表达式的java代码并进行即时编译。具体分为三个步骤:
1. 代码生成。根据语法树生成java代码,封装在wrapper类中:

... // class wrapper
row.getValue(idx) + (1 + 2)
... // class wrapper

2. 即时编译。使用Janino框架把生成代码编译成class文件。
3. 加载执行。最后加载并执行。
优化前后性能有数量级的提升。


Expression

WholeStage级别

考虑如下的sql语句:

select count(*) from store_sales
where ss_item_sk=1000;

生成的物理执行计划如下:


Expression

执行该计划的常规做法是使用火山模型(vocano model),每个Operator都继承了Iterator接口,其next()方法首先驱动上游执行拿到输入,然后执行自己的逻辑。代码示例如下:

class Agg extends Iterator[Row] {
  def doAgg() {
    while (child.hasNext()) {
      val row = child.next();
      // do aggregation
      ...
    }
  }
  def next(): Row {
    if (!doneAgg) {
      doAgg();
    }
    return aggIter.next();
  }
}


class Filter extends Iterator[Row] {
  def next(): Row {
    var current = child.next()
    while (current != null && !predicate(current)) {
      current = child.next()
    }
    return current;
  }
}

从上述代码可知,火山模型会有大量类型转换和虚函数调用。虚函数调用会导致CPU分支预测失败,从而导致严重的性能回退。
为了消除这些overhead,Spark WholestageCodegen会为该物理计划生成类型确定的java代码,然后类似Expression的做法即时编译和加载执行。本例生成的java代码示例如下(非真实代码,真实代码片段见后文):

var count = 0
for (ss_item_sk in store_sales) {
  if (ss_item_sk == 1000) {
    count += 1
  }
}

优化前后性能提升数据如下:


Expression

Spark Codegen框架

Spark Codegen框架有三个核心组成部分
1. 核心接口/类
2. CodegenContext
3. Produce-Consume Pattern

接下来详细介绍。

接口/类

四个核心接口:
1. CodegenSupport(接口)
实现该接口的Operator可以将自己的逻辑拼成java代码。重要方法:

produce() // 输出本节点产出Row的java代码
consume() // 输出本节点消费上游节点输入的Row的java代码

实现类包括但不限于: ProjectExec, FilterExec, HashAggregateExec, SortMergeJoinExec。
2. WholeStageCodegenExec(类)
CodegenSupport的实现类之一,Stage内部所有相邻的实现CodegenSupport接口的Operator的融合,产出的代码把所有被融合的Operator的执行逻辑封装到一个Wrapper类中,该Wrapper类作为Janino即时compile的入参。
3. InputAdapter(类)
CodegenSupport的实现类之一,胶水类,用来连接WholeStageCodegenExec节点和未实现CodegenSupport的上游节点。
4. BufferedRowIterator(接口)
WholeStageCodegenExec生成的java代码的父类,重要方法:

public InternalRow next() // 返回下一条Row
public void append(InternalRow row) // append一条Row

CodegenContext

管理生成代码的核心类。主要涵盖以下功能:

1.命名管理。保证同一Scope内无变量名冲突。
2.变量管理。维护类变量,判断变量类型(应该声明为独立变量还是压缩到类型数组中),维护变量初始化逻辑等。
3.方法管理。维护类方法。
4.内部类管理。维护内部类。
5.相同表达式管理。维护相同子表达式,避免重复计算。
6.size管理。避免方法、类size过大,避免类变量数过多,进行比较拆分。如把表达式块拆分成多个函数;把函数、变量定义拆分到多个内部类。
7.依赖管理。维护该类依赖的外部对象,如Broadcast对象、工具对象、度量对象等。
8.通用模板管理。提供通用代码模板,如genComp, nullSafeExec等。

Produce-Consume Pattern

相邻Operator通过Produce-Consume模式生成代码。
Produce生成整体处理的框架代码,例如aggregation生成的代码框架如下:

if (!initialized) {
  # create a hash map, then build the aggregation hash map
  # call child.produce()
  initialized = true;
}
while (hashmap.hasNext()) {
  row = hashmap.next();
  # build the aggregation results
  # create variables for results
  # call consume(), which will call parent.doConsume()
   if (shouldStop()) return;
}

Consume生成当前节点处理上游输入的Row的逻辑。如Filter生成代码如下:

# code to evaluate the predicate expression, result is isNull1 and value2
if (!isNull1 && value2) {
  # call consume(), which will call parent.doConsume()
}

下图比较清晰地展示了WholestageCodegen生成java代码的call graph:


Expression

Case Study的示例,生成的真实代码如下:

== Subtree 1 / 2 ==
*(2) HashAggregate(keys=[], functions=[count(1)], output=[count(1)#326L])
+- Exchange SinglePartition
   +- *(1) HashAggregate(keys=[], functions=[partial_count(1)], output=[count#329L])
      +- *(1) Project
         +- *(1) Filter (isnotnull(ss_item_sk#13L) && (ss_item_sk#13L = 1000))
            +- *(1) FileScan parquet [ss_item_sk#13L] Batched: true, Format: Parquet, Location: InMemoryFileIndex[file:/home/admin/zhoukeyong/workspace/tpc/tpcds/data/parquet/10/store_sales/par..., PartitionFilters: [], PushedFilters: [IsNotNull(ss_item_sk), EqualTo(ss_item_sk,1000)], ReadSchema: struct<ss_item_sk:bigint>

Generated code:
/* 001 */ public Object generate(Object[] references) {
/* 002 */   return new GeneratedIteratorForCodegenStage2(references);
/* 003 */ }
/* 004 */
/* 005 */ // codegenStageId=2
/* 006 */ final class GeneratedIteratorForCodegenStage2 extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 007 */   private Object[] references;
/* 008 */   private scala.collection.Iterator[] inputs;
/* 009 */   private boolean agg_initAgg_0;
/* 010 */   private boolean agg_bufIsNull_0;
/* 011 */   private long agg_bufValue_0;
/* 012 */   private scala.collection.Iterator inputadapter_input_0;
/* 013 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] agg_mutableStateArray_0 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[1];
/* 014 */
/* 015 */   public GeneratedIteratorForCodegenStage2(Object[] references) {
/* 016 */     this.references = references;
/* 017 */   }
/* 018 */
/* 019 */   public void init(int index, scala.collection.Iterator[] inputs) {
/* 020 */     partitionIndex = index;
/* 021 */     this.inputs = inputs;
/* 022 */
/* 023 */     inputadapter_input_0 = inputs[0];
/* 024 */     agg_mutableStateArray_0[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
/* 025 */
/* 026 */   }
/* 027 */
/* 028 */   private void agg_doAggregateWithoutKey_0() throws java.io.IOException {
/* 029 */     // initialize aggregation buffer
/* 030 */     agg_bufIsNull_0 = false;
/* 031 */     agg_bufValue_0 = 0L;
/* 032 */
/* 033 */     while (inputadapter_input_0.hasNext() && !stopEarly()) {
/* 034 */       InternalRow inputadapter_row_0 = (InternalRow) inputadapter_input_0.next();
/* 035 */       long inputadapter_value_0 = inputadapter_row_0.getLong(0);
/* 036 */
/* 037 */       agg_doConsume_0(inputadapter_row_0, inputadapter_value_0);
/* 038 */       if (shouldStop()) return;
/* 039 */     }
/* 040 */
/* 041 */   }
/* 042 */
/* 043 */   private void agg_doConsume_0(InternalRow inputadapter_row_0, long agg_expr_0_0) throws java.io.IOException {
/* 044 */     // do aggregate
/* 045 */     // common sub-expressions
/* 046 */
/* 047 */     // evaluate aggregate function
/* 048 */     long agg_value_3 = -1L;
/* 049 */     agg_value_3 = agg_bufValue_0 + agg_expr_0_0;
/* 050 */     // update aggregation buffer
/* 051 */     agg_bufIsNull_0 = false;
/* 052 */     agg_bufValue_0 = agg_value_3;
/* 053 */
/* 054 */   }
/* 055 */
/* 056 */   protected void processNext() throws java.io.IOException {
/* 057 */     while (!agg_initAgg_0) {
/* 058 */       agg_initAgg_0 = true;
/* 059 */       long agg_beforeAgg_0 = System.nanoTime();
/* 060 */       agg_doAggregateWithoutKey_0();
/* 061 */       ((org.apache.spark.sql.execution.metric.SQLMetric) references[1] /* aggTime */).add((System.nanoTime() - agg_beforeAgg_0) / 1000000);
/* 062 */
/* 063 */       // output the result
/* 064 */
/* 065 */       ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1);
/* 066 */       agg_mutableStateArray_0[0].reset();
/* 067 */
/* 068 */       agg_mutableStateArray_0[0].zeroOutNullBytes();
/* 069 */
/* 070 */       agg_mutableStateArray_0[0].write(0, agg_bufValue_0);
/* 071 */       append((agg_mutableStateArray_0[0].getRow()));
/* 072 */     }
/* 073 */   }
/* 074 */
/* 075 */ }

== Subtree 2 / 2 ==
*(1) HashAggregate(keys=[], functions=[partial_count(1)], output=[count#329L])
+- *(1) Project
   +- *(1) Filter (isnotnull(ss_item_sk#13L) && (ss_item_sk#13L = 1000))
      +- *(1) FileScan parquet [ss_item_sk#13L] Batched: true, Format: Parquet, Location: InMemoryFileIndex[file:/home/admin/zhoukeyong/workspace/tpc/tpcds/data/parquet/10/store_sales/par..., PartitionFilters: [], PushedFilters: [IsNotNull(ss_item_sk), EqualTo(ss_item_sk,1000)], ReadSchema: struct<ss_item_sk:bigint>

Generated code:
/* 001 */ public Object generate(Object[] references) {
/* 002 */   return new GeneratedIteratorForCodegenStage1(references);
/* 003 */ }
/* 004 */
/* 005 */ // codegenStageId=1
/* 006 */ final class GeneratedIteratorForCodegenStage1 extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 007 */   private Object[] references;
/* 008 */   private scala.collection.Iterator[] inputs;
/* 009 */   private boolean agg_initAgg_0;
/* 010 */   private boolean agg_bufIsNull_0;
/* 011 */   private long agg_bufValue_0;
/* 012 */   private long scan_scanTime_0;
/* 013 */   private boolean outputMetaColumns;
/* 014 */   private int scan_batchIdx_0;
/* 015 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] scan_mutableStateArray_3 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[3];
/* 016 */   private org.apache.spark.sql.vectorized.ColumnarBatch[] scan_mutableStateArray_1 = new org.apache.spark.sql.vectorized.ColumnarBatch[1];
/* 017 */   private scala.collection.Iterator[] scan_mutableStateArray_0 = new scala.collection.Iterator[1];
/* 018 */   private org.apache.spark.sql.execution.vectorized.OffHeapColumnVector[] scan_mutableStateArray_2 = new org.apache.spark.sql.execution.vectorized.OffHeapColumnVector[1];
/* 019 */
/* 020 */   public GeneratedIteratorForCodegenStage1(Object[] references) {
/* 021 */     this.references = references;
/* 022 */   }
/* 023 */
/* 024 */   public void init(int index, scala.collection.Iterator[] inputs) {
/* 025 */     partitionIndex = index;
/* 026 */     this.inputs = inputs;
/* 027 */
/* 028 */     scan_mutableStateArray_0[0] = inputs[0];
/* 029 */     outputMetaColumns = false;
/* 030 */     scan_mutableStateArray_3[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
/* 031 */     scan_mutableStateArray_3[1] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
/* 032 */     scan_mutableStateArray_3[2] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
/* 033 */
/* 034 */   }
/* 035 */
/* 036 */   private void agg_doAggregateWithoutKey_0() throws java.io.IOException {
/* 037 */     // initialize aggregation buffer
/* 038 */     agg_bufIsNull_0 = false;
/* 039 */     agg_bufValue_0 = 0L;
/* 040 */
/* 041 */     if (scan_mutableStateArray_1[0] == null) {
/* 042 */       scan_nextBatch_0();
/* 043 */     }
/* 044 */     while (scan_mutableStateArray_1[0] != null) {
/* 045 */       int scan_numRows_0 = scan_mutableStateArray_1[0].numRows();
/* 046 */       int scan_localEnd_0 = scan_numRows_0 - scan_batchIdx_0;
/* 047 */       for (int scan_localIdx_0 = 0; scan_localIdx_0 < scan_localEnd_0; scan_localIdx_0++) {
/* 048 */         int scan_rowIdx_0 = scan_batchIdx_0 + scan_localIdx_0;
/* 049 */         if (!scan_mutableStateArray_1[0].validAt(scan_rowIdx_0)) { continue; }
/* 050 */         do {
/* 051 */           boolean scan_isNull_0 = scan_mutableStateArray_2[0].isNullAt(scan_rowIdx_0);
/* 052 */           long scan_value_0 = scan_isNull_0 ? -1L : (scan_mutableStateArray_2[0].getLong(scan_rowIdx_0));
/* 053 */
/* 054 */           if (!(!scan_isNull_0)) continue;
/* 055 */
/* 056 */           boolean filter_value_2 = false;
/* 057 */           filter_value_2 = scan_value_0 == 1000L;
/* 058 */           if (!filter_value_2) continue;
/* 059 */
/* 060 */           ((org.apache.spark.sql.execution.metric.SQLMetric) references[2] /* numOutputRows */).add(1);
/* 061 */
/* 062 */           agg_doConsume_0();
/* 063 */
/* 064 */         } while(false);
/* 065 */         // shouldStop check is eliminated
/* 066 */       }
/* 067 */       scan_batchIdx_0 = scan_numRows_0;
/* 068 */       scan_mutableStateArray_1[0] = null;
/* 069 */       scan_nextBatch_0();
/* 070 */     }
/* 071 */     ((org.apache.spark.sql.execution.metric.SQLMetric) references[1] /* scanTime */).add(scan_scanTime_0 / (1000 * 1000));
/* 072 */     scan_scanTime_0 = 0;
/* 073 */
/* 074 */   }
/* 075 */
/* 076 */   private void scan_nextBatch_0() throws java.io.IOException {
/* 077 */     long getBatchStart = System.nanoTime();
/* 078 */     if (scan_mutableStateArray_0[0].hasNext()) {
/* 079 */       scan_mutableStateArray_1[0] = (org.apache.spark.sql.vectorized.ColumnarBatch)scan_mutableStateArray_0[0].next();
/* 080 */       ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(scan_mutableStateArray_1[0].numRows());
/* 081 */       scan_batchIdx_0 = 0;
/* 082 */       scan_mutableStateArray_2[0] = (org.apache.spark.sql.execution.vectorized.OffHeapColumnVector) (outputMetaColumns ?
/* 083 */         scan_mutableStateArray_1[0].column(0, true) : scan_mutableStateArray_1[0].column(0));
/* 084 */
/* 085 */     }
/* 086 */     scan_scanTime_0 += System.nanoTime() - getBatchStart;
/* 087 */   }
/* 088 */
/* 089 */   private void agg_doConsume_0() throws java.io.IOException {
/* 090 */     // do aggregate
/* 091 */     // common sub-expressions
/* 092 */
/* 093 */     // evaluate aggregate function
/* 094 */     long agg_value_1 = -1L;
/* 095 */     agg_value_1 = agg_bufValue_0 + 1L;
/* 096 */     // update aggregation buffer
/* 097 */     agg_bufIsNull_0 = false;
/* 098 */     agg_bufValue_0 = agg_value_1;
/* 099 */
/* 100 */   }
/* 101 */
/* 102 */   protected void processNext() throws java.io.IOException {
/* 103 */     while (!agg_initAgg_0) {
/* 104 */       agg_initAgg_0 = true;
/* 105 */       long agg_beforeAgg_0 = System.nanoTime();
/* 106 */       agg_doAggregateWithoutKey_0();
/* 107 */       ((org.apache.spark.sql.execution.metric.SQLMetric) references[4] /* aggTime */).add((System.nanoTime() - agg_beforeAgg_0) / 1000000);
/* 108 */
/* 109 */       // output the result
/* 110 */
/* 111 */       ((org.apache.spark.sql.execution.metric.SQLMetric) references[3] /* numOutputRows */).add(1);
/* 112 */       scan_mutableStateArray_3[2].reset();
/* 113 */
/* 114 */       scan_mutableStateArray_3[2].zeroOutNullBytes();
/* 115 */
/* 116 */       scan_mutableStateArray_3[2].write(0, agg_bufValue_0);
/* 117 */       append((scan_mutableStateArray_3[2].getRow()));
/* 118 */     }
/* 119 */   }
/* 120 */
/* 121 */ }

阿里巴巴开源大数据技术团队成立Apache Spark中国技术社区,定期推送精彩案例,技术专家直播,问答区数个Spark技术同学每日在线答疑,只为营造纯粹的Spark氛围,欢迎钉钉扫码加入!
二维码.JPG

相关实践学习
基于EMR Serverless StarRocks一键玩转世界杯
基于StarRocks构建极速统一OLAP平台
快速掌握阿里云 E-MapReduce
E-MapReduce 是构建于阿里云 ECS 弹性虚拟机之上,利用开源大数据生态系统,包括 Hadoop、Spark、HBase,为用户提供集群、作业、数据等管理的一站式大数据处理分析服务。 本课程主要介绍阿里云 E-MapReduce 的使用方法。
相关文章
|
8月前
|
设计模式 SQL 分布式计算
Spark Day06:Spark Core之Spark 内核调度和SparkSQL快速入门
Spark Day06:Spark Core之Spark 内核调度和SparkSQL快速入门
81 0
|
5月前
|
分布式计算 资源调度 Shell
如何开始使用Spark框架?
【8月更文挑战第31天】如何开始使用Spark框架?
100 2
|
5月前
|
SQL 机器学习/深度学习 分布式计算
Spark框架
【8月更文挑战第31天】Spark框架
51 2
|
8月前
|
分布式计算 关系型数据库 MySQL
Spark编程实验四:Spark Streaming编程
Spark编程实验四:Spark Streaming编程
300 2
|
机器学习/深度学习 算法 数据挖掘
PySpark ML (评估器)
PySpark 估计器使用方法介绍
PySpark ML (评估器)
|
消息中间件 分布式计算 Java
spark streaming知识总结[优化]
spark streaming知识总结[优化]
231 0
|
分布式计算 Spark
SPARK中的wholeStageCodegen全代码生成--GenerateUnsafeProjection.createCode说明
SPARK中的wholeStageCodegen全代码生成--GenerateUnsafeProjection.createCode说明
145 0
|
消息中间件 分布式计算 网络协议
Spark-Streaming的学习使用
SparkStreaming Spark Streaming可以轻松构建可扩展的容错流应用程序。 便于使用 把之前学过的处理离线数据的算子直接拿来使用 容错 sparkStreaming可以保证数据被消费一次 保证不丢失,不重复消费数据。 集成
134 0
Spark-Streaming的学习使用
|
Java 分布式计算 Spark
Spark Codegen浅析
Codegen是Spark Runtime优化性能的关键技术,核心在于动态生成java代码、即时compile和加载,把解释执行转化为编译执行。Spark Codegen分为Expression级别和WholeStage级别,分别针对表达式计算和全Stage计算做代码生成,都取得了数量级的性能提升。本文浅析Spark Codegen技术原理。
4434 0
|
存储 分布式计算 Apache
Spark Streaming 框架在 5G 中的应用
在发展 5G 和 IoT 场景的准备阶段,爱立信研究了各种可扩展和灵活的流处理框架,以解决数据流水线问题以及提升整体性能。我们通过机器学习流数据进行自适应学习和智能决策从而实现各个领域的自动化。其中使用机器学习算法从流数据中逐步学习模型和获取信息是一个巨大的挑战。