ForkJoin分治编程笔记

文章目录

  • Fork-Join
    • 类结构和基本使用
      • 结构
      • RecursiveAction
      • RecursiveTask
      • Fork-Join求和
    • ForkJoinPool核心方法
      • execute
      • submit
      • 其他方法
      • 监听ForkJoinPool池的状态

Fork-Join

类结构和基本使用

结构

通过使用ForkJoinPool类创建任务池,实现分治编程

public class ForkJoinPool extends AbstractExecutorService {
...
}
支持execute和submit等ExecutorService约定的方法

ForkJoinPool提供任务池,执行具体任务由ForkJoinTask类实现

ForkJoinTask类是抽象类,该类有3个子类

  • CountedCompleter
  • RecursiveAction
  • RecursiveTask

RecursiveAction,RecusiveTask,CounterCompleter等同于Runnable,Callable 用于执行任务

RecursiveAction

RecursiveAction该类执行的任务是具有无返回值的,仅执行一次的任务

实例:执行任务

public class ForkJoinTest {
    private static ForkJoinPool pool = new ForkJoinPool();
    public static void run() {
        pool.submit(new RecursiveActionTest());
    }

    static class RecursiveActionTest extends RecursiveAction {
        @Override
        protected void compute() {
            System.out.println("action运行");
        }
    }

    public static void main(String[] args) throws InterruptedException {
        run();
        Thread.sleep(3000);
    }
}

fork分解RecursiveAction任务

ForkJoinTask的fork方法每调用一次分离一次任务,且增加系统运行负担,此时通过

public static void invokeAll(ForkJoinTask t1, ForkJoinTask t2)该方法优化fork执行优化效率

利用invokeAll分离任务 (本质上使用了fork)

public class ForkJoinTest {
    private static ForkJoinPool pool = new ForkJoinPool();
    public static void run() {
        pool.submit(new RecursiveActionTest(0, 10));
    }

    static class RecursiveActionTest extends RecursiveAction {
        private int beginValue;
        private int endValue;
        public RecursiveActionTest(int beginValue, int endValue) {
            super();
            this.beginValue = beginValue;
            this.endValue = endValue;
        }
        @Override
        protected void compute() {
            String name = Thread.currentThread().getName();
            System.out.println(name + "==========");
            if (endValue - beginValue > 2) {
                int middle = (beginValue + endValue) / 2;
                RecursiveActionTest leftAction = new RecursiveActionTest(beginValue, middle);
                RecursiveActionTest rightAction = new RecursiveActionTest(middle + 1, endValue);
                RecursiveAction.invokeAll(leftAction, rightAction);
            } else {
                System.out.println("打印相邻的组合:" + beginValue + "=" + endValue);
            }
        }
    }
    public static void main(String[] args) throws InterruptedException {
        run();
        Thread.sleep(13000);
    }
}

invokeAll源码

public static void invokeAll(ForkJoinTask... tasks) {
    Throwable ex = null;
    int last = tasks.length - 1;
    for (int i = last; i >= 0; --i) {
        ForkJoinTask t = tasks[i];
        if (t == null) {
            if (ex == null)
                ex = new NullPointerException();
        }
        else if (i != 0)
            t.fork();
        else if (t.doInvoke() < NORMAL && ex == null)
            ex = t.getException();
    }
    for (int i = 1; i <= last; ++i) {
        ForkJoinTask t = tasks[i];
        if (t != null) {
            if (ex != null)
                t.cancel(false);
            else if (t.doJoin() < NORMAL)
                ex = t.getException();
        }
    }
    if (ex != null)
        rethrow(ex);
}

RecursiveTask

类RecursiveTask执行的任务具有返回值

