Java并发编程系列之CountDownLatch用法及详解

背景

前几天一个同事问我,对这个CountDownLatch有没有了解想问一些问题,当时我一脸懵逼,不知道如何回答。今天赶紧抽空好好补补。不得不说Doug Lea大师真的很牛,设计出如此好的类。

1、回顾旧知识

volatile关键字:当一个共享变量被volatile修饰时,它会保证修改的值会立即被更新到主存,当有其他线程需要读取时,它会去内存中读取新值。(这涉及到java内存模型了,有兴趣了解java内存模型的可以先找资料看看)。

2、CountDownLatch简介

CountDownLatch 可以理解就是个计数器,只能减不能加,同时它还有个门闩的作用,当计数器不为0时,门闩是锁着的;当计数器减到0时,门闩就打开了。
如果还不是很理解的话,举个简单的例子就是,你去超市买东西,虽然已经到了关门时间但是只有顾客都走了超市才能关门,至于你买不买东西,超市不关心。只要顾客都走完了,我就可以关门了。

2、CountDownLatch具体使用场景

有A和B两个任务,只有当A任务执行完之后,才能执行B任务。A和B都可以拆分小任务。比如下载一个大文件,可以使用多线程下载,等下载完之后,在统一处理。

2、CountDownLatch实现原理(jdk1.8)

CountDownLatch源码

public class CountDownLatch {
    /**
     * 内部类继承AQS
     * Uses AQS state to represent count.
     */
    private static final class Sync extends AbstractQueuedSynchronizer {
        private static final long serialVersionUID = 4982264981922014374L;

        Sync(int count) {
            setState(count);
        }

        int getCount() {
            return getState();
        }

        protected int tryAcquireShared(int acquires) {
            return (getState() == 0) ? 1 : -1;
        }

        protected boolean tryReleaseShared(int releases) {
            // Decrement count; signal when transition to zero
            for (;;) {
                int c = getState();
                if (c == 0)
                    return false;
                int nextc = c-1;
                if (compareAndSetState(c, nextc))
                    return nextc == 0;
            }
        }
    }

    private final Sync sync;

    /**
     * 构造方法一般传线程总数
     */
    public CountDownLatch(int count) {
        if (count < 0) throw new IllegalArgumentException("count < 0");
        this.sync = new Sync(count);
    }

    /**
     *  等待方法
     */
    public void await() throws InterruptedException {
        sync.acquireSharedInterruptibly(1);
    }

   /**
    * 等待重载超时等待
    */
    public boolean await(long timeout, TimeUnit unit)
        throws InterruptedException {
        return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
    }

    /**
     * 计数减1
     */
    public void countDown() {
        sync.releaseShared(1);
    }

    /**
     *  当前计数
     */
    public long getCount() {
        return sync.getCount();
    }

