Skip to content

Commit

Permalink
Merge main into rw-lock
Browse files Browse the repository at this point in the history
  • Loading branch information
raararaara committed Feb 5, 2025
1 parent 2ba679b commit fd01e46
Show file tree
Hide file tree
Showing 4 changed files with 198 additions and 30 deletions.
143 changes: 130 additions & 13 deletions pkg/locker/locker.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,27 +49,49 @@ type Locker struct {

// lockCtr is used by Locker to represent a lock with a given name.
type lockCtr struct {
mu sync.Mutex
mu sync.RWMutex
// waiters is the number of waiters waiting to acquire the lock
// this is int32 instead of uint32 so we can add `-1` in `dec()`
// this is int32 instead of uint32 so we can add `-1` in `decWaiters()`
waiters int32
// readers is the number of readers currently holding RLock.
readers int32
// writer is 1 if currently holding Lock, otherwise 0.
writer int32
}

// inc increments the number of waiters waiting for the lock
func (l *lockCtr) inc() {
// incWaiters increments the number of waiters waiting for the lock
func (l *lockCtr) incWaiters() {
atomic.AddInt32(&l.waiters, 1)
}

// dec decrements the number of waiters waiting on the lock
func (l *lockCtr) dec() {
// decWaiters decrements the number of waiters waiting on the lock
func (l *lockCtr) decWaiters() {
atomic.AddInt32(&l.waiters, -1)
}

func (l *lockCtr) incReaders() {
atomic.AddInt32(&l.readers, 1)
}

func (l *lockCtr) decReaders() {
atomic.AddInt32(&l.readers, -1)
}

func (l *lockCtr) setWriter(val int32) {
atomic.StoreInt32(&l.writer, val)
}

// count gets the current number of waiters
func (l *lockCtr) count() int32 {
return atomic.LoadInt32(&l.waiters)
}

func (l *lockCtr) canDelete() bool {
return atomic.LoadInt32(&l.waiters) == 0 &&
atomic.LoadInt32(&l.readers) == 0 &&
atomic.LoadInt32(&l.writer) == 0
}

// Lock locks the mutex
func (l *lockCtr) Lock() {
l.mu.Lock()
Expand All @@ -85,6 +107,21 @@ func (l *lockCtr) Unlock() {
l.mu.Unlock()
}

// RLock locks the mutex
func (l *lockCtr) RLock() {
l.mu.RLock()
}

// TryRLock tries to lock the mutex.
func (l *lockCtr) TryRLock() bool {
return l.mu.TryRLock()
}

// RUnlock unlocks the mutex
func (l *lockCtr) RUnlock() {
l.mu.RUnlock()
}

// New creates a new Locker
func New() *Locker {
return &Locker{
Expand All @@ -107,13 +144,15 @@ func (l *Locker) Lock(name string) {

// increment the nameLock waiters while inside the main mutex
// this makes sure that the lock isn't deleted if `Lock` and `Unlock` are called concurrently
nameLock.inc()
nameLock.incWaiters()
l.mu.Unlock()

// Lock the nameLock outside the main mutex so we don't block other operations
// once locked then we can decrement the number of waiters for this lock
nameLock.Lock()
nameLock.dec()

nameLock.decWaiters()
nameLock.setWriter(1)
}

// TryLock locks a mutex with the given name. If it doesn't exist, one is created.
Expand All @@ -131,13 +170,14 @@ func (l *Locker) TryLock(name string) bool {

// increment the nameLock waiters while inside the main mutex
// this makes sure that the lock isn't deleted if `Lock` and `Unlock` are called concurrently
nameLock.inc()
nameLock.incWaiters()
l.mu.Unlock()

// Lock the nameLock outside the main mutex so we don't block other operations
// once locked then we can decrement the number of waiters for this lock
succeeded := nameLock.TryLock()
nameLock.dec()
nameLock.decWaiters()
nameLock.setWriter(1)

return succeeded
}
Expand All @@ -146,17 +186,94 @@ func (l *Locker) TryLock(name string) bool {
// If the given lock is not being waited on by any other callers, it is deleted
func (l *Locker) Unlock(name string) error {
l.mu.Lock()
defer l.mu.Unlock()

nameLock, exists := l.locks[name]
if !exists {
l.mu.Unlock()
return ErrNoSuchLock
}

if nameLock.count() == 0 {
nameLock.Unlock()
nameLock.setWriter(0)

if nameLock.canDelete() {
delete(l.locks, name)
}
nameLock.Unlock()

return nil
}

// RLock acquires a read lock for the given name.
// If there is no lock for that name, a new one is created.
func (l *Locker) RLock(name string) {
l.mu.Lock()
if l.locks == nil {
l.locks = make(map[string]*lockCtr)
}

nameLock, exists := l.locks[name]
if !exists {
nameLock = &lockCtr{}
l.locks[name] = nameLock
}

// 01. Increase waiters inside the global lock
nameLock.incWaiters()
l.mu.Unlock()

// 02. Acquire RLock
nameLock.RLock()

// 03. Decrease waiters and increase readers
nameLock.decWaiters()
nameLock.incReaders()
}

// TryRLock attempts to acquire a read lock for the given name.
// Returns true if success, false if the lock is currently held by a writer.
func (l *Locker) TryRLock(name string) bool {
l.mu.Lock()
if l.locks == nil {
l.locks = make(map[string]*lockCtr)
}

nameLock, exists := l.locks[name]
if !exists {
nameLock = &lockCtr{}
l.locks[name] = nameLock
}

// increment the nameLock waiters while inside the main mutex
// this makes sure that the lock isn't deleted if `Lock` and `Unlock` are called concurrently
nameLock.incWaiters()
l.mu.Unlock()

// Lock the nameLock outside the main mutex so we don't block other operations
// once locked then we can decrement the number of waiters for this lock
succeeded := nameLock.TryRLock()

nameLock.decWaiters()
nameLock.incReaders()

return succeeded
}

// RUnlock releases a read lock for the given name.
func (l *Locker) RUnlock(name string) error {
l.mu.Lock()
defer l.mu.Unlock()

nameLock, exists := l.locks[name]
if !exists {
return ErrNoSuchLock
}

nameLock.RUnlock()
nameLock.decReaders()

if nameLock.canDelete() {
delete(l.locks, name)
}

return nil
}
10 changes: 6 additions & 4 deletions pkg/locker/locker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,15 @@ import (

func TestLockCounter(t *testing.T) {
l := &lockCtr{}
l.inc()
l.incWaiters()

if l.waiters != 1 {
t.Fatal("counter inc failed")
t.Fatal("counter incWaiters failed")
}

l.dec()
l.decWaiters()
if l.waiters != 0 {
t.Fatal("counter dec failed")
t.Fatal("counter decWaiters failed")
}
}

Expand Down Expand Up @@ -119,8 +119,10 @@ func TestLockerConcurrency(t *testing.T) {
for i := 0; i <= 1000; i++ {
wg.Add(1)
go func() {
//fmt.Println("locked: ")
l.Lock("test")
// if there is a concurrency issue, will very likely panic here
//fmt.Println("unLock: ")
assert.NoError(t, l.Unlock("test"))
wg.Done()
}()
Expand Down
31 changes: 31 additions & 0 deletions server/backend/sync/locker.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ type Locker interface {

// Unlock unlocks the mutex.
Unlock(ctx context.Context) error

// RLock acquires a read lock with a cancelable context.
RLock(ctx context.Context) error

// RUnlock releases a read lock previously acquired by RLock.
RUnlock(ctx context.Context) error
}

type internalLocker struct {
Expand Down Expand Up @@ -104,3 +110,28 @@ func (il *internalLocker) Unlock(_ context.Context) error {

return nil
}

// RLock locks the mutex for reading..
func (il *internalLocker) RLock(_ context.Context) error {
il.locks.RLock(il.key)

return nil
}

// TryRLock locks the mutex for reading if not already locked by another session.
func (il *internalLocker) TryRLock(_ context.Context) error {
if !il.locks.TryRLock(il.key) {
return ErrAlreadyLocked
}

return nil
}

// RUnlock unlocks the read lock.
func (il *internalLocker) RUnlock(_ context.Context) error {
if err := il.locks.RUnlock(il.key); err != nil {
return err
}

return nil
}
44 changes: 31 additions & 13 deletions server/rpc/yorkie_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -303,15 +303,15 @@ func (s *yorkieServer) PushPullChanges(
}

project := projects.From(ctx)
if pack.HasChanges() {
locker, err := s.backend.Locker.NewLocker(
ctx,
packs.PushPullKey(project.ID, pack.DocumentKey),
)
if err != nil {
return nil, err
}
locker, err := s.backend.Locker.NewLocker(
ctx,
packs.PushPullKey(project.ID, pack.DocumentKey),
)
if err != nil {
return nil, err
}

if pack.HasChanges() {
if err := locker.Lock(ctx); err != nil {
return nil, err
}
Expand All @@ -320,6 +320,15 @@ func (s *yorkieServer) PushPullChanges(
logging.DefaultLogger().Error(err)
}
}()
} else {
if err := locker.RLock(ctx); err != nil {
return nil, err
}
defer func() {
if err := locker.RUnlock(ctx); err != nil {
logging.DefaultLogger().Error(err)
}
}()
}

clientInfo, err := clients.FindActiveClientInfo(ctx, s.backend, types.ClientRefKey{
Expand Down Expand Up @@ -519,12 +528,12 @@ func (s *yorkieServer) RemoveDocument(
}

project := projects.From(ctx)
if pack.HasChanges() {
locker, err := s.backend.Locker.NewLocker(ctx, packs.PushPullKey(project.ID, pack.DocumentKey))
if err != nil {
return nil, err
}
locker, err := s.backend.Locker.NewLocker(ctx, packs.PushPullKey(project.ID, pack.DocumentKey))
if err != nil {
return nil, err
}

if pack.HasChanges() {
if err := locker.Lock(ctx); err != nil {
return nil, err
}
Expand All @@ -533,6 +542,15 @@ func (s *yorkieServer) RemoveDocument(
logging.DefaultLogger().Error(err)
}
}()
} else {
if err := locker.RLock(ctx); err != nil {
return nil, err
}
defer func() {
if err := locker.RUnlock(ctx); err != nil {
logging.DefaultLogger().Error(err)
}
}()
}

clientInfo, err := clients.FindActiveClientInfo(ctx, s.backend, types.ClientRefKey{
Expand Down

0 comments on commit fd01e46

Please sign in to comment.