【Java并发编程实战】——CountDownLatch源码分析

CountDownLatch 一个同步辅助类,允许一个或多个线程等待,直到其它线程执行完成一组操作。它是 AQS 的共享模式的一种实现。

流程简介:CountDownLatch 必须通过数值 count 来初始化一个大于 0 的计数,任何线程调用 await 方法都会阻塞,直到其它线程调用 countDown 将计数从初始值减为 0,count 变为 0 时,所有阻塞在 await 方法的线程都会恢复运行。这个计数只能使用一次,如果需要循环使用考虑使用 CyclicBarrier 。

第一种用法示例:下面给出了两个类,其中一组 worker 线程使用了两个倒计数锁存器:
第一个类是一个启动信号,在 driver 为继续执行 worker 做好准备之前,它会阻止所有的 worker 继续执行。第二个类是一个完成信号,它允许 driver 在完成所有 worker 之前一直等待。

 class Driver { // ...
   void main() throws InterruptedException {
     CountDownLatch startSignal = new CountDownLatch(1);
     CountDownLatch doneSignal = new CountDownLatch(N);

     for (int i = 0; i < N; ++i) // create and start threads
       new Thread(new Worker(startSignal, doneSignal)).start();

     doSomethingElse();            // don't let run yet
     startSignal.countDown();      // let all threads proceed
     doSomethingElse();
     doneSignal.await();           // wait for all to finish
   }
 }

 class Worker implements Runnable {
   private final CountDownLatch startSignal;
   private final CountDownLatch doneSignal;
   Worker(CountDownLatch startSignal, CountDownLatch doneSignal) {
     this.startSignal = startSignal;
     this.doneSignal = doneSignal;
   }
   public void run() {
     try {
       startSignal.await();
       doWork();
       doneSignal.countDown();
     } catch (InterruptedException ex) {} // return;
   }

   void doWork() { ... }
 }

另一种典型用法是,将一个问题分成 N 个部分,用执行每个部分并让锁存器倒计数的 Runnable 来描述每个部分,然后将所有 Runnable 加入到 Executor 队列。当所有的子部分完成后,协调线程就能够通过 await。

 class Driver2 { // ...
   void main() throws InterruptedException {
     CountDownLatch doneSignal = new CountDownLatch(N);
     Executor e = ...

     for (int i = 0; i < N; ++i) // create and start threads
       e.execute(new WorkerRunnable(doneSignal, i));

     doneSignal.await();           // wait for all to finish
   }
 }

 class WorkerRunnable implements Runnable {
   private final CountDownLatch doneSignal;
   private final int i;
   WorkerRunnable(CountDownLatch doneSignal, int i) {
     this.doneSignal = doneSignal;
     this.i = i;
   }
   public void run() {
     try {
       doWork(i);
       doneSignal.countDown();
     } catch (InterruptedException ex) {} // return;
   }

   void doWork() { ... }
 }

看下 CountDownLatch 源码,它有一个实现了 AQS 的静态内部类 Sync 。
如何构造一个 CountDownLatch 。

CountDownLatch startSignal = new CountDownLatch(1);

/**
 * Constructs a {@code CountDownLatch} initialized with the given count.
 *
 * @param count the number of times {@link #countDown} must be invoked
 *        before threads can pass through {@link #await}
 * @throws IllegalArgumentException if {@code count} is negative
 */
public CountDownLatch(int count) {
	//初始数值不能小于0
    if (count < 0) throw new IllegalArgumentException("count < 0");
    this.sync = new Sync(count);
}

public class CountDownLatch {
    /**
     * Synchronization control For CountDownLatch.
     * Uses AQS state to represent count.
     */
    private static final class Sync extends AbstractQueuedSynchronizer {
        private static final long serialVersionUID = 4982264981922014374L;

        Sync(int count) {
            setState(count);
        }
		...

    }

    private final Sync sync;

}

举一个例子来分析:初始化一个计数为 10 的计数器 CountDownLatch ,然后启动十个线程,每个线程调用 await() 阻塞;然后启动五个线程,共调用 countDown() 十次释放掉计数,恢复前面启动的十个线程。

/**
 * Created by Tangwz on 2019/6/25
 */
