线程间通信之CountDownLatch

    之前看一个开源项目,里面用到了CountDownLatch,当时莫名其妙,不知道这个东西是干嘛的,后来查阅源码才知道这个东西这么好用,那么CountDownLatch是干嘛的呢,简单来说就是一个同步辅助工具类,使用它可以实现在某些线程执行完毕之后再执行另外一些线程,即某些线程执行的时候另外一个线程处于等待状态。我们平时需要在某一个子线程执行完毕之后再执行一些操作,当然可以直接在线程结束后调用我们想要执行的方法,或者采用接口回调的形式来实现,那么通过CountDownLatch来实现这种需求的话,形式上跟接口回调差不多。那么具体怎么使用呢,这里我们关注3个方法即可。
    CountDownLatch的构造方法

/**
     * Constructs a {@code CountDownLatch} initialized with the given count.
     *
     * @param count the number of times {@link #countDown} must be invoked
     *        before threads can pass through {@link #await}
     * @throws IllegalArgumentException if {@code count} is negative
     */
    public CountDownLatch(int count) {
        if (count < 0) throw new IllegalArgumentException("count < 0");
        this.sync = new Sync(count);
    }

    这个构造方法里面传入一个count,也就是说我们需要等待前面count个任务执行完毕之后才执行某个操作
    CountDownLatch的countDown方法

/**
     * Decrements the count of the latch, releasing all waiting threads if
     * the count reaches zero.
     *
     * 

If the current count is greater than zero then it is decremented. * If the new count is zero then all waiting threads are re-enabled for * thread scheduling purposes. * *

If the current count equals zero then nothing happens. */ public void countDown() { sync.releaseShared(1); }

    每个前置任务执行完毕后调用countDown方法使得计数器减1,如果计数器为0的话,那么调用await方法的任务会开始执行
    CountDownLatch的await方法

 /**
     * Causes the current thread to wait until the latch has counted down to
     * zero, unless the thread is {@linkplain Thread#interrupt interrupted}.
     *
     * 

If the current count is zero then this method returns immediately. * *

If the current count is greater than zero then the current * thread becomes disabled for thread scheduling purposes and lies * dormant until one of two things happen: *

    *
  • The count reaches zero due to invocations of the * {@link #countDown} method; or *
  • Some other thread {@linkplain Thread#interrupt interrupts} * the current thread. *
* *

If the current thread: *

    *
  • has its interrupted status set on entry to this method; or *
  • is {@linkplain Thread#interrupt interrupted} while waiting, *
* then {@link InterruptedException} is thrown and the current thread's * interrupted status is cleared. * * @throws InterruptedException if the current thread is interrupted * while waiting */ public void await() throws InterruptedException { sync.acquireSharedInterruptibly(1); }

    在前置任务没有执行完毕的时候,所有调用await方法的线程都会处于等待状态,等待前置任务执行完毕后执行
    那么具体怎么用来,这里我直接摘抄源码注释里的代码

class Driver { // ...
    void main() throws InterruptedException {
      CountDownLatch startSignal = new CountDownLatch(1);
      CountDownLatch doneSignal = new CountDownLatch(N);
 
      for (int i = 0; i < N; ++i) // create and start threads
        new Thread(new Worker(startSignal, doneSignal)).start();
 
      doSomethingElse();            // don't let run yet
      startSignal.countDown();      // let all threads proceed
      doSomethingElse();
      doneSignal.await();           // wait for all to finish
    }
  }
 
  class Worker implements Runnable {
    private final CountDownLatch startSignal;
    private final CountDownLatch doneSignal;
    Worker(CountDownLatch startSignal, CountDownLatch doneSignal) {
       this.startSignal = startSignal;
       this.doneSignal = doneSignal;
    }
    public void run() {
       try {
         startSignal.await();
         doWork();
         doneSignal.countDown();
       } catch (InterruptedException ex) {} // return;
    }
 
    void doWork() { ... }
  }

    以上例子中定义了一个startSignal,在每个线程(Worker)的run方法里面调用了startSignal.await,这样可以确保我们初始化的N个线程同时执行。在run方法里面任务执行完毕后又调用了doneSignal.countDown()让计数器减1,在for循环初始化完毕后调用了doneSignal.await()方法,这样就可以保证这N个线程都执行完毕后再执行下面的操作

class Driver2 { // ...
    void main() throws InterruptedException {
      CountDownLatch doneSignal = new CountDownLatch(N);
      Executor e = ...
 
      for (int i = 0; i < N; ++i) // create and start threads
       e.execute(new WorkerRunnable(doneSignal, i));
 
      doneSignal.await();           // wait for all to finish
    }
  }
 
  class WorkerRunnable implements Runnable {
    private final CountDownLatch doneSignal;
    private final int i;
    WorkerRunnable(CountDownLatch doneSignal, int i) {
       this.doneSignal = doneSignal;
       this.i = i;
    }
    public void run() {
       try {
         doWork(i);
         doneSignal.countDown();
       } catch (InterruptedException ex) {} // return;
    }
 
    void doWork() { ... }
  }

    以上一种用法咋一看好像跟第一个差不多,那么它的一大应用就是将一个任务拆分成N个part,在这N个part都执行完毕之后才执行接下来的操作,比如多线程下载,需要等待所有线程下载完成之后才算是这个文件下载完成,不明白多线程下载的同学请自行查阅资料
    那么CountDownLatch底层是如何实现的呢,从上面贴出的CountDownLatch的构造方法里面我们可以看到,它的构造里面仅仅是使用传递进来的count初始化了一个Sync,那么这个Sync是什么呢,其实就是用来进行同步计数控制的AQS(AbstractQueuedSynchronizer),AbstractQueuedSynchronizer内部其实就是通过FIFO阻塞队列以及wait notify来实现同步的。而对于计数器的修改则是使用的CAS,CAS 就是Compare And Swap,不明白的自行查阅资料。Sync作为CountDownLatch的内部类存在,源码如下

/**
     * Synchronization control For CountDownLatch.
     * Uses AQS state to represent count.
     */
    private static final class Sync extends AbstractQueuedSynchronizer {
        private static final long serialVersionUID = 4982264981922014374L;

        Sync(int count) {
            setState(count);
        }

        int getCount() {
            return getState();
        }

        protected int tryAcquireShared(int acquires) {
            return (getState() == 0) ? 1 : -1;
        }

        protected boolean tryReleaseShared(int releases) {
            // Decrement count; signal when transition to zero
            for (;;) {
                int c = getState();
                if (c == 0)
                    return false;
                int nextc = c-1;
                if (compareAndSetState(c, nextc))
                    return nextc == 0;
            }
        }
    }

我们先看看CountDownLatch.countDown()做了一些什么,从上面给出的countDown方法源码我们可以看到它仅仅是调用了sync.releaseShared(1),每次调用countDown()都会对原子值state进行修改,而调用await()方法的地方会一直执行一个死循环,这个死循环会在原子值state为0的时候结束,说白了就是相当于阻塞调用await方法的线程。大家查阅源码的时候一定要先看Sync类里面的tryAcquireShared()方法和tryReleaseShared()方法,这两个方法是重写的父类AbstractQueuedSynchronizer里面的方法,这两个方法在父类里面有调用,逻辑也不是很复杂,大家自行查阅源码里面的调用逻辑即可。AQS里面用到了Unsafe类,这个类类似c直接操作指针,是线程不安全的类,对于并发的处理实际上是依赖于CAS,有兴趣的同学可以好好研究!
   

你可能感兴趣的:(java多线程)