AQS-semaphore&CyclicBarrier&CountDownLatch源码学习

上文:jdk-BlockingQueue源码学习

源码下载:https://gitee.com/hong99/jdk8


semaphore&cyclicbarrier&CountDownLatch的介绍

semaphore基础功能

semaphore简称信号量,主要用于控制访问特定资源的线程数目,底层用的是AQS的状记state。

package com.aqs;

import java.util.concurrent.Semaphore;

/**
 * @author: csh
 * @Date: 2022/12/13 21:11
 * @Description:信号线学习
 */
public class SemaphoreStudy {
    public static void main(String[] args) {
        //创建10个线程
        Semaphore semaphore = new Semaphore(2);
        for (int i = 0; i < 10; i++) {
            new Thread(new Task("线程"+i,semaphore)).start();
        }
    }

    static class Task extends Thread{
        Semaphore semaphore;

        public Task( String name, Semaphore semaphore) {
            this.setName(name);
            this.semaphore = semaphore;
        }

        @Override
        public void run() {
            try {
                semaphore.acquire();
                System.out.println(Thread.currentThread().getName()+"获取到线程");
                Thread.sleep(1000);
                semaphore.release();
                System.out.println(Thread.currentThread().getName()+"释放了线程");
            }catch (Exception e){
                e.printStackTrace();
            }
        }
    }
}

结果

Thread-3获取到线程
Thread-1获取到线程
Thread-3释放了线程
Thread-7获取到线程
Thread-5获取到线程
Thread-1释放了线程
Thread-7释放了线程
Thread-9获取到线程
Thread-11获取到线程
Thread-5释放了线程
Thread-9释放了线程
Thread-15获取到线程
Thread-11释放了线程
Thread-13获取到线程
Thread-15释放了线程
Thread-17获取到线程
Thread-13释放了线程
Thread-19获取到线程
Thread-17释放了线程
Thread-19释放了线程

Process finished with exit code 0

从以上可以看到通过控制这个信号量可以从而控制线程的访问,很多限流场景其实也是类似的实现。

cyclicbarrier基础功能了解

cyclicbarrier栅栏屏障,主要是指定线程数让线待到线程数执行完再继续往下走,像很多限流或者条件数达到才让走,也是类似的逻辑。

简单举例:比如渡河,拼满10个才过河,所以在这个过程中,没满10人就不走,满10人了,往下走,这中间就是等待。

package com.aqs;

import java.util.concurrent.BrokenBarrierException;
import java.util.concurrent.CyclicBarrier;

/**
 * @author: csh
 * @Date: 2022/12/16 21:02
 * @Description:栅栏学习
 */
public class CyclicBarrierStudy {

    static class CyclicBarrierFactory implements Runnable{
        private CyclicBarrier cyclicBarrier;
        private int index;

        public CyclicBarrierFactory(CyclicBarrier cyclicBarrier, int index) {
            this.cyclicBarrier = cyclicBarrier;
            this.index = index;
        }

        @Override
        public void run() {
            try {
                System.out.println(Thread.currentThread().getName()+"当前线坐标:"+index);
                index--;
                cyclicBarrier.await();
            }catch (Exception e){
                e.printStackTrace();
            }
        }
    }

    public static void main(String[] args) throws BrokenBarrierException, InterruptedException {
        CyclicBarrier cyclicBarrier = new CyclicBarrier(10, new Runnable() {
            @Override
            public void run() {
                System.out.println("准备完毕,准备执行任务!");
            }
        });
        for (int i = 0; i < 10; i++) {
            new Thread(new CyclicBarrierFactory(cyclicBarrier,i)).start();
        }
        //等待
        cyclicBarrier.await();
        System.out.println("全部执行完成!");

    }
}

结果

Thread-7当前线坐标:7
Thread-3当前线坐标:3
Thread-2当前线坐标:2
Thread-0当前线坐标:0
Thread-5当前线坐标:5
Thread-9当前线坐标:9
Thread-8当前线坐标:8
Thread-1当前线坐标:1
Thread-6当前线坐标:6
Thread-4当前线坐标:4
准备完毕,准备执行任务!
全部执行完成!