public class TestCountDownLatch {
    private static CountDownLatch countDownLatch = new CountDownLatch(10);

    private static class Thread1 extends Thread {
        public Thread1(int i) {
            super("Thread" + i);
        }

        @Override
        public void run() {
            try {
                countDownLatch.await();
                System.out.println(Thread.currentThread().getName() + "恢复运行");
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        }
    }

    private static class Thread2 extends Thread {
        @Override
        public void run() {
            countDownLatch.countDown();
            countDownLatch.countDown();
        }
    }

    public static void main(String[] args) throws InterruptedException, IllegalAccessException, 
    		NoSuchFieldException {
        for (int i = 0; i < 10; i++) {
            Thread1 thread1 = new Thread1(i);
            thread1.start();
            //保证线程按照0-9的序号入队列
            TimeUnit.MILLISECONDS.sleep(100);
        }
        //打印同步队列中的节点名称
        printThreads();
        TimeUnit.SECONDS.sleep(1);
        System.out.println("开始唤醒线程");
        for (int i = 0; i < 5; i++) {
            Thread2 thread2 = new Thread2();
            thread2.start();
        }
    }

    private static void printThreads() throws NoSuchFieldException, IllegalAccessException {
        Field sync = CountDownLatch.class.getDeclaredField("sync");
        sync.setAccessible(true);
        AbstractQueuedSynchronizer aqs = (AbstractQueuedSynchronizer) sync.get(countDownLatch);
        ArrayList<Thread> threads = new ArrayList<>(aqs.getQueuedThreads());
        for (int i = threads.size(); i > 0; i--) {
            System.out.println(threads.get(i - 1).getName());
        }
    }
}

线程0、1、2依次进入同步队列的状态变化
【Java并发编程实战】——CountDownLatch源码分析_第1张图片
执行 countDownLatch.await() 是怎么被阻塞的呢?
注意:下面涉及到的一些 AQS 方法也被其他并发工具类使用,而 CountDownLatch 不一定用得上,故步骤分析暂只考虑本类使用到的情况。

/**
 * Causes the current thread to wait until the latch has counted down to
 * zero, unless the thread is {@linkplain Thread#interrupt interrupted}.
 */
public void await() throws InterruptedException {
	//调用内部类sync,以共享方式获取锁,如果中断,中止
    sync.acquireSharedInterruptibly(1);
}

/**
 * Acquires in shared mode, aborting if interrupted.  Implemented
 * by first checking interrupt status, then invoking at least once
 * {@link #tryAcquireShared}, returning on success.  Otherwise the
 * thread is queued, possibly repeatedly blocking and unblocking,
 * invoking {@link #tryAcquireShared} until success or the thread
 * is interrupted.
 * @param arg the acquire argument.
 * This value is conveyed to {@link #tryAcquireShared} but is
 * otherwise uninterpreted and can represent anything
 * you like.
 * @throws InterruptedException if the current thread is interrupted
 */
public final void acquireSharedInterruptibly(int arg)
        throws InterruptedException {
    if (Thread.interrupted())
        throw new InterruptedException();
    //没有被 CountDown() 设置 state 为 0 前,所有线程进来都会获取锁失败
    if (tryAcquireShared(arg) < 0)
    	//以一个共享可中断的节点获取锁
        doAcquireSharedInterruptibly(arg);
}

protected int tryAcquireShared(int acquires) {
	//只有状态为0的情况才返回1,其他都返回-1
    return (getState() == 0) ? 1 : -1;
}

/**
 * Acquires in shared interruptible mode.
 * @param arg the acquire argument
 */
private void doAcquireSharedInterruptibly(int arg)
    throws InterruptedException {
    //十个线程都会进入这里,然后创建一个标识了共享的节点,添加到队尾
    //这行流程按标记的序号进行
    final Node node = addWaiter(Node.SHARED);
    boolean failed = true;
    try {
        for (;;) {
            final Node p = node.predecessor();
            if (p == head) {
            	//1.线程0首次进来,获取不到锁,返回 r=-1
            	//3.线程0还是获取不到锁
            	//5.线程1-9等都会获取失败,然后依次加入队尾
                int r = tryAcquireShared(arg);
                if (r >= 0) {
                    setHeadAndPropagate(node, r);
                    p.next = null; // help GC
                    failed = false;
                    return;
                }
            }
            //2.线程0将头节点状态设置为-1,返回循环
            //4.线程0然后在这里会阻塞,等待被 countDown 唤醒
            if (shouldParkAfterFailedAcquire(p, node) &&
                parkAndCheckInterrupt())
                throw new InterruptedException();
        }
    } finally {
        if (failed)
            cancelAcquire(node);
    }
}

接下来分析启动的五个线程调用 countDownLatch.countDown() 是怎么一步步最终释放锁,唤醒共享节点的。

/**
 * Decrements the count of the latch, releasing all waiting threads if
 * the count reaches zero.
 *
 * 

If the current count is greater than zero then it is decremented. * If the new count is zero then all waiting threads are re-enabled for * thread scheduling purposes. * *

If the current count equals zero then nothing happens. */ public void countDown() { //每次执行一次只减一,当减至0时代表释放了锁,需要唤醒等待节点 sync.releaseShared(1); } /** * Releases in shared mode. Implemented by unblocking one or more * threads if {@link #tryReleaseShared} returns true. * * @param arg the release argument. This value is conveyed to * {@link #tryReleaseShared} but is otherwise uninterpreted * and can represent anything you like. * @return the value returned from {@link #tryReleaseShared} */ public final boolean releaseShared(int arg) { //调用 CountDownLatch.Sync.tryReleaseShared() if (tryReleaseShared(arg)) { //最终有且仅有一个线程能进入到这里执行 //就是最后执行 compareAndSetState(1, 0) 成功的那个线程 doReleaseShared(); return true; } return false; } protected boolean tryReleaseShared(int releases) { // Decrement count; signal when transition to zero //可能会有多个线程同时释放锁,需要考虑并发 for (;;) { int c = getState(); if (c == 0) //状态不能小于0,之前的状态已经为0了需要返回释放锁失败 return false; int nextc = c-1; if (compareAndSetState(c, nextc)) //只有状态被减至0才返回true return nextc == 0; } } /** * Release action for shared mode -- signals successor and ensures * propagation. (Note: For exclusive mode, release just amounts * to calling unparkSuccessor of head if it needs signal.) */ private void doReleaseShared() { /* * Ensure that a release propagates, even if there are other * in-progress acquires/releases. This proceeds in the usual * way of trying to unparkSuccessor of head if it needs * signal. But if it does not, status is set to PROPAGATE to * ensure that upon release, propagation continues. * Additionally, we must loop in case a new node is added * while we are doing this. Also, unlike other uses of * unparkSuccessor, we need to know if CAS to reset status * fails, if so rechecking. */ //这个方法被用来唤醒下个节点,并传递状态 Node.PROPAGATE //先按照本例子最简单的逻辑来分析,即最后一个线程执行完 countDown 然后依次唤醒 0-9 这10个节点 //实际上不会依次唤醒这十个节点,要是运行例子程序会发现输出的线程名称是乱的,不是从0-9打印,原因后面分析 for (;;) { //h为线程为空的头结点 Node h = head; //1.因为队列不为空,有十个节点,判断成功 //6.被唤醒的thread0也会进来执行,然后唤醒下个节点thread1,thread1再唤醒thread2... if (h != null && h != tail) { int ws = h.waitStatus; //2.头节点状态为-1 if (ws == Node.SIGNAL) { //3.设置头结点状态为0,CAS失败的情况后面分析 if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0)) continue; // loop to recheck cases //4.唤醒下一个节点thread0 unparkSuccessor(h); } else if (ws == 0 && //CAS失败的情况,以及 Node.PROPAGATE 有啥用处参考后面的文章 Semaphore 源码分析 !compareAndSetWaitStatus(h, 0, Node.PROPAGATE)) continue; // loop on failed CAS } if (h == head) // loop if head changed //5.退出循环 break; } }

所有阻塞在 countDownLatch.await() 的线程需要被唤醒

private void doAcquireSharedInterruptibly(int arg)
    throws InterruptedException {
    final Node node = addWaiter(Node.SHARED);
    boolean failed = true;
    try {
        for (;;) {
            final Node p = node.predecessor();
            //2.前驱是头结点
            if (p == head) {
            	//3.此时 status 已经减成了 0,这里 r=1
                int r = tryAcquireShared(arg);
                if (r >= 0) {
                	//3.获取共享锁成功后,设置头,唤醒下个节点
                    setHeadAndPropagate(node, r);
                    p.next = null; // help GC
                    failed = false;
                    return;
                }
            }
            if (shouldParkAfterFailedAcquire(p, node) &&
            	//1.thread0 首先被唤醒,没有中断
                parkAndCheckInterrupt())
                throw new InterruptedException();
        }
    } finally {
        if (failed)
            cancelAcquire(node);
    }
}

/**
 * Sets head of queue, and checks if successor may be waiting
 * in shared mode, if so propagating if either propagate > 0 or
 * PROPAGATE status was set.
 *
 * @param node the node
 * @param propagate the return value from a tryAcquireShared
 */
private void setHeadAndPropagate(Node node, int propagate) {
    Node h = head; // Record old head for check below
    //4.thread0 进来后,将自己设置为头节点
    setHead(node);
    /*
     * Try to signal next queued node if:
     *   Propagation was indicated by caller,
     *     or was recorded (as h.waitStatus either before
     *     or after setHead) by a previous operation
     *     (note: this uses sign-check of waitStatus because
     *      PROPAGATE status may transition to SIGNAL.)
     * and
     *   The next node is waiting in shared mode,
     *     or we don't know, because it appears null
     *
     * The conservatism in both of these checks may cause
     * unnecessary wake-ups, but only when there are multiple
     * racing acquires/releases, so most need signals now or soon
     * anyway.
     */
     //5.propagate = 1
    if (propagate > 0 || h == null || h.waitStatus < 0 ||
        (h = head) == null || h.waitStatus < 0) {
        //6.thread0 的下个节点为 thread1
        Node s = node.next;
        //7.thread1 为共享节点
        if (s == null || s.isShared())
        	//8.唤醒后续节点 thread1
            doReleaseShared();
    }
}

前面提到为何不会依次唤醒线程0-9,原因就在 doReleaseShared()
考虑这样一种可能的情况:线程N唤醒 thread0 成功之后,thread0 调用 setHeadAndPropagate() 后也会调用 doReleaseShared(),这个时候唤醒 thread0 的线程N就会和 thread0 产生竞争。

private void doReleaseShared() {
    for (;;) {
    	//4.线程N 执行到这里,thread0 也执行到这里,h为 thread0 
        Node h = head;
        if (h != null && h != tail) {
            int ws = h.waitStatus;
            if (ws == Node.SIGNAL) {
            	//5.thread0 先一步将 thread0 的状态改为 -1,线程N就会CAS失败
                if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0))
                	//6.线程N 继续循环
                    continue;            // loop to recheck cases
                //1.线程N 唤醒 thread0
                //thread1 被唤醒
                unparkSuccessor(h);
            }
            //7.线程N 发现 thread0 的状态为0
            else if (ws == 0 &&
            		 //8.修改 thread0 的状态为 Node.PROPAGATE
            		 //为啥要设置成 PROPAGATE 呢,个人觉得 CountDownLatch 不设置忽略掉这一步也没问题
            		 //因为 CountDownLatch 的状态不可以复用,不会有线程再去修改状态 status,这里根本就不会有竞争
            		 //但是 Semaphore 会用到
                     !compareAndSetWaitStatus(h, 0, Node.PROPAGATE))
                continue;                // loop on failed CAS
        }
        //2.thread0 调用了 setHeadAndPropagate 中的 setHead(node),thread0 变为头节点
        //3.线程N 判断 h不为原来的节点,不会退出循环
        //9.若被唤醒的 thread1 也 调用了 setHeadAndPropagate 中的 setHead(node)
        //9.那么线程T、thread0 和thread1 都会再次执行 doReleaseShared()
        if (h == head)                   // loop if head changed
            break;
    }
}

CyclicBarrier 和 CountDownLatch 类似,具体区别请看 CyclicBarrier源码分析。

你可能感兴趣的:(java并发编程,Java并发编程实战)