【JAVA并发包源码分析】循环栅栏:CyclicBarrier

一、认识CyclicBarrier

对于CyclicBarrier大多数人感到陌生,其实CyclicBarrier是一种多线程并发控制使用工具,和CountDownLatch非常类似,实现线程之间的计数等待,也就是说一个线程或者多个线程等待其他线程完成任务,不过比CountDowwnLatch复杂。

CyclicBarrier是循环栅栏的意思,所谓栅栏就是障碍物,阻止其他人进入,在多线程中,使用该工具类就是阻止线程执行,那么它是怎么阻止的呢?下面会详细介绍。前面的Cyclic意为循环,也就是说可以循环使用该计数器。举个简单例子,比如有5个线程,那么该工具类就要等待这五个线程都到达指定的障碍点,执行完相应的动作后,计数器才会清零,等待下一批线程的到达。

下面我们来看看CyclicBarrier内部的构造以及类之间的依赖关系:
【JAVA并发包源码分析】循环栅栏:CyclicBarrier_第1张图片
上图是CyclicBarrier内部的部分代码,由上图可以画出该工具类的构造图如下:
【JAVA并发包源码分析】循环栅栏:CyclicBarrier_第2张图片

二、使用场景

对于该工具类使用场景也很丰富,这里用一个简单的实例来说明。比如,这里有10个士兵司令下达命令,要求这10个士兵先全部集合来报道,报道完成之后再一起去执行任务,当每一个士兵的任务完成之后然后才会向司令报告任务执行完毕。

  public CyclicBarrier(int parties, Runnable barrierAction) {
        if (parties <= 0) throw new IllegalArgumentException();
        this.parties = parties;
        this.count = parties;
        this.barrierCommand = barrierAction;
    }

对于上述的CyclicBarrier构造方法,它接收两个参数,第一个参数就是计数器总数,参与计数的线程总数,第二个参数barrierAction是一个Runnable接口,它是当一次计数完成之后要做的动作。
对于上述案例,我们来用代码演示该场景:

package cn.just.thread.concurrent;

import java.util.Random;
import java.util.concurrent.BrokenBarrierException;
import java.util.concurrent.CyclicBarrier;

/**
 * 测试循环栅栏:CycleBarrier(int parties,Runnable barrierAction);
 * 第一个参数表示计数的总数,即参与的线程总数
 * 第二个参数表示当一次计数完成后,系统会执行的动作
 * @author Shinelon
 *
 */
public class CycleBarrierDemo {
    public static class Soldier implements Runnable{
        private String soldier;
        private final CyclicBarrier cyclic;

        public Soldier(String soldier, CyclicBarrier cyclic) {
            super();
            this.soldier = soldier;
            this.cyclic = cyclic;
        }

        @Override
        public void run() {
            try{
                //等待所有士兵到齐
                cyclic.await();
                doWork();
                //等待所有士兵去工作
                cyclic.await();
            }catch (InterruptedException e) {
                e.printStackTrace();
            }catch (BrokenBarrierException e) {
                e.printStackTrace();
            }
        }

        private void doWork() {
            try{
                Thread.sleep(Math.abs(new Random().nextInt()%10000));
            }catch (InterruptedException e) {
                e.printStackTrace();
            }
            System.out.println(soldier+":任务完成!");
        }
    }

    public static class BarrierRun implements Runnable{
        boolean flag;
        int N;

        public BarrierRun(boolean flag, int n) {
            super();
            this.flag = flag;
            N = n;
        }

        @Override
        public void run() {
            if(flag){
                System.out.println("司令:【士兵"+N+"个,任务完成】");
            }else{
                System.out.println("司令:【士兵"+N+"个,集合完毕】");
                flag=true;
            }

        }
    }

    public static void main(String[] args) {
        final int N=10;
        Thread[] allSoldier=new Thread[N];
        boolean flag=false;
        CyclicBarrier cyclic=new CyclicBarrier(N, new BarrierRun(flag, N));
        //设置障碍点,主要是为了执行这个方法
        System.out.println("集合队伍");
        for(int i=0;i"士兵"+i+"报道!");
            allSoldier[i]=new Thread(new Soldier("士兵"+i, cyclic));
            allSoldier[i].start();
        }
    }
}

下面是运行结果:
【JAVA并发包源码分析】循环栅栏:CyclicBarrier_第3张图片

上面的代码中,涉及到一个该工具类的内部方法:
await()等待所有的线程计数完成。该方法内部调用dowait方法,在dowait方法中用重入锁进行加锁。实现了一次计数器的等待过程。下面我们来深入源码探究。

三、深入源码

上面说道dowait方法,下面是该方法的源码:

