之前对同步工具类闭锁CountDownLatch、信号量Semaphore、栅栏CycliBarrier有过了解,但是对其原理还不是清晰,在此从源码角度进行分析。
详细的实例代码地址:https://github.com/yq-debug/JavaExercise/tree/master/src/main/java/juc_sync
//允许一个线程或多个线程等待,直到其他线程的操作完成
public class CountDownLatch {
//Sync类继承自AQS类
private static final class Sync extends AbstractQueuedSynchronizer {
private static final long serialVersionUID = 4982264981922014374L;
//初始化Sync类
Sync(int count) {
setState(count);//设置同步状态
}
//获取同步状态
int getCount() {
return getState();
}
protected int tryAcquireShared(int acquires) {
return (getState() == 0) ? 1 : -1;
}
protected boolean tryReleaseShared(int releases) {
for (;;) {
int c = getState();
if (c == 0)
return false;
int nextc = c-1;
if (compareAndSetState(c, nextc))
return nextc == 0;
}
}
}
private final Sync sync;
//构造一个含有count的锁存器的对象
public CountDownLatch(int count) {
if (count < 0) throw new IllegalArgumentException("count < 0");
//初始化sync对象,此时传入的是锁存器的数量
this.sync = new Sync(count);
}
//阻塞当前线程直到线程阻塞器的计数为0
public void await() throws InterruptedException {
sync.acquireSharedInterruptibly(1);
}
//阻塞当前线程直到线程阻塞器的计数为0
public boolean await(long timeout, TimeUnit unit)
throws InterruptedException {
return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
}
//减少锁存器的数量
public void countDown() {
sync.releaseShared(1);
}
//返回当前锁存器的数量
public long getCount() {
return sync.getCount();
}
//返回此闭锁的字符串表示
public String toString() {
return super.toString() + "[Count = " + sync.getCount() + "]";
}
}
CountDownLatch的作用是阻塞当前线程,直到其他线程操作完成之后,即闭锁中锁存器的数量为0时,再激活当前线程继续执行,与join()方法类似,都是将并行线程变为串行线程。
闭锁CountDownLatch在初始化的时候会传入一个参数,表示锁存器的数量,即代表要阻塞的线程的数量。
需要后执行的线程先进行阻塞,需要先行执行的线程在执行完之后将锁存器的数量减1,直到锁存器的数量为0或者时间超时触发阻塞线程执行。
此处使用的两个方法:countDown()锁存器的数量减1, await()阻塞当前线程直到锁存器的数量为0
使用场景:当线程的执行执行结果与执行顺序有关时(即会出现竞争冲突)则采用闭锁CountDownLatch使并发线程变为串行线程执行。
CountDownLatch的底层实现是AQS类
public class Semaphore implements java.io.Serializable {
private static final long serialVersionUID = -3222578661600680210L;
private final Sync sync;
abstract static class Sync extends AbstractQueuedSynchronizer {
private static final long serialVersionUID = 1192457210091910933L;
Sync(int permits) {
setState(permits);//设置同步状态
}
//获取许可证的数量(同步状态)
final int getPermits() {
return getState();
}
final int nonfairTryAcquireShared(int acquires) {
for (;;) {
int available = getState();
int remaining = available - acquires;
if (remaining < 0 ||
compareAndSetState(available, remaining))
return remaining;
}
}
protected final boolean tryReleaseShared(int releases) {
for (;;) {
int current = getState();
int next = current + releases;
if (next < current) // overflow
throw new Error("Maximum permit count exceeded");
if (compareAndSetState(current, next))
return true;
}
}
final void reducePermits(int reductions) {
for (;;) {
int current = getState();
int next = current - reductions;
if (next > current) // underflow
throw new Error("Permit count underflow");
if (compareAndSetState(current, next))
return;
}
}
final int drainPermits() {
for (;;) {
int current = getState();
if (current == 0 || compareAndSetState(current, 0))
return current;
}
}
}
//非公平信号量
static final class NonfairSync extends Sync {
private static final long serialVersionUID = -2694183684443567898L;
//构造函数
NonfairSync(int permits) {
super(permits);
}
protected int tryAcquireShared(int acquires) {
return nonfairTryAcquireShared(acquires);
}
}
//公平信号量
static final class FairSync extends Sync {
private static final long serialVersionUID = 2014338818796000944L;
//构造函数
FairSync(int permits) {
super(permits);
}
protected int tryAcquireShared(int acquires) {
for (;;) {
if (hasQueuedPredecessors())
return -1;
int available = getState();
int remaining = available - acquires;
if (remaining < 0 ||
compareAndSetState(available, remaining))
return remaining;
}
}
}
公平信号量与非公平信号量都是继承自Sync类,其构造方法也一样,但是其获取信号量的方法不一样。
默认创建非公平信号量
//创建permits个许可证的信号量对象,默认为非公平信号量
public Semaphore(int permits) {
sync = new NonfairSync(permits);
}
//创建信号量对象可以设置信号量类型
public Semaphore(int permits, boolean fair) {
sync = fair ? new FairSync(permits) : new NonfairSync(permits);
}
在线程没有得到许可证之前,线程一直处于阻塞状态,除非得到许可证或者线程中断
//获取一个许可证,在线程没有得到许可证之前,线程一直处于阻塞状态,除非得到许可证或者线程中断
public void acquire() throws InterruptedException {
sync.acquireSharedInterruptibly(1);
}
public void acquireUninterruptibly() {
sync.acquireShared(1);
}
//仅当在调用一个信号量时,此信号量的许可证可用时,才获取许可证。
public boolean tryAcquire() {
return sync.nonfairTryAcquireShared(1) >= 0;
}
public boolean tryAcquire(long timeout, TimeUnit unit)
throws InterruptedException {
return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
}
public void acquire(int permits) throws InterruptedException {
if (permits < 0) throw new IllegalArgumentException();
sync.acquireSharedInterruptibly(permits);
}
public void acquireUninterruptibly(int permits) {
if (permits < 0) throw new IllegalArgumentException();
sync.acquireShared(permits);
}
public boolean tryAcquire(int permits) {
if (permits < 0) throw new IllegalArgumentException();
return sync.nonfairTryAcquireShared(permits) >= 0;
}
public boolean tryAcquire(int permits, long timeout, TimeUnit unit)
throws InterruptedException {
if (permits < 0) throw new IllegalArgumentException();
return sync.tryAcquireSharedNanos(permits, unit.toNanos(timeout));
}
线程在执行之前获取许可证,在执行结束之后要释放许可证,即信号量中的许可证数量加1
//释放许可证到信号量中,即许可证的数量加1
public void release() {
sync.releaseShared(1);
}
//释放许可证到信号量中
public void release(int permits) {
if (permits < 0) throw new IllegalArgumentException();
sync.releaseShared(permits);
}
public int availablePermits() {
return sync.getPermits();
}
public int drainPermits() {
return sync.drainPermits();
}
protected void reducePermits(int reduction) {
if (reduction < 0) throw new IllegalArgumentException();
sync.reducePermits(reduction);
}
//判断此信号量是否为公平信号量
public boolean isFair() {
return sync instanceof FairSync;
}
public final boolean hasQueuedThreads() {
return sync.hasQueuedThreads();
}
public final int getQueueLength() {
return sync.getQueueLength();
}
protected Collection<Thread> getQueuedThreads() {
return sync.getQueuedThreads();
}
//信号量的字符串表示
public String toString() {
return super.toString() + "[Permits = " + sync.getPermits() + "]";
}
}
Semaphore信号量机制通过许可证的数量来限制访问某个资源的线程的数量,即许可证的数量为资源的数量
公平信号量与非公平信号量的区别:
Semaphore默认创建非公平信号量
公平信号量是任何申请许可证的线程要按照FIFO队列来申请Semaphore的许可证
非公平信号量指的是任何申请许可证的线程都可以第一时间看见Semaphore中是否有可用的许可证,如果有则立即进行分配
获取许可证的方法acquire()
释放许可证的方法release()
可以将一个容器变为有界阻塞容器,即对容器中的资源分配许可证进行控制。
public class CyclicBarrier {
//创建屏障实例,此屏障实例会进行循环使用
private static class Generation {
boolean broken = false;
}
//创建非公平锁
private final ReentrantLock lock = new ReentrantLock();
//由非公平锁获取一个条件锁
private final Condition trip = lock.newCondition();
//这组中线程的数量
private final int parties;
//当所有的线程都到达栅栏处时要触发的动作线程
private final Runnable barrierCommand;
//当前屏障实例
private Generation generation = new Generation();
//等待的线程数量
private int count;
//更新屏障状态
private void nextGeneration() {
//唤醒所有的等待线程
trip.signalAll();
count = parties;
//更新屏障状态
generation = new Generation();
}
//打破屏障,此时将唤醒所有的线程
private void breakBarrier() {
generation.broken = true;
count = parties;
trip.signalAll();
}
private int dowait(boolean timed, long nanos)
throws InterruptedException, BrokenBarrierException,
TimeoutException {
final ReentrantLock lock = this.lock;
lock.lock();
try {
final Generation g = generation;
if (g.broken)
throw new BrokenBarrierException();
if (Thread.interrupted()) {
breakBarrier();
throw new InterruptedException();
}
int index = --count;
if (index == 0) { // tripped
boolean ranAction = false;
try {
final Runnable command = barrierCommand;
if (command != null)
command.run();
ranAction = true;
nextGeneration();
return 0;
} finally {
if (!ranAction)
breakBarrier();
}
}
// loop until tripped, broken, interrupted, or timed out
for (;;) {
try {
if (!timed)
trip.await();
else if (nanos > 0L)
nanos = trip.awaitNanos(nanos);
} catch (InterruptedException ie) {
if (g == generation && ! g.broken) {
breakBarrier();
throw ie;
} else {
// We're about to finish waiting even if we had not
// been interrupted, so this interrupt is deemed to
// "belong" to subsequent execution.
Thread.currentThread().interrupt();
}
}
if (g.broken)
throw new BrokenBarrierException();
if (g != generation)
return index;
if (timed && nanos <= 0L) {
breakBarrier();
throw new TimeoutException();
}
}
} finally {
lock.unlock();
}
}
//创建屏障对象,并制定等待的线程数量和触发动作的线程
public CyclicBarrier(int parties, Runnable barrierAction) {
if (parties <= 0) throw new IllegalArgumentException();
this.parties = parties;
this.count = parties;
this.barrierCommand = barrierAction;//设置动作线程
}
//创建屏障对象,并指定等待的线程的数量,并且预定义的操作为null
public CyclicBarrier(int parties) {
this(parties, null);
}
//获取所有线程的数量
public int getParties() {
return parties;
}
//将当前线程等待直到所有的线程都到达栅栏处
public int await() throws InterruptedException, BrokenBarrierException {
try {
return dowait(false, 0L);
} catch (TimeoutException toe) {
throw new Error(toe); // cannot happen
}
}
//将当前线程等待直到所有的线程都已经到达栅栏处或者指定的等待时间超时
public int await(long timeout, TimeUnit unit)
throws InterruptedException,
BrokenBarrierException,
TimeoutException {
return dowait(true, unit.toNanos(timeout));
}
//判断此屏障是否处于断开状态
public boolean isBroken() {
final ReentrantLock lock = this.lock;
lock.lock();
try {
return generation.broken;
} finally {
lock.unlock();
}
}
//将当前屏障设置为初始化状态
public void reset() {
final ReentrantLock lock = this.lock;
lock.lock();
try {
breakBarrier(); // break the current generation
nextGeneration(); // start a new generation
} finally {
lock.unlock();
}
}
//获取等待的线程的数量
public int getNumberWaiting() {
final ReentrantLock lock = this.lock;
lock.lock();
try {
return parties - count;
} finally {
lock.unlock();
}
}
}
应用场景要求:当在所有的线程都执行完毕,即达到一个栅栏处时,此时触发一个动作线程执行下一步的动作;例如在游戏中当所有的玩家都加载完毕之后,再同时进入游戏。
最后一个线程触发动作线程,此线程通过构造函数作为参数传入。
栅栏与闭锁的区别在于栅栏会有一个动作线程,最后一个线程到来时触发动作线程;闭锁是最后一个线程等待其他线程执行完成之后再执行