Skip to content

Commit

Permalink
refactor atomic counters
Browse files Browse the repository at this point in the history
  • Loading branch information
Termina1 committed Dec 25, 2024
1 parent 00c8c3d commit e587387
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 44 deletions.
120 changes: 80 additions & 40 deletions atomic_counter.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,34 @@ import (

var ErrNotCounter error = fmt.Errorf("not a counter")
var ErrCounterNotLoaded error = fmt.Errorf("counter not loaded")
var ErrDecrementN error = fmt.Errorf("decrementing natural counter")

type AtomicCounter struct {
data atomic.Value
db *Chotki
wg sync.WaitGroup
rid rdx.ID
offset uint64
localValue atomic.Int64
tlv atomic.Value
rdt atomic.Value
loaded atomic.Bool
lock sync.Mutex
expiration time.Time
updatePeriod time.Duration
}

type atomicNcounter struct {
theirs uint64
total atomic.Uint64
}

type zpart struct {
total int64
revision int64
}

type atomicZCounter struct {
theirs int64
part atomic.Pointer[zpart]
}

// creates counter that has two properties
// - its atomic as long as you use single instance to do all increments, creating multiple instances will break this guarantee
// - it can ease CPU load if updatePeiod > 0, in that case it will not read from db backend
Expand All @@ -43,75 +57,101 @@ func NewAtomicCounter(db *Chotki, rid rdx.ID, offset uint64, updatePeriod time.D
}
}

func (a *AtomicCounter) load() error {
func (a *AtomicCounter) load() (any, error) {
a.lock.Lock()
defer a.lock.Unlock()
now := time.Now()
if a.loaded.Load() && now.Sub(a.expiration) < 0 {
return nil
if a.data.Load() != nil && now.Sub(a.expiration) < 0 {
return a.data.Load(), nil
}
a.wg.Wait()
rdt, tlv, err := a.db.ObjectFieldTLV(a.rid.ToOff(a.offset))
if err != nil {
return err
return nil, err
}
a.rdt.Store(rdt)
a.loaded.Store(true)
a.tlv.Store(tlv)
var data any
switch rdt {
case rdx.ZCounter:
a.localValue.Store(rdx.Znative(tlv))
total, mine, rev := rdx.Znative3(tlv, a.db.clock.Src())
part := zpart{total: total, revision: rev}
c := atomicZCounter{
theirs: total - mine,
part: atomic.Pointer[zpart]{},
}
c.part.Store(&part)
data = &c
case rdx.Natural:
a.localValue.Store(int64(rdx.Nnative(tlv)))
total, mine := rdx.Nnative2(tlv, a.db.clock.Src())
c := atomicNcounter{
theirs: total - mine,
total: atomic.Uint64{},
}
c.total.Add(total)
data = &c
default:
return ErrNotCounter
return nil, ErrNotCounter
}
a.data.Store(data)
a.expiration = now.Add(a.updatePeriod)
return nil
return data, nil
}

func (a *AtomicCounter) Get(ctx context.Context) (int64, error) {
err := a.load()
data, err := a.load()
if err != nil {
return 0, err
}
if !a.loaded.Load() {
switch c := data.(type) {
case *atomicNcounter:
return int64(c.total.Load()), nil
case *atomicZCounter:
return c.part.Load().total, nil
default:
return 0, ErrCounterNotLoaded
}
return a.localValue.Load(), nil
}

// Loads (if needed) and increments counter
func (a *AtomicCounter) Increment(ctx context.Context, val int64) (int64, error) {
err := a.load()
data, err := a.load()
if err != nil {
return 0, err
}
if !a.loaded.Load() {
return 0, ErrCounterNotLoaded
}

rdt := a.rdt.Load().(byte)
a.localValue.Add(val)
a.wg.Add(1)
defer a.wg.Done()
var dtlv []byte
var mtlv []byte
a.lock.Lock()
tlv := a.tlv.Load().([]byte)
switch rdt {
case rdx.Natural:
dtlv = rdx.Ndelta(tlv, uint64(a.localValue.Load()), a.db.Clock())
mtlv = rdx.Nmerge([][]byte{tlv, dtlv})
case rdx.ZCounter:
dtlv = rdx.Zdelta(tlv, a.localValue.Load(), a.db.Clock())
mtlv = rdx.Zmerge([][]byte{tlv, dtlv})
var result int64
var rdt byte
switch c := data.(type) {
case *atomicNcounter:
if val < 0 {
return 0, ErrDecrementN
}
nw := c.total.Add(uint64(val))
dtlv = rdx.Ntlvt(nw-c.theirs, a.db.clock.Src())
result = int64(nw)
rdt = rdx.Natural
case *atomicZCounter:
for {
current := c.part.Load()
nw := zpart{
total: current.total + val,
revision: current.revision + 1,
}
ok := c.part.CompareAndSwap(current, &nw)
if ok {
dtlv = rdx.Ztlvt(nw.total-c.theirs, a.db.clock.Src(), nw.revision)
result = nw.total
rdt = rdx.ZCounter
break
}
}
default:
a.lock.Unlock()
return 0, ErrNotCounter
return 0, ErrCounterNotLoaded
}
a.tlv.Store(mtlv)
a.lock.Unlock()
changes := make(protocol.Records, 0)
changes = append(changes, protocol.Record('F', rdx.ZipUint64(uint64(a.offset))))
changes = append(changes, protocol.Record(rdt, dtlv))
a.db.CommitPacket(ctx, 'E', a.rid.ZeroOff(), changes)
return a.localValue.Load(), nil
return result, nil
}
9 changes: 5 additions & 4 deletions atomic_counter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package chotki

import (
"context"
"fmt"
"os"
"testing"
"time"
Expand Down Expand Up @@ -93,27 +94,27 @@ func TestAtomicCounterWithPeriodicUpdate(t *testing.T) {
// first increment
res, err := counterA.Increment(ctx, 1)
assert.NoError(t, err)
assert.EqualValues(t, 1, res)
assert.EqualValues(t, 1, res, fmt.Sprintf("iteration %d", i))

syncData(a, b)

// increment from another replica
res, err = counterB.Increment(ctx, 1)
assert.NoError(t, err)
assert.EqualValues(t, 2, res)
assert.EqualValues(t, 2, res, fmt.Sprintf("iteration %d", i))

syncData(a, b)

// this increment does not account data from other replica because current value is cached
res, err = counterA.Increment(ctx, 1)
assert.NoError(t, err)
assert.EqualValues(t, 2, res)
assert.EqualValues(t, 2, res, fmt.Sprintf("iteration %d", i))

time.Sleep(100 * time.Millisecond)

// after wait we increment, and we get actual value
res, err = counterA.Increment(ctx, 1)
assert.NoError(t, err)
assert.EqualValues(t, 4, res)
assert.EqualValues(t, 4, res, fmt.Sprintf("iteration %d", i))
}
}
11 changes: 11 additions & 0 deletions rdx/NZ.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,17 @@ func Nnative(tlv []byte) (sum uint64) {
return
}

func Nnative2(tlv []byte, src uint64) (sum, mine uint64) {
it := FIRSTIterator{TLV: tlv}
for it.Next() {
if it.src == src {
mine = it.revz
}
sum += it.revz
}
return
}

// merge TLV values
func Nmerge(tlvs [][]byte) (merged []byte) {
ih := ItHeap[*NIterator]{}
Expand Down

0 comments on commit e587387

Please sign in to comment.