From d963e322d137039209f29519483a2044200eb553 Mon Sep 17 00:00:00 2001 From: Termina1 Date: Thu, 26 Dec 2024 16:42:39 +0200 Subject: [PATCH] fix atomic counter race --- atomic_counter.go | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/atomic_counter.go b/atomic_counter.go index b3375a5..0198b78 100644 --- a/atomic_counter.go +++ b/atomic_counter.go @@ -58,10 +58,15 @@ func NewAtomicCounter(db *Chotki, rid rdx.ID, offset uint64, updatePeriod time.D } func (a *AtomicCounter) load() (any, error) { + now := time.Now() + if a.data.Load() != nil && now.Sub(a.expiration) < 0 { + a.wg.Add(1) + return a.data.Load(), nil + } a.lock.Lock() defer a.lock.Unlock() - now := time.Now() if a.data.Load() != nil && now.Sub(a.expiration) < 0 { + a.wg.Add(1) return a.data.Load(), nil } a.wg.Wait() @@ -93,6 +98,7 @@ func (a *AtomicCounter) load() (any, error) { } a.data.Store(data) a.expiration = now.Add(a.updatePeriod) + a.wg.Add(1) return data, nil } @@ -101,6 +107,7 @@ func (a *AtomicCounter) Get(ctx context.Context) (int64, error) { if err != nil { return 0, err } + defer a.wg.Done() switch c := data.(type) { case *atomicNcounter: return int64(c.total.Load()), nil @@ -117,7 +124,6 @@ func (a *AtomicCounter) Increment(ctx context.Context, val int64) (int64, error) if err != nil { return 0, err } - a.wg.Add(1) defer a.wg.Done() var dtlv []byte var result int64