CountDownLatch 和 CyclicBarrier 的运用及实现原理

I.CountDownLatch 和 CyclicBarrier 的运用

CountDownlatch:

定义: 其是一个线程同步的辅助工具,通过它可以做到使一条线程一直阻塞等待,直到其他线程完成其所处理的任务。一个特性就是它不要求调用countDown方法的线程等到计数到达0时才继续,而在所有线程都能通过之前,它只是阻止任何线程继续通过一个await

用法:用给定的计数初始化CountDownLath。调用countDown()方法计数减 1,在计数被减到 0之前,调用await方法会一直阻塞。减为 0之后,则会迅速释放所有阻塞等待的线程,并且调用await操作会立即返回。

场景:(1)将CountDownLatch 的计数置为 1,此时CountDownLath 可以用作一个肩带的开/关锁存器或入口,在通过调用countDown()的线程打开入口前,所有调用await的线程会一直在入口处等待。(2)用 N (N >= 1) 初始化的CountDownLatch 可以是一条线程在N个线程完成某项操作之前一直等待,或者使其在某项操作完成 N 次之前一直等待。

ps:CountDownLath计数无法被重置,如果需要重置计数,请考虑使用CyclicBarrier.

实践: 下面用代码实现10条线程分别计算一组数字,要求者10条线程逻辑上同时开始计算(其实并不能做到同时,CPU核不够,不能达到并行计算),并且10条线程中如果有任何一条线程没有计算完成之前,谁都不允许提前返回。

MyCalculator.java:

package simple.demo;



import java.util.concurrent.Callable;

import java.util.concurrent.CountDownLatch;

/**

 * @author jianying.wcj

 * @date 2013-8-2

 */

public class MyCalculator implements Callable<Integer> {

/**

 * 开始开关

 */

private CountDownLatch startSwitch;

/**

 * 结束开关

 */

private CountDownLatch stopSwitch;

/**

 * 要计算的分组数

 */

private int groupNum; 

/**

 * 构造函数

 */

public MyCalculator(CountDownLatch startSwitch,CountDownLatch stopSwitch,Integer groupNum) {

    this.startSwitch = startSwitch;

    this.stopSwitch = stopSwitch;

    this.groupNum = groupNum;

}



@Override

public Integer call() throws Exception {



    startSwitch.await();

    int res = compute();

    System.out.println(Thread.currentThread().getName()+" is ok wait other thread...");

    stopSwitch.countDown();

    stopSwitch.await();

    System.out.println(Thread.currentThread().getName()+" is stop! the group"+groupNum+" temp result is sum="+res);

    return res;

}

/**

 * 累计求和

 * @return

 * @throws InterruptedException 

 */

public int compute() throws InterruptedException {

    int sum = 0;

    for(int i = (groupNum - 1)*10+1; i <= groupNum * 10; i++) {

        sum += i;

    }

    return sum;

}    }    

MyTest.java:

package simple.demo;

import java.io.BufferedReader;

import java.io.IOException;

import java.io.InputStreamReader;

import java.util.ArrayList;

import java.util.List;

import java.util.concurrent.CountDownLatch;

import java.util.concurrent.ExecutionException;

import java.util.concurrent.ExecutorService;

import java.util.concurrent.Executors;

import java.util.concurrent.Future;



public class MyTest {



private int groupNum = 10;

/**

 * 开始和结束开关

 */

private CountDownLatch startSwitch = new CountDownLatch(1);



private CountDownLatch stopSwitch = new CountDownLatch(groupNum);

/**

 * 线程池

 */

private ExecutorService service = Executors.newFixedThreadPool(groupNum);

/**

 * 保存计算结果

 */

private List<Future<Integer>> result = new ArrayList<Future<Integer>>();

/**

 * 启动groupNum条线程计算数值

 */

public void init() {



    for(int i = 1; i <= groupNum; i++) {

        result.add(service.submit(new MyCalculator(startSwitch,stopSwitch,i)));

    }

    System.out.println("init is ok!");

}



public void printRes() throws InterruptedException, ExecutionException {



    int sum = 0;



    for(Future<Integer> f : result) {

        sum += f.get();

    }

    System.out.println("the result is "+sum);

}



public void start() {

    this.startSwitch.countDown();

}



public void stop() throws InterruptedException {

    this.stopSwitch.await();

    this.service.shutdown();

}



public static void main(String[] args) throws IOException, InterruptedException, ExecutionException {



    BufferedReader reader = new BufferedReader(new InputStreamReader(System.in));



    MyTest myTest = new MyTest();

    myTest.init();

    System.out.println("please enter start command....");



    reader.readLine();

    myTest.start();

    myTest.stop();



    myTest.printRes();

} }

