多线程AQS (AbstractQueuedSynchronizer)

一、介绍

类如其名,抽象的队列式的同步器,AQS定义了一套多线程访问共享资源的同步器框架,许多同步类实现都依赖于它,如常用的ReentrantLock/Semaphore/CountDownLatch…。

image-20221202151828359

它维护了一个volatile int state(代表共享资源)和一个FIFO线程等待队列(多线程争用资源被阻塞时会进入此队列)。这里volatile是核心关键词,具体volatile的语义,在此不述。state的访问方式有三种:

  • getState()
  • setState()
  • compareAndSetState()

  AQS定义两种资源共享方式:Exclusive(独占,只有一个线程能执行,如ReentrantLock)和Share(共享,多个线程可同时执行,如Semaphore/CountDownLatch)。

  不同的自定义同步器争用共享资源的方式也不同。自定义同步器在实现时只需要实现共享资源state的获取与释放方式即可,至于具体线程等待队列的维护(如获取资源失败入队/唤醒出队等),AQS已经在顶层实现好了。自定义同步器实现时主要实现以下几种方法:

  • isHeldExclusively():该线程是否正在独占资源。只有用到condition才需要去实现它。
  • tryAcquire(int):独占方式。尝试获取资源,成功则返回true,失败则返回false。
  • tryRelease(int):独占方式。尝试释放资源,成功则返回true,失败则返回false。
  • tryAcquireShared(int):共享方式。尝试获取资源。负数表示失败;0表示成功,但没有剩余可用资源;正数表示成功,且有剩余资源。
  • tryReleaseShared(int):共享方式。尝试释放资源,如果释放后允许唤醒后续等待结点返回true,否则返回false。

  以ReentrantLock为例,state初始化为0,表示未锁定状态。A线程lock()时,会调用tryAcquire()独占该锁并将state+1。此后,其他线程再tryAcquire()时就会失败,直到A线程unlock()到state=0(即释放锁)为止,其它线程才有机会获取该锁。当然,释放锁之前,A线程自己是可以重复获取此锁的(state会累加),这就是可重入的概念。但要注意,获取多少次就要释放多么次,这样才能保证state是能回到零态的。

  再以CountDownLatch以例,任务分为N个子线程去执行,state也初始化为N(注意N要与线程个数一致)。这N个子线程是并行执行的,每个子线程执行完后countDown()一次,state会CAS减1。等到所有子线程都执行完后(即state=0),会unpark()主调用线程,然后主调用线程就会从await()函数返回,继续后余动作。

  一般来说,自定义同步器要么是独占方法,要么是共享方式,他们也只需实现tryAcquire-tryRelease、tryAcquireShared-tryReleaseShared中的一种即可。但AQS也支持自定义同步器同时实现独占和共享两种方式,如ReentrantReadWriteLock。

二、阅读ReentrantLock.lock源码

2.1 ReentrantLock.lock方法

image-20221202152239094

1、lock方法

1
2
3
public void lock() {    
sync.acquire(1);
}

sync 是Reentrant的一个属性,在代码中我们可以找到他的实例是对象是NonfairSync对象,他的类继承图是:可以发现最顶层的类实际上就是 AQS类。

image-20221202152424022

2、acquire方法

实际上调用的是 AQS的acquire方法,因为NonfairSync对象没有重写该方法。该方法就是获取锁。

1
2
3
4
5
public final void acquire(int arg) {
if (!tryAcquire(arg) &&
acquireQueued(addWaiter(Node.EXCLUSIVE), arg))
selfInterrupt();
}

该方法内先调用tryAcquire方法尝试获取锁。所以我们先看下tryAcquire方法,如果我们直接点进去会发先该方法直接抛出一个异常,那么显然这是错误的,实际上在运行过程调用的是NonfairSync的该方法,多态,也是设计模式中的模板方法。

3、NonfairSync.tryAcquire方法

1
2
3
protected final boolean tryAcquire(int acquires) {
return nonfairTryAcquire(acquires);
}

直接调用了另一方法,nonfairTryAcquire 。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
final boolean nonfairTryAcquire(int acquires) {
// 获取当前线程
final Thread current = Thread.currentThread();
// 获取状态,调用的是AQS的getState。
int c = getState();
// 该锁有没有被占用
if (c == 0) {
//尝试cas该值
if (compareAndSetState(0, acquires)) {
// 成功则该线程占有锁
setExclusiveOwnerThread(current);
return true;
}
}
// 判断占有锁的是不是当前线程
else if (current == getExclusiveOwnerThread()) {
// 是的话则是锁重入。state需要加1
int nextc = c + acquires;
if (nextc < 0) // overflow
throw new Error("Maximum lock count exceeded");
// 设置同步状态的值
setState(nextc);
return true;
}
//没有获取到锁。
return false;
}

