不就是CountDownLatch

介绍CountDownLatch之前,我相信很多人在学习的时候是不清楚这个CountDownLatch的使用场景是啥。为了回答这个问题,简单说个小段子。
老李家有两个熊孩子小A和小B,小A和小B每天放学后自己回家,到家后都需要老李来开门,不要问我为啥不给小A和小B一把钥匙。由于不是一个年级的,放学的时间不同,每天都需要老李开两次门,有一天老李怒了,告诉两个熊孩子,以后到家了必须敲下门,在门口喊一声,老李听到两个孩子的敲门声再去敲门,不要问我小A和小B是亲生的不。
其实,上面这个例子就是CountDownLatch的使用场景,小A和小B到家时间不同相当于两个线程的执行时间不同,小A和小B每次回家必须喊一次相当于线程间的通信,老李只有听到两个孩子的敲门声才会去敲门相当于主线程不再阻塞,向下进行。

再举个最近项目中的使用场景。
最近在做图像识别的一个项目,需要上传图片到华为云的modelart服务来获取图片的识别信息,然后对返回信息进行处理,分析出想要的信息。

由于有些产品是需要同时上传两张图片,然后再根据返回的信息进行处理。上传一张图片等待返回信息这个过程的时间大概是3-5秒,上传两张图片,需要访问两次华为云modelart服务,如果使用串行方式的话,那么需要花费10s左右的时间,这里就想到了可以使用CountDownLatch,等待这两个上传操作的线程结束拿到返回信息后,再调用后面的接口来分析这两个图片的信息。

这里,就简单介绍完了CountDownLatch的使用场景,下面简单说下CountDownLatch的使用,直接给出CountDownLatch源码中的例子。

* class Driver2 { // ...
 *   void main() throws InterruptedException {
 *     CountDownLatch doneSignal = new CountDownLatch(N);
 *     Executor e = ...
 *
 *     for (int i = 0; i < N; ++i) // create and start threads
 *       e.execute(new WorkerRunnable(doneSignal, i));
 *
 *     doneSignal.await();           // wait for all to finish
 *   }
 * }
 *
 * class WorkerRunnable implements Runnable {
 *   private final CountDownLatch doneSignal;
 *   private final int i;
 *   WorkerRunnable(CountDownLatch doneSignal, int i) {
 *     this.doneSignal = doneSignal;
 *     this.i = i;
 *   }
 *   public void run() {
 *     try {
 *       doWork(i);
 *       doneSignal.countDown();
 *     } catch (InterruptedException ex) {} // return;
 *   }
 *
 *   void doWork() { ... }
 * }}

首先看main方法,一开始根据需要等待的线程数,初始化CountDownLatch,然后启动线程,线程结束后调用CountDownLatch的countDown方法,当调用countDownLatch的counDown次数和初始化CountDownLatch的线程数相同时,主线程中的CountDownLatch的await方法不再阻塞,往下进行。

使用很简单,主要看源码实现。
CountDownLatch的底层实现是使用AQS队列实现,对AQS的不熟悉的同学可以看下方腾飞的《java并发编程的艺术》这本书或者看下这个AQS。
首先看下await方法。

    public void await() throws InterruptedException {
        sync.acquireSharedInterruptibly(1);
    }

sync这个实例是什么类型的呢


public class CountDownLatch {
    /**
     * Synchronization control For CountDownLatch.
     * Uses AQS state to represent count.
     */
    private static final class Sync extends AbstractQueuedSynchronizer {
        private static final long serialVersionUID = 4982264981922014374L;

        Sync(int count) {
            setState(count);
        }

        int getCount() {
            return getState();
        }

        protected int tryAcquireShared(int acquires) {
            return (getState() == 0) ? 1 : -1;
        }

        protected boolean tryReleaseShared(int releases) {
            // Decrement count; signal when transition to zero
            for (;;) {
                int c = getState();
                if (c == 0)
                    return false;
                int nextc = c-1;
                if (compareAndSetState(c, nextc))
                    return nextc == 0;
            }
        }
    }

    private final Sync sync;

}

Sync类继承了AbstractQueuedSynchronizer(AQS), 通过state值的大小来控制锁的获取。下面根据CountDownLatch的使用来分析下源码。
(1)创建CountDownLatch实例时。

    public CountDownLatch(int count) {
        if (count < 0) throw new IllegalArgumentException("count < 0");
        this.sync = new Sync(count);
    }

这里就可以很清楚的看到,这里会初始化AQS队列的state值的大小,state值其实就是需要等待线程数的大小。

(2)主线程调用CountDownLatch的await方法,阻塞主线程,等待其他线程执行结束。

    public final void acquireSharedInterruptibly(int arg)
            throws InterruptedException {
        if (Thread.interrupted())
            throw new InterruptedException();
        if (tryAcquireShared(arg) < 0)
            doAcquireSharedInterruptibly(arg);
    }

        protected int tryAcquireShared(int acquires) {
            return (getState() == 0) ? 1 : -1;
        }

