Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend RWMutex Interface for locker package #1135

Merged
merged 9 commits into from
Feb 11, 2025
71 changes: 63 additions & 8 deletions pkg/locker/locker.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ type Locker struct {

// lockCtr is used by Locker to represent a lock with a given name.
type lockCtr struct {
mu sync.Mutex
mu sync.RWMutex
// waiters is the number of waiters waiting to acquire the lock
// this is int32 instead of uint32 so we can add `-1` in `dec()`
waiters int32
Expand All @@ -70,7 +70,7 @@ func (l *lockCtr) count() int32 {
return atomic.LoadInt32(&l.waiters)
}

// Lock locks the mutex
// Lock locks the mutex for writing
func (l *lockCtr) Lock() {
l.mu.Lock()
}
Expand All @@ -80,11 +80,21 @@ func (l *lockCtr) TryLock() bool {
return l.mu.TryLock()
}

// Unlock unlocks the mutex
// Unlock unlocks the mutex for writing
func (l *lockCtr) Unlock() {
l.mu.Unlock()
}

// RLock locks the mutex for reading
func (l *lockCtr) RLock() {
l.mu.RLock()
}

// RUnlock unlocks the mutex for reading
func (l *lockCtr) RUnlock() {
l.mu.RUnlock()
}

// New creates a new Locker
func New() *Locker {
return &Locker{
Expand All @@ -111,9 +121,7 @@ func (l *Locker) Lock(name string) {
l.mu.Unlock()

// Lock the nameLock outside the main mutex so we don't block other operations
// once locked then we can decrement the number of waiters for this lock
nameLock.Lock()
nameLock.dec()
}

// TryLock locks a mutex with the given name. If it doesn't exist, one is created.
Expand All @@ -135,9 +143,7 @@ func (l *Locker) TryLock(name string) bool {
l.mu.Unlock()

// Lock the nameLock outside the main mutex so we don't block other operations
// once locked then we can decrement the number of waiters for this lock
succeeded := nameLock.TryLock()
nameLock.dec()

return succeeded
}
Expand All @@ -152,10 +158,59 @@ func (l *Locker) Unlock(name string) error {
return ErrNoSuchLock
}

nameLock.Unlock()
// Decrement waiters here to ensure the lock isn't deleted prematurely
// while another goroutine might still be using it.
nameLock.dec()

if nameLock.count() == 0 {
delete(l.locks, name)
}

l.mu.Unlock()
return nil
}

// RLock acquires a read lock for the given name.
// If there is no lock for that name, a new one is created.
func (l *Locker) RLock(name string) {
l.mu.Lock()
if l.locks == nil {
l.locks = make(map[string]*lockCtr)
}

nameLock, exists := l.locks[name]
if !exists {
nameLock = &lockCtr{}
l.locks[name] = nameLock
}

// increment the nameLock waiters while inside the main mutex
// this makes sure that the lock isn't deleted if `Lock` and `Unlock` are called concurrently
nameLock.inc()
l.mu.Unlock()

// Lock the nameLock outside the main mutex so we don't block other operations
nameLock.RLock()
}

// RUnlock releases a read lock for the given name.
func (l *Locker) RUnlock(name string) error {
l.mu.Lock()
nameLock, exists := l.locks[name]
if !exists {
l.mu.Unlock()
return ErrNoSuchLock
}

nameLock.RUnlock()
// Decrement waiters here to ensure the lock isn't deleted prematurely
// while another goroutine might still be using it.
nameLock.dec()

if nameLock.count() == 0 {
delete(l.locks, name)
}
nameLock.Unlock()

l.mu.Unlock()
return nil
Expand Down
156 changes: 153 additions & 3 deletions pkg/locker/locker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ func TestLockerLock(t *testing.T) {
l.Lock("test")
ctr := l.locks["test"]

if ctr.count() != 0 {
t.Fatalf("expected waiters to be 0, got :%d", ctr.waiters)
if ctr.count() != 1 {
t.Fatalf("expected waiters to be 1, got :%d", ctr.waiters)
}

chDone := make(chan struct{})
Expand All @@ -59,7 +59,7 @@ func TestLockerLock(t *testing.T) {
chWaiting := make(chan struct{})
go func() {
for range time.Tick(1 * time.Millisecond) {
if ctr.count() == 1 {
if ctr.count() == 2 {
close(chWaiting)
break
}
Expand Down Expand Up @@ -88,6 +88,10 @@ func TestLockerLock(t *testing.T) {
t.Fatalf("lock should have completed")
}

if err := l.Unlock("test"); err != nil {
t.Fatal(err)
}

if ctr.count() != 0 {
t.Fatalf("expected waiters to be 0, got: %d", ctr.count())
}
Expand Down Expand Up @@ -165,3 +169,149 @@ func TestTryLock(t *testing.T) {
}
}
}

func TestRWLockerRLock(t *testing.T) {
l := New()
l.RLock("test")
ctr := l.locks["test"]

if ctr.count() != 1 {
t.Fatalf("expected waiters to be 1, got :%d", ctr.waiters)
}

chDone := make(chan struct{})
go func() {
l.RLock("test")
close(chDone)
}()

select {
case <-chDone:
case <-time.After(3 * time.Second):
t.Fatalf("lock should have completed")
}

if err := l.RUnlock("test"); err != nil {
t.Fatal(err)
}

if ctr.count() != 1 {
t.Fatalf("expected waiters to be 1, got: %d", ctr.count())
}

if _, exists := l.locks["test"]; !exists {
t.Fatal("expected lock not to be deleted")
}

if err := l.RUnlock("test"); err != nil {
t.Fatal(err)
}

if ctr.count() != 0 {
t.Fatalf("expected waiters to be 0, got: %d", ctr.count())
}

if _, exists := l.locks["test"]; exists {
t.Fatal("expected lock to be deleted")
}
}

func TestLockRLock(t *testing.T) {
l := New()

// RLock after Lock
l.RLock("test")

chDone := make(chan struct{})
go func() {
l.Lock("test")
close(chDone)
}()

select {
case <-chDone:
t.Fatal("lock should not have returned while it was still held")
default:
}

if err := l.RUnlock("test"); err != nil {
t.Fatal(err)
}

select {
case <-chDone:
case <-time.After(3 * time.Second):
t.Fatalf("lock should have completed")
}

if err := l.Unlock("test"); err != nil {
t.Fatal(err)
}

// Lock after RLock
l.Lock("test")

chDone = make(chan struct{})
go func() {
l.RLock("test")
close(chDone)
}()

select {
case <-chDone:
t.Fatal("lock should not have returned while it was still held")
default:
}

if err := l.Unlock("test"); err != nil {
t.Fatal(err)
}

select {
case <-chDone:
case <-time.After(3 * time.Second):
t.Fatalf("lock should have completed")
}

if err := l.RUnlock("test"); err != nil {
t.Fatal(err)
}
}

func TestRWLockerConcurrency(t *testing.T) {
l := New()

var wg sync.WaitGroup
for i := 0; i <= 1000; i++ {
wg.Add(1)
go func(i int) {
if i%2 == 0 {
l.Lock("test")
// if there is a concurrency issue, will very likely panic here
assert.NoError(t, l.Unlock("test"))
} else {
l.RLock("test")
// if there is a concurrency issue, will very likely panic here
assert.NoError(t, l.RUnlock("test"))
}
wg.Done()
}(i)
}

chDone := make(chan struct{})
go func() {
wg.Wait()
close(chDone)
}()

select {
case <-chDone:
case <-time.After(10 * time.Second):
t.Fatal("timeout waiting for locks to complete")
}

// Since everything has unlocked this should not exist anymore
if ctr, exists := l.locks["test"]; exists {
t.Fatalf("lock should not exist: %v", ctr)
}
}
22 changes: 22 additions & 0 deletions server/backend/sync/locker.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ type Locker interface {

// Unlock unlocks the mutex.
Unlock(ctx context.Context) error

// RLock acquires a read lock with a cancelable context.
RLock(ctx context.Context) error

// RUnlock releases a read lock previously acquired by RLock.
RUnlock(ctx context.Context) error
}

type internalLocker struct {
Expand Down Expand Up @@ -104,3 +110,19 @@ func (il *internalLocker) Unlock(_ context.Context) error {

return nil
}

// RLock locks the mutex for reading..
func (il *internalLocker) RLock(_ context.Context) error {
il.locks.RLock(il.key)

return nil
}

// RUnlock unlocks the read lock.
func (il *internalLocker) RUnlock(_ context.Context) error {
if err := il.locks.RUnlock(il.key); err != nil {
return err
}

return nil
}
2 changes: 2 additions & 0 deletions server/rpc/yorkie_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ func (s *yorkieServer) PushPullChanges(
}

project := projects.From(ctx)

if pack.HasChanges() {
locker, err := s.backend.Locker.NewLocker(
ctx,
Expand Down Expand Up @@ -519,6 +520,7 @@ func (s *yorkieServer) RemoveDocument(
}

project := projects.From(ctx)

if pack.HasChanges() {
locker, err := s.backend.Locker.NewLocker(ctx, packs.PushPullKey(project.ID, pack.DocumentKey))
if err != nil {
Expand Down
28 changes: 28 additions & 0 deletions test/bench/locker_bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
package bench

import (
"fmt"
"math/rand"
"strconv"
"testing"
Expand Down Expand Up @@ -65,3 +66,30 @@ func BenchmarkLockerMoreKeys(b *testing.B) {
}
})
}

func BenchmarkRWLocker(b *testing.B) {
b.SetParallelism(128)

rates := []int{2, 10, 100, 1000}
for _, rate := range rates {
b.Run(fmt.Sprintf("RWLock rate %d", rate), func(b *testing.B) {
benchmarkRWLockerParallel(rate, b)
})
}
}

func benchmarkRWLockerParallel(rate int, b *testing.B) {
l := locker.New()

b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
if rand.Intn(rate) == 0 {
l.Lock("test")
assert.NoError(b, l.Unlock("test"))
} else {
l.RLock("test")
assert.NoError(b, l.RUnlock("test"))
}
}
})
}
Loading