如果没有获取到锁 那么第二步的if语句没有短路,则执行acquireQueued(addWaiter(Node.EXCLUSIVE), arg) 方法,其中Node.EXCLUSIVE指示节点正在以独占模式等待的标记。

4、AQS.addWaiter方法

为当前线程和给定模式创建并排队节点。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
private Node addWaiter(Node mode) {
//等待锁队列的一个节点
Node node = new Node(mode);
// 一直加入
for (;;) {
// 获取最后一个节点
Node oldTail = tail;
//最后节点是不是空
if (oldTail != null) {
//不是空则加到结束节点的后面。
node.setPrevRelaxed(oldTail);
// cas设置新的最后节点
if (compareAndSetTail(oldTail, node)) {
oldTail.next = node;
return node;
}
} else {
//是空则初始化队列,会将一个new Node()设置成最后节点。
initializeSyncQueue();
}
}
}

5、acquireQueued 方法

以独占不间断模式获取已在队列中的线程。用于条件等待方法以及获取。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
final boolean acquireQueued(final Node node, int arg) {
//用于结束循环
boolean interrupted = false;
try {
for (;;) {
// 获取上一个节点
final Node p = node.predecessor();
// 如果p是头节点则说明下一个获取锁的是当前线程,tryAcquire 尝试获取锁
if (p == head && tryAcquire(arg)) {
// 获取到锁之后,设置头为当前节点。
setHead(node);
// 释放原头节点
p.next = null; // help GC
// 返回
return interrupted;
}
//检查并更新无法获取的节点的状态。 如果线程应阻塞,则返回true。
//这是所有采集循环中的主要信号控制。要求pred == node.prev。
// 主要就是为了达到阻塞效果。
if (shouldParkAfterFailedAcquire(p, node))
interrupted |= parkAndCheckInterrupt();
}
} catch (Throwable t) {
cancelAcquire(node);
if (interrupted)
selfInterrupt();
throw t;
}
}

2.2 ReentrantLock.unlock方法

image-20221202152707524

1、unlock方法

1
2
3
public void unlock() {
sync.release(1);
}

与 lock方法一致,也是调用的 NonfairSync的方法,但是NonfairSync没有重写该方法所以调用的是AQS 的release方法

2、AQS.release 释放锁

1
2
3
4
5
6
7
8
9
10
11
12
13
public final boolean release(int arg) {
//调用tryRelease
if (tryRelease(arg)) {
// 释放成功获取头节点
Node h = head;
if (h != null && h.waitStatus != 0)
// 头结点不为null 且 头结点状态不为0。
//唤醒节点的后继者(如果存在)。
unparkSuccessor(h);
return true;
}
return false;
}

这里实际是NonfairSync.tryRelease

3、NonfairSync.tryRelease

尝试释放锁,成功释放且state为0(锁重入),返回tree 反之false

这里 releases是 1 表示独占锁。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
protected final boolean tryRelease(int releases) {
// 同步状态 - releases
int c = getState() - releases;
// 如果不是当前线程不是执行获得锁的线程。抛出异常
if (Thread.currentThread() != getExclusiveOwnerThread())
throw new IllegalMonitorStateException();
//返回结果
boolean free = false;
// 如果c == 0 说明不是锁重入 释放成功。
if (c == 0) {
free = true;
// 设置占用锁的线程为空。
setExclusiveOwnerThread(null);
}
// 更新状态
setState(c);
return free;
}

三、利用Aqs实现自己的线程工具类

3.1 CountDownLatchPlus 计数器

该工具类是 jdk的CountDownLatch加强版,作用是当指定方法被调用指定次数后将唤醒另一个阻塞的线程。

CountDownLatchPlus在CountDownLatch的基础上提供了两个方法,一个是resetCount、一个是countUp。

  • resetCount : 重置计数值为初始值
  • countUp:计数值加一