CountDownLatch基础功能了解

CountDownLatch跟CyclicBarrier很像,但是区别是CyclicBarrier主要是针对线程和并发的控制并且可以重置(重复使用),而CountDownLatch不能重置(只能用一次) ,主要以计数为主。

package com.aqs;

import java.util.concurrent.CountDownLatch;

/**
 * @author: csh
 * @Date: 2022/12/16 23:33
 * @Description:线程计数器学习
 */
public class CountDownLatchStudy {


    static class ThreadFactory implements Runnable {
        private CountDownLatch countDownLatch;

        public ThreadFactory(CountDownLatch countDownLatch) {
            this.countDownLatch = countDownLatch;
        }

        @Override
        public void run() {
            System.out.println("当前线程统计数量剩余" + (countDownLatch.getCount() - 1) + "执行了!");
            try {
                Thread.sleep(100);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
            countDownLatch.countDown();
        }
    }

    public static void main(String[] args) throws InterruptedException {
        CountDownLatch countDownLatch = new CountDownLatch(10);
        for (int i = 0; i < 10; i++) {
            new Thread(new ThreadFactory(countDownLatch)).start();
        }
        while (countDownLatch.getCount() > 1) {
            System.out.println("线程等待中,当前还有线程:" + countDownLatch.getCount());
        }
        countDownLatch.await();
        System.out.println("全部执行完毕!");
    }
}

结果

当前线程统计数量剩余9执行了!
当前线程统计数量剩余9执行了!
线程等待中,当前还有线程:10
当前线程统计数量剩余9执行了!
当前线程统计数量剩余9执行了!
当前线程统计数量剩余9执行了!
当前线程统计数量剩余9执行了!
当前线程统计数量剩余9执行了!
当前线程统计数量剩余9执行了!
当前线程统计数量剩余9执行了!
当前线程统计数量剩余9执行了!
线程等待中,当前还有线程:3
全部执行完毕!

源码学习

java.util.concurrent.Semaphore  源码学习

AQS-semaphore&CyclicBarrier&CountDownLatch源码学习_第1张图片

semaphore是通过AQS进行实现锁的功能,可以指定是公平锁或非公平锁。当然与重入锁实现有点像(可以参考前文),下面看看一些公开方法。

方法名称

描述

备注

Semaphore(int)

构造方法


Semaphore(int,boolean)

构建方法

true为公平锁,false为非公平

acquire()

获取锁


acquireUninterruptibly()

获取锁

带中断

tryAcquire()

尝试获取锁


tryAcquire(long,TimeUnit)

尝试获取锁

带超时

release()

释放锁


acquire(int)

获取指定线程


acquireUninterruptibly(int)

获取指定线程

带中断

tryAcquire(int)

尝试获取指定数量线程


tryAcquire(int,long,TimeUnit)

尝试获取指定数量线程

带超时

release(int)

释放指定线程数量的锁


availablePermits()

获取状态


drainPermits()

藜取当前状态


isFair()

是否非公平


hasQueuedThreads

是否有排队


getQueueLength

队列长度


toString

转成字符串


package java.util.concurrent;
import java.util.Collection;
import java.util.concurrent.locks.AbstractQueuedSynchronizer;
//信号量源码实现
public class Semaphore implements java.io.Serializable {
    private static final long serialVersionUID = -3222578661600680210L;
    /** All mechanics via AbstractQueuedSynchronizer subclass */
    private final Sync sync;

