侧边栏壁纸
博主头像
落叶人生博主等级

走进秋风,寻找秋天的落叶

  • 累计撰写 130562 篇文章
  • 累计创建 28 个标签
  • 累计收到 9 条评论
标签搜索

目 录CONTENT

文章目录

Java并发编程之CountDownLatch源码解析

2023-12-15 星期五 / 0 评论 / 0 点赞 / 37 阅读 / 11440 字

一、导语最近在学习并发编程原理,所以准备整理一下自己学到的知识,先写一篇CountDownLatch的源码分析,之后希望可以慢慢写完整个并发编程。二、什么是CountDownLatchCountDow

一、导语

最近在学习并发编程原理,所以准备整理一下自己学到的知识,先写一篇CountDownLatch的源码分析,之后希望可以慢慢写完整个并发编程。

二、什么是CountDownLatch

CountDownLatch是java的JUC并发包里的一个工具类,可以理解为一个倒计时器,主要是用来控制多个线程之间的通信。
比如有一个主线程A,它要等待其他4个子线程执行完毕之后才能执行,此时就可以利用CountDownLatch来实现这种功能了。

三、简单使用

public static void main(String[] args){	System.out.println("主线程和他的两个小兄弟约好去吃火锅");	System.out.println("主线程进入了饭店");	System.out.println("主线程想要开始动筷子吃饭");	//new一个计数器,初始值为2,当计数器为0时,主线程开始执行	CountDownLatch latch = new CountDownLatch(2);		 new Thread(){             public void run() {                 try {                    System.out.println("子线程1——小兄弟A 正在到饭店的路上");                    Thread.sleep(3000);                    System.out.println("子线程1——小兄弟A 到饭店了");		    //一个小兄弟到了,计数器-1                    latch.countDown();                } catch (InterruptedException e) {                    e.printStackTrace();                }             };         }.start();		 	 new Thread(){             public void run() {                 try {                    System.out.println("子线程2——小兄弟B 正在到饭店的路上");                    Thread.sleep(3000);                    System.out.println("子线程2——小兄弟B 到饭店了");		    //另一个小兄弟到了,计数器-1                    latch.countDown();                } catch (InterruptedException e) {                    e.printStackTrace();                }             };         }.start();		//主线程等待,直到其他两个小兄弟也进入饭店(计数器==0),主线程才能吃饭	 latch.await();	 System.out.println("主线程终于可以开始吃饭了~");}

四、源码分析

核心代码:

CountDownLatch latch = new CountDownLatch(1);        latch.await();        latch.countDown();

其中构造函数的参数是计数器的值;
await()方法是用来阻塞线程,直到计数器的值为0
countDown()方法是执行计数器-1操作

1、首先来看构造函数的代码

public CountDownLatch(int count) {        if (count < 0) throw new IllegalArgumentException("count < 0");        this.sync = new Sync(count);    }

这段代码很简单,首先if判断传入的count是否<0,如果小于0直接抛异常。
然后new一个类Sync,这个Sync是什么呢?我们一起来看下

private static final class Sync extends AbstractQueuedSynchronizer {        private static final long serialVersionUID = 4982264981922014374L;        Sync(int count) {            setState(count);        }        int getCount() {            return getState();        }	//尝试获取共享锁        protected int tryAcquireShared(int acquires) {            return (getState() == 0) ? 1 : -1;        }	//尝试释放共享锁        protected boolean tryReleaseShared(int releases) {            // Decrement count; signal when transition to zero            for (;;) {                int c = getState();                if (c == 0)                    return false;                int nextc = c-1;                if (compareAndSetState(c, nextc))                    return nextc == 0;            }        }    }

可以看到Sync是一个内部类,继承了AQS,AQS是一个同步器,之后我们会详细讲。
其中有几个核心点:

  1. 变量 state是父类AQS里面的变量,在这里的语义是计数器的值
  2. getState()方法也是父类AQS里的方法,很简单,就是获取state的值
  3. tryAcquireShared和tryReleaseShared也是父类AQS里面的方法,在这里CountDownLatch对他们进行了重写,先有个印象,之后详讲。

2、了解了CountDownLatch的构造函数之后,我们再来看它的核心代码,首先是await()。

public void await() throws InterruptedException {        sync.acquireSharedInterruptibly(1);    }

可以看到,其实是通过内部类Sync调用了父类AQS的acquireSharedInterruptibly()方法。

public final void acquireSharedInterruptibly(int arg)            throws InterruptedException {	//判断线程是否是中断状态        if (Thread.interrupted())            throw new InterruptedException();	//尝试获取state的值        if (tryAcquireShared(arg) < 0)//step1            doAcquireSharedInterruptibly(arg);//step2    }

tryAcquireShared(arg)这个方法就是我们刚才在Sync内看到的重写父类AQS的方法,意思就是判断是否getState() == 0,如果state为0,返回1,则step1处不进入if体内acquireSharedInterruptibly(int arg)方法执行完毕。若state!=0,则返回-1,进入if体内step2处。

下面我们来看acquireSharedInterruptibly(int arg)方法:

private void doAcquireSharedInterruptibly(int arg)        throws InterruptedException {	//step1、把当前线程封装为共享类型的Node,加入队列尾部        final Node node = addWaiter(Node.SHARED);        boolean failed = true;        try {            for (;;) {		//step2、获取当前node的前一个元素                final Node p = node.predecessor();		//step3、如果前一个元素是队首                if (p == head) {		    //step4、再次调用tryAcquireShared()方法,判断state的值是否为0                    int r = tryAcquireShared(arg);		    //step5、如果state的值==0                    if (r >= 0) {			//step6、设置当前node为队首,并尝试释放共享锁                        setHeadAndPropagate(node, r);                        p.next = null; // help GC                        failed = false;                        return;                    }                }		//step7、是否可以安心挂起当前线程,是就挂起;并且判断当前线程是否中断                if (shouldParkAfterFailedAcquire(p, node) &&                    parkAndCheckInterrupt())                    throw new InterruptedException();            }        } finally {	//step8、如果出现异常,failed没有更新为false,则把当前node从队列中取消            if (failed)                cancelAcquire(node);        }    }

按照代码中的注释,我们可以大概了解该方法的内容,下面我们来仔细看下其中调用的一些方法是干什么的。
1、首先看addWaiter()

//step1private Node addWaiter(Node mode) {	//把当前线程封装为node        Node node = new Node(Thread.currentThread(), mode);        // Try the fast path of enq; backup to full enq on failure	//获取当前队列的队尾tail,并赋值给pred        Node pred = tail;	//如果pred!=null,即当前队尾不为null        if (pred != null) {	//把当前队尾tail,变成当前node的前继节点            node.prev = pred;	    //cas更新当前node为新的队尾            if (compareAndSetTail(pred, node)) {                pred.next = node;                return node;            }        }	//如果队尾为空,走enq方法        enq(node);//step1.1        return node;    }-----------------------------------------------------------------//step1.1private Node enq(final Node node) {        for (;;) {            Node t = tail;	    //如果队尾tail为null,初始化队列            if (t == null) { // Must initialize		//cas设置一个新的空node为队首                if (compareAndSetHead(new Node()))                    tail = head;            } else {		//cas把当前node设置为新队尾,把前队尾设置成当前node的前继节点                node.prev = t;                if (compareAndSetTail(t, node)) {                    t.next = node;                    return t;                }            }        }    }

2、接下来我们在来看setHeadAndPropagate()方法,看其内部实现

//step6private void setHeadAndPropagate(Node node, int propagate) {	//获取队首head        Node h = head; // Record old head for check below	//设置当前node为队首,并取消node所关联的线程        setHead(node);	//        if (propagate > 0 || h == null || h.waitStatus < 0 ||            (h = head) == null || h.waitStatus < 0) {            Node s = node.next;	    //如果当前node的后继节点为null或者是shared类型的            if (s == null || s.isShared())		//释放锁,唤醒下一个线程                doReleaseShared();//step6.1        }    }--------------------------------------------------------------------//step6.1private void doReleaseShared() {        for (;;) {	    //找到头节点            Node h = head;            if (h != null && h != tail) {		//获取头节点状态                int ws = h.waitStatus;                if (ws == Node.SIGNAL) {                    if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0))                        continue;            // loop to recheck cases		    //唤醒head节点的next节点                    unparkSuccessor(h);                }                else if (ws == 0 &&                         !compareAndSetWaitStatus(h, 0, Node.PROPAGATE))                    continue;                // loop on failed CAS            }            if (h == head)                   // loop if head changed                break;        }    }

3、接下来我们来看countDown()方法。

public void countDown() {        sync.releaseShared(1);    }

可以看到调用的是父类AQS的releaseShared 方法

public final boolean releaseShared(int arg) {	//state-1        if (tryReleaseShared(arg)) {//step1	    //唤醒等待线程,内部调用的是LockSupport.unpark方法            doReleaseShared();//step2            return true;        }        return false;    }------------------------------------------------------------------//step1protected boolean tryReleaseShared(int releases) {            // Decrement count; signal when transition to zero            for (;;) {		//获取当前state的值                int c = getState();                if (c == 0)                    return false;                int nextc = c-1;		//cas操作来进行原子减1                if (compareAndSetState(c, nextc))                    return nextc == 0;            }        }

五、总结

CountDownLatch主要是通过计数器state来控制是否可以执行其他操作,如果不能就通过LockSupport.park()方法挂起线程,直到其他线程执行完毕后唤醒它。下面我们通过一个简单的图来帮助我们理解一下:
PS:本人也是还在学习的路上,理解的也不是特别透彻,如有错误,愿倾听教诲。^_^

广告 广告

评论区