static class RecursiveTaskTest extends RecursiveTask {
    @Override
    protected Integer compute() {
        System.out.println("");
        return -1;
    }
}
public static void runTask() {
    try {
        RecursiveTaskTest task = new RecursiveTaskTest();
        System.out.println("task的hashCode:" + task.hashCode());
        ForkJoinTask taskResult = pool.submit(task);
        System.out.println("taskResult的hashCode:" + taskResult.hashCode() + " " + taskResult.get());
    } catch (Exception e) {
        e.printStackTrace();
    }

}

public static void main(String[] args) throws InterruptedException {
    runTask();
    Thread.sleep(13000);
}
返回值:
task的hashCode:1681433494
taskResult的hashCode:1681433494 -1

注意RecursiveTask执行任务,调用ForkJoinPool的submit返回的ForkJoinTask结果对象为传入的RecursiveTask对象

此时除了使用ForkJoinTask的get方法取得返回值也可以通过join方法取得返回值

ForkJoinTask taskResult = pool.submit(task);
System.out.println("taskResult的hashCode:" + taskResult.hashCode() + " " + taskResult.join());

get与join的区别

使用get方法执行任务时,当子任务出现异常时可以在main主线程中进行捕获
方法join遇到异常直接抛出

RecursiveTask可以执行多任务