源码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
/**
* 模仿jdk的{@link java.util.concurrent.CountDownLatch} , 对其进行了加强。
* 本类可以重复使用,不仅可以-1,也可以加一。提供重置方法。
* @author xia17
* @date 2022/12/2
*/
public class CountDownLatchPlus {


/***
* AbstractQueuedSynchronizer
*/
private final Sync sync;

/** 初始count */
private final int initCount;

public CountDownLatchPlus(int count) {
if (count < 1){
throw new RuntimeException("count应该大于0");
}
initCount = count;
this.sync = new Sync(count);
}

/**
* 等待
*/
public void await(){
try {
sync.acquireSharedInterruptibly(1);
}catch (InterruptedException e){
throw new RuntimeException(e);
}
}

/**
* 等待
* @param timeout 超时时间
* @param unit 单位
* @return 是否成功
*/
public boolean await(long timeout, TimeUnit unit) {
try {
return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
}catch (InterruptedException e){
throw new RuntimeException(e);
}
}

/**
* 计数减一
*/
public void countDown(){
sync.releaseShared(1);
}

/**
* 计数直接减到结束
*/
public void countDownToEnd(){
sync.releaseShared(initCount);
}

/**
* 计数加一
*/
public void countUp(){
sync.releaseShared(-1);
}

/**
* 返回count
* @return /
*/
public int getCount(){
return sync.getCount();
}

/**
* 重置count
*/
public void resetCount(){
sync.resetCount(initCount);
}


private static final class Sync extends AbstractQueuedSynchronizer{
private static final long serialVersionUID = 4982264981998014374L;

public Sync(int count) {
setState(count);
}

int getCount(){
return getState();
}

/**
* 尝试获取锁
* @param acquires the acquire argument. This value is always the one
* passed to an acquire method, or is the value saved on entry
* to a condition wait. The value is otherwise uninterpreted
* and can represent anything you like.
*/
@Override
protected int tryAcquireShared(int acquires) {
return (getState() == 0) ? 1 : -1;
}

/**
* 释放锁
* @param releases the release argument. This value is always the one
* passed to a release method, or the current state value upon
* entry to a condition wait. The value is otherwise
* uninterpreted and can represent anything you like.
* @return 返回true将会释放锁
*/
@Override
protected boolean tryReleaseShared(int releases) {
// Decrement count; signal when transition to zero
for (;;) {
int c = getState();
if (c == 0){
return false;
}
int nextc = c - releases;
if (compareAndSetState(c, nextc)){
return nextc == 0;
}
}
}

/**
* 重置 count
* @param count /
*/
void resetCount(int count){
while (true){
int state = getState();
if (compareAndSetState(state,count)){
return;
}
}
}

}


}

3.2 BlockGetter 阻塞获取值

该工具类是一个消费者模式,一个线程提供值,一个线程获取值。

获取值得线程会导致阻塞,需要另外一个线程调用set方法唤醒另一个阻塞线程。

注意:不能提前set,一次set只能一次get。

源码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
/**
* 阻塞式获取值,每次Get都需要调用一次set.
* <p>get():每次成功获取到值后,将会把锁的状态重置未锁住</p>
* <p>set():如果当前锁是没有锁住的状态,将会返回false,即设置失败。</p>
* @author xia17
* @date 2022/12/3
*/
@Slf4j
public class BlockGetter<T> {


public BlockGetter() {
this.sync = new Sync();
}

private final Sync sync;

private T data;


private Optional<T> getData(){
if (data == null){
return Optional.empty();
}
Optional<T> res = Optional.of(data);
this.data = null;
return res;
}


/**
* 获取值
* @return /
*/
public Optional<T> get() {
// 将锁的状态还原至锁住状态
sync.resetState();
// 获取锁
sync.acquire(1);
// 将锁的状态还原至锁住状态
sync.resetState();
return getData();
}

/**
* 获取值
* @param timeout 超时时间
* @param unit 单位
* @return /
*/
public Optional<T> get(long timeout, TimeUnit unit) {
// 将锁的状态还原至锁住状态
sync.resetState();
try {
if (sync.tryAcquireNanos(1, unit.toNanos(timeout))){
// 将锁的状态还原至锁住状态
sync.resetState();
return getData();
}else {
return Optional.empty();
}
}catch (InterruptedException e){
throw new RuntimeException(e);
}
}


/**
* 设置值
* 如果锁是锁住状态将返回false。
* @param data 数据
* @return /
*/
public boolean set(T data){
// 尝试获取锁
if (sync.tryAcquire(1)){
// 可以拿到锁不设置值
return false;
}
this.data = data;
// 释放锁
sync.release(1);
return true;
}


/**
* 锁工具
* state取值 0是锁住状态 1是未锁住
*/
private static final class Sync extends AbstractQueuedSynchronizer {
private static final long serialVersionUID = 4982264981996014374L;

public Sync() {
setState(1);
}


@Override
protected boolean tryAcquire(int arg) {
return getState() == 0;
}



/**
* 释放锁
* @param releases /
* @return 返回true将会释放锁
*/
@Override
protected boolean tryRelease(int releases) {
// Decrement count; signal when transition to zero
for (;;) {
int c = getState();
if (c == 0){
return false;
}
if (compareAndSetState(c, 0)){
return true;
}
}
}


/**
* 锁住
*/
void resetState(){
setState(1);
}

}




}