运行结果:

init is ok!

please enter start command....



pool-1-thread-1 is ok wait other thread...

pool-1-thread-2 is ok wait other thread...

pool-1-thread-3 is ok wait other thread...

pool-1-thread-4 is ok wait other thread...

pool-1-thread-6 is ok wait other thread...

pool-1-thread-5 is ok wait other thread...

pool-1-thread-8 is ok wait other thread...

pool-1-thread-7 is ok wait other thread...

pool-1-thread-9 is ok wait other thread...

pool-1-thread-10 is ok wait other thread...

pool-1-thread-10 is stop! the group10 temp result is sum=955

pool-1-thread-1 is stop! the group1 temp result is sum=55

pool-1-thread-2 is stop! the group2 temp result is sum=155

pool-1-thread-3 is stop! the group3 temp result is sum=255

pool-1-thread-4 is stop! the group4 temp result is sum=355

pool-1-thread-6 is stop! the group6 temp result is sum=555

pool-1-thread-5 is stop! the group5 temp result is sum=455

pool-1-thread-8 is stop! the group8 temp result is sum=755

pool-1-thread-7 is stop! the group7 temp result is sum=655

pool-1-thread-9 is stop! the group9 temp result is sum=855

the result is 5050

CyclicBarrier.java:

定义:其是一个同步辅助类,它允许一组线程互相等待,直到到达某个公共的屏障点,所有线程一起继续执行或者返回。一个特性就是CyclicBarrier支持一个可选的Runnable命令,在一组线程中的最后一个线程到达之后,该命令只在每个屏障点运行一次。若在继续所有参与线程之前更新此共享状态,此屏障操作很有用。

用法:用计数 N 初始化CyclicBarrier, 每调用一次await,线程阻塞,并且计数+1(计数起始是0),当计数增长到指定计数N时,所有阻塞线程会被唤醒。继续调用await也将迅速返回。

场景:用N初始化CyclicBarrier,可以在N线程中分布调用await方法,可以控制N调线程都执行到await方法后,一起继续执行。

实践:和CountDownLatch实践相同,见上文:

MyCalculator.java:

package simple.demo;



import java.util.concurrent.Callable;

import java.util.concurrent.CountDownLatch;

import java.util.concurrent.CyclicBarrier;



public class MyCalculator implements Callable<Integer> {

/**

 * 开始开关

 */

private CyclicBarrier startSwitch;

/**

 * 结束开关

 */

private CyclicBarrier stopSwitch;

/**

 * 要计算的分组数

 */

private int groupNum; 

/**

 * 构造函数

 */

public MyCalculator(CyclicBarrier startSwitch,CyclicBarrier stopSwitch,Integer groupNum) {

    this.startSwitch = startSwitch;

    this.stopSwitch = stopSwitch;

    this.groupNum = groupNum;

}



@Override

public Integer call() throws Exception {



    startSwitch.await();

    int res = compute();

    System.out.println(Thread.currentThread().getName()+" is ok wait other thread...");

    stopSwitch.await();

    System.out.println(Thread.currentThread().getName()+" is stop! the group"+groupNum+" temp result is sum="+res);

    return res;

}

/**

 * 累计求和

 * @return

 * @throws InterruptedException 

 */

public int compute() throws InterruptedException {

    int sum = 0;

    for(int i = (groupNum - 1)*10+1; i <= groupNum * 10; i++) {

        sum += i;

    }

    return sum;

}}

