在之前的文章Java编程拾遗『线程协作』中,介绍了一些线程协作的场景,并做了简单的代码实现,比如上文中的MyLatch、AssemblePoint、MySemaphore等。其实上篇文章介绍的线程协作场景,在Java API中都有响应实现。本篇文章就来介绍一下,Java API中提供的一些线程协作工具及使用场景。本篇文章会介绍以下几种协作工具:
Java API中提供的这几种线程协作工具分别跟之前文章实现的MyLatch、AssemblePoint、MySemaphore对应。区别在于之前我们自定义的协作类是通过wait/notify实现的,而Java API中提供的线程协作类是通过AQS实现的,效率会更高一些。
在之前的文章,我们通过wait/notify实现了一个简单的门栓MyLatch。同时提到,Java并发包中已经提供了类似工具CountDownLatch。它的大概含义是指,它相当于是一个门栓,一开始是关闭的,所有希望通过该门的线程都需要等待,然后开始倒计时,倒计时变为0后,门栓打开,等待的所有线程都可以通过,它是一次性的,打开后就不能再关上了。
public CountDownLatch(int count)
构造函数中的count就是用来计数的数字
多个线程可以基于这个计数进行协作,它的主要方法有:
public void await() throws InterruptedException
public boolean await(long timeout, TimeUnit unit) throws InterruptedException
public void countDown()
await()检查计数是否为0,如果大于0,就等待,await()可以被中断,也可以设置最长等待时间。countDown检查计数,如果已经为0,直接返回,否则减少计数,如果新的计数变为0,则唤醒所有等待的线程。
在之前Java编程拾遗『线程协作』一文中,我们使用MyLatch实现了同时开始和等待结束两种协作模式,这里我们就使用Java API提供的CountDownLatch来实现这一功能。
还用那个运动员赛跑和裁判发令枪的例子,当裁判发令枪响起,所有运动员开始跑步。如下:
public class RacerWithLatchDemo {
static class Racer extends Thread {
private CountDownLatch latch;
public Racer(CountDownLatch latch) {
this.latch = latch;
}
@Override
public void run() {
try {
this.latch.await();
System.out.println("start run "
+ Thread.currentThread().getName());
} catch (InterruptedException ignored) {
}
}
}
public static void main(String[] args) throws Exception {
int num = 10;
CountDownLatch latch = new CountDownLatch(1);
Thread[] racers = new Thread[num];
for (int i = 0; i < num; i++) {
racers[i] = new Racer(latch);
racers[i].start();
}
Thread.sleep(2000);
latch.countDown();
}
}
所有的运动员线程都等待计数变为0,当所有运动员线程都启动之后,由于此时计数为1,那么所有的运动员线程都将阻塞等待(AQS中是通过LockSupport.park()实现的的)。随后主线程将计数变为0,并唤醒所有等待的Racer线程,所有Racer线程同时启动。也就是通过CountDownLatch实现多线程协作,同时开始。
除了同时开始,CountDownLatch还用于另外一个经典的多线程协作场景,等待结束。就是讲线程的角色分为两种,其中一种线程要等待另一种线程执行结束之后继续执行,一般用于异步任务执行,主线程汇总各自线程的执行结果。
public class Worker extends Thread {
private CountDownLatch latch;
public Worker(CountDownLatch latch) {
this.latch = latch;
}
@Override
public void run() {
try {
// 模拟线程运行
Thread.sleep((int) (Math.random() * 1000));
} catch (InterruptedException ignored) {
//ignore
} finally {
latch.countDown();
}
}
}
public class CountDownLatchTest {
public static void main(String[] args) throws Exception{
int workerNum = 100;
CountDownLatch latch = new CountDownLatch(workerNum);
Worker[] workers = new Worker[workerNum];
for (int i = 0; i < workerNum; i++) {
workers[i] = new Worker(latch);
workers[i].start();
}
latch.await();
System.out.println("collect worker results");
}
}
代码比较简单,就是子线程负责减小倒计时计数,主线程等待倒计时计数为0。当所有子线程都启动成功,但是所有子线程还未完全执行结束前,主线程调用latch.await()会阻塞主线程。当所有的自线程执行结束后,倒计时计数也就变为0了,最后一个将倒计时计数变为0的线程会唤醒阻塞的祝线程。
这里还有一点要单独提一下,Worker线程中,countDown方法的调用是在finally块中调用的,这样是为了保证工作线程发生异常的情况下也会被调用,使主线程能够从await调用中返回。
在之前的文章,我们使用wait/notify实现了一个简单的集合点AssemblePoint,同时提到,Java并发包中已经提供了类似工具,就是CyclicBarrier。它的大概含义是指,它相当于是一个栅栏,所有线程在到达该栅栏后都需要等待其他线程,等所有线程都到达后再一起通过,它是循环的,可以用作重复的同步。
CyclicBarrier底层通过显示锁ReentrantLock实现,特别适用于并行迭代计算,每个线程负责一部分计算,然后在栅栏处等待其他线程完成,所有线程到齐后,交换数据和计算结果,再进行下一次迭代。
与CountDownLatch类似,它也有一个数字,表示的是参与的线程个数:
public CyclicBarrier(int parties)
CyclicBarrier还有一个构造方法,接受一个Runnable参数,如下所示:
public CyclicBarrier(int parties, Runnable barrierAction)
这个参数表示当所有线程到达栅栏后,在所有线程执行下一步动作前,运行参数barrierAction中的动作,这个动作由最后一个到达栅栏的线程执行。
除了构造函数,CyclicBarrier最主要的方法就是await,如下:
public int await() throws InterruptedException, BrokenBarrierException
public int await(long timeout, TimeUnit unit) throws InterruptedException, BrokenBarrierException, TimeoutException
await在等待其他线程到达栅栏,调用await后,表示自己已经到达,如果自己是最后一个到达的,就执行可选的命令barrierAction,执行后,唤醒所有等待的线程,然后重置内部的同步计数,以循环使用。
await可以被中断,可以限定最长等待时间,中断或超时后会抛出异常。需要说明的是异常BrokenBarrierException,它表示栅栏被破坏了,什么意思呢?在CyclicBarrier中,参与的线程是互相影响的,只要其中一个线程在调用await时被中断了,或者超时了,栅栏就会被破坏,此外,如果栅栏动作抛出了异常,栅栏也会被破坏,被破坏后,所有在调用await的线程就会退出,抛出BrokenBarrierException。
下面来看个CyclicBarrier的简单示例:
public class CyclicBarrierDemo {
private CyclicBarrier cyclicBarrier;
private List> partialResults = Collections.synchronizedList(new ArrayList<>());
private Random random = new Random();
private int NUM_PARTIAL_RESULTS;
private int NUM_WORKERS;
class NumberCruncherThread implements Runnable {
@Override
public void run() {
String thisThreadName = Thread.currentThread().getName();
List partialResult = new ArrayList<>();
// Crunch some numbers and store the partial result
for (int i = 0; i < NUM_PARTIAL_RESULTS; i++) {
Integer num = random.nextInt(10);
System.out.println(thisThreadName
+ ": Crunching some numbers! Final result - " + num);
partialResult.add(num);
}
partialResults.add(partialResult);
try {
System.out.println(thisThreadName
+ " waiting for others to reach barrier.");
cyclicBarrier.await();
} catch (InterruptedException | BrokenBarrierException e) {
// ...
}
}
}
class AggregatorThread implements Runnable {
@Override
public void run() {
String thisThreadName = Thread.currentThread().getName();
System.out.println(
thisThreadName + ": Computing sum of " + NUM_WORKERS
+ " workers, having " + NUM_PARTIAL_RESULTS + " results each.");
int sum = 0;
for (List threadResult : partialResults) {
System.out.print("Adding ");
for (Integer partialResult : threadResult) {
System.out.print(partialResult+" ");
sum += partialResult;
}
System.out.println();
}
System.out.println(thisThreadName + ": Final result = " + sum);
}
}
public void runSimulation(int numWorkers, int numberOfPartialResults) {
NUM_PARTIAL_RESULTS = numberOfPartialResults;
NUM_WORKERS = numWorkers;
cyclicBarrier = new CyclicBarrier(NUM_WORKERS, new AggregatorThread());
System.out.println("Spawning " + NUM_WORKERS
+ " worker threads to compute "
+ NUM_PARTIAL_RESULTS + " partial results each");
for (int i = 0; i < NUM_WORKERS; i++) {
Thread worker = new Thread(new NumberCruncherThread());
worker.setName("Thread " + i);
worker.start();
}
}
public static void main(String[] args) {
CyclicBarrierDemo demo = new CyclicBarrierDemo();
demo.runSimulation(5, 3);
}
}
执行结果:
Thread 0: Crunching some numbers! Final result - 0
Thread 1: Crunching some numbers! Final result - 1
Thread 0: Crunching some numbers! Final result - 5
Thread 1: Crunching some numbers! Final result - 8
Thread 0: Crunching some numbers! Final result - 0
Thread 1: Crunching some numbers! Final result - 6
Thread 0 waiting for others to reach barrier.
Thread 1 waiting for others to reach barrier.
Thread 2: Crunching some numbers! Final result - 7
Thread 2: Crunching some numbers! Final result - 0
Thread 2: Crunching some numbers! Final result - 3
Thread 2 waiting for others to reach barrier.
Thread 3: Crunching some numbers! Final result - 5
Thread 3: Crunching some numbers! Final result - 9
Thread 3: Crunching some numbers! Final result - 8
Thread 3 waiting for others to reach barrier.
Thread 4: Crunching some numbers! Final result - 3
Thread 4: Crunching some numbers! Final result - 3
Thread 4: Crunching some numbers! Final result - 3
Thread 4 waiting for others to reach barrier.
Thread 4: Computing sum of 5 workers, having 3 results each.
Adding 0 5 0
Adding 1 8 6
Adding 7 0 3
Adding 5 9 8
Adding 3 3 3
Thread 4: Final result = 61
可以看到所有线程到达栅栏之后的执行动作AggregatorThread,是由最后到达栅栏的线程Thread 4执行的。
CyclicBarrier与CountDownLatch看起来可能容易混淆,这里来总结一下:
锁都是限制只有一个线程可以同时访问一个资源。现实中,资源往往有多个,但每个同时只能被一个线程访问,比如火车上的卫生间。有的单个资源即使可以被并发访问,但并发访问数多了可能影响性能,所以希望限制并发访问的线程数。在之前的文章,我们通过wait/notify实现了一个简单的信号量类MySemphore,用于控制并发访问的数量。Java API中提供了相应的信号量机制——Semaphore。
public Semaphore(int permits)
public Semaphore(int permits, boolean fair)
阻塞获取许可,响应中断
阻塞获取许可,不响应中断
批量获取多个许可,响应中断
批量获取多个许可,不响应中断
尝试获取许可,获取不到立即返回
限定等待时间尝试获取许可,在指定时间内未获取成功,直接返回
释放许可
我们看个简单的示例,限制并发访问的用户数不超过100,代码如下:
public class AccessControlService {
public static class ConcurrentLimitException extends RuntimeException {
private static final long serialVersionUID = 1L;
}
private static final int MAX_PERMITS = 100;
private Semaphore permits = new Semaphore(MAX_PERMITS, true);
public boolean login(String name, String password) {
if (!permits.tryAcquire()) {
// 同时登录用户数超过限制
throw new ConcurrentLimitException();
}
// ..其他验证
return true;
}
public void logout(String name) {
permits.release();
}
}
需要说明的是,如果我们将permits的值设为1,你可能会认为它就变成了一般的锁,不过,它与一般的锁是不同的。一般锁只能由持有锁的线程释放,而Semaphore表示的只是一个许可数,任意线程都可以调用其release方法。主要的锁实现类ReentrantLock是可重入的,而Semaphore不是,每一次的acquire调用都会消耗一个许可,比如,看下面代码段:
Semaphore permits = new Semaphore(1);
permits.acquire();
permits.acquire();
System.out.println("acquired");
程序会阻塞在第二个acquire调用,永远都不会输出”acquired”。
信号量的基本原理比较简单,也是基于AQS实现的,permits表示共享的锁个数,acquire方法就是检查锁个数是否大于0,大于则减一,获取成功,否则就等待,release就是将锁个数加一,唤醒第一个等待的线程。
参考链接:
1. Java API
2. 《Java编程的逻辑》
3. CyclicBarrier in Java