CountDownLatch源码分析

CountDownLatch源码分析

上一节讲到CopyOnWriteArrayList,其中用到了CountDownLatch,于是想这节就直接讲讲CountDownLatch的原理吧,顺便说下CountDownLatch的简单用法。

CountDownLatch用法


想来想去还是把上一个demo贴过来吧,如下:

CopyOnWriteArrayList list = new CopyOnWriteArrayList<>();
        ExecutorService executorService = Executors.newFixedThreadPool(10);
        CountDownLatch countDownLatch = new CountDownLatch(9);
        for (int i = 0; i < 9; i++) {
            executorService.execute(() -> {
                Random random = new Random(System.currentTimeMillis());
                list.add(random.nextInt());
                countDownLatch.countDown();
            });
        }
        countDownLatch.await();
        executorService.shutdown();
        System.out.println(list.toString());

这里CountDownLatch会阻塞住executorService的shutdown操作,只有等线程池中的线程运行结束了,这里await方法才会放开限制,继续向下执行。

从这,初步开来,CountDownLatch就是初始化一个count值,countDown方法就是count方法减一操作,await方法即是监听count值为0时,则放开当前方法限制,今日下一步操作,那么这里的count变量应该就是一个cas的操作,否则会受到并发导致count值计算不准确的问题,下面就来看下CountDownLatch的具体源码。

CountDownLatch源码


CountDownLatch对象的初始化为:

CountDownLatch countDownLatch = new CountDownLatch(9);

进入CountDownLatch的构造方法。

 /**
     * Constructs a {@code CountDownLatch} initialized with the given count.
     *
     * @param count the number of times {@link #countDown} must be invoked
     *        before threads can pass through {@link #await}
     * @throws IllegalArgumentException if {@code count} is negative
     */
    public CountDownLatch(int count) {
        if (count < 0) throw new IllegalArgumentException("count < 0");
        this.sync = new Sync(count);
    }

Sync(int count) {
            setState(count);
}

protected final void setState(int newState) {
        state = newState;
}

这里的操作比较简单,和之前猜想的相同,对state进行赋值。

我们继续看countDown方法。

public void countDown() {
        sync.releaseShared(1);
    }
public final boolean releaseShared(int arg) {
        if (tryReleaseShared(arg)) {
            doReleaseShared();
            return true;
        }
        return false;
    }

protected boolean tryReleaseShared(int releases) {
            // Decrement count; signal when transition to zero
            for (;;) {
                int c = getState();
                if (c == 0)
                    return false;
                int nextc = c-1;
                if (compareAndSetState(c, nextc))
                    return nextc == 0;
            }
        }
    }

countDwon主要功能方法为tryReleaseShared,这个方法主要功能就是不断获取当前state值,如果当前值为0,则直接返回false,这个情况应该就是几个线程在达到count值时其他线程的情况。

int c = getState();
 if (c == 0)
      return false;

然后此处最后一个到达的线程的操作就是对当前值减一后,cas操作,然后判断nextc值是否为0,如果为0,则返回true.

if (compareAndSetState(c, nextc))
       return nextc == 0;

这里返回true,就回到了releaseShared方法。

if (tryReleaseShared(arg)) {
            doReleaseShared();
            return true;
        }

此处进入doReleaseShared方法,这个方法涉及到AQS的共享锁模式,这个在后面单独讲AQS的时候进行讲述,在这就不讲了。

继续看CountDownLatch的源码方法,在countDown方法看完之后,我们继续看await方法,如下:

public void await() throws InterruptedException {
        sync.acquireSharedInterruptibly(1);
    }
public final void acquireSharedInterruptibly(int arg)
            throws InterruptedException {
        if (Thread.interrupted())
            throw new InterruptedException();
        if (tryAcquireShared(arg) < 0)
            doAcquireSharedInterruptibly(arg);
    }
protected int tryAcquireShared(int acquires) {
            return (getState() == 0) ? 1 : -1;
}

此处对state值进行判断,state为0,返回1,非0返回-1,在线程没有结束时,进入到doAcquireSharedInterruptibly方法中,这个方法用途就是阻塞当前线程,对state值进行判断,如果值大于1,即所有线程都工作完成后,进行一系列清理操作,最终返回,然后CountDownLatch的工作就完成了。

CountDownLatch的源码还是比较简单的,不过其核心知识点应该在AQS中,这块在最后进行讲述。

你可能感兴趣的:(java技术与应用)