Skip to content

Commit 9cfa369

Browse files
committed
syncs: add ShardedInt expvar.Var type
ShardedInt provides an int type expvar.Var that supports more efficient writes at high frequencies (one order of magnigude on an M1 Max, much more on NUMA systems). There are two implementations of ShardValue, one that abuses sync.Pool that will work on current public Go versions, and one that takes a dependency on a runtime.TailscaleP function exposed in Tailscale's Go fork. The sync.Pool variant has about 10x the throughput of a single atomic integer on an M1 Max, and the runtime.TailscaleP variant is about 10x faster than the sync.Pool variant. Neither variant have perfect distribution, or perfectly always avoid cross-CPU sharing, as there is no locking or affinity to ensure that the time of yield is on the same core as the time of core biasing, but in the average case the distributions are enough to provide substantially better performance. See golang/go#18802 for a related upstream proposal. Updates tailscale/go#109 Updates tailscale/corp#25450 Signed-off-by: James Tucker <[email protected]>
1 parent 00a4504 commit 9cfa369

7 files changed

+372
-1
lines changed

go.toolchain.rev

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
e005697288a8d2fadc87bb7c3e2c74778d08554a
1+
161c3b79ed91039e65eb148f2547dea6b91e2247

syncs/shardedint.go

+63
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
// Copyright (c) Tailscale Inc & AUTHORS
2+
// SPDX-License-Identifier: BSD-3-Clause
3+
4+
package syncs
5+
6+
import (
7+
"encoding/json"
8+
"sync/atomic"
9+
10+
"golang.org/x/sys/cpu"
11+
)
12+
13+
// ShardedInt provides a sharded atomic int64 value that optimizes high
14+
// frequency (Mhz range and above) writes in highly parallel workloads.
15+
// The zero value is not safe for use; use [NewShardedInt].
16+
// ShardedInt implements the expvar.Var interface.
17+
type ShardedInt struct {
18+
sv *ShardValue[intShard]
19+
}
20+
21+
// NewShardedInt returns a new [ShardedInt].
22+
func NewShardedInt() *ShardedInt {
23+
return &ShardedInt{
24+
sv: NewShardValue[intShard](),
25+
}
26+
}
27+
28+
// Add adds delta to the value.
29+
func (m *ShardedInt) Add(delta int64) {
30+
m.sv.One(func(v *intShard) {
31+
v.Add(delta)
32+
})
33+
}
34+
35+
type intShard struct {
36+
atomic.Int64
37+
_ cpu.CacheLinePad // avoid false sharing of neighboring shards
38+
}
39+
40+
// Value returns the current value.
41+
func (m *ShardedInt) Value() int64 {
42+
var v int64
43+
for s := range m.sv.All {
44+
v += s.Load()
45+
}
46+
return v
47+
}
48+
49+
// GetDistribution returns the current value in each shard.
50+
// This is intended for observability/debugging only.
51+
func (m *ShardedInt) GetDistribution() []int64 {
52+
v := make([]int64, 0, m.sv.Len())
53+
for s := range m.sv.All {
54+
v = append(v, s.Load())
55+
}
56+
return v
57+
}
58+
59+
// String implements the expvar.Var interface
60+
func (m *ShardedInt) String() string {
61+
v, _ := json.Marshal(m.Value())
62+
return string(v)
63+
}

syncs/shardedint_test.go

