CountDownLatch学习笔记——常见用法,扩展用法、源码探究及自定义实现

CountDownLatch是一个一次性的线程同步工具,一般用于主线程等待多个工作线程均执行完毕之后,主线程再执行后续工作。

常见用法如下:

    public static void main(String[] args) {
        ExecutorService service = Executors.newFixedThreadPool(3);
        final CountDownLatch latch = new CountDownLatch(3);
        for (int i = 0; i < 3; i++) {
            Runnable runnable = new Runnable() {
                @Override
                public void run() {
                    try {
                        System.out.println("子线程" + Thread.currentThread().getName() + "开始执行");
                        Thread.sleep((long) (Math.random() * 10000));
                        System.out.println("子线程"+Thread.currentThread().getName()+"执行完成");
                        latch.countDown();//当前线程调用此方法,则计数减一
                    } catch (InterruptedException e) {
                        e.printStackTrace();
                    }
                }
            };
            service.execute(runnable);
        }

        try {
            System.out.println("主线程"+Thread.currentThread().getName()+"等待子线程执行完成...");
            latch.await();//阻塞当前线程,直到计数器的值为0
            System.out.println("主线程"+Thread.currentThread().getName()+"开始执行...");
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
    }

上述代码中,首先初始化了一个状态数为3的CountDownLatch,然后启动了三个子线程,在每个线程中,分别调用CountDownLatch的countDown方法。同时在主线程中,调用CountDown的await方法区等待三个子线程完成。

执行结果

主线程main等待子线程执行完成...
子线程pool-1-thread-1开始执行
子线程pool-1-thread-3开始执行
子线程pool-1-thread-2开始执行
子线程pool-1-thread-2执行完成
子线程pool-1-thread-1执行完成
子线程pool-1-thread-3执行完成
主线程main开始执行...

实际上,等待CountDownLatch子线程执行完成的,也可以是多线程。例如如下代码,用两个线程等待

    public static void main(String[] args) {
        ExecutorService service = Executors.newFixedThreadPool(3);
        final CountDownLatch latch = new CountDownLatch(3);
        for (int i = 0; i < 3; i++) {
            Runnable runnable = new Runnable() {
                @Override
                public void run() {
                    try {
                        System.out.println("子线程" + Thread.currentThread().getName() + "开始执行");
                        Thread.sleep((long) (Math.random() * 10000));
                        System.out.println("子线程"+Thread.currentThread().getName()+"执行完成");
                        latch.countDown();//当前线程调用此方法,则计数减一
                    } catch (InterruptedException e) {
                        e.printStackTrace();
                    }
                }
            };
            service.execute(runnable);
        }

        try {
            Thread thread1 = new Thread(() -> {
                System.out.println("线程1等待子线程执行完成");
                try {
                    latch.await();
                } catch (InterruptedException e) {
                    throw new RuntimeException(e);
                }
                System.out.println("线程1开始执行");
            });
            thread1.start();

            Thread thread2 = new Thread(() -> {
                System.out.println("线程2等待子线程执行完成");
                try {
                    latch.await();
                } catch (InterruptedException e) {
                    throw new RuntimeException(e);
                }
                System.out.println("线程2开始执行");
            });
            thread2.start();

            System.in.read();
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
子线程pool-1-thread-3开始执行
线程2等待子线程执行完成
子线程pool-1-thread-1开始执行
线程1等待子线程执行完成
子线程pool-1-thread-2开始执行
子线程pool-1-thread-3执行完成
子线程pool-1-thread-2执行完成
子线程pool-1-thread-1执行完成
线程2开始执行
线程1开始执行

那么CountDownLatch是如何实现可以多个主线程等待多个子线程完成的呢?让我们看一下它的源码

首先是定义了一个AbstractQueuedSynchronizer的子类Sync:

    private static final class Sync extends AbstractQueuedSynchronizer {
        private static final long serialVersionUID = 4982264981922014374L;

        Sync(int count) {
            setState(count);
        }

        int getCount() {
            return getState();
        }

        protected int tryAcquireShared(int acquires) {
            return (getState() == 0) ? 1 : -1;
        }

        protected boolean tryReleaseShared(int releases) {
            // Decrement count; signal when transition to zero
            for (;;) {
                int c = getState();
                if (c == 0)
                    return false;
                int nextc = c - 1;
                if (compareAndSetState(c, nextc))
                    return nextc == 0;
            }
        }
    }

可以看出,该AQS的子类中,重新实现了tryAcquireShared和tryReleaseShared方法。说明它是基于AQS的共享模式实现的。

再看构造器方法

    public CountDownLatch(int count) {
        if (count < 0) throw new IllegalArgumentException("count < 0");
        this.sync = new Sync(count);
    }

参数代表子线程数。

await方法

    public void await() throws InterruptedException {
        sync.acquireSharedInterruptibly(1);
    }

方法中调用了acquireSharedInterruptibly方法。该方法采用共享获取模式。我们知道在AQS中,acquire和release都有独占和共享模式(共享模式在方法名中加了Shared)。其中共享模式与独占模式的区别在于,共享模式在AQS队列中的节点在获取到锁之后,会通知它的下一个节点,它的下一个节点也会尝试获取。(这就是共享锁的意义)

    public final void acquireSharedInterruptibly(int arg)
        throws InterruptedException {
        if (Thread.interrupted() ||
            (tryAcquireShared(arg) < 0 &&
             acquire(null, arg, true, true, false, 0L) < 0))
            throw new InterruptedException();
    }

调用acquire方法时,第三个参数传入true,代表是共享锁模式

    final int acquire(Node node, int arg, boolean shared,
                      boolean interruptible, boolean timed, long time) {
        Thread current = Thread.currentThread();
        byte spins = 0, postSpins = 0;   // retries upon unpark of first thread
        boolean interrupted = false, first = false;
        Node pred = null;                // predecessor of node when enqueued

        /*
         * Repeatedly:
         *  Check if node now first
         *    if so, ensure head stable, else ensure valid predecessor
         *  if node is first or not yet enqueued, try acquiring
         *  else if node not yet created, create it
         *  else if not yet enqueued, try once to enqueue
         *  else if woken from park, retry (up to postSpins times)
         *  else if WAITING status not set, set and retry
         *  else park and clear WAITING status, and check cancellation
         */

        for (;;) {
            if (!first && (pred = (node == null) ? null : node.prev) != null &&
                !(first = (head == pred))) {
                if (pred.status < 0) {
                    cleanQueue();           // predecessor cancelled
                    continue;
                } else if (pred.prev == null) {
                    Thread.onSpinWait();    // ensure serialization
                    continue;
                }
            }
            if (first || pred == null) {
                boolean acquired;
                try {
                    if (shared)
                        acquired = (tryAcquireShared(arg) >= 0);
                    else
                        acquired = tryAcquire(arg);
                } catch (Throwable ex) {
                    cancelAcquire(node, interrupted, false);
                    throw ex;
                }
                if (acquired) {
                    if (first) {
                        node.prev = null;
                        head = node;
                        pred.next = null;
                        node.waiter = null;
                        if (shared)
                            // 重点是这里,如果是共享模式,会通知下一个节点
                            signalNextIfShared(node);
                        if (interrupted)
                            current.interrupt();
                    }
                    return 1;
                }
            }
            if (node == null) {                 // allocate; retry before enqueue
                if (shared)
                    node = new SharedNode();
                else
                    node = new ExclusiveNode();
            } else if (pred == null) {          // try to enqueue
                node.waiter = current;
                Node t = tail;
                node.setPrevRelaxed(t);         // avoid unnecessary fence
                if (t == null)
                    tryInitializeHead();
                else if (!casTail(t, node))
                    node.setPrevRelaxed(null);  // back out
                else
                    t.next = node;
            } else if (first && spins != 0) {
                --spins;                        // reduce unfairness on rewaits
                Thread.onSpinWait();
            } else if (node.status == 0) {
                node.status = WAITING;          // enable signal and recheck
            } else {
                long nanos;
                spins = postSpins = (byte)((postSpins << 1) | 1);
                if (!timed)
                    LockSupport.park(this);
                else if ((nanos = time - System.nanoTime()) > 0L)
                    LockSupport.parkNanos(this, nanos);
                else
                    break;
                node.clearStatus();
                if ((interrupted |= Thread.interrupted()) && interruptible)
                    break;
            }
        }
        return cancelAcquire(node, interrupted, interruptible);
    }

重点是我添加注释那一行

    private static void signalNextIfShared(Node h) {
        Node s;
        if (h != null && (s = h.next) != null &&
            (s instanceof SharedNode) && s.status != 0) {
            s.getAndUnsetStatus(WAITING);
            LockSupport.unpark(s.waiter);
        }
    }

LockSupport的unpark就代表唤醒线程。

countDown方法

    public void countDown() {
        sync.releaseShared(1);
    }
    public final boolean releaseShared(int arg) {
        if (tryReleaseShared(arg)) {
            signalNext(head);
            return true;
        }
        return false;
    }

这里tryReleaseShared方法会调用CountDownLatch中Sync实现的方法

        protected boolean tryReleaseShared(int releases) {
            // Decrement count; signal when transition to zero
            for (;;) {
                int c = getState();
                if (c == 0)
                    return false;
                int nextc = c - 1;
                if (compareAndSetState(c, nextc))
                    return nextc == 0;
            }
        }

它会判断必须当前这一次release将state减为0才返回true。

那么,能否实现只允许一个线程等待的CountDownLatch呢?简单地说,可以将CountDownLatch中涉及到acquire和release的方法都换成独占版的,同时稍微改一下内部类Sync的tryAcquire方法。

下面是我自定义的MyCountDownLatch类

public class MyCountDownLatch {
    /**
     * Synchronization control For CountDownLatch.
     * Uses AQS state to represent count.
     */
    private static final class Sync extends AbstractQueuedSynchronizer {
        private static final long serialVersionUID = 4982264981922014374L;

        Sync(int count) {
            setState(count);
        }

        int getCount() {
            return getState();
        }

        protected boolean tryAcquire(int acquires) {
            return (getState() == 0) ? true : false;
        }

        protected boolean tryRelease(int releases) {
            // Decrement count; signal when transition to zero
            for (;;) {
                int c = getState();
                if (c == 0)
                    return false;
                int nextc = c - 1;
                if (compareAndSetState(c, nextc))
                    return nextc == 0;
            }
        }
    }

    private final Sync sync;

    /**
     * Constructs a {@code CountDownLatch} initialized with the given count.
     *
     * @param count the number of times {@link #countDown} must be invoked
     *        before threads can pass through {@link #await}
     * @throws IllegalArgumentException if {@code count} is negative
     */
    public MyCountDownLatch(int count) {
        if (count < 0) throw new IllegalArgumentException("count < 0");
        this.sync = new Sync(count);
    }

    /**
     * Causes the current thread to wait until the latch has counted down to
     * zero, unless the thread is {@linkplain Thread#interrupt interrupted}.
     *
     * 

If the current count is zero then this method returns immediately. * *

If the current count is greater than zero then the current * thread becomes disabled for thread scheduling purposes and lies * dormant until one of two things happen: *

    *
  • The count reaches zero due to invocations of the * {@link #countDown} method; or *
  • Some other thread {@linkplain Thread#interrupt interrupts} * the current thread. *
* *

If the current thread: *

    *
  • has its interrupted status set on entry to this method; or *
  • is {@linkplain Thread#interrupt interrupted} while waiting, *
* then {@link InterruptedException} is thrown and the current thread's * interrupted status is cleared. * * @throws InterruptedException if the current thread is interrupted * while waiting */ public void await() throws InterruptedException { sync.acquireInterruptibly(1); } /** * Causes the current thread to wait until the latch has counted down to * zero, unless the thread is {@linkplain Thread#interrupt interrupted}, * or the specified waiting time elapses. * *

If the current count is zero then this method returns immediately * with the value {@code true}. * *

If the current count is greater than zero then the current * thread becomes disabled for thread scheduling purposes and lies * dormant until one of three things happen: *

    *
  • The count reaches zero due to invocations of the * {@link #countDown} method; or *
  • Some other thread {@linkplain Thread#interrupt interrupts} * the current thread; or *
  • The specified waiting time elapses. *
* *

If the count reaches zero then the method returns with the * value {@code true}. * *

If the current thread: *

    *
  • has its interrupted status set on entry to this method; or *
  • is {@linkplain Thread#interrupt interrupted} while waiting, *
* then {@link InterruptedException} is thrown and the current thread's * interrupted status is cleared. * *

If the specified waiting time elapses then the value {@code false} * is returned. If the time is less than or equal to zero, the method * will not wait at all. * * @param timeout the maximum time to wait * @param unit the time unit of the {@code timeout} argument * @return {@code true} if the count reached zero and {@code false} * if the waiting time elapsed before the count reached zero * @throws InterruptedException if the current thread is interrupted * while waiting */ public boolean await(long timeout, TimeUnit unit) throws InterruptedException { return sync.tryAcquireNanos(1, unit.toNanos(timeout)); } /** * Decrements the count of the latch, releasing all waiting threads if * the count reaches zero. * *

If the current count is greater than zero then it is decremented. * If the new count is zero then all waiting threads are re-enabled for * thread scheduling purposes. * *

If the current count equals zero then nothing happens. */ public void countDown() { sync.release(1); } /** * Returns the current count. * *

This method is typically used for debugging and testing purposes. * * @return the current count */ public long getCount() { return sync.getCount(); } /** * Returns a string identifying this latch, as well as its state. * The state, in brackets, includes the String {@code "Count ="} * followed by the current count. * * @return a string identifying this latch, as well as its state */ public String toString() { return super.toString() + "[Count = " + sync.getCount() + "]"; } }

其中,主要是将调用AQS的共享模式的方法地方,都换成了独占模式的地方,同时修改了一下tryAcquire方法,改之后是

        protected boolean tryAcquire(int acquires) {
            return (getState() == 0) ? true : false;
        }

主要是因为tryAcquire和tryAcquireShared的返回值类型不一样,前者是boolean,后者是int。

此时,再使用两个主线程去等待三个子线程完成

    public static void main(String[] args) {
        ExecutorService service = Executors.newFixedThreadPool(3);
        final MyCountDownLatch latch = new MyCountDownLatch(3);
        for (int i = 0; i < 3; i++) {
            Runnable runnable = new Runnable() {
                @Override
                public void run() {
                    try {
                        System.out.println("子线程" + Thread.currentThread().getName() + "开始执行");
                        Thread.sleep((long) (Math.random() * 10000));
                        System.out.println("子线程"+Thread.currentThread().getName()+"执行完成");
                        latch.countDown();//当前线程调用此方法,则计数减一
                    } catch (InterruptedException e) {
                        e.printStackTrace();
                    }
                }
            };
            service.execute(runnable);
        }

        try {
            Thread thread1 = new Thread(() -> {
                System.out.println("线程1等待子线程执行完成");
                try {
                    latch.await();
                } catch (InterruptedException e) {
                    throw new RuntimeException(e);
                }
                System.out.println("线程1开始执行");
            });
            thread1.start();

            Thread thread2 = new Thread(() -> {
                System.out.println("线程2等待子线程执行完成");
                try {
                    latch.await();
                } catch (InterruptedException e) {
                    throw new RuntimeException(e);
                }
                System.out.println("线程2开始执行");
            });
            thread2.start();

            System.in.read();
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

输出

子线程pool-1-thread-3开始执行
子线程pool-1-thread-1开始执行
子线程pool-1-thread-2开始执行
线程1等待子线程执行完成
线程2等待子线程执行完成
子线程pool-1-thread-2执行完成
子线程pool-1-thread-3执行完成
子线程pool-1-thread-1执行完成
线程1开始执行

可见只有线程1在MyCountDownLatch的state减为0之后恢复了执行。而线程2却一直在AQS队列中,无法恢复执行。

当然这个MyCountDownLatch类只是为了验证而写的类,不能用于正式环境,因为如果多个主线程等待,就会造成只有一个线程能恢复执行,而其余线程永久阻塞的后果。

你可能感兴趣的:(学习,java,多线程)