MyTest.java:

package simple.demo;



import java.io.BufferedReader;

import java.io.IOException;

import java.io.InputStreamReader;

import java.util.ArrayList;

import java.util.List;

import java.util.concurrent.BrokenBarrierException;

import java.util.concurrent.CyclicBarrier;

import java.util.concurrent.ExecutionException;

import java.util.concurrent.ExecutorService;

import java.util.concurrent.Executors;

import java.util.concurrent.Future;



public class MyTest {



private int groupNum = 10;

/**

 * 开始和结束开关

 */

private CyclicBarrier startSwitch = new CyclicBarrier(groupNum+1);



private CyclicBarrier stopSwitch = new CyclicBarrier(groupNum);

/**

 * 线程池

 */

private ExecutorService service = Executors.newFixedThreadPool(groupNum);

/**

 * 保存计算结果

 */

private List<Future<Integer>> result = new ArrayList<Future<Integer>>();

/**

 * 启动groupNum条线程计算数值

 */

public void init() {



    for(int i = 1; i <= groupNum; i++) {

        result.add(service.submit(new MyCalculator(startSwitch,stopSwitch,i)));

    }

    System.out.println("init is ok!");

}



public void printRes() throws InterruptedException, ExecutionException {



    int sum = 0;



    for(Future<Integer> f : result) {

        sum += f.get();

    }

    System.out.println("the result is "+sum);

}



public void start() throws InterruptedException, BrokenBarrierException {

    this.startSwitch.await();

}



public void stop() throws InterruptedException {



    this.service.shutdown();

}



public static void main(String[] args) throws IOException, InterruptedException, ExecutionException, BrokenBarrierException {



    BufferedReader reader = new BufferedReader(new InputStreamReader(System.in));



    MyTest myTest = new MyTest();

    myTest.init();

    System.out.println("please enter start command....");



    reader.readLine();



    myTest.start();

    myTest.stop();



        myTest.printRes();

    }

}

运行结果:

init is ok!

please enter start command....



pool-1-thread-1 is ok wait other thread...

pool-1-thread-2 is ok wait other thread...

pool-1-thread-3 is ok wait other thread...

pool-1-thread-4 is ok wait other thread...

pool-1-thread-5 is ok wait other thread...

pool-1-thread-6 is ok wait other thread...

pool-1-thread-7 is ok wait other thread...

pool-1-thread-8 is ok wait other thread...

pool-1-thread-9 is ok wait other thread...

pool-1-thread-10 is ok wait other thread...

pool-1-thread-10 is stop! the group10 temp result is sum=955

pool-1-thread-1 is stop! the group1 temp result is sum=55

pool-1-thread-2 is stop! the group2 temp result is sum=155

pool-1-thread-3 is stop! the group3 temp result is sum=255

pool-1-thread-5 is stop! the group5 temp result is sum=455

pool-1-thread-6 is stop! the group6 temp result is sum=555

pool-1-thread-4 is stop! the group4 temp result is sum=355

pool-1-thread-8 is stop! the group8 temp result is sum=755

pool-1-thread-7 is stop! the group7 temp result is sum=655

pool-1-thread-9 is stop! the group9 temp result is sum=855

the result is 5050

II.CountDownLatch 和 CyclicBarrier的实现原理

CountDownLatch的类图如下:

CountDownLatch 和 CyclicBarrier 的运用及实现原理

CountDownLatch的实现是基于AQS的,其实现了一个sync的内部类,而sync继承了AQS。关键的源代码如下:
await方法

 /**

 * Causes the current thread to wait until the latch has counted down to

 * zero, unless the thread is {@linkplain Thread#interrupt interrupted}.

 *

 * <p>If the current count is zero then this method returns immediately.

 *

 * <p>If the current count is greater than zero then the current

 * thread becomes disabled for thread scheduling purposes and lies

 * dormant until one of two things happen:

 * <ul>

 * <li>The count reaches zero due to invocations of the

 * {@link #countDown} method; or

 * <li>Some other thread {@linkplain Thread#interrupt interrupts}

 * the current thread.

 * </ul>

 *

 * <p>If the current thread:

 * <ul>

 * <li>has its interrupted status set on entry to this method; or

 * <li>is {@linkplain Thread#interrupt interrupted} while waiting,

 * </ul>

 * then {@link InterruptedException} is thrown and the current thread's

 * interrupted status is cleared.

 *

 * @throws InterruptedException if the current thread is interrupted

 *         while waiting

 */

