From 1a6d10078d75b0581c96c5bde5dbd4668dcaf438 Mon Sep 17 00:00:00 2001 From: ekexium Date: Thu, 23 Jan 2025 14:58:12 +0800 Subject: [PATCH 1/6] memdb: prevent iterator invalidation Signed-off-by: ekexium --- examples/gcworker/go.mod | 4 +- examples/rawkv/go.mod | 4 +- examples/txnkv/1pc_txn/go.mod | 4 +- examples/txnkv/async_commit/go.mod | 4 +- examples/txnkv/delete_range/go.mod | 4 +- examples/txnkv/go.mod | 4 +- examples/txnkv/pessimistic_txn/go.mod | 4 +- examples/txnkv/unsafedestoryrange/go.mod | 4 +- internal/unionstore/arena/arena.go | 16 ++ internal/unionstore/art/art.go | 21 +++ internal/unionstore/art/art_iterator.go | 9 +- internal/unionstore/memdb_art.go | 159 ++++++++++++++++ internal/unionstore/memdb_bench_test.go | 80 +++++++- internal/unionstore/memdb_rbt.go | 35 ++++ internal/unionstore/memdb_test.go | 229 ++++++++++++++++++++++- internal/unionstore/pipelined_memdb.go | 8 + internal/unionstore/union_store.go | 31 ++- tikv/unionstore_export.go | 2 + 18 files changed, 600 insertions(+), 22 deletions(-) diff --git a/examples/gcworker/go.mod b/examples/gcworker/go.mod index 7af1d6b76e..81562a4bdd 100644 --- a/examples/gcworker/go.mod +++ b/examples/gcworker/go.mod @@ -20,7 +20,7 @@ require ( github.com/grpc-ecosystem/go-grpc-middleware v1.1.0 // indirect github.com/opentracing/opentracing-go v1.2.0 // indirect github.com/pingcap/errors v0.11.5-0.20211224045212-9687c2b0f87c // indirect - github.com/pingcap/failpoint v0.0.0-20220801062533-2eaa32854a6c // indirect + github.com/pingcap/failpoint v0.0.0-20240528011301-b51a646c7c86 // indirect github.com/pingcap/kvproto v0.0.0-20241120071417-b5b7843d9037 // indirect github.com/pingcap/log v1.1.1-0.20221110025148-ca232912c9f3 // indirect github.com/pkg/errors v0.9.1 // indirect @@ -30,7 +30,7 @@ require ( github.com/prometheus/procfs v0.15.1 // indirect github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0 // indirect github.com/tiancaiamao/gp v0.0.0-20221230034425-4025bc8a4d4a // indirect - github.com/tikv/pd/client v0.0.0-20241220053006-461b86adc78d // indirect + github.com/tikv/pd/client v0.0.0-20250107032658-5c4ab57d68de // indirect github.com/twmb/murmur3 v1.1.3 // indirect go.etcd.io/etcd/api/v3 v3.5.10 // indirect go.etcd.io/etcd/client/pkg/v3 v3.5.10 // indirect diff --git a/examples/rawkv/go.mod b/examples/rawkv/go.mod index b779427b3c..041c3cb224 100644 --- a/examples/rawkv/go.mod +++ b/examples/rawkv/go.mod @@ -20,7 +20,7 @@ require ( github.com/grpc-ecosystem/go-grpc-middleware v1.1.0 // indirect github.com/opentracing/opentracing-go v1.2.0 // indirect github.com/pingcap/errors v0.11.5-0.20211224045212-9687c2b0f87c // indirect - github.com/pingcap/failpoint v0.0.0-20220801062533-2eaa32854a6c // indirect + github.com/pingcap/failpoint v0.0.0-20240528011301-b51a646c7c86 // indirect github.com/pingcap/kvproto v0.0.0-20241120071417-b5b7843d9037 // indirect github.com/pingcap/log v1.1.1-0.20221110025148-ca232912c9f3 // indirect github.com/pkg/errors v0.9.1 // indirect @@ -30,7 +30,7 @@ require ( github.com/prometheus/procfs v0.15.1 // indirect github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0 // indirect github.com/tiancaiamao/gp v0.0.0-20221230034425-4025bc8a4d4a // indirect - github.com/tikv/pd/client v0.0.0-20241220053006-461b86adc78d // indirect + github.com/tikv/pd/client v0.0.0-20250107032658-5c4ab57d68de // indirect github.com/twmb/murmur3 v1.1.3 // indirect go.etcd.io/etcd/api/v3 v3.5.10 // indirect go.etcd.io/etcd/client/pkg/v3 v3.5.10 // indirect diff --git a/examples/txnkv/1pc_txn/go.mod b/examples/txnkv/1pc_txn/go.mod index 4926479a1b..2fe90783db 100644 --- a/examples/txnkv/1pc_txn/go.mod +++ b/examples/txnkv/1pc_txn/go.mod @@ -20,7 +20,7 @@ require ( github.com/grpc-ecosystem/go-grpc-middleware v1.1.0 // indirect github.com/opentracing/opentracing-go v1.2.0 // indirect github.com/pingcap/errors v0.11.5-0.20211224045212-9687c2b0f87c // indirect - github.com/pingcap/failpoint v0.0.0-20220801062533-2eaa32854a6c // indirect + github.com/pingcap/failpoint v0.0.0-20240528011301-b51a646c7c86 // indirect github.com/pingcap/kvproto v0.0.0-20241120071417-b5b7843d9037 // indirect github.com/pingcap/log v1.1.1-0.20221110025148-ca232912c9f3 // indirect github.com/pkg/errors v0.9.1 // indirect @@ -30,7 +30,7 @@ require ( github.com/prometheus/procfs v0.15.1 // indirect github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0 // indirect github.com/tiancaiamao/gp v0.0.0-20221230034425-4025bc8a4d4a // indirect - github.com/tikv/pd/client v0.0.0-20241220053006-461b86adc78d // indirect + github.com/tikv/pd/client v0.0.0-20250107032658-5c4ab57d68de // indirect github.com/twmb/murmur3 v1.1.3 // indirect go.etcd.io/etcd/api/v3 v3.5.10 // indirect go.etcd.io/etcd/client/pkg/v3 v3.5.10 // indirect diff --git a/examples/txnkv/async_commit/go.mod b/examples/txnkv/async_commit/go.mod index 066120e1ec..236e443831 100644 --- a/examples/txnkv/async_commit/go.mod +++ b/examples/txnkv/async_commit/go.mod @@ -20,7 +20,7 @@ require ( github.com/grpc-ecosystem/go-grpc-middleware v1.1.0 // indirect github.com/opentracing/opentracing-go v1.2.0 // indirect github.com/pingcap/errors v0.11.5-0.20211224045212-9687c2b0f87c // indirect - github.com/pingcap/failpoint v0.0.0-20220801062533-2eaa32854a6c // indirect + github.com/pingcap/failpoint v0.0.0-20240528011301-b51a646c7c86 // indirect github.com/pingcap/kvproto v0.0.0-20241120071417-b5b7843d9037 // indirect github.com/pingcap/log v1.1.1-0.20221110025148-ca232912c9f3 // indirect github.com/pkg/errors v0.9.1 // indirect @@ -30,7 +30,7 @@ require ( github.com/prometheus/procfs v0.15.1 // indirect github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0 // indirect github.com/tiancaiamao/gp v0.0.0-20221230034425-4025bc8a4d4a // indirect - github.com/tikv/pd/client v0.0.0-20241220053006-461b86adc78d // indirect + github.com/tikv/pd/client v0.0.0-20250107032658-5c4ab57d68de // indirect github.com/twmb/murmur3 v1.1.3 // indirect go.etcd.io/etcd/api/v3 v3.5.10 // indirect go.etcd.io/etcd/client/pkg/v3 v3.5.10 // indirect diff --git a/examples/txnkv/delete_range/go.mod b/examples/txnkv/delete_range/go.mod index 2f9d244b9b..f599c50606 100644 --- a/examples/txnkv/delete_range/go.mod +++ b/examples/txnkv/delete_range/go.mod @@ -20,7 +20,7 @@ require ( github.com/grpc-ecosystem/go-grpc-middleware v1.1.0 // indirect github.com/opentracing/opentracing-go v1.2.0 // indirect github.com/pingcap/errors v0.11.5-0.20211224045212-9687c2b0f87c // indirect - github.com/pingcap/failpoint v0.0.0-20220801062533-2eaa32854a6c // indirect + github.com/pingcap/failpoint v0.0.0-20240528011301-b51a646c7c86 // indirect github.com/pingcap/kvproto v0.0.0-20241120071417-b5b7843d9037 // indirect github.com/pingcap/log v1.1.1-0.20221110025148-ca232912c9f3 // indirect github.com/pkg/errors v0.9.1 // indirect @@ -30,7 +30,7 @@ require ( github.com/prometheus/procfs v0.15.1 // indirect github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0 // indirect github.com/tiancaiamao/gp v0.0.0-20221230034425-4025bc8a4d4a // indirect - github.com/tikv/pd/client v0.0.0-20241220053006-461b86adc78d // indirect + github.com/tikv/pd/client v0.0.0-20250107032658-5c4ab57d68de // indirect github.com/twmb/murmur3 v1.1.3 // indirect go.etcd.io/etcd/api/v3 v3.5.10 // indirect go.etcd.io/etcd/client/pkg/v3 v3.5.10 // indirect diff --git a/examples/txnkv/go.mod b/examples/txnkv/go.mod index 5a23a1978e..06bf6c7f53 100644 --- a/examples/txnkv/go.mod +++ b/examples/txnkv/go.mod @@ -20,7 +20,7 @@ require ( github.com/grpc-ecosystem/go-grpc-middleware v1.1.0 // indirect github.com/opentracing/opentracing-go v1.2.0 // indirect github.com/pingcap/errors v0.11.5-0.20211224045212-9687c2b0f87c // indirect - github.com/pingcap/failpoint v0.0.0-20220801062533-2eaa32854a6c // indirect + github.com/pingcap/failpoint v0.0.0-20240528011301-b51a646c7c86 // indirect github.com/pingcap/kvproto v0.0.0-20241120071417-b5b7843d9037 // indirect github.com/pingcap/log v1.1.1-0.20221110025148-ca232912c9f3 // indirect github.com/pkg/errors v0.9.1 // indirect @@ -30,7 +30,7 @@ require ( github.com/prometheus/procfs v0.15.1 // indirect github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0 // indirect github.com/tiancaiamao/gp v0.0.0-20221230034425-4025bc8a4d4a // indirect - github.com/tikv/pd/client v0.0.0-20241220053006-461b86adc78d // indirect + github.com/tikv/pd/client v0.0.0-20250107032658-5c4ab57d68de // indirect github.com/twmb/murmur3 v1.1.3 // indirect go.etcd.io/etcd/api/v3 v3.5.10 // indirect go.etcd.io/etcd/client/pkg/v3 v3.5.10 // indirect diff --git a/examples/txnkv/pessimistic_txn/go.mod b/examples/txnkv/pessimistic_txn/go.mod index 016964ad93..da6c997c67 100644 --- a/examples/txnkv/pessimistic_txn/go.mod +++ b/examples/txnkv/pessimistic_txn/go.mod @@ -20,7 +20,7 @@ require ( github.com/grpc-ecosystem/go-grpc-middleware v1.1.0 // indirect github.com/opentracing/opentracing-go v1.2.0 // indirect github.com/pingcap/errors v0.11.5-0.20211224045212-9687c2b0f87c // indirect - github.com/pingcap/failpoint v0.0.0-20220801062533-2eaa32854a6c // indirect + github.com/pingcap/failpoint v0.0.0-20240528011301-b51a646c7c86 // indirect github.com/pingcap/kvproto v0.0.0-20241120071417-b5b7843d9037 // indirect github.com/pingcap/log v1.1.1-0.20221110025148-ca232912c9f3 // indirect github.com/pkg/errors v0.9.1 // indirect @@ -30,7 +30,7 @@ require ( github.com/prometheus/procfs v0.15.1 // indirect github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0 // indirect github.com/tiancaiamao/gp v0.0.0-20221230034425-4025bc8a4d4a // indirect - github.com/tikv/pd/client v0.0.0-20241220053006-461b86adc78d // indirect + github.com/tikv/pd/client v0.0.0-20250107032658-5c4ab57d68de // indirect github.com/twmb/murmur3 v1.1.3 // indirect go.etcd.io/etcd/api/v3 v3.5.10 // indirect go.etcd.io/etcd/client/pkg/v3 v3.5.10 // indirect diff --git a/examples/txnkv/unsafedestoryrange/go.mod b/examples/txnkv/unsafedestoryrange/go.mod index 7f5c8d11ea..91a04ec8cc 100644 --- a/examples/txnkv/unsafedestoryrange/go.mod +++ b/examples/txnkv/unsafedestoryrange/go.mod @@ -20,7 +20,7 @@ require ( github.com/grpc-ecosystem/go-grpc-middleware v1.1.0 // indirect github.com/opentracing/opentracing-go v1.2.0 // indirect github.com/pingcap/errors v0.11.5-0.20211224045212-9687c2b0f87c // indirect - github.com/pingcap/failpoint v0.0.0-20220801062533-2eaa32854a6c // indirect + github.com/pingcap/failpoint v0.0.0-20240528011301-b51a646c7c86 // indirect github.com/pingcap/kvproto v0.0.0-20241120071417-b5b7843d9037 // indirect github.com/pingcap/log v1.1.1-0.20221110025148-ca232912c9f3 // indirect github.com/pkg/errors v0.9.1 // indirect @@ -30,7 +30,7 @@ require ( github.com/prometheus/procfs v0.15.1 // indirect github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0 // indirect github.com/tiancaiamao/gp v0.0.0-20221230034425-4025bc8a4d4a // indirect - github.com/tikv/pd/client v0.0.0-20241220053006-461b86adc78d // indirect + github.com/tikv/pd/client v0.0.0-20250107032658-5c4ab57d68de // indirect github.com/twmb/murmur3 v1.1.3 // indirect go.etcd.io/etcd/api/v3 v3.5.10 // indirect go.etcd.io/etcd/client/pkg/v3 v3.5.10 // indirect diff --git a/internal/unionstore/arena/arena.go b/internal/unionstore/arena/arena.go index ef5c081921..9671ccb492 100644 --- a/internal/unionstore/arena/arena.go +++ b/internal/unionstore/arena/arena.go @@ -38,6 +38,9 @@ import ( "encoding/binary" "math" + "github.com/tikv/client-go/v2/internal/logutil" + "go.uber.org/zap" + "github.com/tikv/client-go/v2/kv" "go.uber.org/atomic" ) @@ -223,6 +226,19 @@ func (cp *MemDBCheckpoint) IsSamePosition(other *MemDBCheckpoint) bool { return cp.blocks == other.blocks && cp.offsetInBlock == other.offsetInBlock } +func (cp *MemDBCheckpoint) LessThan(cp2 *MemDBCheckpoint) bool { + if cp == nil || cp2 == nil { + logutil.BgLogger().Panic("unexpected nil checkpoint", zap.Any("cp", cp), zap.Any("cp2", cp2)) + } + if cp.blocks < cp2.blocks { + return true + } + if cp.blocks == cp2.blocks && cp.offsetInBlock < cp2.offsetInBlock { + return true + } + return false +} + func (a *MemdbArena) Checkpoint() MemDBCheckpoint { snap := MemDBCheckpoint{ blockSize: a.blockSize, diff --git a/internal/unionstore/art/art.go b/internal/unionstore/art/art.go index 36acb907a5..ee98bf1e64 100644 --- a/internal/unionstore/art/art.go +++ b/internal/unionstore/art/art.go @@ -52,6 +52,13 @@ type ART struct { lastTraversedNode atomic.Uint64 hitCount atomic.Uint64 missCount atomic.Uint64 + + // The counter of every write operation, used to invalidate iterators that were created before the write operation. + SeqNo int + // increased by 1 when an operation that may affect the content returned by "snapshot iter" (i.e. stage[0]) happens. + // It's used to invalidate snapshot iterators. + // invariant: no concurrent access to it + SnapshotSeqNo int } func New() *ART { @@ -115,6 +122,7 @@ func (t *ART) Set(key artKey, value []byte, ops ...kv.FlagsOp) error { } } + t.SeqNo++ if len(t.stages) == 0 { t.dirty = true } @@ -479,6 +487,10 @@ func (t *ART) RevertToCheckpoint(cp *arena.MemDBCheckpoint) { t.allocator.vlogAllocator.RevertToCheckpoint(t, cp) t.allocator.vlogAllocator.Truncate(cp) t.allocator.vlogAllocator.OnMemChange() + t.SeqNo++ + if len(t.stages) == 0 || t.stages[0].LessThan(cp) { + t.SnapshotSeqNo++ + } } func (t *ART) Stages() []arena.MemDBCheckpoint { @@ -498,7 +510,9 @@ func (t *ART) Release(h int) { if h != len(t.stages) { panic("cannot release staging buffer") } + t.SeqNo++ if h == 1 { + t.SnapshotSeqNo++ tail := t.checkpoint() if !t.stages[0].IsSamePosition(&tail) { t.dirty = true @@ -519,6 +533,11 @@ func (t *ART) Cleanup(h int) { panic(fmt.Sprintf("cannot cleanup staging buffer, h=%v, len(tree.stages)=%v", h, len(t.stages))) } + t.SeqNo++ + if h == 1 { + t.SnapshotSeqNo++ + } + cp := &t.stages[h-1] if !t.vlogInvalid { curr := t.checkpoint() @@ -542,6 +561,8 @@ func (t *ART) Reset() { t.allocator.nodeAllocator.Reset() t.allocator.vlogAllocator.Reset() t.lastTraversedNode.Store(arena.NullU64Addr) + t.SnapshotSeqNo++ + t.SeqNo++ } // DiscardValues releases the memory used by all values. diff --git a/internal/unionstore/art/art_iterator.go b/internal/unionstore/art/art_iterator.go index 2bf4fdba64..ba7ff0b855 100644 --- a/internal/unionstore/art/art_iterator.go +++ b/internal/unionstore/art/art_iterator.go @@ -56,6 +56,7 @@ func (t *ART) iter(lowerBound, upperBound []byte, reverse, includeFlags bool) (* // this avoids the initial value of currAddr equals to endAddr. currAddr: arena.BadAddr, endAddr: arena.NullAddr, + seqNo: t.SeqNo, } it.init(lowerBound, upperBound) if !it.valid { @@ -76,9 +77,12 @@ type Iterator struct { currLeaf *artLeaf currAddr arena.MemdbArenaAddr endAddr arena.MemdbArenaAddr + + // only when seqNo == art.seqNo, the iterator is valid. + seqNo int } -func (it *Iterator) Valid() bool { return it.valid } +func (it *Iterator) Valid() bool { return it.valid && it.seqNo == it.tree.SeqNo } func (it *Iterator) Key() []byte { return it.currLeaf.GetKey() } func (it *Iterator) Flags() kv.KeyFlags { return it.currLeaf.GetKeyFlags() } func (it *Iterator) Value() []byte { @@ -102,6 +106,9 @@ func (it *Iterator) Next() error { // iterate is finished return errors.New("Art: iterator is finished") } + if it.seqNo != it.tree.SeqNo { + return errors.New(fmt.Sprintf("seqNo mismatch: iter=%d, art=%d", it.seqNo, it.tree.SeqNo)) + } if it.currAddr == it.endAddr { it.valid = false return nil diff --git a/internal/unionstore/memdb_art.go b/internal/unionstore/memdb_art.go index c7c1b21d98..ebd63edce8 100644 --- a/internal/unionstore/memdb_art.go +++ b/internal/unionstore/memdb_art.go @@ -16,8 +16,11 @@ package unionstore import ( "context" + "fmt" "sync" + "github.com/pingcap/errors" + tikverr "github.com/tikv/client-go/v2/error" "github.com/tikv/client-go/v2/internal/unionstore/arena" "github.com/tikv/client-go/v2/internal/unionstore/art" @@ -151,6 +154,32 @@ func (db *artDBWithContext) IterReverse(upper, lower []byte) (Iterator, error) { return db.ART.IterReverse(upper, lower) } +func (db *artDBWithContext) ForEachInSnapshotRange(lower []byte, upper []byte, f func(k, v []byte) (stop bool, err error), reverse bool) error { + db.RLock() + defer db.RUnlock() + var iter Iterator + if reverse { + iter = db.SnapshotIterReverse(upper, lower) + } else { + iter = db.SnapshotIter(lower, upper) + } + defer iter.Close() + for iter.Valid() { + stop, err := f(iter.Key(), iter.Value()) + if err != nil { + return err + } + err = iter.Next() + if err != nil { + return err + } + if stop { + break + } + } + return nil +} + // SnapshotIter returns an Iterator for a snapshot of MemBuffer. func (db *artDBWithContext) SnapshotIter(lower, upper []byte) Iterator { return db.ART.SnapshotIter(lower, upper) @@ -165,3 +194,133 @@ func (db *artDBWithContext) SnapshotIterReverse(upper, lower []byte) Iterator { func (db *artDBWithContext) SnapshotGetter() Getter { return db.ART.SnapshotGetter() } + +type snapshotBatchedIter struct { + db *artDBWithContext + snapshotTruncateSeqNo int + lower []byte + upper []byte + reverse bool + + // current batch + kvs []KvPair + pos int + batchSize int + nextKey []byte +} + +func (db *artDBWithContext) BatchedSnapshotIter(lower, upper []byte, reverse bool) Iterator { + iter := &snapshotBatchedIter{ + db: db, + snapshotTruncateSeqNo: db.SnapshotSeqNo, + lower: lower, + upper: upper, + reverse: reverse, + batchSize: 4, + } + + // Position at first key immediately + iter.fillBatch() + return iter +} + +func (it *snapshotBatchedIter) fillBatch() error { + if it.snapshotTruncateSeqNo != it.db.SnapshotSeqNo { + return errors.New(fmt.Sprintf("invalid iter: truncation happened, iter's=%d, db's=%d", + it.snapshotTruncateSeqNo, it.db.SnapshotSeqNo)) + } + + it.db.RLock() + defer it.db.RUnlock() + + if it.kvs == nil { + it.kvs = make([]KvPair, 0, it.batchSize) + } else { + it.kvs = it.kvs[:0] + } + + var snapshotIter Iterator + if it.reverse { + searchUpper := it.upper + if it.nextKey != nil { + searchUpper = it.nextKey + } + snapshotIter = it.db.SnapshotIterReverse(searchUpper, it.lower) + } else { + searchLower := it.lower + if it.nextKey != nil { + searchLower = it.nextKey + } + snapshotIter = it.db.SnapshotIter(searchLower, it.upper) + } + defer snapshotIter.Close() + + // fill current batch + for i := 0; i < it.batchSize && snapshotIter.Valid(); i++ { + it.kvs = append(it.kvs, KvPair{ + Key: snapshotIter.Key(), + Value: snapshotIter.Value(), + }) + if err := snapshotIter.Next(); err != nil { + return err + } + } + + // update state + it.pos = 0 + if len(it.kvs) > 0 { + lastKV := it.kvs[len(it.kvs)-1] + if it.reverse { + it.nextKey = append([]byte(nil), lastKV.Key...) + } else { + it.nextKey = append(append([]byte(nil), lastKV.Key...), 0) + } + } else { + it.nextKey = nil + } + + it.batchSize = min(it.batchSize*2, 4096) + return nil +} + +func (it *snapshotBatchedIter) Valid() bool { + return it.snapshotTruncateSeqNo == it.db.SnapshotSeqNo && + it.pos < len(it.kvs) +} + +func (it *snapshotBatchedIter) Next() error { + if it.snapshotTruncateSeqNo != it.db.SnapshotSeqNo { + return errors.New( + fmt.Sprintf( + "invalid snapshotBatchedIter: truncation happened, iter's=%d, db's=%d", + it.snapshotTruncateSeqNo, + it.db.SnapshotSeqNo, + ), + ) + } + + it.pos++ + if it.pos >= len(it.kvs) { + return it.fillBatch() + } + return nil +} + +func (it *snapshotBatchedIter) Key() []byte { + if !it.Valid() { + return nil + } + return it.kvs[it.pos].Key +} + +func (it *snapshotBatchedIter) Value() []byte { + if !it.Valid() { + return nil + } + return it.kvs[it.pos].Value +} + +func (it *snapshotBatchedIter) Close() { + it.kvs = nil + it.nextKey = nil +} diff --git a/internal/unionstore/memdb_bench_test.go b/internal/unionstore/memdb_bench_test.go index 8a2c3e5d4b..07fa6d3ce9 100644 --- a/internal/unionstore/memdb_bench_test.go +++ b/internal/unionstore/memdb_bench_test.go @@ -172,13 +172,36 @@ func BenchmarkMemDbBufferRandom(b *testing.B) { } func BenchmarkMemDbIter(b *testing.B) { - fn := func(b *testing.B, buffer MemBuffer) { + fnIter := func(b *testing.B, buffer MemBuffer) { benchIterator(b, buffer) b.ReportAllocs() } - b.Run("RBT", func(b *testing.B) { fn(b, newRbtDBWithContext()) }) - b.Run("ART", func(b *testing.B) { fn(b, newArtDBWithContext()) }) + b.Run("RBT", func(b *testing.B) { fnIter(b, newRbtDBWithContext()) }) + b.Run("ART", func(b *testing.B) { fnIter(b, newArtDBWithContext()) }) +} + +func BenchmarkSnapshotIter(b *testing.B) { + f := func(b *testing.B, buffer MemBuffer) { + benchSnapshotIter(b, buffer) + b.ReportAllocs() + } + + fBatched := func(b *testing.B, buffer MemBuffer) { + benchBatchedSnapshotIter(b, buffer) + b.ReportAllocs() + } + + fForEach := func(b *testing.B, buffer MemBuffer) { + benchForEachInSnapshot(b, buffer) + b.ReportAllocs() + } + + b.Run("RBT-SnapshotIter", func(b *testing.B) { f(b, newRbtDBWithContext()) }) + // unimplemented for RBT + b.Run("ART-SnapshotIter", func(b *testing.B) { f(b, newArtDBWithContext()) }) + b.Run("ART-BatchedSnapshotIter", func(b *testing.B) { fBatched(b, newArtDBWithContext()) }) + b.Run("ART-ForEachInSnapshot", func(b *testing.B) { fForEach(b, newArtDBWithContext()) }) } func BenchmarkMemDbCreation(b *testing.B) { @@ -224,6 +247,40 @@ func benchIterator(b *testing.B, buffer MemBuffer) { if err != nil { b.Error(err) } + for iter.Valid() { + _ = iter.Key() + _ = iter.Value() + iter.Next() + } + iter.Close() + } +} + +func benchSnapshotIter(b *testing.B, buffer MemBuffer) { + for k := 0; k < opCnt; k++ { + buffer.Set(encodeInt(k), encodeInt(k)) + } + buffer.Staging() + b.ResetTimer() + for i := 0; i < b.N; i++ { + iter := buffer.SnapshotIter(nil, nil) + for iter.Valid() { + _ = iter.Value() + _ = iter.Key() + iter.Next() + } + iter.Close() + } +} + +func benchBatchedSnapshotIter(b *testing.B, buffer MemBuffer) { + for k := 0; k < opCnt; k++ { + buffer.Set(encodeInt(k), encodeInt(k)) + } + buffer.Staging() + b.ResetTimer() + for i := 0; i < b.N; i++ { + iter := buffer.BatchedSnapshotIter(nil, nil, false) for iter.Valid() { iter.Next() } @@ -231,6 +288,23 @@ func benchIterator(b *testing.B, buffer MemBuffer) { } } +func benchForEachInSnapshot(b *testing.B, buffer MemBuffer) { + for k := 0; k < opCnt; k++ { + buffer.Set(encodeInt(k), encodeInt(k)) + } + buffer.Staging() + b.ResetTimer() + f := func(key, value []byte) (bool, error) { + return false, nil + } + for i := 0; i < b.N; i++ { + err := buffer.ForEachInSnapshotRange(nil, nil, f, false) + if err != nil { + b.Error(err) + } + } +} + func BenchmarkMemBufferCache(b *testing.B) { fn := func(b *testing.B, buffer MemBuffer) { buf := make([][keySize]byte, b.N) diff --git a/internal/unionstore/memdb_rbt.go b/internal/unionstore/memdb_rbt.go index c805f49935..f45f941c46 100644 --- a/internal/unionstore/memdb_rbt.go +++ b/internal/unionstore/memdb_rbt.go @@ -161,6 +161,32 @@ func (db *rbtDBWithContext) IterReverse(upper, lower []byte) (Iterator, error) { return db.RBT.IterReverse(upper, lower) } +func (db *rbtDBWithContext) ForEachInSnapshotRange(lower []byte, upper []byte, f func(k, v []byte) (stop bool, err error), reverse bool) error { + db.RLock() + defer db.RUnlock() + var iter Iterator + if reverse { + iter = db.SnapshotIterReverse(upper, lower) + } else { + iter = db.SnapshotIter(lower, upper) + } + defer iter.Close() + for iter.Valid() { + stop, err := f(iter.Key(), iter.Value()) + if err != nil { + return err + } + err = iter.Next() + if err != nil { + return err + } + if stop { + break + } + } + return nil +} + // SnapshotIter returns an Iterator for a snapshot of MemBuffer. func (db *rbtDBWithContext) SnapshotIter(lower, upper []byte) Iterator { return db.RBT.SnapshotIter(lower, upper) @@ -175,3 +201,12 @@ func (db *rbtDBWithContext) SnapshotIterReverse(upper, lower []byte) Iterator { func (db *rbtDBWithContext) SnapshotGetter() Getter { return db.RBT.SnapshotGetter() } + +func (db *rbtDBWithContext) BatchedSnapshotIter(lower, upper []byte, reverse bool) Iterator { + // TODO: implement this + if reverse { + return db.SnapshotIterReverse(upper, lower) + } else { + return db.SnapshotIter(lower, upper) + } +} diff --git a/internal/unionstore/memdb_test.go b/internal/unionstore/memdb_test.go index 4721f837e1..2481fb3373 100644 --- a/internal/unionstore/memdb_test.go +++ b/internal/unionstore/memdb_test.go @@ -1336,7 +1336,7 @@ func TestSnapshotReaderWithWrite(t *testing.T) { h := db.Staging() defer db.Release(h) - iter := db.SnapshotIter([]byte{0, 0}, []byte{0, 255}) + iter := db.BatchedSnapshotIter([]byte{0, 0}, []byte{0, 255}, false) assert.Equal(t, iter.Key(), []byte{0, 0}) db.Set([]byte{0, byte(num)}, []byte{0, byte(num)}) // ART: node4/node16/node48 is freed and wait to be reused. @@ -1364,3 +1364,230 @@ func TestSnapshotReaderWithWrite(t *testing.T) { check(newRbtDBWithContext(), 48) check(newArtDBWithContext(), 48) } + +func TestBatchedSnapshotIter(t *testing.T) { + check := func(db *artDBWithContext, num int) { + // Insert test data + for i := 0; i < num; i++ { + db.Set([]byte{0, byte(i)}, []byte{0, byte(i)}) + } + h := db.Staging() + defer db.Release(h) + + // Create iterator - should be positioned at first key + iter := db.BatchedSnapshotIter([]byte{0, 0}, []byte{0, 255}, false) + defer iter.Close() + + // Should be able to read first key immediately + require.True(t, iter.Valid()) + require.Equal(t, []byte{0, 0}, iter.Key()) + + // Write additional data + db.Set([]byte{0, byte(num)}, []byte{0, byte(num)}) + for i := 0; i < num; i++ { + db.Set([]byte{1, byte(i)}, []byte{1, byte(i)}) + } + + // Verify iteration + i := 0 + for ; i < num; i++ { + require.True(t, iter.Valid()) + require.Equal(t, []byte{0, byte(i)}, iter.Key()) + require.Equal(t, []byte{0, byte(i)}, iter.Value()) + require.NoError(t, iter.Next()) + } + require.False(t, iter.Valid()) + } + + checkReverse := func(db *artDBWithContext, num int) { + for i := 0; i < num; i++ { + db.Set([]byte{0, byte(i)}, []byte{0, byte(i)}) + } + h := db.Staging() + defer db.Release(h) + + iter := db.BatchedSnapshotIter([]byte{0, 0}, []byte{0, 255}, true) + defer iter.Close() + + // Should be positioned at last key + require.True(t, iter.Valid()) + require.Equal(t, []byte{0, byte(num - 1)}, iter.Key()) + + db.Set([]byte{0, byte(num)}, []byte{0, byte(num)}) + for i := 0; i < num; i++ { + db.Set([]byte{1, byte(i)}, []byte{1, byte(i)}) + } + + i := num - 1 + for ; i >= 0; i-- { + require.True(t, iter.Valid()) + require.Equal(t, []byte{0, byte(i)}, iter.Key()) + require.Equal(t, []byte{0, byte(i)}, iter.Value()) + require.NoError(t, iter.Next()) + } + require.False(t, iter.Valid()) + } + + // Run size test cases + check(newArtDBWithContext(), 3) + check(newArtDBWithContext(), 17) + check(newArtDBWithContext(), 64) + + checkReverse(newArtDBWithContext(), 3) + checkReverse(newArtDBWithContext(), 17) + checkReverse(newArtDBWithContext(), 64) +} + +func TestBatchedSnapshotIterEdgeCase(t *testing.T) { + t.Run("EdgeCases", func(t *testing.T) { + db := newArtDBWithContext() + + // invalid range - should be invalid immediately + iter := db.BatchedSnapshotIter([]byte{1}, []byte{1}, false) + require.False(t, iter.Valid()) + iter.Close() + + // empty range - should be invalid immediately + iter = db.BatchedSnapshotIter([]byte{0}, []byte{1}, false) + require.False(t, iter.Valid()) + iter.Close() + + // Single element range + db.Set([]byte{1}, []byte{1}) + iter = db.BatchedSnapshotIter([]byte{1}, []byte{2}, false) + require.True(t, iter.Valid()) + require.Equal(t, []byte{1}, iter.Key()) + require.NoError(t, iter.Next()) + require.False(t, iter.Valid()) + iter.Close() + + // Multiple elements + db.Set([]byte{2}, []byte{2}) + db.Set([]byte{3}, []byte{3}) + db.Set([]byte{4}, []byte{4}) + + // Forward iteration [2,4) + iter = db.BatchedSnapshotIter([]byte{2}, []byte{4}, false) + vals := []byte{} + for iter.Valid() { + vals = append(vals, iter.Key()[0]) + require.NoError(t, iter.Next()) + } + require.Equal(t, []byte{2, 3}, vals) + iter.Close() + + // Reverse iteration [2,4) + iter = db.BatchedSnapshotIter([]byte{2}, []byte{4}, true) + vals = []byte{} + for iter.Valid() { + vals = append(vals, iter.Key()[0]) + require.NoError(t, iter.Next()) + } + require.Equal(t, []byte{3, 2}, vals) + iter.Close() + }) + + t.Run("BoundaryTests", func(t *testing.T) { + db := newArtDBWithContext() + keys := [][]byte{ + {1, 0}, {1, 2}, {1, 4}, {1, 6}, {1, 8}, + } + for _, k := range keys { + db.Set(k, k) + } + + // lower bound included + iter := db.BatchedSnapshotIter([]byte{1, 2}, []byte{1, 9}, false) + vals := []byte{} + for iter.Valid() { + vals = append(vals, iter.Key()[1]) + require.NoError(t, iter.Next()) + } + require.Equal(t, []byte{2, 4, 6, 8}, vals) + iter.Close() + + // upper bound excluded + iter = db.BatchedSnapshotIter([]byte{1, 0}, []byte{1, 6}, false) + vals = []byte{} + for iter.Valid() { + vals = append(vals, iter.Key()[1]) + require.NoError(t, iter.Next()) + } + require.Equal(t, []byte{0, 2, 4}, vals) + iter.Close() + + // reverse + iter = db.BatchedSnapshotIter([]byte{1, 0}, []byte{1, 6}, true) + vals = []byte{} + for iter.Valid() { + vals = append(vals, iter.Key()[1]) + require.NoError(t, iter.Next()) + } + require.Equal(t, []byte{4, 2, 0}, vals) + iter.Close() + }) + + t.Run("AlphabeticalOrder", func(t *testing.T) { + db := newArtDBWithContext() + keys := [][]byte{ + {2}, + {2, 1}, + {2, 1, 1}, + {2, 1, 1, 1}, + } + for _, k := range keys { + db.Set(k, k) + } + + // forward + iter := db.BatchedSnapshotIter([]byte{2}, []byte{3}, false) + count := 0 + for iter.Valid() { + require.Equal(t, keys[count], iter.Key()) + require.NoError(t, iter.Next()) + count++ + } + require.Equal(t, len(keys), count) + iter.Close() + + // reverse + iter = db.BatchedSnapshotIter([]byte{2}, []byte{3}, true) + count = len(keys) - 1 + for iter.Valid() { + require.Equal(t, keys[count], iter.Key()) + require.NoError(t, iter.Next()) + count-- + } + require.Equal(t, -1, count) + iter.Close() + }) + + t.Run("BatchSizeGrowth", func(t *testing.T) { + db := newArtDBWithContext() + for i := 0; i < 100; i++ { + db.Set([]byte{3, byte(i)}, []byte{3, byte(i)}) + } + + // forward + iter := db.BatchedSnapshotIter([]byte{3, 0}, []byte{3, 255}, false) + count := 0 + for iter.Valid() { + require.Equal(t, []byte{3, byte(count)}, iter.Key()) + require.NoError(t, iter.Next()) + count++ + } + require.Equal(t, 100, count) + iter.Close() + + // reverse + iter = db.BatchedSnapshotIter([]byte{3, 0}, []byte{3, 255}, true) + count = 99 + for iter.Valid() { + require.Equal(t, []byte{3, byte(count)}, iter.Key()) + require.NoError(t, iter.Next()) + count-- + } + require.Equal(t, -1, count) + iter.Close() + }) +} diff --git a/internal/unionstore/pipelined_memdb.go b/internal/unionstore/pipelined_memdb.go index 163a289f4c..888f2ecffb 100644 --- a/internal/unionstore/pipelined_memdb.go +++ b/internal/unionstore/pipelined_memdb.go @@ -412,6 +412,10 @@ func (p *PipelinedMemDB) IterReverse([]byte, []byte) (Iterator, error) { return nil, errors.New("pipelined memdb does not support IterReverse") } +func (db *PipelinedMemDB) ForEachInSnapshotRange(lower []byte, upper []byte, f func(k, v []byte) (bool, error), reverse bool) error { + return errors.New("pipelined memdb does not support ForEachInSnapshotRange") +} + // SetEntrySizeLimit sets the size limit for each entry and total buffer. func (p *PipelinedMemDB) SetEntrySizeLimit(entryLimit, _ uint64) { p.entryLimit = entryLimit @@ -550,3 +554,7 @@ func (p *PipelinedMemDB) GetMetrics() Metrics { func (p *PipelinedMemDB) MemHookSet() bool { return p.memChangeHook != nil } + +func (p *PipelinedMemDB) BatchedSnapshotIter(lower, upper []byte, reverse bool) Iterator { + panic("BatchedSnapshotIter is not supported for PipelinedMemDB") +} diff --git a/internal/unionstore/union_store.go b/internal/unionstore/union_store.go index 1a5f1a36b9..b6beaad010 100644 --- a/internal/unionstore/union_store.go +++ b/internal/unionstore/union_store.go @@ -162,6 +162,11 @@ func (us *KVUnionStore) SetEntrySizeLimit(entryLimit, bufferLimit uint64) { us.memBuffer.SetEntrySizeLimit(entryLimit, bufferLimit) } +type KvPair struct { + Key []byte + Value []byte +} + // MemBuffer is an interface that stores mutations that written during transaction execution. // It now unifies MemDB and PipelinedMemDB. // The implementations should follow the transaction guarantees: @@ -193,15 +198,39 @@ type MemBuffer interface { Delete([]byte) error // DeleteWithFlags deletes the key k in the MemBuffer with flags. DeleteWithFlags([]byte, ...kv.FlagsOp) error + // Iter implements the Retriever interface. Iter([]byte, []byte) (Iterator, error) // IterReverse implements the Retriever interface. IterReverse([]byte, []byte) (Iterator, error) // SnapshotIter returns an Iterator for a snapshot of MemBuffer. + // Deprecated: use ForEachInSnapshotRange or BatchedSnapshotIter instead. SnapshotIter([]byte, []byte) Iterator // SnapshotIterReverse returns a reversed Iterator for a snapshot of MemBuffer. + // Deprecated: use ForEachInSnapshotRange or BatchedSnapshotIter instead. SnapshotIterReverse([]byte, []byte) Iterator - // SnapshotGetter returns a Getter for a snapshot of MemBuffer. + + // ForEachInSnapshotRange scans the key-value pairs in the state[0] snapshot if it exists, + // otherwise it uses the current checkpoint as snapshot. + // + // NOTE: returned kv-pairs are only valid during the iteration. If you want to use them after the iteration, + // you need to make a copy. + // + // The method is protected by a RWLock to prevent potential iterator invalidation, i.e. + // You cannot modify the MemBuffer during the iteration. + // + // Use it when you need to scan the whole range, otherwise consider using BatchedSnapshotIter. + ForEachInSnapshotRange(lower []byte, upper []byte, f func(k, v []byte) (stop bool, err error), reverse bool) error + + // BatchedSnapshotIter iterates in batches to prevent iterator invalidation: + // It does not save any iterator state, instead it copies the keys and values to a buffer. + // It behaves like SnapshotIter, but it is safe to use the returned keys and values after the iteration. + // Use it when you need on-demand "next", otherwise consider using ForEachInSnapshotRange. + // + // The iterator becomes invalid after a membuffer vlog truncation operation. + BatchedSnapshotIter(lower, upper []byte, reverse bool) Iterator + + //SnapshotGetter returns a Getter for a snapshot of MemBuffer. SnapshotGetter() Getter // InspectStage iterates all buffered keys and values in MemBuffer. InspectStage(handle int, f func([]byte, kv.KeyFlags, []byte)) diff --git a/tikv/unionstore_export.go b/tikv/unionstore_export.go index 80ee88f42a..efbaf93f2c 100644 --- a/tikv/unionstore_export.go +++ b/tikv/unionstore_export.go @@ -60,3 +60,5 @@ type MemDBCheckpoint = unionstore.MemDBCheckpoint // Metrics is the metrics of unionstore. type Metrics = unionstore.Metrics + +type KvPair = unionstore.KvPair From c46d4903553c667c147888c30bcf809366f56f99 Mon Sep 17 00:00:00 2001 From: ekexium Date: Thu, 23 Jan 2025 16:21:15 +0800 Subject: [PATCH 2/6] fix for snapshot iter Signed-off-by: ekexium --- internal/unionstore/art/art_iterator.go | 6 ++++-- internal/unionstore/art/art_snapshot.go | 1 + internal/unionstore/memdb_test.go | 2 +- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/internal/unionstore/art/art_iterator.go b/internal/unionstore/art/art_iterator.go index ba7ff0b855..05a23e62a4 100644 --- a/internal/unionstore/art/art_iterator.go +++ b/internal/unionstore/art/art_iterator.go @@ -80,9 +80,11 @@ type Iterator struct { // only when seqNo == art.seqNo, the iterator is valid. seqNo int + // ignoreSeqNo is used to ignore the seqNo check, used for snapshot iter before its full deprecation. + ignoreSeqNo bool } -func (it *Iterator) Valid() bool { return it.valid && it.seqNo == it.tree.SeqNo } +func (it *Iterator) Valid() bool { return it.valid && (it.seqNo == it.tree.SeqNo || it.ignoreSeqNo) } func (it *Iterator) Key() []byte { return it.currLeaf.GetKey() } func (it *Iterator) Flags() kv.KeyFlags { return it.currLeaf.GetKeyFlags() } func (it *Iterator) Value() []byte { @@ -106,7 +108,7 @@ func (it *Iterator) Next() error { // iterate is finished return errors.New("Art: iterator is finished") } - if it.seqNo != it.tree.SeqNo { + if !it.ignoreSeqNo && it.seqNo != it.tree.SeqNo { return errors.New(fmt.Sprintf("seqNo mismatch: iter=%d, art=%d", it.seqNo, it.tree.SeqNo)) } if it.currAddr == it.endAddr { diff --git a/internal/unionstore/art/art_snapshot.go b/internal/unionstore/art/art_snapshot.go index 454634b234..6b240367ac 100644 --- a/internal/unionstore/art/art_snapshot.go +++ b/internal/unionstore/art/art_snapshot.go @@ -49,6 +49,7 @@ func (t *ART) newSnapshotIterator(start, end []byte, desc bool) *SnapIter { if err != nil { panic(err) } + inner.ignoreSeqNo = true it := &SnapIter{ Iterator: inner, cp: t.getSnapshot(), diff --git a/internal/unionstore/memdb_test.go b/internal/unionstore/memdb_test.go index 2481fb3373..3261d12e1b 100644 --- a/internal/unionstore/memdb_test.go +++ b/internal/unionstore/memdb_test.go @@ -1336,7 +1336,7 @@ func TestSnapshotReaderWithWrite(t *testing.T) { h := db.Staging() defer db.Release(h) - iter := db.BatchedSnapshotIter([]byte{0, 0}, []byte{0, 255}, false) + iter := db.SnapshotIter([]byte{0, 0}, []byte{0, 255}) assert.Equal(t, iter.Key(), []byte{0, 0}) db.Set([]byte{0, byte(num)}, []byte{0, byte(num)}) // ART: node4/node16/node48 is freed and wait to be reused. From 4c16a14e44ad1904520d24f5bd3bdc7d91e10179 Mon Sep 17 00:00:00 2001 From: ekexium Date: Thu, 23 Jan 2025 19:30:05 +0800 Subject: [PATCH 3/6] fix initKeysAndMutations Signed-off-by: ekexium --- internal/unionstore/art/art_iterator.go | 23 +++++++++++++++++++---- txnkv/transaction/2pc.go | 7 ++++++- 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/internal/unionstore/art/art_iterator.go b/internal/unionstore/art/art_iterator.go index 05a23e62a4..ee825f9dd6 100644 --- a/internal/unionstore/art/art_iterator.go +++ b/internal/unionstore/art/art_iterator.go @@ -19,6 +19,9 @@ import ( "fmt" "sort" + "github.com/tikv/client-go/v2/internal/logutil" + "go.uber.org/zap" + "github.com/pkg/errors" "github.com/tikv/client-go/v2/internal/unionstore/arena" "github.com/tikv/client-go/v2/kv" @@ -84,7 +87,21 @@ type Iterator struct { ignoreSeqNo bool } -func (it *Iterator) Valid() bool { return it.valid && (it.seqNo == it.tree.SeqNo || it.ignoreSeqNo) } +func (it *Iterator) checkSeqNo() { + if it.seqNo != it.tree.SeqNo && !it.ignoreSeqNo { + logutil.BgLogger().Panic( + "seqNo mismatch", + zap.Int("it seqNo", it.seqNo), + zap.Int("art seqNo", it.tree.SeqNo), + zap.Stack("stack"), + ) + } +} + +func (it *Iterator) Valid() bool { + it.checkSeqNo() + return it.valid +} func (it *Iterator) Key() []byte { return it.currLeaf.GetKey() } func (it *Iterator) Flags() kv.KeyFlags { return it.currLeaf.GetKeyFlags() } func (it *Iterator) Value() []byte { @@ -108,9 +125,7 @@ func (it *Iterator) Next() error { // iterate is finished return errors.New("Art: iterator is finished") } - if !it.ignoreSeqNo && it.seqNo != it.tree.SeqNo { - return errors.New(fmt.Sprintf("seqNo mismatch: iter=%d, art=%d", it.seqNo, it.tree.SeqNo)) - } + it.checkSeqNo() if it.currAddr == it.endAddr { it.valid = false return nil diff --git a/txnkv/transaction/2pc.go b/txnkv/transaction/2pc.go index 6a739b3616..14307bf38e 100644 --- a/txnkv/transaction/2pc.go +++ b/txnkv/transaction/2pc.go @@ -559,6 +559,7 @@ func (c *twoPhaseCommitter) initKeysAndMutations(ctx context.Context) error { var err error var assertionError error + toUpdatePrewriteOnly := make([][]byte, 0) for it := memBuf.IterWithFlags(nil, nil); it.Valid(); err = it.Next() { _ = err key := it.Key() @@ -607,7 +608,7 @@ func (c *twoPhaseCommitter) initKeysAndMutations(ctx context.Context) error { // due to `Op_CheckNotExists` doesn't prewrite lock, so mark those keys should not be used in commit-phase. op = kvrpcpb.Op_CheckNotExists checkCnt++ - memBuf.UpdateFlags(key, kv.SetPrewriteOnly) + toUpdatePrewriteOnly = append(toUpdatePrewriteOnly, key) } else { if flags.HasNewlyInserted() { // The delete-your-write keys in pessimistic transactions, only lock needed keys and skip @@ -682,6 +683,10 @@ func (c *twoPhaseCommitter) initKeysAndMutations(ctx context.Context) error { } } + for _, key := range toUpdatePrewriteOnly { + memBuf.UpdateFlags(key, kv.SetPrewriteOnly) + } + if c.mutations.Len() == 0 { return nil } From e1a3b5aa286619b52831ac365025518b2836a2db Mon Sep 17 00:00:00 2001 From: ekexium Date: Thu, 23 Jan 2025 20:21:30 +0800 Subject: [PATCH 4/6] more checks Signed-off-by: ekexium --- internal/unionstore/art/art_iterator.go | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/internal/unionstore/art/art_iterator.go b/internal/unionstore/art/art_iterator.go index ee825f9dd6..381d13753c 100644 --- a/internal/unionstore/art/art_iterator.go +++ b/internal/unionstore/art/art_iterator.go @@ -102,9 +102,19 @@ func (it *Iterator) Valid() bool { it.checkSeqNo() return it.valid } -func (it *Iterator) Key() []byte { return it.currLeaf.GetKey() } -func (it *Iterator) Flags() kv.KeyFlags { return it.currLeaf.GetKeyFlags() } + +func (it *Iterator) Key() []byte { + it.checkSeqNo() + return it.currLeaf.GetKey() +} + +func (it *Iterator) Flags() kv.KeyFlags { + it.checkSeqNo() + return it.currLeaf.GetKeyFlags() +} + func (it *Iterator) Value() []byte { + it.checkSeqNo() if it.currLeaf.vLogAddr.IsNull() { return nil } From 29fc98ea5f26a83902c15bdadacbd293536dd161 Mon Sep 17 00:00:00 2001 From: ekexium Date: Fri, 24 Jan 2025 12:50:35 +0800 Subject: [PATCH 5/6] optimize batched iter Signed-off-by: ekexium --- internal/unionstore/memdb_art.go | 55 +++++++++++++++++++++----------- 1 file changed, 36 insertions(+), 19 deletions(-) diff --git a/internal/unionstore/memdb_art.go b/internal/unionstore/memdb_art.go index ebd63edce8..4598a99564 100644 --- a/internal/unionstore/memdb_art.go +++ b/internal/unionstore/memdb_art.go @@ -203,7 +203,8 @@ type snapshotBatchedIter struct { reverse bool // current batch - kvs []KvPair + keys [][]byte + values [][]byte pos int batchSize int nextKey []byte @@ -216,10 +217,9 @@ func (db *artDBWithContext) BatchedSnapshotIter(lower, upper []byte, reverse boo lower: lower, upper: upper, reverse: reverse, - batchSize: 4, + batchSize: 32, } - // Position at first key immediately iter.fillBatch() return iter } @@ -233,10 +233,12 @@ func (it *snapshotBatchedIter) fillBatch() error { it.db.RLock() defer it.db.RUnlock() - if it.kvs == nil { - it.kvs = make([]KvPair, 0, it.batchSize) + if it.keys == nil || it.values == nil || cap(it.keys) < it.batchSize || cap(it.values) < it.batchSize { + it.keys = make([][]byte, 0, it.batchSize) + it.values = make([][]byte, 0, it.batchSize) } else { - it.kvs = it.kvs[:0] + it.keys = it.keys[:0] + it.values = it.values[:0] } var snapshotIter Iterator @@ -256,11 +258,12 @@ func (it *snapshotBatchedIter) fillBatch() error { defer snapshotIter.Close() // fill current batch + // Further optimization: let the underlying memdb support batch iter. for i := 0; i < it.batchSize && snapshotIter.Valid(); i++ { - it.kvs = append(it.kvs, KvPair{ - Key: snapshotIter.Key(), - Value: snapshotIter.Value(), - }) + it.keys = it.keys[:i+1] + it.values = it.values[:i+1] + it.keys[i] = snapshotIter.Key() + it.values[i] = snapshotIter.Value() if err := snapshotIter.Next(); err != nil { return err } @@ -268,12 +271,25 @@ func (it *snapshotBatchedIter) fillBatch() error { // update state it.pos = 0 - if len(it.kvs) > 0 { - lastKV := it.kvs[len(it.kvs)-1] + if len(it.keys) > 0 { + lastKey := it.keys[len(it.keys)-1] + keyLen := len(lastKey) + if it.reverse { - it.nextKey = append([]byte(nil), lastKV.Key...) + if cap(it.nextKey) >= keyLen { + it.nextKey = it.nextKey[:keyLen] + } else { + it.nextKey = make([]byte, keyLen) + } + copy(it.nextKey, lastKey) } else { - it.nextKey = append(append([]byte(nil), lastKV.Key...), 0) + if cap(it.nextKey) >= keyLen+1 { + it.nextKey = it.nextKey[:keyLen+1] + } else { + it.nextKey = make([]byte, keyLen+1) + } + copy(it.nextKey, lastKey) + it.nextKey[keyLen] = 0 } } else { it.nextKey = nil @@ -285,7 +301,7 @@ func (it *snapshotBatchedIter) fillBatch() error { func (it *snapshotBatchedIter) Valid() bool { return it.snapshotTruncateSeqNo == it.db.SnapshotSeqNo && - it.pos < len(it.kvs) + it.pos < len(it.keys) } func (it *snapshotBatchedIter) Next() error { @@ -300,7 +316,7 @@ func (it *snapshotBatchedIter) Next() error { } it.pos++ - if it.pos >= len(it.kvs) { + if it.pos >= len(it.keys) { return it.fillBatch() } return nil @@ -310,17 +326,18 @@ func (it *snapshotBatchedIter) Key() []byte { if !it.Valid() { return nil } - return it.kvs[it.pos].Key + return it.keys[it.pos] } func (it *snapshotBatchedIter) Value() []byte { if !it.Valid() { return nil } - return it.kvs[it.pos].Value + return it.values[it.pos] } func (it *snapshotBatchedIter) Close() { - it.kvs = nil + it.keys = nil + it.values = nil it.nextKey = nil } From e906cd537b58476bc4544c941a9c164124347875 Mon Sep 17 00:00:00 2001 From: ekexium Date: Fri, 24 Jan 2025 14:14:21 +0800 Subject: [PATCH 6/6] refine comment Signed-off-by: ekexium --- internal/unionstore/memdb_art.go | 5 +++++ internal/unionstore/union_store.go | 14 ++++++++++---- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/internal/unionstore/memdb_art.go b/internal/unionstore/memdb_art.go index 4598a99564..306a4d6d34 100644 --- a/internal/unionstore/memdb_art.go +++ b/internal/unionstore/memdb_art.go @@ -19,6 +19,8 @@ import ( "fmt" "sync" + "github.com/tikv/client-go/v2/internal/logutil" + "github.com/pingcap/errors" tikverr "github.com/tikv/client-go/v2/error" @@ -211,6 +213,9 @@ type snapshotBatchedIter struct { } func (db *artDBWithContext) BatchedSnapshotIter(lower, upper []byte, reverse bool) Iterator { + if len(db.Stages()) == 0 { + logutil.BgLogger().Error("should not use BatchedSnapshotIter for a memdb without any staging buffer") + } iter := &snapshotBatchedIter{ db: db, snapshotTruncateSeqNo: db.SnapshotSeqNo, diff --git a/internal/unionstore/union_store.go b/internal/unionstore/union_store.go index b6beaad010..6127101863 100644 --- a/internal/unionstore/union_store.go +++ b/internal/unionstore/union_store.go @@ -200,8 +200,12 @@ type MemBuffer interface { DeleteWithFlags([]byte, ...kv.FlagsOp) error // Iter implements the Retriever interface. + // Any write operation to the memdb invalidates this iterator immediately after its creation. + // Attempting to use such an invalidated iterator will result in a panic. Iter([]byte, []byte) (Iterator, error) // IterReverse implements the Retriever interface. + // Any write operation to the memdb invalidates this iterator immediately after its creation. + // Attempting to use such an invalidated iterator will result in a panic. IterReverse([]byte, []byte) (Iterator, error) // SnapshotIter returns an Iterator for a snapshot of MemBuffer. // Deprecated: use ForEachInSnapshotRange or BatchedSnapshotIter instead. @@ -222,12 +226,14 @@ type MemBuffer interface { // Use it when you need to scan the whole range, otherwise consider using BatchedSnapshotIter. ForEachInSnapshotRange(lower []byte, upper []byte, f func(k, v []byte) (stop bool, err error), reverse bool) error - // BatchedSnapshotIter iterates in batches to prevent iterator invalidation: - // It does not save any iterator state, instead it copies the keys and values to a buffer. - // It behaves like SnapshotIter, but it is safe to use the returned keys and values after the iteration. + // BatchedSnapshotIter returns an iterator of the "snapshot", namely stage[0]. + // It iterates in batches and prevents iterator invalidation. + // // Use it when you need on-demand "next", otherwise consider using ForEachInSnapshotRange. + // NOTE: you should never use it when there are no stages. // - // The iterator becomes invalid after a membuffer vlog truncation operation. + // The iterator becomes invalid when any operation that may modify the "snapshot", + // e.g. RevertToCheckpoint or releasing stage[0]. BatchedSnapshotIter(lower, upper []byte, reverse bool) Iterator //SnapshotGetter returns a Getter for a snapshot of MemBuffer.