    /**
    *
     */
    public String toString() {
        return super.toString() + "[Count = " + sync.getCount() + "]";
    }

为什么定义一个内部类?这种结构的好处在于我们不必关心AbstractQueuedSynchronizer(以下简称AQS)的同步状态管理、线程排队、等待与唤醒等底层操作,我们只需重写我们想要的方法。可生成特定并发工具类。
CountDownLatch主要两个方法就是一是CountDownLatch.await()阻塞当前线程,二是CountDownLatch.countDown()当前线程把计数器减一
看完源码,我们可以看出实现CountDownLatch主要思想就是使用volatile和同步队列来放置这些阻塞队列。
a、CountDownLatch.await()方法
如果让我们自己实现一个await方法我们会怎么做
一、首先会想到是会使用线程间wait/notify,使用synchronized关键字,检查计数器值不为0,然后调用Object.wait();直到计数器值0则调用notifyAll()唤醒等待线程。但是大量的
synchronized代码块会存在假唤醒。
我们还是看看Doug Lea是怎么实现这个类的。

CountDownLatch构造方法

public CountDownLatch(int count) {
        if (count < 0) throw new IllegalArgumentException("count < 0");
        this.sync = new Sync(count);
    }

构造方法传入了一个int变量,这个int变量是AQS中的state,类型是volatile的,它就是用来表示计数器值的。内存共享这个变量,只有有修改,其他线程都能读取到。

await方法

public void await() throws InterruptedException {
        sync.acquireSharedInterruptibly(1);
    }

调用await()的方法后,会默认调用sync这个实例的acquireSharedInterruptibly这个方法,并且参数为1,需要注意的是,这个方法声明了一个InterruptedException异常,表示调用该方法的线程支持打断操作。

sync acquireSharedInterruptibly

public final void acquireSharedInterruptibly(int arg)
            throws InterruptedException {
        if (Thread.interrupted())
            throw new InterruptedException();
        if (tryAcquireShared(arg) < 0)
            doAcquireSharedInterruptibly(arg);
    }

acquireSharedInterruptibly这个方法是sync继承AQS而来的,这个方法的调用是响应线程的打断的,所以在前两行会检查线程是否被打断。接着调用tryAcquireShared()方法来判断返回值,根据值的大小决定是否执行doAcquireSharedInterruptibly()。

tryAcquireShared这个方法是在Sync中重写方法

protected int tryAcquireShared(int acquires) {
            return (getState() == 0) ? 1 : -1;
        }

在子类sync的tryAcquireShared中它只验证了计数器的值是否为0,如果为0则返回1,反之返回-1,根据上面代码可以看出,整数就不会执行doAcquireSharedInterruptibly(),该线程就结束方法,继续执行本身代码了。

doAcquireSharedInterruptibly

private void doAcquireSharedInterruptibly(int arg)
        throws InterruptedException {
        final Node node = addWaiter(Node.SHARED);// 往同步队列中添加节点
        boolean failed = true;
        try {
            for (;;) {//死循环
                final Node p = node.predecessor();//当前线程的前一个节点
                if (p == head) {//首节点
                    int r = tryAcquireShared(arg);
                    if (r >= 0) {
                        setHeadAndPropagate(node, r);
                        p.next = null; // help GC
                        failed = false;
                        return;
                    }
                }
                if (shouldParkAfterFailedAcquire(p, node) &&
                    parkAndCheckInterrupt())
                    throw new InterruptedException();
            }
        } finally {
            if (failed)
                cancelAcquire(node);
        }
    }

因为计数器值不为0需要阻塞线程,所以在进入方法时,将该线程包装成节点并加入到同步队列尾部(addWaiter方法),我们看到这个方法退出去的途径直有两个,一个是return,一个是throw InterruptedException。注意最后的finally的处理。return退出方法有必须满足两个条件首先是首节点,其次是计数值为0。
throw InterruptedException是响应打断操作的,线程在阻塞期间,如果你不想在等待了,可以打断线程让它继续运行后面的任务(注意异常处理)。

addWaiter添加节点

private Node addWaiter(Node mode) {
        Node node = new Node(Thread.currentThread(), mode);//包装节点
        // Try the fast path of enq; backup to full enq on failure
        Node pred = tail; //同步队列尾节点
        if (pred != null) {
            node.prev = pred;//同步队列有尾节点 将我们的节点通过cas方式添加到队列后面
            if (compareAndSetTail(pred, node)) {
                pred.next = node;
                return node;
            }
        }
        enq(node);// 两种情况执行这个代码 1.队列尾节点为null 2.队列尾节点不为null,但是我们原子添加尾节点失败
        return node;
    }

private Node enq(final Node node) {
        for (;;) {
            Node t = tail;
            if (t == null) { //  cas形式添加头节点  注意 是头节点
                if (compareAndSetHead(new Node()))
                    tail = head;
            } else {
                node.prev = t;//cas形式添加尾节点
                if (compareAndSetTail(t, node)) {
                    t.next = node;
                    return t;//结束方法必须是尾节点添加成功
                }
            }
        }
    }

b、CountDownLatch.countDown()方法
当部分线程调用await()方法后,它们在同步队列中被挂起,然后循环的检查自己能否满足醒来的条件(还记得那个条件吗?1、state为0,2、该节点为头节点),

 *      +------+  prev +-----+       +-----+
 * head |      | <---- |     | <---- |     |  tail
 *      +------+       +-----+       +-----+

同步队列

volatile Node prev;
volatile Node next;

volatile的prev指向上一个node节点,volatile的next指向下一个node节点。当然如果是头节点,那么它的prev为null,同理尾节点的next为null。

private transient volatile Node head;
private transient volatile Node tail;

用来表示同步队列的头节点和尾节点

countDown方法