private int dowait(boolean timed, long nanos)
        throws InterruptedException, BrokenBarrierException,
               TimeoutException {
        final ReentrantLock lock = this.lock;
        lock.lock();
        try {
            //标志着每一个线程,当一个线程到来就生成一个新生代
            final Generation g = generation;
            //当计数器被破坏,抛出BrokenBarrierException异常
            if (g.broken)
                throw new BrokenBarrierException();
            //当线程被中断。抛出中断异常
            if (Thread.interrupted()) {
                breakBarrier();
                throw new InterruptedException();
            }
            //当一个线程到来时count减1,直到count为0则计数完成
            int index = --count;
            if (index == 0) {  // tripped
                boolean ranAction = false;
                try {
                    final Runnable command = barrierCommand;
                    if (command != null)
                        command.run();
                    ranAction = true;
                    //更新标志,唤醒所有等待线程
                    nextGeneration();
                    return 0;
                } finally {
                    //如果计数完成,唤醒所有等待的线程,计数器重新开始工作
                    if (!ranAction)
                        breakBarrier();
                }
            }

            // loop until tripped, broken, interrupted, or timed out
            for (;;) {
                try {
                    //如果当前线程没有超时则继续等待
                    if (!timed)
                        trip.await();
                      //如果调用超时,调用awaitNanos方法等待
                    else if (nanos > 0L)
                        nanos = trip.awaitNanos(nanos);
                } catch (InterruptedException ie) {
                    //如果所有线程都已经到达或者被中断则计数完成,进入下一次循环
                    if (g == generation && ! g.broken) {
                        breakBarrier();
                        throw ie;
                    } else {
                        // We're about to finish waiting even if we had not
                        // been interrupted, so this interrupt is deemed to
                        // "belong" to subsequent execution.
                        Thread.currentThread().interrupt();
                    }
                }

                if (g.broken)
                    throw new BrokenBarrierException();
                //如果不是同一个线程,则返回index
                if (g != generation)
                    return index;

                if (timed && nanos <= 0L) {
                    breakBarrier();
                    throw new TimeoutException();
                }
            }
        } finally {
            //释放锁
            lock.unlock();
        }
    }

解释一下上面的源代码,对于每一个线程,它都会有一个generation进行标志用来区分不同的线程(我是这样理解的),因为generation对象中有一个属性broken标志着是否该计数器被破坏或者计数是否完成,默认是false:

 private static class Generation {
        boolean broken = false;
    }

CyclicBarrier设置了两个异常,一个是BrokenBarrierException,另一个InterruptedException,InterruptedException异常相信大家都很熟悉,如果发生中断则抛出异常,BrokenBarrierException异常是当计数器被破坏的时候抛出。当一个线程来到的时候count-1,然后判断count是否为0,如果为零则计数完成,则执行下面相应的动作进入下一次的循环计数:

  final Runnable command = barrierCommand;
                    if (command != null)
                        command.run();
                    ranAction = true;
                    //更新标志
                    nextGeneration();
  /**
     * Updates state on barrier trip and wakes up everyone.
     * Called only while holding lock.
     */
 private void nextGeneration() {
        // signal completion of last generation
        trip.signalAll();
        // set up next generation
        count = parties;
        generation = new Generation();
    }

依据上面的场景我们可以理解,10个士兵执行任务,count为10,每次到来一个士兵则count-1,当10个士兵全部到来时则count为0,然后执行BarrierRun线程执行相应的动作。接着调用nextGeneration方法更新标志并且唤醒所有等待的线程继续向下执行。
在判断计数器是否完成一次计数时它调用breakBarrier()方法:

    /**
     * Sets current barrier generation as broken and wakes up everyone.
     * Called only while holding lock.
     */
    private void breakBarrier() {
        generation.broken = true;
        count = parties;
        trip.signalAll();
    }

这个方法同样会更新标志并且唤醒所有等待的线程。

在接下来的整个for循环中,判断了当前线程是否被中断,计数器是否被破坏,等待是否超时。

  1. 如果等待超时则调用awaitNanos方法继续等待,该方法时Contition接口的实现类的一个方法,让线程在合适的时间进行等待或者在特定的时间内得到通知,继续执行,该方法内部实现复杂,笔者能力有限,这里就不进行分析了。有兴趣的话可以自己查看源码。
    【JAVA并发包源码分析】循环栅栏:CyclicBarrier_第4张图片
  2. 它会判断所有线程是否都已经到达,如果所有线程已经执行完毕到达则进行下一次循环
//如果所有线程都已经到达或者被中断则计数完成,进入下一次循环
                    if (g == generation && ! g.broken) {
                        breakBarrier();
                        throw ie;
                    }

3.同时它也会判断是否是同一个线程,并且更新标志。

//如果不是同一个线程,则返回index
                if (g != generation)
                    return index;

当该线程的所有任务都执行完毕后它就会释放锁。

至此,本文已经介绍完CyclicBarrier工具类的介绍,本人能力有限,如有不足之处还请指教。多谢!
欢迎关注微信公众号:
【JAVA并发包源码分析】循环栅栏:CyclicBarrier_第5张图片


你可能感兴趣的:(Java,Java多线程)