首先调用tryAcquireShared获取当前state的值,如果值为0返回1,说明其他线程执行结束,不再阻塞。如果值不为0,则返回-1,说明其他线程还未执行结束,需要调用doAcquireSharedInterruptibly方法阻塞等待。
下面看下这个方法的实现。

    private void doAcquireSharedInterruptibly(int arg)
        throws InterruptedException {
        final Node node = addWaiter(Node.SHARED);  ##队列中插入node节点,保存线程信息
        boolean failed = true;
        try {
            for (;;) {
                final Node p = node.predecessor();    ##获取node节点的前一个节点
                if (p == head) {      ## 判断p节点是否是头结点
                    int r = tryAcquireShared(arg);  ##获取state值得大小
                    if (r >= 0) {                      ## r>=0 说明state值为0
                        setHeadAndPropagate(node, r);  ##设置头结点并且触发队列中头结点的下一个节点是否是共享节点,如果是的话,下个节点对应的线程也不再阻塞,具有传播特性。
                        p.next = null; // help GC
                        failed = false;
                        return;
                    }
                }
                if (shouldParkAfterFailedAcquire(p, node) &&
                    parkAndCheckInterrupt())  ## 阻塞调用此方法的线程
                    throw new InterruptedException();
            }
        } finally {
            if (failed)
                cancelAcquire(node);
        }
    }

上面的注释已经说明上面方法中整个的处理过程,其中setHeadAndPropagate和shouldParkAfterFailedAcquire还需要详细分析一下,首先看下setHeadAndPropagate方法。

  private void setHeadAndPropagate(Node node, int propagate) {
        Node h = head; // Record old head for check below
        setHead(node);
        if (propagate > 0 || h == null || h.waitStatus < 0 ||
            (h = head) == null || h.waitStatus < 0) {
            Node s = node.next;
            if (s == null || s.isShared())
                doReleaseShared();
        }
    }

执行此方法的前提是node的前一个节点是head节点,并且state值为0。在这个方法里,首先将当前的node节点设置为head节点,然后根据propagate这个值的大小,判断是否获取node节点的下一个节点,然后根据下一个节点是否是共享式类型的节点,来释放下个节点对应的线程,使下个节点的线程也不再阻塞,propagate使线程的释放具有了传播性,从队列的头结点开始,只要头结点不再阻塞,也可以使队列中的其他共享节点也不再阻塞,具有了传播性。
然后看下shouldParkAfterFailedAcquire方法的实现。

    private static boolean shouldParkAfterFailedAcquire(Node pred, Node node) {
        int ws = pred.waitStatus;
        if (ws == Node.SIGNAL)
            return true;
        if (ws > 0) {
            do {
                node.prev = pred = pred.prev;
            } while (pred.waitStatus > 0);
            pred.next = node;
        } else {
            compareAndSetWaitStatus(pred, ws, Node.SIGNAL);
        }
        return false;
    }

这个方法的目的主要是获取state值不为0时,是否阻塞此线程。如果此方法返回true则会调用parkAndCheckInterrupt这个方法,在这个方法里调用LockSupport的park方法阻塞此线程。那么阻塞后,什么时候唤醒这个线程呢,想要解决这个疑问就需要看下CountDownLatch的countDown方法的处理逻辑了。
(3) 线程执行完,调用CountDownLatch的countDown方法。

    public void countDown() {
        sync.releaseShared(1);
    }

    public final boolean releaseShared(int arg) {
        if (tryReleaseShared(arg)) {
            doReleaseShared();
            return true;
        }
        return false;
    }

        protected boolean tryReleaseShared(int releases) {
            // Decrement count; signal when transition to zero
            for (;;) {
                int c = getState();
                if (c == 0)
                    return false;
                int nextc = c-1;
                if (compareAndSetState(c, nextc))
                    return nextc == 0;
            }
        }

首先,在tryReleaseShared方法中将state值的大小减一,然后执行doReleaseShared方法,

    private void doReleaseShared() {
      
        for (;;) {
            Node h = head;
            if (h != null && h != tail) {
                int ws = h.waitStatus;
                if (ws == Node.SIGNAL) {
                    if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0))
                        continue;            // loop to recheck cases
                    unparkSuccessor(h);
                }
                else if (ws == 0 &&
                         !compareAndSetWaitStatus(h, 0, Node.PROPAGATE))
                    continue;                // loop on failed CAS
            }
            if (h == head)                   // loop if head changed
                break;
        }
    }

在doReleaseShared方法中通过unparkSuccessor获取head节点的下一个节点的thread信息,然后执行LockSupport的unpark方法,这样的话之前await方法中阻塞的线程就不再阻塞,继续往下执行。

通过研究CountDownLatch的这三个方法,基本理解了底层实现,另外,如果能看懂这几个方法的源码,其实对AQS的源码也已经了解的差不多了,后面可以去看下Lock的源码,也是基于AQS实现的。

你可能感兴趣的:(不就是CountDownLatch)