 public void countDown() {
        sync.releaseShared(1);
    }

releaseShared方法

public final boolean releaseShared(int arg) {
        if (tryReleaseShared(arg)) {
            doReleaseShared();
            return true;
        }
        return false;
    }

在Sync类中并没有releaseShared()方法,所以继承与AQS,看到AQS这个方法中,退出该方法的只有两条路。tryReleaseShared(arg)条件为真执行一个doReleaseShared()退出;条件为假直接退出。

protected boolean tryReleaseShared(int releases) {
    for (;;) {//死循环
        int c = getState();// 获取主存中的state值
        if (c == 0) //state已经为0 直接退出
            return false;
        int nextc = c-1; // 减一 准备cas更新该值
        if (compareAndSetState(c, nextc)) //cas更新
            return nextc == 0; //更新成功 判断是否为0 退出;更新失败则继续for循环,直到线程并发更新成功
    }
}

doReleaseShared方法

private void doReleaseShared() {
    for (;;) {//死循环
        Node h = head;
        if (h != null && h != tail) {
            int ws = h.waitStatus;
            if (ws == Node.SIGNAL) {//如果当前节点是SIGNAL,它正在等待一个信号,或者说它在等待被唤醒,因此做两件事,1是重置waitStatus标志位,2是重置成功后,唤醒下一个节点。
                if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0))
                    continue;            
                unparkSuccessor(h);
            }else if (ws == 0 && !compareAndSetWaitStatus(h, 0, Node.PROPAGATE))//如果本身头节点的waitStatus是出于重置状态(waitStatus==0)的,将其设置为“传播”状态。意味着需要将状态向后一个节点传播。
                continue;                
        }
        if (h == head)                   
            break;
    }
}
重点来了

为啥要执行这个方法呀,因为state已经为0啦,我们该将同步队列中的线程状态设置为共享状态(Node.PROPAGATE,默认状态ws == 0),并向后传播,实现状态共享。

退出死循环,只有一条,那就是h==head,即该线程是头节点,且状态为共享状态。

可能有人有疑问,state已经等于0了,我们也通过循环的方式把头节点的状态设置为共享状态,但是它怎么醒过来的呢?看上面doAcquireSharedInterruptibly方法。

在同步队列中挂起的线程,它们自旋的形式查看自己是否满足条件醒来(state==0,且为头节点),如果成立将调用setHeadAndPropagate这个方法

private void setHeadAndPropagate(Node node, int propagate) {
     Node h = head; // Record old head for check below
     setHead(node);
     if (propagate > 0 || h == null || h.waitStatus < 0 ||
         (h = head) == null || h.waitStatus < 0) {
         Node s = node.next;
        if (s == null || s.isShared())
            doReleaseShared();
    }
}

看一个例子在加深下印象

/**
 * @author shuliangzhao
 * @Title: TestCountDownLatch
 * @ProjectName design-parent
 * @Description: TODO
 * @date 2019/6/2 12:19
 */
public class CountDownLatchExc {

    private static final int i = 2;

    static class MyRunable implements Runnable {

        private int num;

        private CountDownLatch countDownLatch;

        public MyRunable(int num,CountDownLatch countDownLatch) {
            this.num = num;
            this.countDownLatch = countDownLatch;
        }

        @Override
        public void run() {
            try {
                System.out.println("第" + num + "个线程开始执行任务...");
                Thread.sleep(2000);
                System.out.println("第" + num + "个线程开始执行结束...");
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
            countDownLatch.countDown();
        }
    }

    public static void main(String[] args) {
        CountDownLatch countDownLatch = new CountDownLatch(i);
        for (int i = 0;i < 5;i++) {
            Thread thread = new Thread(new MyRunable(i,countDownLatch));
            thread.start();
        }
        System.out.println("main thread wait.");
        try {
            countDownLatch.await();
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        System.out.println("main thread end...");
    }
}

运行结果


Java并发编程系列之CountDownLatch用法及详解_第1张图片
image.png

以上就是CountDownLatch两大重要方法解释,可能理解有偏差,欢迎指出。

你可能感兴趣的:(Java并发编程系列之CountDownLatch用法及详解)