public void await() throws InterruptedException {

    sync.acquireSharedInterruptibly(1);

}

CyclicBarrier的类图如下:

CountDownLatch 和 CyclicBarrier 的运用及实现原理

/**

 * Decrements the count of the latch, releasing all waiting threads if

 * the count reaches zero.

 *

 * <p>If the current count is greater than zero then it is decremented.

 * If the new count is zero then all waiting threads are re-enabled for

 * thread scheduling purposes.

 *

 * <p>If the current count equals zero then nothing happens.

 */

public void countDown() {

     sync.releaseShared(1);

}

以上是CountDownLatch的两个关键方法 await 和 countDown 的定义。具体的方法通过注释能够理解,其实CountDownLatch只是简单的利用了 AQS 的 state 属性(表示锁可重入的次数),CountDownLatch 的内部类 sync 重写了 AQS 的 tryAcquireShared,CountDownLatch 的 tryAcquireShared 方法的定义是:

public int tryAcquireShared(int acquires) {

    return getState() == 0? 1 : -1;

}

state的初始值就是初始化 CountDownLatch 时的计数器,在 sync 调用 AQS 的 acquireSharedInterruptibly的时候会判断 tryAcquireShared(int acquires) 是否大于 0,如果小于 0,会将线程挂起。具体的AQS当中挂起线程的方法是:

 /**

 * Acquires in shared interruptible mode.

 * @param arg the acquire argument

 */

private void doAcquireSharedInterruptibly(int arg)

throws InterruptedException {

 final Node node = addWaiter(Node.SHARED);

try {

    for (;;) {

        final Node p = node.predecessor();

        if (p == head) {

            int r = tryAcquireShared(arg);

            if (r >= 0) {

            setHeadAndPropagate(node, r);

            p.next = null; // help GC

            return;

        }

    }

if (shouldParkAfterFailedAcquire(p, node) &&

    parkAndCheckInterrupt())

    break;

}

} catch (RuntimeException ex) {

    cancelAcquire(node);

    throw ex;

}

// Arrive here only if interrupted

    cancelAcquire(node);

    throw new InterruptedException();

}

在CountDownLatch调用countDown方法时,会调用CountDownLatch中内部类sync重写AQS的方法tryReleaseShared,方法的定义如下:

public boolean tryReleaseShared(int releases) {

    // Decrement count; signal when transition to zero

    for (;;) {

         int c = getState();

     if (c == 0)

        return false;

        int nextc = c-1;

    if (compareAndSetState(c, nextc))

         return nextc == 0;

    }

}

可见没调用一次都会将state减1,直到等于 0。CountDownLatch就先说这么多。

CyclicBarrier的类图如下:

CountDownLatch 和 CyclicBarrier 的运用及实现原理

CyclicBarrier的实现是基于ReentrantLock的,而ReentrantLock是基于AQS的,说白了CyclicBarrier最终还是基于AQS的。CyclicBarrier内部使用ReentrantLock的Condition来唤醒栅栏前的线程,关键源代码如下:
await方法:

