vlambda博客
学习文章列表

独占可重入锁ReentrantLock 的原理

独占可重入锁ReentrantLock 的原理

1. 类图结构

ReentrantLock 是可重入的独占锁, 同时只能有一个线程可以获取该锁,其他获取该锁的线程会被阻塞而被放入该锁的AQS阻塞队列里面。

ReentrantLock中有一个抽象静态内部类Sync继承自AQS,使用AQS的state代表锁的拥有次数(可重入锁)。ReentrantLock默认构造函数是非公平锁

 public ReentrantLock() {
     sync = new NonfairSync();
 }
 public ReentrantLock(boolean fair) {
     sync = fair ? new FairSync() : new NonfairSync();
 }

Sync的子类NonfairSync 和FairSync 分别实现了获取锁的非公平与公平策略.

在默认情况下state的值为0表示当前锁没有被任何线程持有。当一个线程第一次开始获取锁时会尝试使用CAS设置state的值为1,如果CAS设置成功则当前线程获取了锁,然后记录该锁的持有者为当前线程。在该线程没有释放锁的情况下第二次获取锁后,状态值设置为2,这就是可重入次数,在该线程释放锁时,使用CAS设置状态减1,直到状态值为0,则该线程释放了锁。

2 获取锁

2.1 void lock()

当一个线程调用该方法时,如果锁当前没有被其他线程占用并且当前线程之前没有获取过该锁,则当前线程会获取到该锁,然后设置当前锁的拥有者为当前线程,并设置AQS的state值为1,然后直接返回。如果当前线程已经获取了锁,则这次只是简单的把AQS的state值加1后返回。如果该锁已经被其他线程持有,则调用该方法的线程会被放入 AQS队列后阻塞挂起。

 public void lock() {
     sync.lock();
 }

ReentrantLock的lock()委托给sync类,根据创建ReentrantLock构造函数选择sync的实现是NonfairSync还是FairSync,这个锁是非公平锁还是公平锁。

2.1.1 非公平锁

 final void lock() {
     //1 CAS设置状态为1
     if (compareAndSetState(0, 1))
         setExclusiveOwnerThread(Thread.currentThread());
     else
         //2 调用AQS的acquire
         acquire(1);
 }

代码1 因为默认AQS的状态值为0,所以第一个调用Lock的线程会通过CAS设置状态值为1.CAS成功则表示当前线程获取到了锁,然后setExclusiveOwnerThread设置该锁持有者是当前线程。如果这时候有其他线程调用lock方法企图获取该锁,CAS会失败,然后调用AQS的acquire方法参数是1

AQS的acquire:

 public final void acquire(int arg) {
     //AQS并没有提供可用的tryAcquire,都是其子类实现
     if (!tryAcquire(arg) &&
         acquireQueued(addWaiter(Node.EXCLUSIVE), arg))
         selfInterrupt();
 }

我们直接看非公平锁的实现

 protected final boolean tryAcquire(int acquires) {
     return nonfairTryAcquire(acquires);
 }
 final boolean nonfairTryAcquire(int acquires) {
     final Thread current = Thread.currentThread();
     //先获取当前AQS状态值
     int c = getState();
     //如果是0,代表锁空闲,尝试CAS获取该锁
     if (c == 0) {
         if (compareAndSetState(0, acquires)) {
             setExclusiveOwnerThread(current);
             return true;
        }
    }//如果不是0,说明该锁已经被某个线程持有,查看当前线程是不是锁的持有者
     else if (current == getExclusiveOwnerThread()) {
         //如果是状态值+1
         int nextc = c + acquires;
         if (nextc < 0) // overflow
             throw new Error("Maximum lock count exceeded");
         setState(nextc);
         return true;
    }
     //当前AQS状态值既不是0,也不是当前线程持有者,返回false,加入AQS队列。
     return false;
 }

怎么不公平了呢?看下图:

独占可重入锁ReentrantLock 的原理

1 当线程1持有锁,

2 线程1释放锁的时候,会唤醒线程1的next节点线程2,进行lock操作。

3 此时线程4也执行lock—>tryAcquire->nonfairTryAcquire:发现当前AQS的状态值为0了,所以通过CAS设置获取到了该锁。

大家都在排队,明明应该是线程2改工作了,现在却被没有排队的线程4抢到了锁,不公平。

2.1.2 公平锁

公平锁的实现代码:

 protected final boolean tryAcquire(int acquires) {
     final Thread current = Thread.currentThread();
     int c = getState();
     if (c == 0) {
         //与非公平锁不一样的是添加了hasQueuedPredecessors()
         if (!hasQueuedPredecessors() &&
             compareAndSetState(0, acquires)) {
             setExclusiveOwnerThread(current);
             return true;
        }
    }
     else if (current == getExclusiveOwnerThread()) {
         int nextc = c + acquires;
         if (nextc < 0)
             throw new Error("Maximum lock count exceeded");
         setState(nextc);
         return true;
    }
     return false;
 }
 
 public final boolean hasQueuedPredecessors() {
     Node t = tail; // Read fields in reverse initialization order
     Node h = head;
     Node s;
     return h != t &&
        ((s = h.next) == null || s.thread != Thread.currentThread());
 }

画个图理解一下:

如果h == t 代表当前AQS队列为空,返回false.

如果h != t 并且 s == null 则说明有一个元素将要作为AQS 的第一个节点入队列返回true.

如果h!=t 并且s!=null和s.thread != Thread.cunentThread()则说明队列里面的第一个元素不是当前线程,那么返回true 。

如果队列中有节点返回true,则!true = false 发生短路,则不能抢占锁,乖乖入队吧。


