CountDownLatch实现原理

前言

Github:https://github.com/yihonglei/thinking-in-concurrent

一 CountDownLatch

1、介绍

CountDownLatch(同步工具类)允许一个或多个线程等待其他线程完成操作。

使用CountDownLatch时,需要指定一个整数值N,此值是线程将要等待的操作数。

当线程M为了要执行操作A而等待时,线程M需要调用await()方法。

await()方法让线程M进入休眠状态直到所有等待的操作A完成为止。

当等待的某个操作A执行完成(每一个处理),它使用countDown()方法来减少CountDownLatch类的内部计数器,

N每次减少1。当内部计数器递减为0时,CountDownLatch会唤醒所有调用await方法休眠的线程,即会唤醒M。

从而实现M执行前执行完操作A。

2、原理

CountDownLatch的构造函数接收一个int类型的参数作为计数器构造参数,如果你想等待N个点完成,这里就传入N。

当我们调用CountDownLatch的countDown()方法时,N就会减1,CountDownLatch的await()方法会阻塞当前线程,

直到N变成零。由于countDown()方法可以用在任何地方,所以这里说的N个点,可以是N个线程,也可以是1个线程里的

N个执行步骤。用在多个线程时,只需要把这个CountDownLatch的引用传递到线程里即可。

3、核心方法

countDown():用于减少计数器次数,每调用一次就会减少1,当锁释放完时,进行线程唤醒。

await():负责线程的阻塞,当CountDownLatch计数的值为0时,获取到锁,才返回主线程执行。

4、典型场景

CountDownLatch使用场景主要用于控制主线程等待所有子线程全部执行完成然后恢复主线程执行。

二 CountDownLatch实例

1、实例场景

我们需要批量的从数据库查询出数据进行处理。一般会想到用多线程去处理,

但是,有一个问题就是我们如何保证每一次查询的数据不是正在处理的数据?

   方法有很多种,可以在每一批数据处理完之后再去数据库取下一批数据,每一批数据采取多线程处理的方式。

我们也可以采用别的方案,这里只针对使用CountDownLatch来实现批量处理。

CountDownLatch控制主线程必须等待线程池子线程执行完才恢复执行主线程。

2、实例代码

package com.jpeony.concurrent.countdownlatch;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

/**
 * 多线程+CountDownLatch演示
 *
 * @author yihonglei
 */
public class CountDownLatchTest {
    // 线程池
    private static ExecutorService executorService = Executors.newFixedThreadPool(10);

    public static void main(String[] args) {
        int counterBatch = 1;
        try {
            // 数据循环处理
            while (true) {
                // 模拟数据库查询出的List
                List list = new ArrayList<>();
                for (int i = 0; i < 10; i++) {
                    list.add("user" + i);
                }
                // 计数器大小定义为集合大小,避免处理不一致导致主线程无限等待
                CountDownLatch countDownLatch = new CountDownLatch(list.size());
                // 循环处理List
                list.parallelStream().forEach(userId -> {
                    // 任务提交线程池
                    CompletableFuture.supplyAsync(() -> {
                        try {
                            // 处理用户数据
                            dealUser(userId);
                        } finally {
                            countDownLatch.countDown();
                        }
                        return 1;
                    }, executorService);
                });
                // 主线程等待所有子线程都执行完成时,恢复执行主线程
                countDownLatch.await();
                System.out.println("========================恢复主线程执行==========================");
                // 数据批次计数器
                counterBatch++;
                // 模拟执行5批
                if (counterBatch > 5) {
                    break;
                }
            }
            System.out.println("循环退出,程序执行完成,counterBatch=" + counterBatch);
            // 关闭线程池
            executorService.shutdown();
        } catch (Exception e) {
            System.out.println("异常日志");
        }
    }

    /**
     * 模拟根据用户Id处理用户数据的逻辑
     */
    public static void dealUser(String userId) {
        System.out.println("ThreadName:" + Thread.currentThread().getName() + ", userId:" + userId + " 处理完成!");
    }

}

运行结果:

