From 4621fd80704065ca83a4c96d29e1a6933d25dcb2 Mon Sep 17 00:00:00 2001 From: Faye Amacker <33205765+fxamacker@users.noreply.github.com> Date: Thu, 29 Feb 2024 19:19:40 -0600 Subject: [PATCH] Fix data race in account grouping --- cmd/util/ledger/util/payload_grouping.go | 14 +++--- cmd/util/ledger/util/payload_grouping_test.go | 44 +++++++++++++++++++ 2 files changed, 51 insertions(+), 7 deletions(-) diff --git a/cmd/util/ledger/util/payload_grouping.go b/cmd/util/ledger/util/payload_grouping.go index 328fbedabc2..9f368abc45f 100644 --- a/cmd/util/ledger/util/payload_grouping.go +++ b/cmd/util/ledger/util/payload_grouping.go @@ -112,7 +112,7 @@ func GroupPayloadsByAccount( indexes := make([]int, 0, estimatedNumOfAccount) for i := 0; i < len(p); { indexes = append(indexes, i) - i = p.FindNextKeyIndex(i) + i = p.FindNextKeyIndexUntil(i, len(p)) } end = time.Now() @@ -177,17 +177,17 @@ func (s sortablePayloads) Swap(i, j int) { s[i], s[j] = s[j], s[i] } -func (s sortablePayloads) FindNextKeyIndex(i int) int { +func (s sortablePayloads) FindNextKeyIndexUntil(i int, upperBound int) int { low := i step := 1 - for low+step < len(s) && s.Compare(low+step, i) == 0 { + for low+step < upperBound && s.Compare(low+step, i) == 0 { low += step step *= 2 } high := low + step - if high > len(s) { - high = len(s) + if high > upperBound { + high = upperBound } for low < high { @@ -248,13 +248,13 @@ func mergeInto(source, buffer sortablePayloads, i int, mid int, j int) { // More elements in the both partitions to process. if source.Compare(left, right) <= 0 { // Move left partition elements with the same address to buffer. - nextLeft := source.FindNextKeyIndex(left) + nextLeft := source.FindNextKeyIndexUntil(left, mid) n := copy(buffer[k:], source[left:nextLeft]) left = nextLeft k += n } else { // Move right partition elements with the same address to buffer. - nextRight := source.FindNextKeyIndex(right) + nextRight := source.FindNextKeyIndexUntil(right, j) n := copy(buffer[k:], source[right:nextRight]) right = nextRight k += n diff --git a/cmd/util/ledger/util/payload_grouping_test.go b/cmd/util/ledger/util/payload_grouping_test.go index 96b50bd4e5b..9ab7392e5e6 100644 --- a/cmd/util/ledger/util/payload_grouping_test.go +++ b/cmd/util/ledger/util/payload_grouping_test.go @@ -29,6 +29,20 @@ func TestGroupPayloadsByAccount(t *testing.T) { require.Greater(t, groups.Len(), 1) } +func TestGroupPayloadsByAccountForDataRace(t *testing.T) { + log := zerolog.New(zerolog.NewTestWriter(t)) + + const accountSize = 4 + var payloads []*ledger.Payload + for i := 0; i < accountSize; i++ { + payloads = append(payloads, generateRandomPayloadsWithAddress(generateRandomAddress(), 100_000)...) + } + + const nWorkers = 8 + groups := util.GroupPayloadsByAccount(log, payloads, nWorkers) + require.Equal(t, accountSize, groups.Len()) +} + func TestGroupPayloadsByAccountCompareResults(t *testing.T) { log := zerolog.Nop() payloads := generateRandomPayloads(1000000) @@ -129,6 +143,36 @@ func generateRandomPayloads(n int) []*ledger.Payload { return payloads } +func generateRandomPayloadsWithAddress(address string, n int) []*ledger.Payload { + const meanPayloadsPerAccount = 100 + const minPayloadsPerAccount = 1 + + payloads := make([]*ledger.Payload, 0, n) + + for i := 0; i < n; { + + registersForAccount := minPayloadsPerAccount + int(rand2.ExpFloat64()*(meanPayloadsPerAccount-minPayloadsPerAccount)) + if registersForAccount > n-i { + registersForAccount = n - i + } + i += registersForAccount + + accountKey := convert.RegisterIDToLedgerKey(flow.RegisterID{ + Owner: address, + Key: generateRandomString(10), + }) + for j := 0; j < registersForAccount; j++ { + payloads = append(payloads, + ledger.NewPayload( + accountKey, + []byte(generateRandomString(10)), + )) + } + } + + return payloads +} + func generateRandomAccountKey() ledger.Key { return convert.RegisterIDToLedgerKey(flow.RegisterID{ Owner: generateRandomAddress(),