2.2 lockInterruptibly()

该方法与lock()方法不同的是,它对中断进行响应,就是当前线程在调用该方法时,如果其他线程调用了当前线程的interrupt()方法,则当前线程会抛出InterruptedException异常,然后返回。

 public final void acquireInterruptibly(int arg)
             throws InterruptedException {
     //如果当前线程被中断,则直接抛出异常
     if (Thread.interrupted())
         throw new InterruptedException();
     if (!tryAcquire(arg))
         //调用AQS可被中断的方法
         doAcquireInterruptibly(arg);
 }

2.3 tryLock()

尝试获取锁,如果当前该锁没有被其他线程持有,则当前线程获取该锁并返回true,否则返回false.此方法不会阻塞当前线程。

 public boolean tryLock() {
     return sync.nonfairTryAcquire(1);
 }
 
 final boolean nonfairTryAcquire(int acquires) {
     final Thread current = Thread.currentThread();
     int c = getState();
     if (c == 0) {
         if (compareAndSetState(0, acquires)) {
             setExclusiveOwnerThread(current);
             return true;
        }
    }
     else if (current == getExclusiveOwnerThread()) {
         int nextc = c + acquires;
         if (nextc < 0) // overflow
             throw new Error("Maximum lock count exceeded");
         setState(nextc);
         return true;
    }
     return false;
 }


这个与非公平锁调用的是同一个方法,也就是说tryLock是用的非公平策略。

2.4 tryLock(long timeout, TimeUnit unit)

尝试获取锁,与tryLock()不同的是,可以设置超时时间,如果超时时间到没有获取到锁则返回false.

 public boolean tryLock(long timeout, TimeUnit unit)
             throws InterruptedException {
     return sync.tryAcquireNanos(1, unit.toNanos(timeout));
 }

调用AQS的tryAcquireNanos

3 释放锁

3.1 void unlock()

如果当前线程持有该锁,则调用该方法会让该线程对该线程持有的AQS状态值减1,如果减去1后当前状态值为0,则当前线程会释放锁,否则仅仅是减去1而已。如果当前线程没有持有该锁而调用了该方法则会抛出IllegalMonitorStateException异常。

 public void unlock() {
     sync.release(1);
 }
 
 //AQS
 public final boolean release(int arg) {
     if (tryRelease(arg)) {
         Node h = head;
         if (h != null && h.waitStatus != 0)
             unparkSuccessor(h);
         return true;
    }
     return false;
 }
 
 protected final boolean tryRelease(int releases) {
     //获取AQS的state值减去1
     int c = getState() - releases;
     //当前线程不是锁持有者线程抛出异常
     if (Thread.currentThread() != getExclusiveOwnerThread())
         throw new IllegalMonitorStateException();
     boolean free = false;
     //如果当前锁的可重入次数是0,则标识释放成功并清掉锁持有线程
     if (c == 0) {
         free = true;
         setExclusiveOwnerThread(null);
    }
     setState(c);
     return free;
 }

我们看一下ArrayList的添加元素在多线程下的情况:

 ArrayList list = new ArrayList(10);
 Thread thread1 = new Thread(() -> {
     for (int i = 1; i <= 5; i++) {
         list.add("i=" + i);
 
    }
 });
 Thread thread2 = new Thread(() -> {
     for (int i = 6; i <= 10; i++) {
         list.add("i=" + i);
 
    }
 });
 thread1.start();
 thread2.start();
 
 try {
     TimeUnit.SECONDS.sleep(1L);
 } catch (InterruptedException e) {
     e.printStackTrace();
 }
 for (int i = 0; i < list.size(); i++) {
     System.out.println("第" + (i+1) + "个元素是: " + list.get(i));
 }

初始化ArrayList集合,线程1插入5个元素,线程2插入5个元素,最后打印集合:

 第1个元素是: i=1
 第2个元素是: i=7
 第3个元素是: i=2
 第4个元素是: i=8
 第5个元素是: i=3
 第6个元素是: i=9
 第7个元素是: i=10
 第8个元素是: i=4
 第9个元素是: i=5

只有9个元素,少了一个元素,这就是在线程1,线程2都是对数组0赋值的时候被覆盖了。

4 实现简单的线程安全的ArrayList

 public class ReentrantLockList {
     private ArrayList<String> arrayList = new ArrayList<>(10);
     private ReentrantLock lock = new ReentrantLock();
 
     /**
      * 添加元素
      */
     public void add(String e) {
         //加锁
         lock.lock();
         try {
             arrayList.add(e);
        } catch (Exception exception) {
             exception.printStackTrace();
        } finally {
             //释放锁
             lock.unlock();
        }
    }
     public String get(int index) {
         //加锁
         lock.lock();
         try {
             return arrayList.get(index);
        } catch (Exception exception) {
             exception.printStackTrace();
        } finally {
             //释放锁
             lock.unlock();
        }
         return null;
    }
     public static void main(String[] args) {
         ReentrantLockList lockList = new ReentrantLockList();
 
         Thread thread1 = new Thread(() -> {
             for (int i = 1; i <= 5; i++) {
                 lockList.add("i=" + i);
 
            }
        });
         Thread thread2 = new Thread(() -> {
             for (int i = 6; i <= 10; i++) {
                 lockList.add("i=" + i);
 
            }
        });
         thread1.start();
         thread2.start();
 
         try {
             TimeUnit.SECONDS.sleep(1L);
        } catch (InterruptedException e) {
             e.printStackTrace();
        }
         for (int i = 0; i < lockList.arrayList.size(); i++) {
             System.out.println("第" + (i+1) + "个元素是: " + lockList.get(i));
        }
    }
 }