ThreadName:pool-1-thread-3, userId:user4 处理完成!
ThreadName:pool-1-thread-7, userId:user9 处理完成!
ThreadName:pool-1-thread-2, userId:user7 处理完成!
ThreadName:pool-1-thread-9, userId:user3 处理完成!
ThreadName:pool-1-thread-5, userId:user2 处理完成!
ThreadName:pool-1-thread-4, userId:user8 处理完成!
ThreadName:pool-1-thread-1, userId:user0 处理完成!
ThreadName:pool-1-thread-8, userId:user6 处理完成!
ThreadName:pool-1-thread-10, userId:user5 处理完成!
ThreadName:pool-1-thread-6, userId:user1 处理完成!
========================恢复主线程执行==========================
ThreadName:pool-1-thread-3, userId:user8 处理完成!
ThreadName:pool-1-thread-7, userId:user1 处理完成!
ThreadName:pool-1-thread-9, userId:user9 处理完成!
ThreadName:pool-1-thread-2, userId:user4 处理完成!
ThreadName:pool-1-thread-7, userId:user6 处理完成!
ThreadName:pool-1-thread-9, userId:user0 处理完成!
ThreadName:pool-1-thread-4, userId:user3 处理完成!
ThreadName:pool-1-thread-5, userId:user2 处理完成!
ThreadName:pool-1-thread-3, userId:user7 处理完成!
ThreadName:pool-1-thread-1, userId:user5 处理完成!
========================恢复主线程执行==========================
ThreadName:pool-1-thread-2, userId:user5 处理完成!
ThreadName:pool-1-thread-7, userId:user8 处理完成!
ThreadName:pool-1-thread-10, userId:user0 处理完成!
ThreadName:pool-1-thread-4, userId:user1 处理完成!
ThreadName:pool-1-thread-8, userId:user2 处理完成!
ThreadName:pool-1-thread-5, userId:user7 处理完成!
ThreadName:pool-1-thread-9, userId:user6 处理完成!
ThreadName:pool-1-thread-6, userId:user4 处理完成!
ThreadName:pool-1-thread-2, userId:user9 处理完成!
ThreadName:pool-1-thread-7, userId:user3 处理完成!
========================恢复主线程执行==========================
ThreadName:pool-1-thread-3, userId:user1 处理完成!
ThreadName:pool-1-thread-1, userId:user8 处理完成!
ThreadName:pool-1-thread-8, userId:user2 处理完成!
ThreadName:pool-1-thread-9, userId:user6 处理完成!
ThreadName:pool-1-thread-2, userId:user3 处理完成!
ThreadName:pool-1-thread-1, userId:user4 处理完成!
ThreadName:pool-1-thread-5, userId:user0 处理完成!
ThreadName:pool-1-thread-10, userId:user5 处理完成!
ThreadName:pool-1-thread-3, userId:user9 处理完成!
ThreadName:pool-1-thread-4, userId:user7 处理完成!
========================恢复主线程执行==========================
ThreadName:pool-1-thread-6, userId:user0 处理完成!
ThreadName:pool-1-thread-7, userId:user8 处理完成!
ThreadName:pool-1-thread-2, userId:user3 处理完成!
ThreadName:pool-1-thread-5, userId:user5 处理完成!
ThreadName:pool-1-thread-8, userId:user1 处理完成!
ThreadName:pool-1-thread-10, userId:user6 处理完成!
ThreadName:pool-1-thread-1, userId:user7 处理完成!
ThreadName:pool-1-thread-7, userId:user2 处理完成!
ThreadName:pool-1-thread-6, userId:user4 处理完成!
ThreadName:pool-1-thread-9, userId:user9 处理完成!
========================恢复主线程执行==========================
循环退出,程序执行完成,counterBatch=6

程序分析:

1)模拟从数据库每一次取出一批数据,每批数据为10条;

2)CountDownLatch计数器大小设定与数据条数相同,这里就为10;

3)然后循环List,每一条数据创建一个线程,然后提交线程池,每一个线程处理完要调countDown(),每次减1。

4)主线程也就是这里的main线程,调用了await()方法,await()方法表示等待线程池的线程执行完成,恢复主线程执行,

即CountDownLatch计数器为0时恢复主线程,进行下一次的循环取批数据处理。

从而我们可以实现每一批数据取出后,交由线程池多线程处理,并且主线程会等待子线程都执行完成,

然后才恢复执行,进行下一次的循环取批处理,就不会出现取批次时取到正在处理的数据。

三 CountDownLatch源码分析(jdk8)

1、CountDownLatch(int count)构造函数

public CountDownLatch(int count) {
        if (count < 0) throw new IllegalArgumentException("count < 0");
        this.sync = new Sync(count);
    }

先从构造函数看起,传入int的count,对count进行校验,然后new Sync(count)。

Sync(int count) {
            setState(count);
        }

Sync为AQS的子类,在构造函数里面,通过setState设置state的值为count,state为volatile变量,保证多线程可见性。

2、await()方法

public void await() throws InterruptedException {
        sync.acquireSharedInterruptibly(1);
    }