执行多任务时,任务之间的运行是异步的,但join和get方法是同步的
public static void runMultiTask() throws InterruptedException{
    RecursiveTask task = new RecursiveTask() {
        @Override
        protected Integer compute() {
            try {

                String name = Thread.currentThread().getName();
                System.out.println(name + ":begin");
                Thread.sleep(3000);
                System.out.println(name + ":end");
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
            return 100;
        }
    };
    ForkJoinTask task1 = pool.submit(task);
    ForkJoinTask task2 = pool.submit(task);
    System.out.println(task1.join() + " A");
    System.out.println(task2.join() + " B");
}
public static void main(String[] args) throws InterruptedException {
    runMultiTask();
    Thread.sleep(13000);
}
打印
ForkJoinPool-1-worker-1:begin
ForkJoinPool-1-worker-2:begin
ForkJoinPool-1-worker-2:end
ForkJoinPool-1-worker-1:end
100 A
100 B

Fork-Join求和

字符串累加:

public class MyRecursiveTask extends RecursiveTask {

    private int beginValue;
    private int endValue;

    public MyRecursiveTask(int beginValue, int endValue) {
        this.beginValue = beginValue;
        this.endValue = endValue;
    }

    @Override
    protected String compute() {
        String name = Thread.currentThread().getName();
        System.out.println(name + ":begin");
        if (endValue - beginValue > 2) {
            int middleValue = (endValue + beginValue) / 2;
            MyRecursiveTask leftTask = new MyRecursiveTask(beginValue, middleValue);
            MyRecursiveTask rightTask = new MyRecursiveTask(middleValue + 1, endValue);
            RecursiveTask.invokeAll(leftTask, rightTask);
            return leftTask.join() + rightTask.join();
        } else {
            String resultStr = "";
            for (int i = beginValue; i <= endValue; i++) {
                resultStr += (i);
            }
            System.out.println("返回:" + resultStr + " " + beginValue + " " + endValue);
            return resultStr;
        }
    }
}
public class ForkJoinTest {
    private static ForkJoinPool pool = new ForkJoinPool();
    public static void runCountStr() {
        MyRecursiveTask task = new MyRecursiveTask(1, 5);
        ForkJoinTask task1 = pool.submit(task);
        System.out.println(task1.join());
    }

    public static void main(String[] args) throws InterruptedException {
        runCountStr();
        Thread.sleep(13000);
    }
}
打印
ForkJoinPool-1-worker-1:begin
ForkJoinPool-1-worker-1:begin
返回:123 1 3
ForkJoinPool-1-worker-1:begin
返回:45 4 5
12345

求和:

public class MyRecursiveTask extends RecursiveTask {

    private int beginValue;
    private int endValue;

    public MyRecursiveTask(int beginValue, int endValue) {
        this.beginValue = beginValue;
        this.endValue = endValue;
    }

    @Override
    protected Integer compute() {
       Integer sumValue = 0;
       if (beginValue != endValue) {
           int middleValue = (beginValue + endValue) / 2;
           MyRecursiveTask left = new MyRecursiveTask(beginValue, middleValue);
           MyRecursiveTask right = new MyRecursiveTask(middleValue + 1, endValue);
           RecursiveTask.invokeAll(left, right);
           return left.join() + right.join();
       } else {
           return endValue;
       }
    }
}

public class ForkJoinTest {
    private static ForkJoinPool pool = new ForkJoinPool();
    public static void runCountStr() {
        MyRecursiveTask task = new MyRecursiveTask(1, 5);
        ForkJoinTask task1 = pool.submit(task);
        System.out.println(task1.join());
    }

    public static void main(String[] args) throws InterruptedException {
        runCountStr();
        Thread.sleep(13000);
    }
}
打印
15相当于1+2+3+4+5

ForkJoinPool核心方法

execute

在ForkJoinPool中execute方法以异步的方式执行任务且没有返回值,但可以通过RecusiveTask对象处理返回值

public void execute(ForkJoinTask task|Runnable task)

forkJoinPool.execute(recursiveTask);
// 如此获得返回值
recursiveTask.get()

submit

方法submit有返回值,返回ForkJoinTask的实现类

public ForkJoinTask submit(ForkJoinTask task|Runnable task|Callable task|Runnable task, T result)

注意参数为Runnable任务时,返回值ForkJoinTask的get方法依然阻塞,只不过任务完成返回null
针对Runnable任务的返回值可以传入Runnable task, T result两个参数,其中result是返回值的封装,如此result依赖于task任务的执行
    此时submit返回值是ForkJoinTask,封装了传入的T result
    故而使用
    Future future = forkJoinPool.submit(runnable, result);
    result = future.get();
    如此同步获取返回值

其他方法

invoke

public T invoke(ForkJoinTask task)

方法execute,submit,invoke都可以在异步队列中执行任务,但方法invoke是阻塞的,直接将返回值进行返回,无需get

invokeAll
ForkJoinPool的invokeAll不同于ForkJoinTask的invokeAll分离任务,此invokeAll用于获取所有任务的返回结果

public List invokeAll(Collection tasks)

shutdown

ForkJoinPool执行shutdown再执行任务会出现异常,ForkJoinPool的shutdown方法不具有中断的效果,遇到sleep方法不会发生中断异常

通过isShutdown方法判断池是否销毁

而对于shutdownNow,销毁池并关闭正在执行的任务

此时shutdown后调用get方法不出现异常,而shutdownNow后调用get方法异常

注意shutdownNow内部使用了interrupt方法,遇到sleep方法抛出InterruptedException

awaitTermination

public boolean awaitTermination(long timeout, TimeUnit unit)

awaitTermination等待池销毁,有阻塞特性,配合shutdown使用

监听ForkJoinPool池的状态

getParallelism: 获得并行的数量,依赖CPU的内核数
getPoolSize: 获得任务池的大小
getQueuedSubmissionCount: 取得已经提交但尚未被执行的任务数量
hasQueuedSubmissions: 判断队列中是否有未执行的任务
getActiveThreadCount: 获得活动的线程个数
getQueuedTaskCount: 获得任务的总个数
getStealCount: 获得偷取的任务个数
getRunningThreadCount: 获得正在运行并且不再阻塞状态下的线程个数
isQuiescent: 判断任务池是否是静止未执行的状态

ForkJoinTask对象
通过isCompletedAbnormally判断任务是否出现异常,isCompletedNormally判断任务是否正常执行完毕,getException返回报错异常(定位每个任务的异常)

注意Fork-Join分治编程主要掌握ForkJoinTask的两个常用子类的fork分解算法

你可能感兴趣的:(java,ForkJoin,Java并发,多线程)