diff --git a/atomic_counter.go b/atomic_counter.go index 5fd40e1..b3375a5 100644 --- a/atomic_counter.go +++ b/atomic_counter.go @@ -13,20 +13,34 @@ import ( var ErrNotCounter error = fmt.Errorf("not a counter") var ErrCounterNotLoaded error = fmt.Errorf("counter not loaded") +var ErrDecrementN error = fmt.Errorf("decrementing natural counter") type AtomicCounter struct { + data atomic.Value db *Chotki + wg sync.WaitGroup rid rdx.ID offset uint64 - localValue atomic.Int64 - tlv atomic.Value - rdt atomic.Value - loaded atomic.Bool lock sync.Mutex expiration time.Time updatePeriod time.Duration } +type atomicNcounter struct { + theirs uint64 + total atomic.Uint64 +} + +type zpart struct { + total int64 + revision int64 +} + +type atomicZCounter struct { + theirs int64 + part atomic.Pointer[zpart] +} + // creates counter that has two properties // - its atomic as long as you use single instance to do all increments, creating multiple instances will break this guarantee // - it can ease CPU load if updatePeiod > 0, in that case it will not read from db backend @@ -43,75 +57,101 @@ func NewAtomicCounter(db *Chotki, rid rdx.ID, offset uint64, updatePeriod time.D } } -func (a *AtomicCounter) load() error { +func (a *AtomicCounter) load() (any, error) { a.lock.Lock() defer a.lock.Unlock() now := time.Now() - if a.loaded.Load() && now.Sub(a.expiration) < 0 { - return nil + if a.data.Load() != nil && now.Sub(a.expiration) < 0 { + return a.data.Load(), nil } + a.wg.Wait() rdt, tlv, err := a.db.ObjectFieldTLV(a.rid.ToOff(a.offset)) if err != nil { - return err + return nil, err } - a.rdt.Store(rdt) - a.loaded.Store(true) - a.tlv.Store(tlv) + var data any switch rdt { case rdx.ZCounter: - a.localValue.Store(rdx.Znative(tlv)) + total, mine, rev := rdx.Znative3(tlv, a.db.clock.Src()) + part := zpart{total: total, revision: rev} + c := atomicZCounter{ + theirs: total - mine, + part: atomic.Pointer[zpart]{}, + } + c.part.Store(&part) + data = &c case rdx.Natural: - a.localValue.Store(int64(rdx.Nnative(tlv))) + total, mine := rdx.Nnative2(tlv, a.db.clock.Src()) + c := atomicNcounter{ + theirs: total - mine, + total: atomic.Uint64{}, + } + c.total.Add(total) + data = &c default: - return ErrNotCounter + return nil, ErrNotCounter } + a.data.Store(data) a.expiration = now.Add(a.updatePeriod) - return nil + return data, nil } func (a *AtomicCounter) Get(ctx context.Context) (int64, error) { - err := a.load() + data, err := a.load() if err != nil { return 0, err } - if !a.loaded.Load() { + switch c := data.(type) { + case *atomicNcounter: + return int64(c.total.Load()), nil + case *atomicZCounter: + return c.part.Load().total, nil + default: return 0, ErrCounterNotLoaded } - return a.localValue.Load(), nil } // Loads (if needed) and increments counter func (a *AtomicCounter) Increment(ctx context.Context, val int64) (int64, error) { - err := a.load() + data, err := a.load() if err != nil { return 0, err } - if !a.loaded.Load() { - return 0, ErrCounterNotLoaded - } - - rdt := a.rdt.Load().(byte) - a.localValue.Add(val) + a.wg.Add(1) + defer a.wg.Done() var dtlv []byte - var mtlv []byte - a.lock.Lock() - tlv := a.tlv.Load().([]byte) - switch rdt { - case rdx.Natural: - dtlv = rdx.Ndelta(tlv, uint64(a.localValue.Load()), a.db.Clock()) - mtlv = rdx.Nmerge([][]byte{tlv, dtlv}) - case rdx.ZCounter: - dtlv = rdx.Zdelta(tlv, a.localValue.Load(), a.db.Clock()) - mtlv = rdx.Zmerge([][]byte{tlv, dtlv}) + var result int64 + var rdt byte + switch c := data.(type) { + case *atomicNcounter: + if val < 0 { + return 0, ErrDecrementN + } + nw := c.total.Add(uint64(val)) + dtlv = rdx.Ntlvt(nw-c.theirs, a.db.clock.Src()) + result = int64(nw) + rdt = rdx.Natural + case *atomicZCounter: + for { + current := c.part.Load() + nw := zpart{ + total: current.total + val, + revision: current.revision + 1, + } + ok := c.part.CompareAndSwap(current, &nw) + if ok { + dtlv = rdx.Ztlvt(nw.total-c.theirs, a.db.clock.Src(), nw.revision) + result = nw.total + rdt = rdx.ZCounter + break + } + } default: - a.lock.Unlock() - return 0, ErrNotCounter + return 0, ErrCounterNotLoaded } - a.tlv.Store(mtlv) - a.lock.Unlock() changes := make(protocol.Records, 0) changes = append(changes, protocol.Record('F', rdx.ZipUint64(uint64(a.offset)))) changes = append(changes, protocol.Record(rdt, dtlv)) a.db.CommitPacket(ctx, 'E', a.rid.ZeroOff(), changes) - return a.localValue.Load(), nil + return result, nil } diff --git a/atomic_counter_test.go b/atomic_counter_test.go index 38d96f0..a7f4c39 100644 --- a/atomic_counter_test.go +++ b/atomic_counter_test.go @@ -2,6 +2,7 @@ package chotki import ( "context" + "fmt" "os" "testing" "time" @@ -93,27 +94,27 @@ func TestAtomicCounterWithPeriodicUpdate(t *testing.T) { // first increment res, err := counterA.Increment(ctx, 1) assert.NoError(t, err) - assert.EqualValues(t, 1, res) + assert.EqualValues(t, 1, res, fmt.Sprintf("iteration %d", i)) syncData(a, b) // increment from another replica res, err = counterB.Increment(ctx, 1) assert.NoError(t, err) - assert.EqualValues(t, 2, res) + assert.EqualValues(t, 2, res, fmt.Sprintf("iteration %d", i)) syncData(a, b) // this increment does not account data from other replica because current value is cached res, err = counterA.Increment(ctx, 1) assert.NoError(t, err) - assert.EqualValues(t, 2, res) + assert.EqualValues(t, 2, res, fmt.Sprintf("iteration %d", i)) time.Sleep(100 * time.Millisecond) // after wait we increment, and we get actual value res, err = counterA.Increment(ctx, 1) assert.NoError(t, err) - assert.EqualValues(t, 4, res) + assert.EqualValues(t, 4, res, fmt.Sprintf("iteration %d", i)) } } diff --git a/rdx/NZ.go b/rdx/NZ.go index b4000e0..c784995 100644 --- a/rdx/NZ.go +++ b/rdx/NZ.go @@ -44,6 +44,17 @@ func Nnative(tlv []byte) (sum uint64) { return } +func Nnative2(tlv []byte, src uint64) (sum, mine uint64) { + it := FIRSTIterator{TLV: tlv} + for it.Next() { + if it.src == src { + mine = it.revz + } + sum += it.revz + } + return +} + // merge TLV values func Nmerge(tlvs [][]byte) (merged []byte) { ih := ItHeap[*NIterator]{}