diff --git a/pkg/locker/locker.go b/pkg/locker/locker.go index e0af69dd8..ce89c62ce 100644 --- a/pkg/locker/locker.go +++ b/pkg/locker/locker.go @@ -49,27 +49,49 @@ type Locker struct { // lockCtr is used by Locker to represent a lock with a given name. type lockCtr struct { - mu sync.Mutex + mu sync.RWMutex // waiters is the number of waiters waiting to acquire the lock - // this is int32 instead of uint32 so we can add `-1` in `dec()` + // this is int32 instead of uint32 so we can add `-1` in `decWaiters()` waiters int32 + // readers is the number of readers currently holding RLock. + readers int32 + // writer is 1 if currently holding Lock, otherwise 0. + writer int32 } -// inc increments the number of waiters waiting for the lock -func (l *lockCtr) inc() { +// incWaiters increments the number of waiters waiting for the lock +func (l *lockCtr) incWaiters() { atomic.AddInt32(&l.waiters, 1) } -// dec decrements the number of waiters waiting on the lock -func (l *lockCtr) dec() { +// decWaiters decrements the number of waiters waiting on the lock +func (l *lockCtr) decWaiters() { atomic.AddInt32(&l.waiters, -1) } +func (l *lockCtr) incReaders() { + atomic.AddInt32(&l.readers, 1) +} + +func (l *lockCtr) decReaders() { + atomic.AddInt32(&l.readers, -1) +} + +func (l *lockCtr) setWriter(val int32) { + atomic.StoreInt32(&l.writer, val) +} + // count gets the current number of waiters func (l *lockCtr) count() int32 { return atomic.LoadInt32(&l.waiters) } +func (l *lockCtr) canDelete() bool { + return atomic.LoadInt32(&l.waiters) == 0 && + atomic.LoadInt32(&l.readers) == 0 && + atomic.LoadInt32(&l.writer) == 0 +} + // Lock locks the mutex func (l *lockCtr) Lock() { l.mu.Lock() @@ -85,6 +107,21 @@ func (l *lockCtr) Unlock() { l.mu.Unlock() } +// RLock locks the mutex +func (l *lockCtr) RLock() { + l.mu.RLock() +} + +// TryRLock tries to lock the mutex. +func (l *lockCtr) TryRLock() bool { + return l.mu.TryRLock() +} + +// RUnlock unlocks the mutex +func (l *lockCtr) RUnlock() { + l.mu.RUnlock() +} + // New creates a new Locker func New() *Locker { return &Locker{ @@ -107,13 +144,15 @@ func (l *Locker) Lock(name string) { // increment the nameLock waiters while inside the main mutex // this makes sure that the lock isn't deleted if `Lock` and `Unlock` are called concurrently - nameLock.inc() + nameLock.incWaiters() l.mu.Unlock() // Lock the nameLock outside the main mutex so we don't block other operations // once locked then we can decrement the number of waiters for this lock nameLock.Lock() - nameLock.dec() + + nameLock.decWaiters() + nameLock.setWriter(1) } // TryLock locks a mutex with the given name. If it doesn't exist, one is created. @@ -131,13 +170,14 @@ func (l *Locker) TryLock(name string) bool { // increment the nameLock waiters while inside the main mutex // this makes sure that the lock isn't deleted if `Lock` and `Unlock` are called concurrently - nameLock.inc() + nameLock.incWaiters() l.mu.Unlock() // Lock the nameLock outside the main mutex so we don't block other operations // once locked then we can decrement the number of waiters for this lock succeeded := nameLock.TryLock() - nameLock.dec() + nameLock.decWaiters() + nameLock.setWriter(1) return succeeded } @@ -146,17 +186,94 @@ func (l *Locker) TryLock(name string) bool { // If the given lock is not being waited on by any other callers, it is deleted func (l *Locker) Unlock(name string) error { l.mu.Lock() + defer l.mu.Unlock() + nameLock, exists := l.locks[name] if !exists { - l.mu.Unlock() return ErrNoSuchLock } - if nameLock.count() == 0 { + nameLock.Unlock() + nameLock.setWriter(0) + + if nameLock.canDelete() { delete(l.locks, name) } - nameLock.Unlock() + return nil +} + +// RLock acquires a read lock for the given name. +// If there is no lock for that name, a new one is created. +func (l *Locker) RLock(name string) { + l.mu.Lock() + if l.locks == nil { + l.locks = make(map[string]*lockCtr) + } + + nameLock, exists := l.locks[name] + if !exists { + nameLock = &lockCtr{} + l.locks[name] = nameLock + } + + // 01. Increase waiters inside the global lock + nameLock.incWaiters() + l.mu.Unlock() + + // 02. Acquire RLock + nameLock.RLock() + + // 03. Decrease waiters and increase readers + nameLock.decWaiters() + nameLock.incReaders() +} + +// TryRLock attempts to acquire a read lock for the given name. +// Returns true if success, false if the lock is currently held by a writer. +func (l *Locker) TryRLock(name string) bool { + l.mu.Lock() + if l.locks == nil { + l.locks = make(map[string]*lockCtr) + } + + nameLock, exists := l.locks[name] + if !exists { + nameLock = &lockCtr{} + l.locks[name] = nameLock + } + + // increment the nameLock waiters while inside the main mutex + // this makes sure that the lock isn't deleted if `Lock` and `Unlock` are called concurrently + nameLock.incWaiters() l.mu.Unlock() + + // Lock the nameLock outside the main mutex so we don't block other operations + // once locked then we can decrement the number of waiters for this lock + succeeded := nameLock.TryRLock() + + nameLock.decWaiters() + nameLock.incReaders() + + return succeeded +} + +// RUnlock releases a read lock for the given name. +func (l *Locker) RUnlock(name string) error { + l.mu.Lock() + defer l.mu.Unlock() + + nameLock, exists := l.locks[name] + if !exists { + return ErrNoSuchLock + } + + nameLock.RUnlock() + nameLock.decReaders() + + if nameLock.canDelete() { + delete(l.locks, name) + } + return nil } diff --git a/pkg/locker/locker_test.go b/pkg/locker/locker_test.go index 3d3496aaf..d5dd6fb35 100644 --- a/pkg/locker/locker_test.go +++ b/pkg/locker/locker_test.go @@ -29,15 +29,15 @@ import ( func TestLockCounter(t *testing.T) { l := &lockCtr{} - l.inc() + l.incWaiters() if l.waiters != 1 { - t.Fatal("counter inc failed") + t.Fatal("counter incWaiters failed") } - l.dec() + l.decWaiters() if l.waiters != 0 { - t.Fatal("counter dec failed") + t.Fatal("counter decWaiters failed") } } @@ -119,8 +119,10 @@ func TestLockerConcurrency(t *testing.T) { for i := 0; i <= 1000; i++ { wg.Add(1) go func() { + //fmt.Println("locked: ") l.Lock("test") // if there is a concurrency issue, will very likely panic here + //fmt.Println("unLock: ") assert.NoError(t, l.Unlock("test")) wg.Done() }() diff --git a/server/backend/sync/locker.go b/server/backend/sync/locker.go index 4d8eaa96d..559378b40 100644 --- a/server/backend/sync/locker.go +++ b/server/backend/sync/locker.go @@ -73,6 +73,12 @@ type Locker interface { // Unlock unlocks the mutex. Unlock(ctx context.Context) error + + // RLock acquires a read lock with a cancelable context. + RLock(ctx context.Context) error + + // RUnlock releases a read lock previously acquired by RLock. + RUnlock(ctx context.Context) error } type internalLocker struct { @@ -104,3 +110,28 @@ func (il *internalLocker) Unlock(_ context.Context) error { return nil } + +// RLock locks the mutex for reading.. +func (il *internalLocker) RLock(_ context.Context) error { + il.locks.RLock(il.key) + + return nil +} + +// TryRLock locks the mutex for reading if not already locked by another session. +func (il *internalLocker) TryRLock(_ context.Context) error { + if !il.locks.TryRLock(il.key) { + return ErrAlreadyLocked + } + + return nil +} + +// RUnlock unlocks the read lock. +func (il *internalLocker) RUnlock(_ context.Context) error { + if err := il.locks.RUnlock(il.key); err != nil { + return err + } + + return nil +} diff --git a/server/rpc/yorkie_server.go b/server/rpc/yorkie_server.go index 7625ca4b6..f5d29ebd8 100644 --- a/server/rpc/yorkie_server.go +++ b/server/rpc/yorkie_server.go @@ -303,15 +303,15 @@ func (s *yorkieServer) PushPullChanges( } project := projects.From(ctx) - if pack.HasChanges() { - locker, err := s.backend.Locker.NewLocker( - ctx, - packs.PushPullKey(project.ID, pack.DocumentKey), - ) - if err != nil { - return nil, err - } + locker, err := s.backend.Locker.NewLocker( + ctx, + packs.PushPullKey(project.ID, pack.DocumentKey), + ) + if err != nil { + return nil, err + } + if pack.HasChanges() { if err := locker.Lock(ctx); err != nil { return nil, err } @@ -320,6 +320,15 @@ func (s *yorkieServer) PushPullChanges( logging.DefaultLogger().Error(err) } }() + } else { + if err := locker.RLock(ctx); err != nil { + return nil, err + } + defer func() { + if err := locker.RUnlock(ctx); err != nil { + logging.DefaultLogger().Error(err) + } + }() } clientInfo, err := clients.FindActiveClientInfo(ctx, s.backend, types.ClientRefKey{ @@ -519,12 +528,12 @@ func (s *yorkieServer) RemoveDocument( } project := projects.From(ctx) - if pack.HasChanges() { - locker, err := s.backend.Locker.NewLocker(ctx, packs.PushPullKey(project.ID, pack.DocumentKey)) - if err != nil { - return nil, err - } + locker, err := s.backend.Locker.NewLocker(ctx, packs.PushPullKey(project.ID, pack.DocumentKey)) + if err != nil { + return nil, err + } + if pack.HasChanges() { if err := locker.Lock(ctx); err != nil { return nil, err } @@ -533,6 +542,15 @@ func (s *yorkieServer) RemoveDocument( logging.DefaultLogger().Error(err) } }() + } else { + if err := locker.RLock(ctx); err != nil { + return nil, err + } + defer func() { + if err := locker.RUnlock(ctx); err != nil { + logging.DefaultLogger().Error(err) + } + }() } clientInfo, err := clients.FindActiveClientInfo(ctx, s.backend, types.ClientRefKey{