调用CountDownLatch内部类Sync父类AbstractQueuedSynchronizer的模板方法acquireSharedInterruptibly()尝试获取共享锁。

public final void acquireSharedInterruptibly(int arg)
            throws InterruptedException {
        // 1
        if (Thread.interrupted())
            throw new InterruptedException();
        // 2
        if (tryAcquireShared(arg) < 0)
            // 3
            doAcquireSharedInterruptibly(arg);
    }

第1步:判断线程是否中断,中断则抛出线程中断异常;

第2步:tryAcquireShared(arg)方法尝试获取共享锁,当state为0时,返回1才能获取锁,主线程会继续执行

否则返回-1,获取不到锁,则调用await的线程(主线程)通过doAcquireSharedInterruptibly(arg)方法进行阻塞操作

这里可以结合实例理解为main主线程被阻塞,那么主线程在哪里被唤醒的?在countDown()方法里进行主线程唤醒。

protected int tryAcquireShared(int acquires) {
            return (getState() == 0) ? 1 : -1;
        }

第3步:doAcquireSharedInterruptibly(arg)如何阻塞主线程?

/**
     * Acquires in shared interruptible mode.
     * @param arg the acquire argument
     */
    private void doAcquireSharedInterruptibly(int arg)
        throws InterruptedException {
        // 基于共享模式创建节点
        final Node node = addWaiter(Node.SHARED);
        boolean failed = true;
        try {
            for (;;) {
                // 获取当前节点的前驱节点
                final Node p = node.predecessor();
                if (p == head) {
                    // 尝试获取共享锁
                    int r = tryAcquireShared(arg);
                    if (r >= 0) {
                        setHeadAndPropagate(node, r);
                        p.next = null; // help GC
                        failed = false;
                        return;
                    }
                }
                // 获取锁失败后暂停线程
                if (shouldParkAfterFailedAcquire(p, node) &&
                    parkAndCheckInterrupt())
                    throw new InterruptedException();
            }
        } finally {
            // 如果获取锁失败,线程已经被暂停了,取消尝试获取锁的操作
            if (failed)
                cancelAcquire(node);
        }
    }

addWaiter(Node mode):初始化队里,并基于当前线程构建节点添加到队列尾部。

private Node addWaiter(Node mode) {
        // 基于当前线程构建Node
        Node node = new Node(Thread.currentThread(), mode);
        // Try the fast path of enq; backup to full enq on failure
        // 先尝试通过compareAndSetTail快速添加队列节点,不行再通过enq入队。
        Node pred = tail;
        // 添加第一个队列节点时,尾节点是空的,不会走快速添加,之后才会走CAS快速添加
        if (pred != null) {
            node.prev = pred;
            if (compareAndSetTail(pred, node)) {
                pred.next = node;
                return node;
            }
        }
        // 第一次添加节点,走这里,老哥
        enq(node);
        return node;
    }

enq(final Node node):初始化队列并添加当前线程构建的节点到队尾。

private Node enq(final Node node) {
        for (;;) {
            // 获取尾节点
            Node t = tail;
            // 第一次循环,t是null,会进入if判断,compareAndSetHead设置new Node()到队列,
            // 这个时候队列只有一个节点,就是头结点,也是尾节点
            if (t == null) { // Must initialize
                if (compareAndSetHead(new Node()))
                    tail = head;
            } else {// 节点插入队尾
                // 第二次循环时,当前节点的前驱节点
                node.prev = t;
                // 节点添加到队尾
                if (compareAndSetTail(t, node)) {
                    // t的下一个节点指向node,跟头结点建立引用,形成链表
                    t.next = node;
                    // 返回t(这个时候队列的头结点是new Node(),尾节点是我们传进来的node,队列里只有两个节点)
                    return t;
                }
            }
        }
    }

shouldParkAfterFailedAcquire(Node pred, Node node):设置节点的状态为等待唤醒状态。

private static boolean shouldParkAfterFailedAcquire(Node pred, Node node) {
        int ws = pred.waitStatus;
        if (ws == Node.SIGNAL)
            /*
             * This node has already set status asking a release
             * to signal it, so it can safely park.
             */
            return true;
        if (ws > 0) {
            /*
             * Predecessor was cancelled. Skip over predecessors and
             * indicate retry.
             */
            do {
                node.prev = pred = pred.prev;
            } while (pred.waitStatus > 0);
            pred.next = node;
        } else {
            /*
             * waitStatus must be 0 or PROPAGATE.  Indicate that we
             * need a signal, but don't park yet.  Caller will need to
             * retry to make sure it cannot acquire before parking.
             */
            compareAndSetWaitStatus(pred, ws, Node.SIGNAL);
        }
        return false;
    }