    //继承aqs
    abstract static class Sync extends AbstractQueuedSynchronizer {
        private static final long serialVersionUID = 1192457210091910933L;
        //构造方法带初始线程数
        Sync(int permits) {
            setState(permits);
        }
        //获取同步状态
        final int getPermits() {
            return getState();
        }
        //没有锁的话尝试获取锁
        final int nonfairTryAcquireShared(int acquires) {
            for (;;) {
                //获取状态
                int available = getState();
                //如果不是取消和共享进尝试更改状态 一直循环
                int remaining = available - acquires;
                if (remaining < 0 ||
                    compareAndSetState(available, remaining))
                    return remaining;
            }
        }
        //尝试释放指定个数锁
        protected final boolean tryReleaseShared(int releases) {
            //循环
            for (;;) {
                //获取同步状态
                int current = getState();
                //当前状态-解锁个数
                int next = current + releases;
                //小于当前数量 跑出异常
                if (next < current) // overflow
                    throw new Error("Maximum permit count exceeded");
                //更新状态
                if (compareAndSetState(current, next))
                    return true;
            }
        }
        //减少指定数量的线程许可
        final void reducePermits(int reductions) {
            for (;;) {
                int current = getState();
                int next = current - reductions;
                //大于当前数量 抛出异常
                if (next > current) // underflow
                    throw new Error("Permit count underflow");
                //尝试释放
                if (compareAndSetState(current, next))
                    return;
            }
        }
        //清空当前线程的所有许可
        final int drainPermits() {
            for (;;) {
                int current = getState();
                //全部置为初始化
                if (current == 0 || compareAndSetState(current, 0))
                    return current;
            }
        }
    }

    //非公平实现
    static final class NonfairSync extends Sync {
        private static final long serialVersionUID = -2694183684443567898L;
        //构造方法
        NonfairSync(int permits) {
            super(permits);
        }
        //非公平获取线程资源(随机)
        protected int tryAcquireShared(int acquires) {
            return nonfairTryAcquireShared(acquires);
        }
    }

    //公平锁实现
    static final class FairSync extends Sync {
        private static final long serialVersionUID = 2014338818796000944L;
        //构造方法 指定初始数量
        FairSync(int permits) {
            super(permits);
        }
        //公平获取acquires个线程资源
        protected int tryAcquireShared(int acquires) {
            for (;;) {
                //如果有排队返回-1
                if (hasQueuedPredecessors())
                    return -1;
                //获取当前状态
                int available = getState();
                //如果得到后的值小于0或更新成功 则返回
                int remaining = available - acquires;
                if (remaining < 0 ||
                    compareAndSetState(available, remaining))
                    return remaining;
            }
        }
    }

    //指定线程数量构建方法(默认为非公平锁)
    public Semaphore(int permits) {
        sync = new NonfairSync(permits);
    }

    //指定线程数量及锁类型构建方法 true公平锁 false非公平锁
    public Semaphore(int permits, boolean fair) {
        sync = fair ? new FairSync(permits) : new NonfairSync(permits);
    }

    //带中断异常释放锁
    public void acquire() throws InterruptedException {
        sync.acquireSharedInterruptibly(1);
    }

    //获取一个信号信许可
    public void acquireUninterruptibly() {
        sync.acquireShared(1);
    }

    //获取许可
    public boolean tryAcquire() {
        return sync.nonfairTryAcquireShared(1) >= 0;
    }
    //带超时时间获取一个许可
    public boolean tryAcquire(long timeout, TimeUnit unit)
        throws InterruptedException {
        return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
    }

    //释放一个许可
    public void release() {
        sync.releaseShared(1);
    }

    //带中断释放指定permits数量的许可
    public void acquire(int permits) throws InterruptedException {
        if (permits < 0) throw new IllegalArgumentException();
        sync.acquireSharedInterruptibly(permits);
    }

    //获取指定数量的许可
    public void acquireUninterruptibly(int permits) {
        if (permits < 0) throw new IllegalArgumentException();
        sync.acquireShared(permits);
    }

    //获取指定数量的许全部获取成功才返回true
    public boolean tryAcquire(int permits) {
        if (permits < 0) throw new IllegalArgumentException();
        return sync.nonfairTryAcquireShared(permits) >= 0;
    }

    //带过期时间及判断是否中断获取指定数量的许可
    public boolean tryAcquire(int permits, long timeout, TimeUnit unit)
        throws InterruptedException {
        if (permits < 0) throw new IllegalArgumentException();
        return sync.tryAcquireSharedNanos(permits, unit.toNanos(timeout));
    }

    //释放指定数量的许可
    public void release(int permits) {
        if (permits < 0) throw new IllegalArgumentException();
        sync.releaseShared(permits);
    }

    //查看剩余许可
    public int availablePermits() {
        return sync.getPermits();
    }