+94
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
// Copyright (c) Tailscale Inc & AUTHORS
2+
// SPDX-License-Identifier: BSD-3-Clause
3+
4+
package syncs
5+
6+
import (
7+
"expvar"
8+
"sync"
9+
"testing"
10+
11+
"tailscale.com/tstest"
12+
)
13+
14+
func BenchmarkShardedInt(b *testing.B) {
15+
b.ReportAllocs()
16+
17+
b.Run("expvar", func(b *testing.B) {
18+
var m expvar.Int
19+
b.RunParallel(func(pb *testing.PB) {
20+
for pb.Next() {
21+
m.Add(1)
22+
}
23+
})
24+
})
25+
26+
b.Run("sharded int", func(b *testing.B) {
27+
m := NewShardedInt()
28+
b.RunParallel(func(pb *testing.PB) {
29+
for pb.Next() {
30+
m.Add(1)
31+
}
32+
})
33+
})
34+
}
35+
36+
func TestShardedInt(t *testing.T) {
37+
t.Run("basics", func(t *testing.T) {
38+
m := NewShardedInt()
39+
if got, want := m.Value(), int64(0); got != want {
40+
t.Errorf("got %v, want %v", got, want)
41+
}
42+
m.Add(1)
43+
if got, want := m.Value(), int64(1); got != want {
44+
t.Errorf("got %v, want %v", got, want)
45+
}
46+
m.Add(2)
47+
if got, want := m.Value(), int64(3); got != want {
48+
t.Errorf("got %v, want %v", got, want)
49+
}
50+
m.Add(-1)
51+
if got, want := m.Value(), int64(2); got != want {
52+
t.Errorf("got %v, want %v", got, want)
53+
}
54+
})
55+
56+
t.Run("high concurrency", func(t *testing.T) {
57+
m := NewShardedInt()
58+
wg := sync.WaitGroup{}
59+
numWorkers := 1000
60+
numIncrements := 1000
61+
wg.Add(numWorkers)
62+
for i := 0; i < numWorkers; i++ {
63+
go func() {
64+
defer wg.Done()
65+
for i := 0; i < numIncrements; i++ {
66+
m.Add(1)
67+
}
68+
}()
69+
}
70+
wg.Wait()
71+
if got, want := m.Value(), int64(numWorkers*numIncrements); got != want {
72+
t.Errorf("got %v, want %v", got, want)
73+
}
74+
for i, shard := range m.GetDistribution() {
75+
t.Logf("shard %d: %d", i, shard)
76+
}
77+
})
78+
79+
t.Run("allocs", func(t *testing.T) {
80+
m := NewShardedInt()
81+
tstest.MinAllocsPerRun(t, 0, func() {
82+
m.Add(1)
83+
_ = m.Value()
84+
})
85+
86+
// TODO(raggi): fix access to expvar's internal append based
87+
// interface, unfortunately it's not currently closed for external
88+
// use, this will alloc when it escapes.
89+
tstest.MinAllocsPerRun(t, 0, func() {
90+
m.Add(1)
91+
_ = m.String()
92+
})
93+
})
94+
}

syncs/shardvalue.go

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
// Copyright (c) Tailscale Inc & AUTHORS
2+
// SPDX-License-Identifier: BSD-3-Clause
3+
4+
package syncs
5+
6+
// TODO(raggi): this implementation is still imperfect as it will still result
7+
// in cross CPU sharing periodically, we instead really want a per-CPU shard
8+
// key, but the limitations of calling platform code make reaching for even the
9+
// getcpu vdso very painful. See https://github.com/golang/go/issues/18802, and
10+
// hopefully one day we can replace with a primitive that falls out of that
11+
// work.
12+
13+
// ShardValue contains a value sharded over a set of shards.
14+
// In order to be useful, T should be aligned to cache lines.
15+
// Users must organize that usage in One and All is concurrency safe.
16+
// The zero value is not safe for use; use [NewShardValue].
17+
type ShardValue[T any] struct {
18+
shards []T
19+
20+
// empty struct under tailscale_go builds
21+
pool shardValuePool
22+
}
23+
24+
// Len returns the number of shards.
25+
func (sp *ShardValue[T]) Len() int {
26+
return len(sp.shards)
27+
}
28+
29+
// All yields a pointer to the value in each shard.
30+
func (sp *ShardValue[T]) All(yield func(*T) bool) {
31+
for i := range sp.shards {
32+
if !yield(&sp.shards[i]) {
33+
return
34+
}
35+
}
36+
}

syncs/shardvalue_go.go

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
// Copyright (c) Tailscale Inc & AUTHORS
2+
// SPDX-License-Identifier: BSD-3-Clause
3+
4+
//go:build !tailscale_go
5+
6+
package syncs
7+
8+
import (
9+
"runtime"
10+
"sync"
11+
"sync/atomic"
12+
)
13+
14+
type shardValuePool struct {
15+
atomic.Int64
16+
sync.Pool
17+
}
18+
19+
// NewShardValue constructs a new ShardValue[T] with a shard per CPU.
20+
func NewShardValue[T any]() *ShardValue[T] {
21+
sp := &ShardValue[T]{
22+
shards: make([]T, runtime.NumCPU()),
23+
}
24+
sp.pool.New = func() any {
25+
i := sp.pool.Add(1) - 1
26+
return &sp.shards[i%int64(len(sp.shards))]
27+
}
28+
return sp
29+
}
30+
31+
// One yields a pointer to a single shard value with best-effort P-locality.
32+
func (sp *ShardValue[T]) One(yield func(*T)) {
33+
v := sp.pool.Get().(*T)
34+
yield(v)
35+
sp.pool.Put(v)
36+
}