boolean parkAndCheckInterrupt():调用LockSupport.park暂停当前线程,并返回线程是否中断的状态。

private final boolean parkAndCheckInterrupt() {
        // 暂停当前的线程
        LockSupport.park(this);
        // 获取线程是否中断的状态
        return Thread.interrupted();
    }

调用await()方法的现在在这里被暂停的,后期通过countDown()里面的逻辑进行唤醒。 

3、countDown()方法

调用countDown()方法,每调用一次state就会减1。

public void countDown() {
        sync.releaseShared(1);
    }

调用CountDownLatch内部类Sync的releaseShared()方法,arg传值为1。

public final boolean releaseShared(int arg) {
        // 1
        if (tryReleaseShared(arg)) {
            // 2
            doReleaseShared();
            return true;
        }
        return false;
    }

第1步:执行tryReleaseShared(arg)方法,返回true或false,尝试去释放共享锁。

protected boolean tryReleaseShared(int releases) {
            // Decrement count; signal when transition to zero
            for (;;) {// 自旋减1
                int c = getState();
                if (c == 0)
                    return false;
                int nextc = c-1;
                if (compareAndSetState(c, nextc))
                    return nextc == 0;// nextc减到为0时,返回true
            }
        }

即当最后一次进行countDown()操作时state为1,即c为1,则nextc为0,进行CAS操作后,state变为0,返回true,

则执行doReleaseShared()方法。

第2步:执行doReleaseShared():方法释放共享锁,唤醒调用await()等待线程

private void doReleaseShared() {
        /*
         * Ensure that a release propagates, even if there are other
         * in-progress acquires/releases.  This proceeds in the usual
         * way of trying to unparkSuccessor of head if it needs
         * signal. But if it does not, status is set to PROPAGATE to
         * ensure that upon release, propagation continues.
         * Additionally, we must loop in case a new node is added
         * while we are doing this. Also, unlike other uses of
         * unparkSuccessor, we need to know if CAS to reset status
         * fails, if so rechecking.
         */
        for (;;) {
            // 获取头结点
            Node h = head;
            // 判断头结点不为空,并且不是尾节点,则进入if逻辑
            if (h != null && h != tail) {
                int ws = h.waitStatus;
                if (ws == Node.SIGNAL) {// 头结点的状态为Node.SIGNAL
                    if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0))
                        continue;            // loop to recheck cases
                    unparkSuccessor(h);// 唤醒头节点的后续节点线程
                }
                else if (ws == 0 &&
                         !compareAndSetWaitStatus(h, 0, Node.PROPAGATE))
                    continue;                // loop on failed CAS
            }
            // 队列里面只有头结点时,退出锁的循环释放
            if (h == head)                   // loop if head changed
                break;
        }
    }

unparkSuccessor(Node node):唤醒后续节点线程。

/**
     * Wakes up node's successor, if one exists.
     *
     * @param node the node
     */
    private void unparkSuccessor(Node node) {
        /*
         * If status is negative (i.e., possibly needing signal) try
         * to clear in anticipation of signalling.  It is OK if this
         * fails or if status is changed by waiting thread.
         */
        int ws = node.waitStatus;
        if (ws < 0)
            compareAndSetWaitStatus(node, ws, 0);

        /*
         * Thread to unpark is held in successor, which is normally
         * just the next node.  But if cancelled or apparently null,
         * traverse backwards from tail to find the actual
         * non-cancelled successor.
         */
        // node是外层传入的头节点,s为头节点的后继节点
        Node s = node.next;
        if (s == null || s.waitStatus > 0) {
            s = null;
            for (Node t = tail; t != null && t != node; t = t.prev)
                if (t.waitStatus <= 0)
                    s = t;
        }
        if (s != null)
            // 唤醒线程
            LockSupport.unpark(s.thread);
    }

主线程一开始被构建在Node节点中作为成员变量,被LockSupport.park暂停了,这里当state为0时获取锁到锁,

通过LockSupport.unpark唤醒主线程,当线程唤醒后,调用await()的线程会继续执行,去获取到锁,继续执行代码。

四 CountDownLatch总结

1、CountDownLatch常用于线程控制,批量处理。

2、A操作调用countDown()减少计数器数值,M调用await()一直等待,直到countDown()将state减为0时恢复主线程执行。

你可能感兴趣的:(#,---多线程/高并发,Thinking,In,Concurrent)