    //获取并返回立即可用的许可
    public int drainPermits() {
        return sync.drainPermits();
    }

    //减少指定数量的许可(不等待)
    protected void reducePermits(int reduction) {
        if (reduction < 0) throw new IllegalArgumentException();
        sync.reducePermits(reduction);
    }

    //判断是否公平锁,如果是 true 否 false
    public boolean isFair() {
        return sync instanceof FairSync;
    }

    //判断当前是否有排队线程
    public final boolean hasQueuedThreads() {
        return sync.hasQueuedThreads();
    }

    //获取队列长度
    public final int getQueueLength() {
        return sync.getQueueLength();
    }

    //获取队列的线程集合
    protected Collection getQueuedThreads() {
        return sync.getQueuedThreads();
    }

    //转字符串方法
    public String toString() {
        return super.toString() + "[Permits = " + sync.getPermits() + "]";
    }
}

java.util.concurrent.CyclicBarrier  源码学习

AQS-semaphore&CyclicBarrier&CountDownLatch源码学习_第2张图片

package java.util.concurrent;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.ReentrantLock;
//栅栏源码实现
public class CyclicBarrier {
    //供栅栏判断的条件
    private static class Generation {
        boolean broken = false;
    }

    /** 重入锁*/
    private final ReentrantLock lock = new ReentrantLock();
    /** 上锁条件 */
    private final Condition trip = lock.newCondition();
    /** 锁数量*/
    private final int parties;
    /* 运行的线程*/
    private final Runnable barrierCommand;
    /** 当前状态 */
    private Generation generation = new Generation();

    //已执行数量
    private int count;

    //重置屏bujj 唤醒所有锁 并更新执行状态为可执行
    private void nextGeneration() {
        // signal completion of last generation
        trip.signalAll();
        // set up next generation
        count = parties;
        generation = new Generation();
    }

    //打破屏障
    private void breakBarrier() {
        //更新状太卡
        generation.broken = true;
        //赋值总数
        count = parties;
        //幻醒
        trip.signalAll();
    }

    //带超时间的
    private int dowait(boolean timed, long nanos)
        throws InterruptedException, BrokenBarrierException,
               TimeoutException {
        //获取锁并上锁
        final ReentrantLock lock = this.lock;
        lock.lock();
        try {
            //获取当前实例
            final Generation g = generation;
            //如果为直 直接抛出异常
            if (g.broken)
                throw new BrokenBarrierException();
            //如果线程中断直接打破壁垒并抛出异常
            if (Thread.interrupted()) {
                breakBarrier();
                throw new InterruptedException();
            }
            //总数减1 获取当前下标
            int index = --count;
            if (index == 0) { // tripped
                //执行标记
                boolean ranAction = false;
                try {
                    //获取线程
                    final Runnable command = barrierCommand;
                    //不为空则执行
                    if (command != null)
                        command.run();
                    //设为直
                    ranAction = true;
                    //重置屏bujj
                    nextGeneration();
                    //返回0
                    return 0;
                } finally {
                    //如果为false则打破屏障
                    if (!ranAction)
                        breakBarrier();
                }
            }
            //循环直理 其中不超时则等待 如果直过时间中断退出 抛出异常
            for (;;) {
                try {
                    if (!timed)
                        trip.await();
                    else if (nanos > 0L)
                        nanos = trip.awaitNanos(nanos);
                } catch (InterruptedException ie) {
                    if (g == generation && ! g.broken) {
                        breakBarrier();
                        throw ie;
                    } else {
                        // We're about to finish waiting even if we had not
                        // been interrupted, so this interrupt is deemed to
                        // "belong" to subsequent execution.
                        Thread.currentThread().interrupt();
                    }
                }
                //为打破状态
                if (g.broken)
                    throw new BrokenBarrierException();
                //返回下标
                if (g != generation)
                    return index;
                //超过时间抛出异常
                if (timed && nanos <= 0L) {
                    breakBarrier();
                    throw new TimeoutException();
                }
            }
        } finally {
            //最终解锁
            lock.unlock();
        }
    }