syncs/shardvalue_tailscale.go

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// Copyright (c) Tailscale Inc & AUTHORS
2+
// SPDX-License-Identifier: BSD-3-Clause
3+
4+
// TODO(raggi): update build tag after toolchain update
5+
//go:build tailscale_go
6+
7+
package syncs
8+
9+
import (
10+
"runtime"
11+
)
12+
13+
type shardValuePool struct{} // unused in this variant
14+
15+
// NewShardValue constructs a new ShardValue[T] with a shard per CPU.
16+
func NewShardValue[T any]() *ShardValue[T] {
17+
return &ShardValue[T]{shards: make([]T, runtime.NumCPU())}
18+
}
19+
20+
// One yields a pointer to a single shard value with best-effort P-locality.
21+
func (sp *ShardValue[T]) One(f func(*T)) {
22+
f(&sp.shards[runtime.TailscaleCurrentP()%len(sp.shards)])
23+
}

syncs/shardvalue_test.go

+119
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
// Copyright (c) Tailscale Inc & AUTHORS
2+
// SPDX-License-Identifier: BSD-3-Clause
3+
4+
package syncs
5+
6+
import (
7+
"math"
8+
"runtime"
9+
"sync"
10+
"sync/atomic"
11+
"testing"
12+
13+
"golang.org/x/sys/cpu"
14+
)
15+
16+
func TestShardValue(t *testing.T) {
17+
type intVal struct {
18+
atomic.Int64
19+
_ cpu.CacheLinePad
20+
}
21+
22+
t.Run("One", func(t *testing.T) {
23+
sv := NewShardValue[intVal]()
24+
sv.One(func(v *intVal) {
25+
v.Store(10)
26+
})
27+
28+
var v int64
29+
for i := range sv.shards {
30+
v += sv.shards[i].Load()
31+
}
32+
if v != 10 {
33+
t.Errorf("got %v, want 10", v)
34+
}
35+
})
36+
37+
t.Run("All", func(t *testing.T) {
38+
sv := NewShardValue[intVal]()
39+
for i := range sv.shards {
40+
sv.shards[i].Store(int64(i))
41+
}
42+
43+
var total int64
44+
sv.All(func(v *intVal) bool {
45+
total += v.Load()
46+
return true
47+
})
48+
// triangle coefficient lower one order due to 0 index
49+
want := int64(len(sv.shards) * (len(sv.shards) - 1) / 2)
50+
if total != want {
51+
t.Errorf("got %v, want %v", total, want)
52+
}
53+
})
54+
55+
t.Run("Len", func(t *testing.T) {
56+
sv := NewShardValue[intVal]()
57+
if got, want := sv.Len(), runtime.NumCPU(); got != want {
58+
t.Errorf("got %v, want %v", got, want)
59+
}
60+
})
61+
62+
t.Run("distribution", func(t *testing.T) {
63+
sv := NewShardValue[intVal]()
64+
65+
goroutines := 1000
66+
iterations := 10000
67+
var wg sync.WaitGroup
68+
wg.Add(goroutines)
69+
for i := 0; i < goroutines; i++ {
70+
go func() {
71+
defer wg.Done()
72+
for i := 0; i < iterations; i++ {
73+
sv.One(func(v *intVal) {
74+
v.Add(1)
75+
})
76+
}
77+
}()
78+
}
79+
wg.Wait()
80+
81+
var (
82+
total int64
83+
distribution []int64
84+
)
85+
t.Logf("distribution:")
86+
sv.All(func(v *intVal) bool {
87+
total += v.Load()
88+
distribution = append(distribution, v.Load())
89+
t.Logf("%d", v.Load())
90+
return true
91+
})
92+
93+
if got, want := total, int64(goroutines*iterations); got != want {
94+
t.Errorf("got %v, want %v", got, want)
95+
}
96+
if got, want := len(distribution), runtime.NumCPU(); got != want {
97+
t.Errorf("got %v, want %v", got, want)
98+
}
99+
100+
mean := total / int64(len(distribution))
101+
for _, v := range distribution {
102+
if v < mean/10 || v > mean*10 {
103+
t.Logf("distribution is very unbalanced: %v", distribution)
104+
}
105+
}
106+
t.Logf("mean: %d", mean)
107+
108+
var standardDev int64
109+
for _, v := range distribution {
110+
standardDev += ((v - mean) * (v - mean))
111+
}
112+
standardDev = int64(math.Sqrt(float64(standardDev / int64(len(distribution)))))
113+
t.Logf("stdev: %d", standardDev)
114+
115+
if standardDev > mean/3 {
116+
t.Logf("standard deviation is too high: %v", standardDev)
117+
}
118+
})
119+
}

0 commit comments

Comments
 (0)