之前看一个开源项目,里面用到了CountDownLatch,当时莫名其妙,不知道这个东西是干嘛的,后来查阅源码才知道这个东西这么好用,那么CountDownLatch是干嘛的呢,简单来说就是一个同步辅助工具类,使用它可以实现在某些线程执行完毕之后再执行另外一些线程,即某些线程执行的时候另外一个线程处于等待状态。我们平时需要在某一个子线程执行完毕之后再执行一些操作,当然可以直接在线程结束后调用我们想要执行的方法,或者采用接口回调的形式来实现,那么通过CountDownLatch来实现这种需求的话,形式上跟接口回调差不多。那么具体怎么使用呢,这里我们关注3个方法即可。
CountDownLatch的构造方法
/**
* Constructs a {@code CountDownLatch} initialized with the given count.
*
* @param count the number of times {@link #countDown} must be invoked
* before threads can pass through {@link #await}
* @throws IllegalArgumentException if {@code count} is negative
*/
public CountDownLatch(int count) {
if (count < 0) throw new IllegalArgumentException("count < 0");
this.sync = new Sync(count);
}
这个构造方法里面传入一个count,也就是说我们需要等待前面count个任务执行完毕之后才执行某个操作
CountDownLatch的countDown方法
/**
* Decrements the count of the latch, releasing all waiting threads if
* the count reaches zero.
*
* If the current count is greater than zero then it is decremented.
* If the new count is zero then all waiting threads are re-enabled for
* thread scheduling purposes.
*
*
If the current count equals zero then nothing happens.
*/
public void countDown() {
sync.releaseShared(1);
}
每个前置任务执行完毕后调用countDown方法使得计数器减1,如果计数器为0的话,那么调用await方法的任务会开始执行
CountDownLatch的await方法
/**
* Causes the current thread to wait until the latch has counted down to
* zero, unless the thread is {@linkplain Thread#interrupt interrupted}.
*
* If the current count is zero then this method returns immediately.
*
*
If the current count is greater than zero then the current
* thread becomes disabled for thread scheduling purposes and lies
* dormant until one of two things happen:
*
* - The count reaches zero due to invocations of the
* {@link #countDown} method; or
*
- Some other thread {@linkplain Thread#interrupt interrupts}
* the current thread.
*
*
* If the current thread:
*
* - has its interrupted status set on entry to this method; or
*
- is {@linkplain Thread#interrupt interrupted} while waiting,
*
* then {@link InterruptedException} is thrown and the current thread's
* interrupted status is cleared.
*
* @throws InterruptedException if the current thread is interrupted
* while waiting
*/
public void await() throws InterruptedException {
sync.acquireSharedInterruptibly(1);
}
在前置任务没有执行完毕的时候,所有调用await方法的线程都会处于等待状态,等待前置任务执行完毕后执行
那么具体怎么用来,这里我直接摘抄源码注释里的代码
class Driver { // ...
void main() throws InterruptedException {
CountDownLatch startSignal = new CountDownLatch(1);
CountDownLatch doneSignal = new CountDownLatch(N);
for (int i = 0; i < N; ++i) // create and start threads
new Thread(new Worker(startSignal, doneSignal)).start();
doSomethingElse(); // don't let run yet
startSignal.countDown(); // let all threads proceed
doSomethingElse();
doneSignal.await(); // wait for all to finish
}
}
class Worker implements Runnable {
private final CountDownLatch startSignal;
private final CountDownLatch doneSignal;
Worker(CountDownLatch startSignal, CountDownLatch doneSignal) {
this.startSignal = startSignal;
this.doneSignal = doneSignal;
}
public void run() {
try {
startSignal.await();
doWork();
doneSignal.countDown();
} catch (InterruptedException ex) {} // return;
}
void doWork() { ... }
}
以上例子中定义了一个startSignal,在每个线程(Worker)的run方法里面调用了startSignal.await,这样可以确保我们初始化的N个线程同时执行。在run方法里面任务执行完毕后又调用了doneSignal.countDown()让计数器减1,在for循环初始化完毕后调用了doneSignal.await()方法,这样就可以保证这N个线程都执行完毕后再执行下面的操作
class Driver2 { // ...
void main() throws InterruptedException {
CountDownLatch doneSignal = new CountDownLatch(N);
Executor e = ...
for (int i = 0; i < N; ++i) // create and start threads
e.execute(new WorkerRunnable(doneSignal, i));
doneSignal.await(); // wait for all to finish
}
}
class WorkerRunnable implements Runnable {
private final CountDownLatch doneSignal;
private final int i;
WorkerRunnable(CountDownLatch doneSignal, int i) {
this.doneSignal = doneSignal;
this.i = i;
}
public void run() {
try {
doWork(i);
doneSignal.countDown();
} catch (InterruptedException ex) {} // return;
}
void doWork() { ... }
}
以上一种用法咋一看好像跟第一个差不多,那么它的一大应用就是将一个任务拆分成N个part,在这N个part都执行完毕之后才执行接下来的操作,比如多线程下载,需要等待所有线程下载完成之后才算是这个文件下载完成,不明白多线程下载的同学请自行查阅资料
那么CountDownLatch底层是如何实现的呢,从上面贴出的CountDownLatch的构造方法里面我们可以看到,它的构造里面仅仅是使用传递进来的count初始化了一个Sync,那么这个Sync是什么呢,其实就是用来进行同步计数控制的AQS(AbstractQueuedSynchronizer),AbstractQueuedSynchronizer内部其实就是通过FIFO阻塞队列以及wait notify来实现同步的。而对于计数器的修改则是使用的CAS,CAS 就是Compare And Swap,不明白的自行查阅资料。Sync作为CountDownLatch的内部类存在,源码如下
/**
* Synchronization control For CountDownLatch.
* Uses AQS state to represent count.
*/
private static final class Sync extends AbstractQueuedSynchronizer {
private static final long serialVersionUID = 4982264981922014374L;
Sync(int count) {
setState(count);
}
int getCount() {
return getState();
}
protected int tryAcquireShared(int acquires) {
return (getState() == 0) ? 1 : -1;
}
protected boolean tryReleaseShared(int releases) {
// Decrement count; signal when transition to zero
for (;;) {
int c = getState();
if (c == 0)
return false;
int nextc = c-1;
if (compareAndSetState(c, nextc))
return nextc == 0;
}
}
}
我们先看看CountDownLatch.countDown()做了一些什么,从上面给出的countDown方法源码我们可以看到它仅仅是调用了sync.releaseShared(1),每次调用countDown()都会对原子值state进行修改,而调用await()方法的地方会一直执行一个死循环,这个死循环会在原子值state为0的时候结束,说白了就是相当于阻塞调用await方法的线程。大家查阅源码的时候一定要先看Sync类里面的tryAcquireShared()方法和tryReleaseShared()方法,这两个方法是重写的父类AbstractQueuedSynchronizer里面的方法,这两个方法在父类里面有调用,逻辑也不是很复杂,大家自行查阅源码里面的调用逻辑即可。AQS里面用到了Unsafe类,这个类类似c直接操作指针,是线程不安全的类,对于并发的处理实际上是依赖于CAS,有兴趣的同学可以好好研究!