CountDownLatch是一个一次性的线程同步工具,一般用于主线程等待多个工作线程均执行完毕之后,主线程再执行后续工作。
常见用法如下:
public static void main(String[] args) {
ExecutorService service = Executors.newFixedThreadPool(3);
final CountDownLatch latch = new CountDownLatch(3);
for (int i = 0; i < 3; i++) {
Runnable runnable = new Runnable() {
@Override
public void run() {
try {
System.out.println("子线程" + Thread.currentThread().getName() + "开始执行");
Thread.sleep((long) (Math.random() * 10000));
System.out.println("子线程"+Thread.currentThread().getName()+"执行完成");
latch.countDown();//当前线程调用此方法,则计数减一
} catch (InterruptedException e) {
e.printStackTrace();
}
}
};
service.execute(runnable);
}
try {
System.out.println("主线程"+Thread.currentThread().getName()+"等待子线程执行完成...");
latch.await();//阻塞当前线程,直到计数器的值为0
System.out.println("主线程"+Thread.currentThread().getName()+"开始执行...");
} catch (InterruptedException e) {
e.printStackTrace();
}
}
上述代码中,首先初始化了一个状态数为3的CountDownLatch,然后启动了三个子线程,在每个线程中,分别调用CountDownLatch的countDown方法。同时在主线程中,调用CountDown的await方法区等待三个子线程完成。
执行结果
主线程main等待子线程执行完成...
子线程pool-1-thread-1开始执行
子线程pool-1-thread-3开始执行
子线程pool-1-thread-2开始执行
子线程pool-1-thread-2执行完成
子线程pool-1-thread-1执行完成
子线程pool-1-thread-3执行完成
主线程main开始执行...
实际上,等待CountDownLatch子线程执行完成的,也可以是多线程。例如如下代码,用两个线程等待
public static void main(String[] args) {
ExecutorService service = Executors.newFixedThreadPool(3);
final CountDownLatch latch = new CountDownLatch(3);
for (int i = 0; i < 3; i++) {
Runnable runnable = new Runnable() {
@Override
public void run() {
try {
System.out.println("子线程" + Thread.currentThread().getName() + "开始执行");
Thread.sleep((long) (Math.random() * 10000));
System.out.println("子线程"+Thread.currentThread().getName()+"执行完成");
latch.countDown();//当前线程调用此方法,则计数减一
} catch (InterruptedException e) {
e.printStackTrace();
}
}
};
service.execute(runnable);
}
try {
Thread thread1 = new Thread(() -> {
System.out.println("线程1等待子线程执行完成");
try {
latch.await();
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
System.out.println("线程1开始执行");
});
thread1.start();
Thread thread2 = new Thread(() -> {
System.out.println("线程2等待子线程执行完成");
try {
latch.await();
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
System.out.println("线程2开始执行");
});
thread2.start();
System.in.read();
} catch (Exception e) {
e.printStackTrace();
}
}
子线程pool-1-thread-3开始执行
线程2等待子线程执行完成
子线程pool-1-thread-1开始执行
线程1等待子线程执行完成
子线程pool-1-thread-2开始执行
子线程pool-1-thread-3执行完成
子线程pool-1-thread-2执行完成
子线程pool-1-thread-1执行完成
线程2开始执行
线程1开始执行
那么CountDownLatch是如何实现可以多个主线程等待多个子线程完成的呢?让我们看一下它的源码
首先是定义了一个AbstractQueuedSynchronizer的子类Sync:
private static final class Sync extends AbstractQueuedSynchronizer {
private static final long serialVersionUID = 4982264981922014374L;
Sync(int count) {
setState(count);
}
int getCount() {
return getState();
}
protected int tryAcquireShared(int acquires) {
return (getState() == 0) ? 1 : -1;
}
protected boolean tryReleaseShared(int releases) {
// Decrement count; signal when transition to zero
for (;;) {
int c = getState();
if (c == 0)
return false;
int nextc = c - 1;
if (compareAndSetState(c, nextc))
return nextc == 0;
}
}
}
可以看出,该AQS的子类中,重新实现了tryAcquireShared和tryReleaseShared方法。说明它是基于AQS的共享模式实现的。
再看构造器方法
public CountDownLatch(int count) {
if (count < 0) throw new IllegalArgumentException("count < 0");
this.sync = new Sync(count);
}
参数代表子线程数。
await方法
public void await() throws InterruptedException {
sync.acquireSharedInterruptibly(1);
}
方法中调用了acquireSharedInterruptibly方法。该方法采用共享获取模式。我们知道在AQS中,acquire和release都有独占和共享模式(共享模式在方法名中加了Shared)。其中共享模式与独占模式的区别在于,共享模式在AQS队列中的节点在获取到锁之后,会通知它的下一个节点,它的下一个节点也会尝试获取。(这就是共享锁的意义)
public final void acquireSharedInterruptibly(int arg)
throws InterruptedException {
if (Thread.interrupted() ||
(tryAcquireShared(arg) < 0 &&
acquire(null, arg, true, true, false, 0L) < 0))
throw new InterruptedException();
}
调用acquire方法时,第三个参数传入true,代表是共享锁模式
final int acquire(Node node, int arg, boolean shared,
boolean interruptible, boolean timed, long time) {
Thread current = Thread.currentThread();
byte spins = 0, postSpins = 0; // retries upon unpark of first thread
boolean interrupted = false, first = false;
Node pred = null; // predecessor of node when enqueued
/*
* Repeatedly:
* Check if node now first
* if so, ensure head stable, else ensure valid predecessor
* if node is first or not yet enqueued, try acquiring
* else if node not yet created, create it
* else if not yet enqueued, try once to enqueue
* else if woken from park, retry (up to postSpins times)
* else if WAITING status not set, set and retry
* else park and clear WAITING status, and check cancellation
*/
for (;;) {
if (!first && (pred = (node == null) ? null : node.prev) != null &&
!(first = (head == pred))) {
if (pred.status < 0) {
cleanQueue(); // predecessor cancelled
continue;
} else if (pred.prev == null) {
Thread.onSpinWait(); // ensure serialization
continue;
}
}
if (first || pred == null) {
boolean acquired;
try {
if (shared)
acquired = (tryAcquireShared(arg) >= 0);
else
acquired = tryAcquire(arg);
} catch (Throwable ex) {
cancelAcquire(node, interrupted, false);
throw ex;
}
if (acquired) {
if (first) {
node.prev = null;
head = node;
pred.next = null;
node.waiter = null;
if (shared)
// 重点是这里,如果是共享模式,会通知下一个节点
signalNextIfShared(node);
if (interrupted)
current.interrupt();
}
return 1;
}
}
if (node == null) { // allocate; retry before enqueue
if (shared)
node = new SharedNode();
else
node = new ExclusiveNode();
} else if (pred == null) { // try to enqueue
node.waiter = current;
Node t = tail;
node.setPrevRelaxed(t); // avoid unnecessary fence
if (t == null)
tryInitializeHead();
else if (!casTail(t, node))
node.setPrevRelaxed(null); // back out
else
t.next = node;
} else if (first && spins != 0) {
--spins; // reduce unfairness on rewaits
Thread.onSpinWait();
} else if (node.status == 0) {
node.status = WAITING; // enable signal and recheck
} else {
long nanos;
spins = postSpins = (byte)((postSpins << 1) | 1);
if (!timed)
LockSupport.park(this);
else if ((nanos = time - System.nanoTime()) > 0L)
LockSupport.parkNanos(this, nanos);
else
break;
node.clearStatus();
if ((interrupted |= Thread.interrupted()) && interruptible)
break;
}
}
return cancelAcquire(node, interrupted, interruptible);
}
重点是我添加注释那一行
private static void signalNextIfShared(Node h) {
Node s;
if (h != null && (s = h.next) != null &&
(s instanceof SharedNode) && s.status != 0) {
s.getAndUnsetStatus(WAITING);
LockSupport.unpark(s.waiter);
}
}
LockSupport的unpark就代表唤醒线程。
countDown方法
public void countDown() {
sync.releaseShared(1);
}
public final boolean releaseShared(int arg) {
if (tryReleaseShared(arg)) {
signalNext(head);
return true;
}
return false;
}
这里tryReleaseShared方法会调用CountDownLatch中Sync实现的方法
protected boolean tryReleaseShared(int releases) {
// Decrement count; signal when transition to zero
for (;;) {
int c = getState();
if (c == 0)
return false;
int nextc = c - 1;
if (compareAndSetState(c, nextc))
return nextc == 0;
}
}
它会判断必须当前这一次release将state减为0才返回true。
那么,能否实现只允许一个线程等待的CountDownLatch呢?简单地说,可以将CountDownLatch中涉及到acquire和release的方法都换成独占版的,同时稍微改一下内部类Sync的tryAcquire方法。
下面是我自定义的MyCountDownLatch类
public class MyCountDownLatch {
/**
* Synchronization control For CountDownLatch.
* Uses AQS state to represent count.
*/
private static final class Sync extends AbstractQueuedSynchronizer {
private static final long serialVersionUID = 4982264981922014374L;
Sync(int count) {
setState(count);
}
int getCount() {
return getState();
}
protected boolean tryAcquire(int acquires) {
return (getState() == 0) ? true : false;
}
protected boolean tryRelease(int releases) {
// Decrement count; signal when transition to zero
for (;;) {
int c = getState();
if (c == 0)
return false;
int nextc = c - 1;
if (compareAndSetState(c, nextc))
return nextc == 0;
}
}
}
private final Sync sync;
/**
* Constructs a {@code CountDownLatch} initialized with the given count.
*
* @param count the number of times {@link #countDown} must be invoked
* before threads can pass through {@link #await}
* @throws IllegalArgumentException if {@code count} is negative
*/
public MyCountDownLatch(int count) {
if (count < 0) throw new IllegalArgumentException("count < 0");
this.sync = new Sync(count);
}
/**
* Causes the current thread to wait until the latch has counted down to
* zero, unless the thread is {@linkplain Thread#interrupt interrupted}.
*
* If the current count is zero then this method returns immediately.
*
*
If the current count is greater than zero then the current
* thread becomes disabled for thread scheduling purposes and lies
* dormant until one of two things happen:
*
* - The count reaches zero due to invocations of the
* {@link #countDown} method; or
*
- Some other thread {@linkplain Thread#interrupt interrupts}
* the current thread.
*
*
* If the current thread:
*
* - has its interrupted status set on entry to this method; or
*
- is {@linkplain Thread#interrupt interrupted} while waiting,
*
* then {@link InterruptedException} is thrown and the current thread's
* interrupted status is cleared.
*
* @throws InterruptedException if the current thread is interrupted
* while waiting
*/
public void await() throws InterruptedException {
sync.acquireInterruptibly(1);
}
/**
* Causes the current thread to wait until the latch has counted down to
* zero, unless the thread is {@linkplain Thread#interrupt interrupted},
* or the specified waiting time elapses.
*
* If the current count is zero then this method returns immediately
* with the value {@code true}.
*
*
If the current count is greater than zero then the current
* thread becomes disabled for thread scheduling purposes and lies
* dormant until one of three things happen:
*
* - The count reaches zero due to invocations of the
* {@link #countDown} method; or
*
- Some other thread {@linkplain Thread#interrupt interrupts}
* the current thread; or
*
- The specified waiting time elapses.
*
*
* If the count reaches zero then the method returns with the
* value {@code true}.
*
*
If the current thread:
*
* - has its interrupted status set on entry to this method; or
*
- is {@linkplain Thread#interrupt interrupted} while waiting,
*
* then {@link InterruptedException} is thrown and the current thread's
* interrupted status is cleared.
*
* If the specified waiting time elapses then the value {@code false}
* is returned. If the time is less than or equal to zero, the method
* will not wait at all.
*
* @param timeout the maximum time to wait
* @param unit the time unit of the {@code timeout} argument
* @return {@code true} if the count reached zero and {@code false}
* if the waiting time elapsed before the count reached zero
* @throws InterruptedException if the current thread is interrupted
* while waiting
*/
public boolean await(long timeout, TimeUnit unit)
throws InterruptedException {
return sync.tryAcquireNanos(1, unit.toNanos(timeout));
}
/**
* Decrements the count of the latch, releasing all waiting threads if
* the count reaches zero.
*
*
If the current count is greater than zero then it is decremented.
* If the new count is zero then all waiting threads are re-enabled for
* thread scheduling purposes.
*
*
If the current count equals zero then nothing happens.
*/
public void countDown() {
sync.release(1);
}
/**
* Returns the current count.
*
*
This method is typically used for debugging and testing purposes.
*
* @return the current count
*/
public long getCount() {
return sync.getCount();
}
/**
* Returns a string identifying this latch, as well as its state.
* The state, in brackets, includes the String {@code "Count ="}
* followed by the current count.
*
* @return a string identifying this latch, as well as its state
*/
public String toString() {
return super.toString() + "[Count = " + sync.getCount() + "]";
}
}
其中,主要是将调用AQS的共享模式的方法地方,都换成了独占模式的地方,同时修改了一下tryAcquire方法,改之后是
protected boolean tryAcquire(int acquires) {
return (getState() == 0) ? true : false;
}
主要是因为tryAcquire和tryAcquireShared的返回值类型不一样,前者是boolean,后者是int。
此时,再使用两个主线程去等待三个子线程完成
public static void main(String[] args) {
ExecutorService service = Executors.newFixedThreadPool(3);
final MyCountDownLatch latch = new MyCountDownLatch(3);
for (int i = 0; i < 3; i++) {
Runnable runnable = new Runnable() {
@Override
public void run() {
try {
System.out.println("子线程" + Thread.currentThread().getName() + "开始执行");
Thread.sleep((long) (Math.random() * 10000));
System.out.println("子线程"+Thread.currentThread().getName()+"执行完成");
latch.countDown();//当前线程调用此方法,则计数减一
} catch (InterruptedException e) {
e.printStackTrace();
}
}
};
service.execute(runnable);
}
try {
Thread thread1 = new Thread(() -> {
System.out.println("线程1等待子线程执行完成");
try {
latch.await();
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
System.out.println("线程1开始执行");
});
thread1.start();
Thread thread2 = new Thread(() -> {
System.out.println("线程2等待子线程执行完成");
try {
latch.await();
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
System.out.println("线程2开始执行");
});
thread2.start();
System.in.read();
} catch (Exception e) {
e.printStackTrace();
}
}
输出
子线程pool-1-thread-3开始执行
子线程pool-1-thread-1开始执行
子线程pool-1-thread-2开始执行
线程1等待子线程执行完成
线程2等待子线程执行完成
子线程pool-1-thread-2执行完成
子线程pool-1-thread-3执行完成
子线程pool-1-thread-1执行完成
线程1开始执行
可见只有线程1在MyCountDownLatch的state减为0之后恢复了执行。而线程2却一直在AQS队列中,无法恢复执行。
当然这个MyCountDownLatch类只是为了验证而写的类,不能用于正式环境,因为如果多个主线程等待,就会造成只有一个线程能恢复执行,而其余线程永久阻塞的后果。