/**

 * Waits until all {@linkplain #getParties parties} have invoked

 * <tt>await</tt> on this barrier.

 *

 * <p>If the current thread is not the last to arrive then it is

 * disabled for thread scheduling purposes and lies dormant until

 * one of the following things happens:

 * <ul>

 * <li>The last thread arrives; or

 * <li>Some other thread {@linkplain Thread#interrupt interrupts}

 * the current thread; or

 * <li>Some other thread {@linkplain Thread#interrupt interrupts}

 * one of the other waiting threads; or

 * <li>Some other thread times out while waiting for barrier; or

 * <li>Some other thread invokes {@link #reset} on this barrier.

 * </ul>

 *

 * <p>If the current thread:

 * <ul>

 * <li>has its interrupted status set on entry to this method; or

 * <li>is {@linkplain Thread#interrupt interrupted} while waiting

 * </ul>

 * then {@link InterruptedException} is thrown and the current thread's

 * interrupted status is cleared.

 *

 * <p>If the barrier is {@link #reset} while any thread is waiting,

 * or if the barrier {@linkplain #isBroken is broken} when

 * <tt>await</tt> is invoked, or while any thread is waiting, then

 * {@link BrokenBarrierException} is thrown.

 *

 * <p>If any thread is {@linkplain Thread#interrupt interrupted} while waiting,

 * then all other waiting threads will throw

 * {@link BrokenBarrierException} and the barrier is placed in the broken

 * state.

 *

 * <p>If the current thread is the last thread to arrive, and a

 * non-null barrier action was supplied in the constructor, then the

 * current thread runs the action before allowing the other threads to

 * continue.

 * If an exception occurs during the barrier action then that exception

 * will be propagated in the current thread and the barrier is placed in

 * the broken state.

 *

 * @return the arrival index of the current thread, where index

 * <tt>{@link #getParties()} - 1</tt> indicates the first

 * to arrive and zero indicates the last to arrive

 * @throws InterruptedException if the current thread was interrupted

 * while waiting

 * @throws BrokenBarrierException if <em>another</em> thread was

 * interrupted or timed out while the current thread was

 * waiting, or the barrier was reset, or the barrier was

 * broken when {@code await} was called, or the barrier

 * action (if present) failed due an exception.

 */

public int await() throws InterruptedException, BrokenBarrierException {

    try {

      return dowait(false, 0L);

    } catch (TimeoutException toe) {

      throw new Error(toe); // cannot happen;

    }

}

私有的 dowait 方法:

 /**

 * Main barrier code, covering the various policies.

 */

private int dowait(boolean timed, long nanos)

        throws InterruptedException, BrokenBarrierException,

         TimeoutException {

    final ReentrantLock lock = this.lock;

    lock.lock();

    try {

        final Generation g = generation;



        if (g.broken)

            throw new BrokenBarrierException();



        if (Thread.interrupted()) {

            breakBarrier();

            throw new InterruptedException();

        }



       int index = --count;

       if (index == 0) {  // tripped

           boolean ranAction = false;

           try {

       final Runnable command = barrierCommand;

               if (command != null)

                   command.run();

               ranAction = true;

               nextGeneration();

               return 0;

           } finally {

               if (!ranAction)

                   breakBarrier();

           }

       }



        // loop until tripped, broken, interrupted, or timed out

        for (;;) {

            try {

                if (!timed)

                    trip.await();

                else if (nanos > 0L)

                    nanos = trip.awaitNanos(nanos);

            } catch (InterruptedException ie) {

                if (g == generation && ! g.broken) {

                    breakBarrier();

        throw ie;

        } else {

        // We're about to finish waiting even if we had not

        // been interrupted, so this interrupt is deemed to

        // "belong" to subsequent execution.

        Thread.currentThread().interrupt();

        }

            }



            if (g.broken)

                throw new BrokenBarrierException();



            if (g != generation)

                return index;



            if (timed && nanos <= 0L) {

                breakBarrier();

                throw new TimeoutException();

            }

        }

    } finally {

        lock.unlock();

    }

}   

从doAwait方法中可以看到,没调用一次index 会减1,当减为 0时,会调用 breakBarrier()方法。 breakBarrier方法的实现是:

 /**

 * Sets current barrier generation as broken and wakes up everyone.

 * Called only while holding lock.

 */

private void breakBarrier() {

   generation.broken = true;

   count = parties;

   trip.signalAll();

}

会调用 trip.signalAll()唤醒所有的线程(trip的定义 Condition trip = lock.newCondition())。可见 CyclicBarrier 是对独占锁 ReentrantLock 的简单利用。

你可能感兴趣的:(CountDownLatch)