正文
一、Fork/Join简介
简单的说,Fork/Join是一个并行任务执行框架,能够把一个大的任务拆分成若干个小任务,并行地进行执行,最终还可以汇总各个小任务的执行结果。比如我们想计算1+2+…+100的结果,我们可以把这个大的任务拆分为10个小的任务,这10个小任务分别是1+…+10、11+…+20、…91+…+100,然后最终把这10个小任务的结果再加起来得到大任务的结果。
工作窃取算法,是指大任务被分成多个小任务的时候,这些小任务会被分到不通的任务队列中,每个任务队列会有一个工作线程来执行任务。但是在有些时候,有的线程会执行的比较快,提前完成了所有任务的执行,那么它就会去别的线程的任务队列中窃取任务进行执行,从而加快整体任务的完成。
但是窃取线程如何保证不和被窃取任务的线程冲突呢?这里用到了双端队列,即任务队列都是这种数据结构,队列绑定的工作线程都从队列头部取任务进行执行,而窃取线程会从别的队列尾部获取任务进行执行。
工作窃取算法充分利用多线程进行并行计算,提高了执行效率,同时使用双端队列减少了线程间的冲突竞争;然后,不能完全避免冲突,比如某个任务队列中仅有一个任务的时候,两个线程同时竞争。还有,该算法会创建多个线程和多个双端队列,对系统资源的消耗会增加。
二、Fork/Join使用
在使用之前,我们需要记住两个概念:
ForkJoinTask,我们需要自己定义一个ForkJoinTask,用来定义我们需要执行的任务,它可以继承RecursiveAction或者RecursiveTask,从而获取fork和join操作的能力,前者表示不需要返回值,后者则是需要返回值,比如我们上面的1+2+…+100的累加操作则是需要返回值的,所以应该继承RecursiveTask
ForkJoinPool,类似于线程池,连接池的概念,ForkJoinTask都需要交给ForkJoinPool才能执行。
@Slf4j public class CountTask extends RecursiveTask<Integer> { // 设置最小子任务的阈值 private static final int taskLimit = 10; private int start; private int end; public CountTask(int start,int end){ this.start = start; this.end = end; } @Override protected Integer compute() { int sumResult = 0; boolean taskFlag = (end - start) <= taskLimit; if(taskFlag){ log.info("进行计算,当前start为{},end为{}",start,end); // 子任务足够小了,可以开始执行 for (int i=start;i<=end;i++){ sumResult += i; } } else { log.info("需要拆分:当前start为{},end为{}",start,end); // 子任务还不是足够小,需要进一步分割 int middle = (start + end)/2; CountTask leftTask = new CountTask(start,middle); CountTask rightTask = new CountTask(middle+1, end); // 分配任务 leftTask.fork(); rightTask.fork(); // 等待子任务执行完成,进行结果的合并 int leftResult = leftTask.join(); int rightResult = rightTask.join(); sumResult = leftResult + rightResult; } log.info("当前计算结果为:{}",sumResult); return sumResult; } }
我们在compute中定义任务的拆分粒度和最小任务的执行逻辑,并通过fork和join能力来实现多线程并发。当子任务调用fork的时候,会继续执行子任务的compute;当子任务调用join的时候,会等待其所有子孙任务的执行结果。
public static void main(String[] args) throws ExecutionException, InterruptedException { ForkJoinPool forkJoinPool = new ForkJoinPool(); // 计算从1累加到100,应该获取5050 CountTask myTask = new CountTask(1,100); ForkJoinTask<Integer> result = forkJoinPool.submit(myTask); if(myTask.isCompletedAbnormally()){ log.warn("发生了异常,{}", myTask.getException()); } log.info("1+2+...+100={}",result.get()); }