Fork/Join 分治编程
在 JDK 中并行执行框架 Fork-Join 使用了 “工作窃取(work-stealing)”算法,它是指某个线程从其他队列中窃取任务来执行。
比如要完成一个比较大的任务,完全可以把这个大的任务分割为若千互不依赖的子任务/小任务,为了更加方便地管理这些任务,于是把这些子任务分别放到不同的队列里,这时就会处理,完成任务的线程与其等着,不如去帮助其他线程分担要执行的任务,于是它就去其他线程的队列里窃取一一个任务来执行,这就是所谓的“工作窃取(work-stealing)” 算法。
工作窃取
ForkJoinPool与ThreadPoolExecutor有个很大的不同之处在于,ForkJoinPool存在引入了工作窃取设计,它是其性能保证的关键之一。工作窃取,就是允许空闲线程从繁忙线程的双端队列中窃取任务。默认情况下,工作线程从它自己的双端队列的头部获取任务。但是,当自己的任务为空时,线程会从其他繁忙线程双端队列的尾部中获取任务。这种方法,最大限度地减少了线程竞争任务的可能性。
ForkJoinPool的大部分操作都发生在工作窃取队列(work-stealingqueues)中,该队列由内部类WorkQueue实现。它是Deques的特殊形式,但仅支持三种操作方式:push、pop和poll(也称为窃取) 。在ForkJoinPool中,队列的读取有着严格的约束,push和pop仅能从其所属线程调用,而poll则可以从其他线程调用。
工作窃取的运行流程如下图所示:
- 工作窃取算法的优点是充分利用线程进行并行计算,并减少了线程间的竞争;
- 工作窃取算法缺点是在某些情况下还是存在竞争,比如双端队列里只有一个任务时。并且消耗了更多的系统资源,比如创建多个线程和多个双端队列。
为什么工作线程从队列头部获取,工作窃取从尾部窃取?
这样做的主要原因是为了提高性能,通过始终选择最近提交的任务,可以增加资源仍分配在CPU缓存中的机会,这样CPU处理起来要快一些。而窃取者之所以从尾部获取任务,则是为了降低线程之间的竞争可能,毕竟大家都从一个部分拿任务,竞争的可能要大很多。 此外,这样的设计还有一种考虑。由于任务是可分割的,那队列中较旧的任务最有可能粒度较大,因为它们可能还没有被分割,而空闲的线程则相对更有“精力”来完成这些粒度较大的任务。
分治算法
分治算法的基本思想是将一个规模为N的问题分解为K个规模较小的子问题,这些子问题相互独立且与原问题性质相同。求出子问题的解,就可得到原问题的解。即一种分目标完成程序算法,简单问题可用二分法完成。
分治法解题的一般步骤:
(1)分解,将要解决的问题划分成若干规模较小的同类问题;
(2)求解,当子问题划分得足够小时,用较简单的方法解决;
(3)合并,按原问题的要求,将子问题的解逐层合并构成原问题的解。
例子和代码可以参考后面的 ForkJoinPoolTest
中累计求和的例子, 注释中有每个步骤的代码开始部分。
ForkJoinPool 和 ForkJoinTask
ForkJoinPool
ForkJoinPool 是用来执行 ForJoinTask 任务的任务池,区别于线程池的 Worker + Queue 的组合,而是维护了一个队列数组 WorkQuque(WorkQuque[]) 在提交任务和线程任务的时候大幅度减少碰撞。
构造方法代码如下:
public ForkJoinPool(int parallelism, ForkJoinWorkerThreadFactory factory, UncaughtExceptionHandler handler, boolean asyncMode) { this(checkParallelism(parallelism), checkFactory(factory), handler, asyncMode ? FIFO_QUEUE : LIFO_QUEUE, "ForkJoinPool-" + nextPoolId() + "-worker-"); checkPermission(); } private ForkJoinPool(int parallelism, ForkJoinWorkerThreadFactory factory, UncaughtExceptionHandler handler, int mode, String workerNamePrefix) { this.workerNamePrefix = workerNamePrefix; this.factory = factory; this.ueh = handler; this.config = (parallelism & SMASK) | mode; long np = (long)(-parallelism); // offset ctl counts this.ctl = ((np << AC_SHIFT) & AC_MASK) | ((np << TC_SHIFT) & TC_MASK); }
参数的含义:
- parallelism 指定并行级别(parallelism level)。ForkJoin 将根据这个设定来决定工作线程的数量。如果没有设置将使用
Math.min(MAX_CAP, Runtime.getRuntime().availableProcessors())
其实也就是 cpu 核心线程数。
- factory ForkJoinPool 创建线程时,会通过factory 来创建,自定义需要实现
ForkJoinWorkerThreadFactory
接口,默认使用DefaultForkJoinWorkerThreadFactory
- handler 指定异常处理器,当任务在运行中出错,将由设定的 handler 进行处理
- mode 模式,设置队列工作模式:两种 FIFO_QUEUE, LIFO_QUEUE
- workerNamePrefix 线程名的前缀
ForkJoinTask
ForkJoinTask 是 ForkJoinPool 的核心之一,它是任务的实际载体,定义了执行时间的具体逻辑和拆分逻辑。
ForkJoinTask 继承了 Future 接口,也可以当作是一个轻量级的 Future.
ForkJoinTask 是一个抽象类,它的方法有很多,最核心的方法是 fork() 方法和 join 方法, 承载了主要的任务协调作用,一个用于任务提交,一个用于获取结果
- fork() 提交任务 : fork()方法用于向当前任务所运行的线程池中提交任务。如果当前线程是ForkJoinWorkerThread类型,将会放入该线程的工作队列,否则放入common线程池的工作队列中。
- join() 获取任务结果: join()方法用于获取任务的执行结果。调用join()时,将阻塞当前线程直到对应的子任务完成运行并返回结果。
通常情况下,我们不需要直接继承 ForkJoinTask 类, 而只需要继承它的子类,Fork/Join 框架提供了ForkJoinTask 的三个子类:
- RecursiveAction 用于递归执行且不需要返回结果的任务
- RecursiveTask 用于递归执行且返回结果的任务
- CountedCompleter:在执行完成任务后会触发一个自定义的钩子
ForkJoin 最适合纯粹的计算任务,也就是纯粹的函数计算,计算过程中都是独立运行的,没有外部数据/逻辑依赖。提交 ForkJoinPool 中的任务应该避免执行阻塞 I/O。
执行例子
通过实现 RecursiveTask
实现 int a -> b 的累加,具体的代码如下:
public class ForkJoinPoolTest { static class MyForkJoinTask extends RecursiveTask<Integer> { // 每个任务的任务量 private static final Integer MAX = 200; // 子任务开始计算的值 private Integer startValue; // 子任务结束计算的值 private Integer endValue; public MyForkJoinTask(Integer startValue, Integer endValue) { this.startValue = startValue; this.endValue = endValue; } @Override protected Integer compute() { //(2)求解,当子问题划分得足够小时,用较简单的方法解决; if (endValue - startValue < MAX) { System.out.println(Thread.currentThread().getName()+": 开始计算的部分:startValue = " + startValue + ";endValue = " + endValue); Integer totalValue = 0; for (int index = this.startValue; index <= this.endValue; index++) { totalValue += index; } return totalValue; } //(1)分解,将要解决的问题划分成若干规模较小的同类问题; else { MyForkJoinTask subTask1 = new MyForkJoinTask(startValue, (startValue + endValue) / 2); subTask1.fork(); MyForkJoinTask subTask2 = new MyForkJoinTask((startValue + endValue) / 2 + 1, endValue); subTask2.fork(); // (3)合并,按原问题的要求,将子问题的解逐层合并构成原问题的解。 return subTask1.join() + subTask2.join(); } } } public static void main(String[] args) throws ExecutionException, InterruptedException { ForkJoinPool forkJoinPool = new ForkJoinPool(); // 0 - 1000 累加 ForkJoinTask<Integer> task = forkJoinPool.submit(new MyForkJoinTask(0, 1000)); // 获取结果 Integer result = task.get(); // 结果打印 System.out.println(result); } }
输出结果如下:
ForkJoinPool-1-worker-3: 开始计算的部分:startValue = 501;endValue = 625 ForkJoinPool-1-worker-2: 开始计算的部分:startValue = 0;endValue = 125 ForkJoinPool-1-worker-0: 开始计算的部分:startValue = 751;endValue = 875 ForkJoinPool-1-worker-3: 开始计算的部分:startValue = 626;endValue = 750 ForkJoinPool-1-worker-0: 开始计算的部分:startValue = 876;endValue = 1000 ForkJoinPool-1-worker-2: 开始计算的部分:startValue = 126;endValue = 250 ForkJoinPool-1-worker-0: 开始计算的部分:startValue = 251;endValue = 375 ForkJoinPool-1-worker-2: 开始计算的部分:startValue = 376;endValue = 500 500500
结论:我们从线程名称就可以看到,我全部都是使用的默认参数,一共启动了 4 个线程。每次大概分到 125 次计算