diff --git a/src/main/java/io/lettuce/core/protocol/SharedLock.java b/src/main/java/io/lettuce/core/protocol/SharedLock.java index 13a9cb8cfe..c3ad425c16 100644 --- a/src/main/java/io/lettuce/core/protocol/SharedLock.java +++ b/src/main/java/io/lettuce/core/protocol/SharedLock.java @@ -26,6 +26,8 @@ class SharedLock { private final Lock lock = new ReentrantLock(); + private final ThreadLocal threadWriters = ThreadLocal.withInitial(() -> 0); + private volatile long writers = 0; private volatile Thread exclusiveLockOwner; @@ -45,6 +47,7 @@ void incrementWriters() { if (WRITERS.get(this) >= 0) { WRITERS.incrementAndGet(this); + threadWriters.set(threadWriters.get() + 1); return; } } @@ -63,6 +66,7 @@ void decrementWriters() { } WRITERS.decrementAndGet(this); + threadWriters.set(threadWriters.get() - 1); } /** @@ -121,7 +125,8 @@ private void lockWritersExclusive() { try { for (;;) { - if (WRITERS.compareAndSet(this, 0, -1)) { + // allow reentrant exclusive lock by comparing writers count and threadWriters count + if (WRITERS.compareAndSet(this, threadWriters.get(), -1)) { exclusiveLockOwner = Thread.currentThread(); return; } @@ -137,9 +142,13 @@ private void lockWritersExclusive() { private void unlockWritersExclusive() { if (exclusiveLockOwner == Thread.currentThread()) { - if (WRITERS.incrementAndGet(this) == 0) { + // check exclusive look not reentrant first + if (WRITERS.compareAndSet(this, -1, threadWriters.get())) { exclusiveLockOwner = null; + return; } + // otherwise unlock until no more reentrant left + WRITERS.incrementAndGet(this); } } diff --git a/src/test/java/io/lettuce/core/protocol/SharedLockTest.java b/src/test/java/io/lettuce/core/protocol/SharedLockTest.java new file mode 100644 index 0000000000..a10e712ae4 --- /dev/null +++ b/src/test/java/io/lettuce/core/protocol/SharedLockTest.java @@ -0,0 +1,57 @@ +package io.lettuce.core.protocol; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +public class SharedLockTest { + + @Test + public void safety_on_reentrant_lock_exclusive_on_writers() throws InterruptedException { + final SharedLock sharedLock = new SharedLock(); + CountDownLatch cnt = new CountDownLatch(1); + try { + sharedLock.incrementWriters(); + + String result = sharedLock.doExclusive(() -> { + return sharedLock.doExclusive(() -> { + return "ok"; + }); + }); + if ("ok".equals(result)) { + cnt.countDown(); + } + } finally { + sharedLock.decrementWriters(); + } + + boolean await = cnt.await(1, TimeUnit.SECONDS); + Assertions.assertTrue(await); + + // verify writers won't be negative after finally decrementWriters + String result = sharedLock.doExclusive(() -> { + return sharedLock.doExclusive(() -> { + return "ok"; + }); + }); + + Assertions.assertEquals("ok", result); + + // and other writers should be passed after exclusive lock released + CountDownLatch cntOtherThread = new CountDownLatch(1); + new Thread(() -> { + try { + sharedLock.incrementWriters(); + cntOtherThread.countDown(); + } finally { + sharedLock.decrementWriters(); + } + }).start(); + + await = cntOtherThread.await(1, TimeUnit.SECONDS); + Assertions.assertTrue(await); + } + +}