    //构造方法
    public CyclicBarrier(int parties, Runnable barrierAction) {
        if (parties <= 0) throw new IllegalArgumentException();
        this.parties = parties;
        this.count = parties;
        this.barrierCommand = barrierAction;
    }

    //指定数量构造方法
    public CyclicBarrier(int parties) {
        this(parties, null);
    }

    //获取剩余未执行数量
    public int getParties() {
        return parties;
    }

    //等待 (一直等到全部线程执行完) 这里是循环致全部结束
    public int await() throws InterruptedException, BrokenBarrierException {
        try {
            return dowait(false, 0L);
        } catch (TimeoutException toe) {
            throw new Error(toe); // cannot happen
        }
    }

    //带超时时间的等待
    public int await(long timeout, TimeUnit unit)
        throws InterruptedException,
               BrokenBarrierException,
               TimeoutException {
        return dowait(true, unit.toNanos(timeout));
    }

    //是否打破状态
    public boolean isBroken() {
        final ReentrantLock lock = this.lock;
        lock.lock();
        try {
            return generation.broken;
        } finally {
            lock.unlock();
        }
    }

    //重置(带锁)
    public void reset() {
        final ReentrantLock lock = this.lock;
        lock.lock();
        try {
            breakBarrier(); // break the current generation
            nextGeneration(); // start a new generation
        } finally {
            lock.unlock();
        }
    }

    //获取当前执行数量
    public int getNumberWaiting() {
        final ReentrantLock lock = this.lock;
        lock.lock();
        try {
            return parties - count;
        } finally {
            lock.unlock();
        }
    }
}

cyclicbarrier内部是通过重入锁来实现,其实也是aqs的一种实现方式,只是这种比较独立,利用了重入锁的功能而以,没有其他那么复杂。

java.util.concurrent.CountDownLatch  源码学习

AQS-semaphore&CyclicBarrier&CountDownLatch源码学习_第3张图片

//计数器实现
public class CountDownLatch {
    /**
     * 
     *使用aqs计数
     */
    private static final class Sync extends AbstractQueuedSynchronizer {
        private static final long serialVersionUID = 4982264981922014374L;
        //同步锁
        Sync(int count) {
            setState(count);
        }
        //获取总数(state)
        int getCount() {
            return getState();
        }
        //尝试指定数量的计数
        protected int tryAcquireShared(int acquires) {
            return (getState() == 0) ? 1 : -1;
        }
        //获取指定数量的计数
        protected boolean tryReleaseShared(int releases) {
            // Decrement count; signal when transition to zero
            //循环
            for (;;) {
                int c = getState();
                if (c == 0)
                    return false;
                int nextc = c-1;
                if (compareAndSetState(c, nextc))
                    return nextc == 0;
            }
        }
    }
    //同步锁
    private final Sync sync;

    //构造方法 默认用了同步锁 数量必须大于0
    public CountDownLatch(int count) {
        if (count < 0) throw new IllegalArgumentException("count < 0");
        this.sync = new Sync(count);
    }

    //等待方法的实现 如果已中断抛出异常 未中断则一直cas等待
    public void await() throws InterruptedException {
        sync.acquireSharedInterruptibly(1);
    }

    //带超时等待
    public boolean await(long timeout, TimeUnit unit)
        throws InterruptedException {
        return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
    }

    //减少一个数量
    public void countDown() {
        sync.releaseShared(1);
    }

    //获取当前待处理数量
    public long getCount() {
        return sync.getCount();
    }

    //转string方法
    public String toString() {
        return super.toString() + "[Count = " + sync.getCount() + "]";
    }
}

可以看到这个计数器更简单,当然使用场景也很有限一般都是一次性的。

最后

    上次看了几个阻塞队列,看得我真的不太想全部去看一遍,因为实在太长了,大几千行的,后面看了这几个计数器实在简单,再之几位同学又一直问这几个索性就看了下。其实这几个看似在工作中用得不多,但是在各类计数或者并发框架中底层很多都使用到了,比如信号量,一般会用来做限流,而计数器或栅栏一般会出现在错误次数或者达到一定量的场景下定时的通知或次数的告警。

你可能感兴趣的:(学习,java,jvm,开发语言)