浅析CountDownLatch源码

[TOC]

需要了解 AQS 知识。

CountDownLatch 能够等待一个或一组线程,直到其他线程执行完成(计数器减为 0)时,才继续执行。

其实调用线程的 join() 方法能够实现等待线程完成后再继续执行的场景。

不过 CountDownLatch 更为灵活:https://blog.csdn.net/zhutulang/article/details/48504487

CountDownLatch 实现的原理大致如下:

创建时传入计数器初始值,子任务完成时,AQS 中的 state 属性可以表示等待完成的任务数量,没完成一项计数 -1,计数器为 0 时,唤醒调用线程。

构造方法

CountDownLatch 的构造方法必须传入一个整形作为计数器的初始值,该数值用于初始化 Sync。

private final Sync sync;

public CountDownLatch(int count) {
    if (count < 0) throw new IllegalArgumentException("count < 0");
    this.sync = new Sync(count);
}

属性

Sync

CountDownLatch 内部方法实现都调用了 Sync,可见 Sync 为该类核心。

private static final class Sync extends AbstractQueuedSynchronizer {
    private static final long serialVersionUID = 4982264981922014374L;
    
    Sync(int count) {
        // Sync 继承自 AQS,从名称可以看出初始化值传入 AQS 的 state 属性
         setState(count);
    }   

    int getCount() {
        return getState();
    }
    
    // 使用了 AQS 的共享模式
    protected int tryAcquireShared(int acquires) {
        return (getState() == 0) ? 1 : -1;
    }

    protected boolean tryReleaseShared(int releases) {
        // 通过自选操作实现自减 1
        for (;;) {
            // 获取更新的 state 值
            int c = getState();
            // 若无需释放锁(=0)
            if (c == 0)
                return false;
            // 若释放锁则递减
            int nextc = c-1;
            if (compareAndSetState(c, nextc))
                return nextc == 0;
        }
    }
}

方法

// 获取还需要等待的任务数量
public long getCount() {...}

public String toString() {...}

public void await()

调用 await() 方法后,调用线程会被阻塞,直到出现下面情况之一:

  • 所有任务线程调用 countDown 方法,即计数器为 0
  • 其他线程调用当前线程的 interrupt() 方法进行中断,此时会抛出异常
// 等待任务完成,计数器为 0 时返回
public void await() throws InterruptedException  {
    sync.acquireSharedInterruptibly(1);
}

// 设置了超时时间
public boolean await(long timeout, TimeUnit unit) { ... }

acquireSharedInterruptibly 是 AQS 中定义的方法

public final void acquireSharedInterruptibly(int arg)
        throws InterruptedException {
    // 线程被中断则抛出异常
    if (Thread.interrupted())
        throw new InterruptedException();
    
    // 至少尝试一次 tryAcquireShared
    //      成功:返回
    //      失败:线程进入等待队列
    if (tryAcquireShared(arg) < 0) // (1)
        // 进入等待队列
        doAcquireSharedInterruptibly(arg);  // (2)
}

(1)CountDownLatch 中 sync 定义的方法,判断 state 是否为 0

protected int tryAcquireShared(int acquires) {
    return (getState() == 0) ? 1 : -1;
}

doAcquireSharedInterruptibly 方法创建了共享模式的 AQS 节点进入等待队列进行排队。

CountDownLatch 设置 state 后未置为 0,调用 await 的线程都会进行等待。

private void doAcquireSharedInterruptibly(int arg)
    throws InterruptedException {
    // 创建共享模式节点,并加入等待队列
    final Node node = addWaiter(Node.SHARED);
    boolean failed = true;
    try {
        // 自旋尝试获得锁
        for (;;) {
            final Node p = node.predecessor();
            if (p == head) {
                int r = tryAcquireShared(arg);
                if (r >= 0) {
                    setHeadAndPropagate(node, r);
                    p.next = null; // help GC
                    failed = false;
                    return;
                }
            }
            
            // 当前线程是否需要挂起
            if (shouldParkAfterFailedAcquire(p, node) &&
                parkAndCheckInterrupt())
                throw new InterruptedException();
        }
    } finally {
        if (failed)
            cancelAcquire(node);
    }
}

public void countDown()

// 递减计数器,计数器等于 0,则释放所有等待的线程
public void countDown() { sync.releaseShared(1); }

AQS 中的 releaseShared 实现

public final boolean releaseShared(int arg) {
// 会执行 sync 的 tryReleaseShared 方法 -1 ,然后进行共享锁的释放操作
    if (tryReleaseShared(arg)) { // (1)
        doReleaseShared();  // (2)
        return true;
    }
    return false;
}

(1)CountDownLatch 中 sync 定义的方法,判断 state 是否为 0

protected boolean tryReleaseShared(int releases) {
    // Decrement count; signal when transition to zero
    for (;;) {
        int c = getState();
        if (c == 0)
            return false;   // state 计数为 0 时,释放失败
        int nextc = c-1;
        if (compareAndSetState(c, nextc))   // 将 state 值使用 CAS 置为 state-1
            return nextc == 0;
    }
}

你可能感兴趣的:(浅析CountDownLatch源码)