多线程AQS (AbstractQueuedSynchronizer) 一、介绍 类如其名,抽象的队列式的同步器,AQS定义了一套多线程访问共享资源的同步器框架,许多同步类实现都依赖于它,如常用的ReentrantLock/Semaphore/CountDownLatch…。
它维护了一个volatile int state(代表共享资源)和一个FIFO线程等待队列(多线程争用资源被阻塞时会进入此队列)。这里volatile是核心关键词,具体volatile的语义,在此不述。state的访问方式有三种:
getState()
setState()
compareAndSetState()
AQS定义两种资源共享方式:Exclusive(独占,只有一个线程能执行,如ReentrantLock)和Share(共享,多个线程可同时执行,如Semaphore/CountDownLatch)。
不同的自定义同步器争用共享资源的方式也不同。自定义同步器在实现时只需要实现共享资源state的获取与释放方式即可 ,至于具体线程等待队列的维护(如获取资源失败入队/唤醒出队等),AQS已经在顶层实现好了。自定义同步器实现时主要实现以下几种方法:
isHeldExclusively():该线程是否正在独占资源。只有用到condition才需要去实现它。
tryAcquire(int):独占方式。尝试获取资源,成功则返回true,失败则返回false。
tryRelease(int):独占方式。尝试释放资源,成功则返回true,失败则返回false。
tryAcquireShared(int):共享方式。尝试获取资源。负数表示失败;0表示成功,但没有剩余可用资源;正数表示成功,且有剩余资源。
tryReleaseShared(int):共享方式。尝试释放资源,如果释放后允许唤醒后续等待结点返回true,否则返回false。
以ReentrantLock为例,state初始化为0,表示未锁定状态。A线程lock()时,会调用tryAcquire()独占该锁并将state+1。此后,其他线程再tryAcquire()时就会失败,直到A线程unlock()到state=0(即释放锁)为止,其它线程才有机会获取该锁。当然,释放锁之前,A线程自己是可以重复获取此锁的(state会累加),这就是可重入的概念。但要注意,获取多少次就要释放多么次,这样才能保证state是能回到零态的。
再以CountDownLatch以例,任务分为N个子线程去执行,state也初始化为N(注意N要与线程个数一致)。这N个子线程是并行执行的,每个子线程执行完后countDown()一次,state会CAS减1。等到所有子线程都执行完后(即state=0),会unpark()主调用线程,然后主调用线程就会从await()函数返回,继续后余动作。
一般来说,自定义同步器要么是独占方法,要么是共享方式,他们也只需实现tryAcquire-tryRelease、tryAcquireShared-tryReleaseShared中的一种即可。但AQS也支持自定义同步器同时实现独占和共享两种方式,如ReentrantReadWriteLock。
二、阅读ReentrantLock.lock源码 2.1 ReentrantLock.lock方法
1、lock方法
1 2 3 public void lock () { sync.acquire(1 ); }
sync 是Reentrant的一个属性,在代码中我们可以找到他的实例是对象是NonfairSync对象,他的类继承图是:可以发现最顶层的类实际上就是 AQS类。
2、acquire方法
实际上调用的是 AQS的acquire方法,因为NonfairSync对象没有重写该方法。该方法就是获取锁。
1 2 3 4 5 public final void acquire (int arg) { if (!tryAcquire(arg) && acquireQueued(addWaiter(Node.EXCLUSIVE), arg)) selfInterrupt(); }
该方法内先调用tryAcquire方法尝试获取锁。所以我们先看下tryAcquire方法,如果我们直接点进去会发先该方法直接抛出一个异常,那么显然这是错误的,实际上在运行过程调用的是NonfairSync的该方法,多态,也是设计模式中的模板方法。
3、NonfairSync.tryAcquire方法
1 2 3 protected final boolean tryAcquire (int acquires) { return nonfairTryAcquire(acquires); }
直接调用了另一方法,nonfairTryAcquire 。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 final boolean nonfairTryAcquire (int acquires) { final Thread current = Thread.currentThread(); int c = getState(); if (c == 0 ) { if (compareAndSetState(0 , acquires)) { setExclusiveOwnerThread(current); return true ; } } else if (current == getExclusiveOwnerThread()) { int nextc = c + acquires; if (nextc < 0 ) throw new Error ("Maximum lock count exceeded" ); setState(nextc); return true ; } return false ; }
如果没有获取到锁 那么第二步的if语句没有短路,则执行acquireQueued(addWaiter(Node.EXCLUSIVE), arg) 方法,其中Node.EXCLUSIVE指示节点正在以独占模式等待的标记。
4、AQS.addWaiter方法
为当前线程和给定模式创建并排队节点。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 private Node addWaiter (Node mode) { Node node = new Node (mode); for (;;) { Node oldTail = tail; if (oldTail != null ) { node.setPrevRelaxed(oldTail); if (compareAndSetTail(oldTail, node)) { oldTail.next = node; return node; } } else { initializeSyncQueue(); } } }
5、acquireQueued 方法
以独占不间断模式获取已在队列中的线程。用于条件等待方法以及获取。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 final boolean acquireQueued (final Node node, int arg) { boolean interrupted = false ; try { for (;;) { final Node p = node.predecessor(); if (p == head && tryAcquire(arg)) { setHead(node); p.next = null ; return interrupted; } if (shouldParkAfterFailedAcquire(p, node)) interrupted |= parkAndCheckInterrupt(); } } catch (Throwable t) { cancelAcquire(node); if (interrupted) selfInterrupt(); throw t; } }
2.2 ReentrantLock.unlock方法
1、unlock方法
1 2 3 public void unlock () { sync.release(1 ); }
与 lock方法一致,也是调用的 NonfairSync的方法,但是NonfairSync没有重写该方法所以调用的是AQS 的release方法
2、AQS.release 释放锁
1 2 3 4 5 6 7 8 9 10 11 12 13 public final boolean release (int arg) { if (tryRelease(arg)) { Node h = head; if (h != null && h.waitStatus != 0 ) unparkSuccessor(h); return true ; } return false ; }
这里实际是NonfairSync.tryRelease
3、NonfairSync.tryRelease
尝试释放锁,成功释放且state为0(锁重入),返回tree 反之false
这里 releases是 1 表示独占锁。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 protected final boolean tryRelease (int releases) { int c = getState() - releases; if (Thread.currentThread() != getExclusiveOwnerThread()) throw new IllegalMonitorStateException (); boolean free = false ; if (c == 0 ) { free = true ; setExclusiveOwnerThread(null ); } setState(c); return free; }
三、利用Aqs实现自己的线程工具类 3.1 CountDownLatchPlus 计数器 该工具类是 jdk的CountDownLatch加强版,作用是当指定方法被调用指定次数后将唤醒另一个阻塞的线程。
CountDownLatchPlus在CountDownLatch的基础上提供了两个方法,一个是resetCount、一个是countUp。
resetCount : 重置计数值为初始值
countUp:计数值加一
源码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 public class CountDownLatchPlus { private final Sync sync; private final int initCount; public CountDownLatchPlus (int count) { if (count < 1 ){ throw new RuntimeException ("count应该大于0" ); } initCount = count; this .sync = new Sync (count); } public void await () { try { sync.acquireSharedInterruptibly(1 ); }catch (InterruptedException e){ throw new RuntimeException (e); } } public boolean await (long timeout, TimeUnit unit) { try { return sync.tryAcquireSharedNanos(1 , unit.toNanos(timeout)); }catch (InterruptedException e){ throw new RuntimeException (e); } } public void countDown () { sync.releaseShared(1 ); } public void countDownToEnd () { sync.releaseShared(initCount); } public void countUp () { sync.releaseShared(-1 ); } public int getCount () { return sync.getCount(); } public void resetCount () { sync.resetCount(initCount); } private static final class Sync extends AbstractQueuedSynchronizer { private static final long serialVersionUID = 4982264981998014374L ; public Sync (int count) { setState(count); } int getCount () { return getState(); } @Override protected int tryAcquireShared (int acquires) { return (getState() == 0 ) ? 1 : -1 ; } @Override protected boolean tryReleaseShared (int releases) { for (;;) { int c = getState(); if (c == 0 ){ return false ; } int nextc = c - releases; if (compareAndSetState(c, nextc)){ return nextc == 0 ; } } } void resetCount (int count) { while (true ){ int state = getState(); if (compareAndSetState(state,count)){ return ; } } } } }
3.2 BlockGetter 阻塞获取值 该工具类是一个消费者模式,一个线程提供值,一个线程获取值。
获取值得线程会导致阻塞,需要另外一个线程调用set方法唤醒另一个阻塞线程。
注意:不能提前set,一次set只能一次get。
源码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 @Slf4j public class BlockGetter <T> { public BlockGetter () { this .sync = new Sync (); } private final Sync sync; private T data; private Optional<T> getData () { if (data == null ){ return Optional.empty(); } Optional<T> res = Optional.of(data); this .data = null ; return res; } public Optional<T> get () { sync.resetState(); sync.acquire(1 ); sync.resetState(); return getData(); } public Optional<T> get (long timeout, TimeUnit unit) { sync.resetState(); try { if (sync.tryAcquireNanos(1 , unit.toNanos(timeout))){ sync.resetState(); return getData(); }else { return Optional.empty(); } }catch (InterruptedException e){ throw new RuntimeException (e); } } public boolean set (T data) { if (sync.tryAcquire(1 )){ return false ; } this .data = data; sync.release(1 ); return true ; } private static final class Sync extends AbstractQueuedSynchronizer { private static final long serialVersionUID = 4982264981996014374L ; public Sync () { setState(1 ); } @Override protected boolean tryAcquire (int arg) { return getState() == 0 ; } @Override protected boolean tryRelease (int releases) { for (;;) { int c = getState(); if (c == 0 ){ return false ; } if (compareAndSetState(c, 0 )){ return true ; } } } void resetState () { setState(1 ); } } }