Spark2.4.0源码分析之WorldCount Stage提交(DAGScheduler)(六)
- 理解ShuffuleMapStage是如何转化为ShuffleMapTask并作为TaskSet提交
- 理解ResultStage是如何转化为ResultTask并作为TaskSet提交
package com.opensource.bigdata.spark.standalone.base
import org.apache.spark.sql.SparkSession
* 得到SparkSession
* 首先 extends BaseSparkSession
* 本地: val spark = sparkSession(true)
* 集群: val spark = sparkSession()
class BaseSparkSession {
var appName = "sparkSession"
var master = "spark://" //本地模式:local standalone:spark://master:7077
def sparkSession(): SparkSession = {
val spark = SparkSession.builder
//import spark.implicits._
* @param isLocal
* @param isHiveSupport
* @param remoteDebug
* @param maxPartitionBytes -1 不设置,否则设置分片大小
* @return
def sparkSession(isLocal:Boolean = false, isHiveSupport:Boolean = false, remoteDebug:Boolean=false,maxPartitionBytes:Int = -1): SparkSession = {
val warehouseLocation = new File("spark-warehouse").getAbsolutePath
master = "local[1]"
var builder = SparkSession.builder
builder = builder.enableHiveSupport()
if(maxPartitionBytes != -1){
builder.config("spark.sql.files.maxPartitionBytes",maxPartitionBytes) //32
val spark = builder.getOrCreate()
//import spark.implicits._
var builder = SparkSession.builder
if(maxPartitionBytes != -1){
builder.config("spark.sql.files.maxPartitionBytes",maxPartitionBytes) //32
// .config("spark.sql.shuffle.partitions",2)
//executor debug,是在提交作的地方读取
builder.config("spark.executor.extraJavaOptions","-Xdebug -Xrunjdwp:transport=dt_socket,server=y,suspend=y,address=10002")
builder = builder.enableHiveSupport()
val spark = builder.getOrCreate()
* 得到当前工程的路径
* @return
def getProjectPath:String=System.getProperty("user.dir")
package com.opensource.bigdata.spark.standalone.wordcount.spark.session
import com.opensource.bigdata.spark.standalone.base.BaseSparkSession
object WorldCount extends BaseSparkSession{
def main(args: Array[String]): Unit = {
appName = "WorldCount"
val spark = sparkSession(false,false,false,7)
import spark.implicits._
val distFile ="data/text/worldCount.txt")
val dataset = distFile.flatMap( line => line.split(" ")).groupByKey(x => x ).count()
源码分析 ShuffleMapStage提交
- 第一次提交的是ShuffleMapStage
- stage.findMissingPartitions()得到当前Stage有多少个partitions,当前Stage为ShuffleMapStage,对HDFS上的文件进行逻辑分区,我这里设置的是spark.sql.files.maxPartitionBytes的值为7 byte,所以计算文件分区大小为7 byte,总文件大小为14个byte,所以
PartitionedFile(0)=hdfs://, range: 0-7 PartitionedFile(1)=hdfs://, range: 7-14
- partitionsToCompute 得到的值为 Vector(0,1)
- 把ShuffleMapStage加到runningStages中,表示此Stage已开始处理,在提交Stage时验证使用
- 对partitions中的每个partition进行优选位置计算,就是任务在哪台机器上运行性能高,效率高
val taskIdToLocations: Map[Int, Seq[TaskLocation]] = try {
stage match {
case s: ShuffleMapStage => { id => (id, getPreferredLocs(stage.rdd, id))}.toMap
case s: ResultStage => { id =>
val p = s.partitions(id)
(id, getPreferredLocs(stage.rdd, p))
} catch {
case NonFatal(e) =>
stage.makeNewStageAttempt(partitionsToCompute.size), properties))
abortStage(stage, s"Task creation failed: $e\n${Utils.exceptionString(e)}", Some(e))
runningStages -= stage
- 把ShuffleMapStage中的RDD,ShuffleDependency进行序列化,进行广播变量,这样,就可以在计算时拿到该变量
var taskBinary: Broadcast[Array[Byte]] = null
var partitions: Array[Partition] = null
try {
// For ShuffleMapTask, serialize and broadcast (rdd, shuffleDep).
// For ResultTask, serialize and broadcast (rdd, func).
var taskBinaryBytes: Array[Byte] = null
// taskBinaryBytes and partitions are both effected by the checkpoint status. We need
// this synchronization in case another concurrent job is checkpointing this RDD, so we get a
// consistent view of both variables.
RDDCheckpointData.synchronized {
taskBinaryBytes = stage match {
case stage: ShuffleMapStage =>
closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef))
case stage: ResultStage =>
JavaUtils.bufferToArray(closureSerializer.serialize((stage.rdd, stage.func): AnyRef))
partitions = stage.rdd.partitions
taskBinary = sc.broadcast(taskBinaryBytes)
- 把ShuffuleMapStage拆分成ShuffleMapTask,根据ShuffleMapStage.rdd.partitions有多少个,就分成多少个ShuffleMapTask,此处,我们是拆分成2个ShuffleMapTask
val tasks: Seq[Task[_]] = try {
val serializedTaskMetrics = closureSerializer.serialize(stage.latestInfo.taskMetrics).array()
stage match {
case stage: ShuffleMapStage =>
stage.pendingPartitions.clear() { id =>
val locs = taskIdToLocations(id)
val part = partitions(id)
stage.pendingPartitions += id
new ShuffleMapTask(, stage.latestInfo.attemptNumber,
taskBinary, part, locs, properties, serializedTaskMetrics, Option(jobId),
Option(sc.applicationId), sc.applicationAttemptId, stage.rdd.isBarrier())
case stage: ResultStage => { id =>
val p: Int = stage.partitions(id)
val part = partitions(p)
val locs = taskIdToLocations(id)
new ResultTask(, stage.latestInfo.attemptNumber,
taskBinary, part, locs, id, properties, serializedTaskMetrics,
Option(jobId), Option(sc.applicationId), sc.applicationAttemptId,
- 把所有拆分的ShuffleMapTask转化成TaskSet做为参数,进行任务提交
if (tasks.size > 0) {
logInfo(s"Submitting ${tasks.size} missing tasks from $stage (${stage.rdd}) (first 15 " +
s"tasks are for partitions ${tasks.take(15).map(_.partitionId)})")
taskScheduler.submitTasks(new TaskSet(
tasks.toArray,, stage.latestInfo.attemptNumber, jobId, properties))
/** Called when stage's parents are available and we can now do its task. */
private def submitMissingTasks(stage: Stage, jobId: Int) {
logDebug("submitMissingTasks(" + stage + ")")
// First figure out the indexes of partition ids to compute.
val partitionsToCompute: Seq[Int] = stage.findMissingPartitions()
// Use the scheduling pool, job group, description, etc. from an ActiveJob associated
// with this Stage
val properties = jobIdToActiveJob(jobId).properties
runningStages += stage
// SparkListenerStageSubmitted should be posted before testing whether tasks are
// serializable. If tasks are not serializable, a SparkListenerStageCompleted event
// will be posted, which should always come after a corresponding SparkListenerStageSubmitted
// event.
stage match {
case s: ShuffleMapStage =>
outputCommitCoordinator.stageStart(stage =, maxPartitionId = s.numPartitions - 1)
case s: ResultStage =>
stage =, maxPartitionId = s.rdd.partitions.length - 1)
val taskIdToLocations: Map[Int, Seq[TaskLocation]] = try {
stage match {
case s: ShuffleMapStage => { id => (id, getPreferredLocs(stage.rdd, id))}.toMap
case s: ResultStage => { id =>
val p = s.partitions(id)
(id, getPreferredLocs(stage.rdd, p))
} catch {
case NonFatal(e) =>
stage.makeNewStageAttempt(partitionsToCompute.size), properties))
abortStage(stage, s"Task creation failed: $e\n${Utils.exceptionString(e)}", Some(e))
runningStages -= stage
stage.makeNewStageAttempt(partitionsToCompute.size, taskIdToLocations.values.toSeq)
// If there are tasks to execute, record the submission time of the stage. Otherwise,
// post the even without the submission time, which indicates that this stage was
// skipped.
if (partitionsToCompute.nonEmpty) {
stage.latestInfo.submissionTime = Some(clock.getTimeMillis())
}, properties))
// TODO: Maybe we can keep the taskBinary in Stage to avoid serializing it multiple times.
// Broadcasted binary for the task, used to dispatch tasks to executors. Note that we broadcast
// the serialized copy of the RDD and for each task we will deserialize it, which means each
// task gets a different copy of the RDD. This provides stronger isolation between tasks that
// might modify state of objects referenced in their closures. This is necessary in Hadoop
// where the JobConf/Configuration object is not thread-safe.
var taskBinary: Broadcast[Array[Byte]] = null
var partitions: Array[Partition] = null
try {
// For ShuffleMapTask, serialize and broadcast (rdd, shuffleDep).
// For ResultTask, serialize and broadcast (rdd, func).
var taskBinaryBytes: Array[Byte] = null
// taskBinaryBytes and partitions are both effected by the checkpoint status. We need
// this synchronization in case another concurrent job is checkpointing this RDD, so we get a
// consistent view of both variables.
RDDCheckpointData.synchronized {
taskBinaryBytes = stage match {
case stage: ShuffleMapStage =>
closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef))
case stage: ResultStage =>
JavaUtils.bufferToArray(closureSerializer.serialize((stage.rdd, stage.func): AnyRef))
partitions = stage.rdd.partitions
taskBinary = sc.broadcast(taskBinaryBytes)
} catch {
// In the case of a failure during serialization, abort the stage.
case e: NotSerializableException =>
abortStage(stage, "Task not serializable: " + e.toString, Some(e))
runningStages -= stage
// Abort execution
case NonFatal(e) =>
abortStage(stage, s"Task serialization failed: $e\n${Utils.exceptionString(e)}", Some(e))
runningStages -= stage
val tasks: Seq[Task[_]] = try {
val serializedTaskMetrics = closureSerializer.serialize(stage.latestInfo.taskMetrics).array()
stage match {
case stage: ShuffleMapStage =>
stage.pendingPartitions.clear() { id =>
val locs = taskIdToLocations(id)
val part = partitions(id)
stage.pendingPartitions += id
new ShuffleMapTask(, stage.latestInfo.attemptNumber,
taskBinary, part, locs, properties, serializedTaskMetrics, Option(jobId),
Option(sc.applicationId), sc.applicationAttemptId, stage.rdd.isBarrier())
case stage: ResultStage => { id =>
val p: Int = stage.partitions(id)
val part = partitions(p)
val locs = taskIdToLocations(id)
new ResultTask(, stage.latestInfo.attemptNumber,
taskBinary, part, locs, id, properties, serializedTaskMetrics,
Option(jobId), Option(sc.applicationId), sc.applicationAttemptId,
} catch {
case NonFatal(e) =>
abortStage(stage, s"Task creation failed: $e\n${Utils.exceptionString(e)}", Some(e))
runningStages -= stage
if (tasks.size > 0) {
logInfo(s"Submitting ${tasks.size} missing tasks from $stage (${stage.rdd}) (first 15 " +
s"tasks are for partitions ${tasks.take(15).map(_.partitionId)})")
taskScheduler.submitTasks(new TaskSet(
tasks.toArray,, stage.latestInfo.attemptNumber, jobId, properties))
} else {
// Because we posted SparkListenerStageSubmitted earlier, we should mark
// the stage as completed here in case there are no tasks to run
markStageAsFinished(stage, None)
stage match {
case stage: ShuffleMapStage =>
logDebug(s"Stage ${stage} is actually done; " +
s"(available: ${stage.isAvailable}," +
s"available outputs: ${stage.numAvailableOutputs}," +
s"partitions: ${stage.numPartitions})")
case stage : ResultStage =>
logDebug(s"Stage ${stage} is actually done; (partitions: ${stage.numPartitions})")
源码分析 ResultStage提交
- 第二次提交的是ResultStage
- stage.findMissingPartitions()得到当前Stage有多少个partitions,当前Stage为ResultStage,ShuffleExchangeExec.prepareShuffleDependency,默认的分区器为HashPartitioning,默认的分区个数为200
- partitionsToCompute 得到的值为 Vector(0,1,2,3,......198,199)
- 把ResultStage加到runningStages中,表示此Stage已开始处理,在提交Stage时验证使用
- 对partitions中的每个partition进行优选位置计算,就是任务在哪台机器上运行性能高,效率高
val taskIdToLocations: Map[Int, Seq[TaskLocation]] = try {
stage match {
case s: ShuffleMapStage => { id => (id, getPreferredLocs(stage.rdd, id))}.toMap
case s: ResultStage => { id =>
val p = s.partitions(id)
(id, getPreferredLocs(stage.rdd, p))
} catch {
case NonFatal(e) =>
stage.makeNewStageAttempt(partitionsToCompute.size), properties))
abortStage(stage, s"Task creation failed: $e\n${Utils.exceptionString(e)}", Some(e))
runningStages -= stage
- 把ResultStage中的RDD,function进行序列化,进行广播变量,这样,就可以在计算时拿到该变量
var taskBinary: Broadcast[Array[Byte]] = null
var partitions: Array[Partition] = null
try {
// For ShuffleMapTask, serialize and broadcast (rdd, shuffleDep).
// For ResultTask, serialize and broadcast (rdd, func).
var taskBinaryBytes: Array[Byte] = null
// taskBinaryBytes and partitions are both effected by the checkpoint status. We need
// this synchronization in case another concurrent job is checkpointing this RDD, so we get a
// consistent view of both variables.
RDDCheckpointData.synchronized {
taskBinaryBytes = stage match {
case stage: ShuffleMapStage =>
closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef))
case stage: ResultStage =>
JavaUtils.bufferToArray(closureSerializer.serialize((stage.rdd, stage.func): AnyRef))
partitions = stage.rdd.partitions
taskBinary = sc.broadcast(taskBinaryBytes)
- 把ResultStage拆分成ResultTask,根据partitionsToCompute有多少个,就分成多少个ResultTask,此处,我们是拆分成200个ResultTask
val tasks: Seq[Task[_]] = try {
val serializedTaskMetrics = closureSerializer.serialize(stage.latestInfo.taskMetrics).array()
stage match {
case stage: ShuffleMapStage =>
stage.pendingPartitions.clear() { id =>
val locs = taskIdToLocations(id)
val part = partitions(id)
stage.pendingPartitions += id
new ShuffleMapTask(, stage.latestInfo.attemptNumber,
taskBinary, part, locs, properties, serializedTaskMetrics, Option(jobId),
Option(sc.applicationId), sc.applicationAttemptId, stage.rdd.isBarrier())
case stage: ResultStage => { id =>
val p: Int = stage.partitions(id)
val part = partitions(p)
val locs = taskIdToLocations(id)
new ResultTask(, stage.latestInfo.attemptNumber,
taskBinary, part, locs, id, properties, serializedTaskMetrics,
Option(jobId), Option(sc.applicationId), sc.applicationAttemptId,
- 把所有拆分的ShuffleMapTask转化成TaskSet做为参数,进行任务提交
if (tasks.size > 0) {
logInfo(s"Submitting ${tasks.size} missing tasks from $stage (${stage.rdd}) (first 15 " +
s"tasks are for partitions ${tasks.take(15).map(_.partitionId)})")
taskScheduler.submitTasks(new TaskSet(
tasks.toArray,, stage.latestInfo.attemptNumber, jobId, properties))
/** Called when stage's parents are available and we can now do its task. */
private def submitMissingTasks(stage: Stage, jobId: Int) {
logDebug("submitMissingTasks(" + stage + ")")
// First figure out the indexes of partition ids to compute.
val partitionsToCompute: Seq[Int] = stage.findMissingPartitions()
// Use the scheduling pool, job group, description, etc. from an ActiveJob associated
// with this Stage
val properties = jobIdToActiveJob(jobId).properties
runningStages += stage
// SparkListenerStageSubmitted should be posted before testing whether tasks are
// serializable. If tasks are not serializable, a SparkListenerStageCompleted event
// will be posted, which should always come after a corresponding SparkListenerStageSubmitted
// event.
stage match {
case s: ShuffleMapStage =>
outputCommitCoordinator.stageStart(stage =, maxPartitionId = s.numPartitions - 1)
case s: ResultStage =>
stage =, maxPartitionId = s.rdd.partitions.length - 1)
val taskIdToLocations: Map[Int, Seq[TaskLocation]] = try {
stage match {
case s: ShuffleMapStage => { id => (id, getPreferredLocs(stage.rdd, id))}.toMap
case s: ResultStage => { id =>
val p = s.partitions(id)
(id, getPreferredLocs(stage.rdd, p))
} catch {
case NonFatal(e) =>
stage.makeNewStageAttempt(partitionsToCompute.size), properties))
abortStage(stage, s"Task creation failed: $e\n${Utils.exceptionString(e)}", Some(e))
runningStages -= stage
stage.makeNewStageAttempt(partitionsToCompute.size, taskIdToLocations.values.toSeq)
// If there are tasks to execute, record the submission time of the stage. Otherwise,
// post the even without the submission time, which indicates that this stage was
// skipped.
if (partitionsToCompute.nonEmpty) {
stage.latestInfo.submissionTime = Some(clock.getTimeMillis())
}, properties))
// TODO: Maybe we can keep the taskBinary in Stage to avoid serializing it multiple times.
// Broadcasted binary for the task, used to dispatch tasks to executors. Note that we broadcast
// the serialized copy of the RDD and for each task we will deserialize it, which means each
// task gets a different copy of the RDD. This provides stronger isolation between tasks that
// might modify state of objects referenced in their closures. This is necessary in Hadoop
// where the JobConf/Configuration object is not thread-safe.
var taskBinary: Broadcast[Array[Byte]] = null
var partitions: Array[Partition] = null
try {
// For ShuffleMapTask, serialize and broadcast (rdd, shuffleDep).
// For ResultTask, serialize and broadcast (rdd, func).
var taskBinaryBytes: Array[Byte] = null
// taskBinaryBytes and partitions are both effected by the checkpoint status. We need
// this synchronization in case another concurrent job is checkpointing this RDD, so we get a
// consistent view of both variables.
RDDCheckpointData.synchronized {
taskBinaryBytes = stage match {
case stage: ShuffleMapStage =>
closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef))
case stage: ResultStage =>
JavaUtils.bufferToArray(closureSerializer.serialize((stage.rdd, stage.func): AnyRef))
partitions = stage.rdd.partitions
taskBinary = sc.broadcast(taskBinaryBytes)
} catch {
// In the case of a failure during serialization, abort the stage.
case e: NotSerializableException =>
abortStage(stage, "Task not serializable: " + e.toString, Some(e))
runningStages -= stage
// Abort execution
case NonFatal(e) =>
abortStage(stage, s"Task serialization failed: $e\n${Utils.exceptionString(e)}", Some(e))
runningStages -= stage
val tasks: Seq[Task[_]] = try {
val serializedTaskMetrics = closureSerializer.serialize(stage.latestInfo.taskMetrics).array()
stage match {
case stage: ShuffleMapStage =>
stage.pendingPartitions.clear() { id =>
val locs = taskIdToLocations(id)
val part = partitions(id)
stage.pendingPartitions += id
new ShuffleMapTask(, stage.latestInfo.attemptNumber,
taskBinary, part, locs, properties, serializedTaskMetrics, Option(jobId),
Option(sc.applicationId), sc.applicationAttemptId, stage.rdd.isBarrier())
case stage: ResultStage => { id =>
val p: Int = stage.partitions(id)
val part = partitions(p)
val locs = taskIdToLocations(id)
new ResultTask(, stage.latestInfo.attemptNumber,
taskBinary, part, locs, id, properties, serializedTaskMetrics,
Option(jobId), Option(sc.applicationId), sc.applicationAttemptId,
} catch {
case NonFatal(e) =>
abortStage(stage, s"Task creation failed: $e\n${Utils.exceptionString(e)}", Some(e))
runningStages -= stage
if (tasks.size > 0) {
logInfo(s"Submitting ${tasks.size} missing tasks from $stage (${stage.rdd}) (first 15 " +
s"tasks are for partitions ${tasks.take(15).map(_.partitionId)})")
taskScheduler.submitTasks(new TaskSet(
tasks.toArray,, stage.latestInfo.attemptNumber, jobId, properties))
} else {
// Because we posted SparkListenerStageSubmitted earlier, we should mark
// the stage as completed here in case there are no tasks to run
markStageAsFinished(stage, None)
stage match {
case stage: ShuffleMapStage =>
logDebug(s"Stage ${stage} is actually done; " +
s"(available: ${stage.isAvailable}," +
s"available outputs: ${stage.numAvailableOutputs}," +
s"partitions: ${stage.numPartitions})")
case stage : ResultStage =>
logDebug(s"Stage ${stage} is actually done; (partitions: ${stage.numPartitions})")