diff --git a/cmd/util/cmd/execution-state-extract/cmd.go b/cmd/util/cmd/execution-state-extract/cmd.go index 55728b428a8..5cfeea0312f 100644 --- a/cmd/util/cmd/execution-state-extract/cmd.go +++ b/cmd/util/cmd/execution-state-extract/cmd.go @@ -2,11 +2,16 @@ package extract import ( "encoding/hex" + "fmt" + "os" "path" + "strings" "github.com/rs/zerolog/log" "github.com/spf13/cobra" + runtimeCommon "github.com/onflow/cadence/runtime/common" + "github.com/onflow/flow-go/cmd/util/cmd/common" "github.com/onflow/flow-go/model/bootstrap" "github.com/onflow/flow-go/model/flow" @@ -26,6 +31,9 @@ var ( flagNoReport bool flagValidateMigration bool flagLogVerboseValidationError bool + flagInputPayloadFileName string + flagOutputPayloadFileName string + flagOutputPayloadByAddresses string ) var Cmd = &cobra.Command{ @@ -68,6 +76,35 @@ func init() { Cmd.Flags().BoolVar(&flagLogVerboseValidationError, "log-verbose-validation-error", false, "log entire Cadence values on validation error (atree migration)") + // If specified, the state will consist of payloads from the given input payload file. + // If not specified, then the state will be extracted from the latest checkpoint file. + // This flag can be used to reduce total duration of migrations when state extraction involves + // multiple migrations because it helps avoid repeatedly reading from checkpoint file to rebuild trie. + // The input payload file must be created by state extraction running with either + // flagOutputPayloadFileName or flagOutputPayloadByAddresses. + Cmd.Flags().StringVar( + &flagInputPayloadFileName, + "input-payload-filename", + "", + "input payload file", + ) + + Cmd.Flags().StringVar( + &flagOutputPayloadFileName, + "output-payload-filename", + "", + "output payload file", + ) + + Cmd.Flags().StringVar( + // Extract payloads of specified addresses (comma separated list of hex-encoded addresses) + // to file specified by --output-payload-filename. + // If no address is specified (empty string) then this flag is ignored. + &flagOutputPayloadByAddresses, + "extract-payloads-by-address", + "", + "extract payloads of addresses (comma separated hex-encoded addresses) to file specified by output-payload-filename", + ) } func run(*cobra.Command, []string) { @@ -78,6 +115,19 @@ func run(*cobra.Command, []string) { return } + if len(flagBlockHash) == 0 && len(flagStateCommitment) == 0 && len(flagInputPayloadFileName) == 0 { + log.Fatal().Msg("--block-hash or --state-commitment or --input-payload-filename must be specified") + } + + if len(flagInputPayloadFileName) > 0 && (len(flagBlockHash) > 0 || len(flagStateCommitment) > 0) { + log.Fatal().Msg("--input-payload-filename cannot be used with --block-hash or --state-commitment") + } + + // When flagOutputPayloadByAddresses is specified, flagOutputPayloadFileName is required. + if len(flagOutputPayloadFileName) == 0 && len(flagOutputPayloadByAddresses) > 0 { + log.Fatal().Msg("--extract-payloads-by-address requires --output-payload-filename to be specified") + } + if len(flagBlockHash) > 0 { blockID, err := flow.HexStringToIdentifier(flagBlockHash) if err != nil { @@ -112,20 +162,38 @@ func run(*cobra.Command, []string) { log.Info().Msgf("extracting state by state commitment: %x", stateCommitment) } - if len(flagBlockHash) == 0 && len(flagStateCommitment) == 0 { - log.Fatal().Msg("no --block-hash or --state-commitment was specified") + if len(flagInputPayloadFileName) > 0 { + if _, err := os.Stat(flagInputPayloadFileName); os.IsNotExist(err) { + log.Fatal().Msgf("payload input file %s doesn't exist", flagInputPayloadFileName) + } } - log.Info().Msgf("Extracting state from %s, exporting root checkpoint to %s, version: %v", - flagExecutionStateDir, - path.Join(flagOutputDir, bootstrap.FilenameWALRootCheckpoint), - 6, - ) + if len(flagOutputPayloadFileName) > 0 { + if _, err := os.Stat(flagOutputPayloadFileName); os.IsExist(err) { + log.Fatal().Msgf("payload output file %s exists", flagOutputPayloadFileName) + } + } + + var exportedAddresses []runtimeCommon.Address + + if len(flagOutputPayloadByAddresses) > 0 { + + addresses := strings.Split(flagOutputPayloadByAddresses, ",") - log.Info().Msgf("Block state commitment: %s from %v, output dir: %s", - hex.EncodeToString(stateCommitment[:]), - flagExecutionStateDir, - flagOutputDir) + for _, hexAddr := range addresses { + b, err := hex.DecodeString(strings.TrimSpace(hexAddr)) + if err != nil { + log.Fatal().Err(err).Msgf("cannot hex decode address %s for payload export", strings.TrimSpace(hexAddr)) + } + + addr, err := runtimeCommon.BytesToAddress(b) + if err != nil { + log.Fatal().Err(err).Msgf("cannot decode address %x for payload export", b) + } + + exportedAddresses = append(exportedAddresses, addr) + } + } // err := ensureCheckpointFileExist(flagExecutionStateDir) // if err != nil { @@ -148,14 +216,65 @@ func run(*cobra.Command, []string) { log.Warn().Msgf("atree migration has verbose validation error logging enabled which may increase size of log") } - err := extractExecutionState( - log.Logger, - flagExecutionStateDir, - stateCommitment, - flagOutputDir, - flagNWorker, - !flagNoMigration, - ) + var inputMsg string + if len(flagInputPayloadFileName) > 0 { + // Input is payloads + inputMsg = fmt.Sprintf("reading payloads from %s", flagInputPayloadFileName) + } else { + // Input is execution state + inputMsg = fmt.Sprintf("reading block state commitment %s from %s", + hex.EncodeToString(stateCommitment[:]), + flagExecutionStateDir, + ) + } + + var outputMsg string + if len(flagOutputPayloadFileName) > 0 { + // Output is payload file + if len(exportedAddresses) == 0 { + outputMsg = fmt.Sprintf("exporting all payloads to %s", flagOutputPayloadFileName) + } else { + outputMsg = fmt.Sprintf( + "exporting payloads by addresses %v to %s", + flagOutputPayloadByAddresses, + flagOutputPayloadFileName, + ) + } + } else { + // Output is checkpoint files + outputMsg = fmt.Sprintf( + "exporting root checkpoint to %s, version: %d", + path.Join(flagOutputDir, bootstrap.FilenameWALRootCheckpoint), + 6, + ) + } + + log.Info().Msgf("state extraction plan: %s, %s", inputMsg, outputMsg) + + var err error + if len(flagInputPayloadFileName) > 0 { + err = extractExecutionStateFromPayloads( + log.Logger, + flagExecutionStateDir, + flagOutputDir, + flagNWorker, + !flagNoMigration, + flagInputPayloadFileName, + flagOutputPayloadFileName, + exportedAddresses, + ) + } else { + err = extractExecutionState( + log.Logger, + flagExecutionStateDir, + stateCommitment, + flagOutputDir, + flagNWorker, + !flagNoMigration, + flagOutputPayloadFileName, + exportedAddresses, + ) + } if err != nil { log.Fatal().Err(err).Msgf("error extracting the execution state: %s", err.Error()) diff --git a/cmd/util/cmd/execution-state-extract/execution_state_extract.go b/cmd/util/cmd/execution-state-extract/execution_state_extract.go index b2146878898..945274c6be3 100644 --- a/cmd/util/cmd/execution-state-extract/execution_state_extract.go +++ b/cmd/util/cmd/execution-state-extract/execution_state_extract.go @@ -5,12 +5,15 @@ import ( "fmt" "math" "os" + "time" + "github.com/onflow/cadence/runtime/common" "github.com/rs/zerolog" "go.uber.org/atomic" migrators "github.com/onflow/flow-go/cmd/util/ledger/migrations" "github.com/onflow/flow-go/cmd/util/ledger/reporters" + "github.com/onflow/flow-go/cmd/util/ledger/util" "github.com/onflow/flow-go/ledger" "github.com/onflow/flow-go/ledger/common/hash" "github.com/onflow/flow-go/ledger/common/pathfinder" @@ -34,6 +37,8 @@ func extractExecutionState( outputDir string, nWorker int, // number of concurrent worker to migation payloads runMigrations bool, + outputPayloadFile string, + exportPayloadsByAddresses []common.Address, ) error { log.Info().Msg("init WAL") @@ -84,30 +89,7 @@ func extractExecutionState( <-compactor.Done() }() - var migrations []ledger.Migration - - if runMigrations { - rwf := reporters.NewReportFileWriterFactory(dir, log) - - migrations = []ledger.Migration{ - migrators.CreateAccountBasedMigration( - log, - nWorker, - []migrators.AccountBasedMigration{ - migrators.NewAtreeRegisterMigrator( - rwf, - flagValidateMigration, - flagLogVerboseValidationError, - ), - - &migrators.DeduplicateContractNamesMigration{}, - - // This will fix storage used discrepancies caused by the - // DeduplicateContractNamesMigration. - &migrators.AccountUsageMigrator{}, - }), - } - } + migrations := newMigrations(log, dir, nWorker, runMigrations) newState := ledger.State(targetHash) @@ -134,6 +116,25 @@ func extractExecutionState( log.Error().Err(err).Msgf("can not generate report for migrated state: %v", newMigratedState) } + exportPayloads := len(outputPayloadFile) > 0 + if exportPayloads { + payloads := newTrie.AllPayloads() + + exportedPayloadCount, err := util.CreatePayloadFile( + log, + outputPayloadFile, + payloads, + exportPayloadsByAddresses, + ) + if err != nil { + return fmt.Errorf("cannot generate payloads file: %w", err) + } + + log.Info().Msgf("Exported %d payloads out of %d payloads", exportedPayloadCount, len(payloads)) + + return nil + } + migratedState, err := createCheckpoint( newTrie, log, @@ -191,3 +192,167 @@ func writeStatusFile(fileName string, e error) error { err := os.WriteFile(fileName, checkpointStatusJson, 0644) return err } + +func extractExecutionStateFromPayloads( + log zerolog.Logger, + dir string, + outputDir string, + nWorker int, // number of concurrent worker to migation payloads + runMigrations bool, + inputPayloadFile string, + outputPayloadFile string, + exportPayloadsByAddresses []common.Address, +) error { + + payloads, err := util.ReadPayloadFile(log, inputPayloadFile) + if err != nil { + return err + } + + log.Info().Msgf("read %d payloads", len(payloads)) + + migrations := newMigrations(log, dir, nWorker, runMigrations) + + payloads, err = migratePayloads(log, payloads, migrations) + if err != nil { + return err + } + + exportPayloads := len(outputPayloadFile) > 0 + if exportPayloads { + exportedPayloadCount, err := util.CreatePayloadFile( + log, + outputPayloadFile, + payloads, + exportPayloadsByAddresses, + ) + if err != nil { + return fmt.Errorf("cannot generate payloads file: %w", err) + } + + log.Info().Msgf("Exported %d payloads out of %d payloads", exportedPayloadCount, len(payloads)) + + return nil + } + + newTrie, err := createTrieFromPayloads(log, payloads) + if err != nil { + return err + } + + migratedState, err := createCheckpoint( + newTrie, + log, + outputDir, + bootstrap.FilenameWALRootCheckpoint, + ) + if err != nil { + return fmt.Errorf("cannot generate the output checkpoint: %w", err) + } + + log.Info().Msgf( + "New state commitment for the exported state is: %s (base64: %s)", + migratedState.String(), + migratedState.Base64(), + ) + + return nil +} + +func migratePayloads(logger zerolog.Logger, payloads []*ledger.Payload, migrations []ledger.Migration) ([]*ledger.Payload, error) { + + if len(migrations) == 0 { + return payloads, nil + } + + var err error + payloadCount := len(payloads) + + // migrate payloads + for i, migrate := range migrations { + logger.Info().Msgf("migration %d/%d is underway", i, len(migrations)) + + start := time.Now() + payloads, err = migrate(payloads) + elapsed := time.Since(start) + + if err != nil { + return nil, fmt.Errorf("error applying migration (%d): %w", i, err) + } + + newPayloadCount := len(payloads) + + if payloadCount != newPayloadCount { + logger.Warn(). + Int("migration_step", i). + Int("expected_size", payloadCount). + Int("outcome_size", newPayloadCount). + Msg("payload counts has changed during migration, make sure this is expected.") + } + logger.Info().Str("timeTaken", elapsed.String()).Msgf("migration %d is done", i) + + payloadCount = newPayloadCount + } + + return payloads, nil +} + +func createTrieFromPayloads(logger zerolog.Logger, payloads []*ledger.Payload) (*trie.MTrie, error) { + // get paths + paths, err := pathfinder.PathsFromPayloads(payloads, complete.DefaultPathFinderVersion) + if err != nil { + return nil, fmt.Errorf("cannot export checkpoint, can't construct paths: %w", err) + } + + logger.Info().Msgf("constructing a new trie with migrated payloads (count: %d)...", len(payloads)) + + emptyTrie := trie.NewEmptyMTrie() + + derefPayloads := make([]ledger.Payload, len(payloads)) + for i, p := range payloads { + derefPayloads[i] = *p + } + + // no need to prune the data since it has already been prunned through migrations + applyPruning := false + newTrie, _, err := trie.NewTrieWithUpdatedRegisters(emptyTrie, paths, derefPayloads, applyPruning) + if err != nil { + return nil, fmt.Errorf("constructing updated trie failed: %w", err) + } + + return newTrie, nil +} + +func newMigrations( + log zerolog.Logger, + dir string, + nWorker int, // number of concurrent worker to migation payloads + runMigrations bool, +) []ledger.Migration { + if runMigrations { + rwf := reporters.NewReportFileWriterFactory(dir, log) + + migrations := []ledger.Migration{ + migrators.CreateAccountBasedMigration( + log, + nWorker, + []migrators.AccountBasedMigration{ + migrators.NewAtreeRegisterMigrator( + rwf, + flagValidateMigration, + flagLogVerboseValidationError, + ), + + &migrators.DeduplicateContractNamesMigration{}, + + // This will fix storage used discrepancies caused by the + // DeduplicateContractNamesMigration. + &migrators.AccountUsageMigrator{}, + }), + } + + return migrations + } + + return nil +} diff --git a/cmd/util/cmd/execution-state-extract/execution_state_extract_test.go b/cmd/util/cmd/execution-state-extract/execution_state_extract_test.go index 70f8ca6bc89..882c88df898 100644 --- a/cmd/util/cmd/execution-state-extract/execution_state_extract_test.go +++ b/cmd/util/cmd/execution-state-extract/execution_state_extract_test.go @@ -2,14 +2,20 @@ package extract import ( "crypto/rand" + "encoding/hex" "math" + "path/filepath" + "strings" "testing" "github.com/rs/zerolog" "github.com/stretchr/testify/require" "go.uber.org/atomic" + runtimeCommon "github.com/onflow/cadence/runtime/common" + "github.com/onflow/flow-go/cmd/util/cmd/common" + "github.com/onflow/flow-go/cmd/util/ledger/util" "github.com/onflow/flow-go/ledger" "github.com/onflow/flow-go/ledger/common/pathfinder" "github.com/onflow/flow-go/ledger/complete" @@ -66,6 +72,8 @@ func TestExtractExecutionState(t *testing.T) { outdir, 10, false, + "", + nil, ) require.Error(t, err) }) @@ -96,7 +104,7 @@ func TestExtractExecutionState(t *testing.T) { var stateCommitment = f.InitialState() - //saved data after updates + // saved data after updates keysValuesByCommit := make(map[string]map[string]keyPair) commitsByBlocks := make(map[flow.Identifier]ledger.State) blocksInOrder := make([]flow.Identifier, size) @@ -108,7 +116,7 @@ func TestExtractExecutionState(t *testing.T) { require.NoError(t, err) stateCommitment, _, err = f.Set(update) - //stateCommitment, err = f.UpdateRegisters(keys, values, stateCommitment) + // stateCommitment, err = f.UpdateRegisters(keys, values, stateCommitment) require.NoError(t, err) // generate random block and map it to state commitment @@ -135,13 +143,13 @@ func TestExtractExecutionState(t *testing.T) { err = db.Close() require.NoError(t, err) - //for blockID, stateCommitment := range commitsByBlocks { + // for blockID, stateCommitment := range commitsByBlocks { for i, blockID := range blocksInOrder { stateCommitment := commitsByBlocks[blockID] - //we need fresh output dir to prevent contamination + // we need fresh output dir to prevent contamination unittest.RunWithTempDir(t, func(outdir string) { Cmd.SetArgs([]string{ @@ -182,7 +190,7 @@ func TestExtractExecutionState(t *testing.T) { require.NoError(t, err) registerValues, err := storage.Get(query) - //registerValues, err := mForest.Read([]byte(stateCommitment), keys) + // registerValues, err := mForest.Read([]byte(stateCommitment), keys) require.NoError(t, err) for i, key := range keys { @@ -190,7 +198,7 @@ func TestExtractExecutionState(t *testing.T) { require.Equal(t, data[key.String()].value, registerValue) } - //make sure blocks after this one are not in checkpoint + // make sure blocks after this one are not in checkpoint // ie - extraction stops after hitting right hash for j := i + 1; j < len(blocksInOrder); j++ { @@ -207,6 +215,339 @@ func TestExtractExecutionState(t *testing.T) { }) } +// TestExtractPayloadsFromExecutionState tests state extraction with checkpoint as input and payload as output. +func TestExtractPayloadsFromExecutionState(t *testing.T) { + metr := &metrics.NoopCollector{} + + const payloadFileName = "root.payload" + + t.Run("all payloads", func(t *testing.T) { + withDirs(t, func(_, execdir, outdir string) { + + const ( + checkpointDistance = math.MaxInt // A large number to prevent checkpoint creation. + checkpointsToKeep = 1 + ) + + outputPayloadFileName := filepath.Join(outdir, payloadFileName) + + size := 10 + + diskWal, err := wal.NewDiskWAL(zerolog.Nop(), nil, metrics.NewNoopCollector(), execdir, size, pathfinder.PathByteSize, wal.SegmentSize) + require.NoError(t, err) + f, err := complete.NewLedger(diskWal, size*10, metr, zerolog.Nop(), complete.DefaultPathFinderVersion) + require.NoError(t, err) + compactor, err := complete.NewCompactor(f, diskWal, zerolog.Nop(), uint(size), checkpointDistance, checkpointsToKeep, atomic.NewBool(false), &metrics.NoopCollector{}) + require.NoError(t, err) + <-compactor.Ready() + + var stateCommitment = f.InitialState() + + // Save generated data after updates + keysValues := make(map[string]keyPair) + + for i := 0; i < size; i++ { + keys, values := getSampleKeyValues(i) + + update, err := ledger.NewUpdate(stateCommitment, keys, values) + require.NoError(t, err) + + stateCommitment, _, err = f.Set(update) + require.NoError(t, err) + + for j, key := range keys { + keysValues[key.String()] = keyPair{ + key: key, + value: values[j], + } + } + } + + <-f.Done() + <-compactor.Done() + + tries, err := f.Tries() + require.NoError(t, err) + + err = wal.StoreCheckpointV6SingleThread(tries, execdir, "checkpoint.00000001", zerolog.Nop()) + require.NoError(t, err) + + // Export all payloads + Cmd.SetArgs([]string{ + "--execution-state-dir", execdir, + "--output-dir", outdir, + "--state-commitment", hex.EncodeToString(stateCommitment[:]), + "--no-migration", + "--no-report", + "--output-payload-filename", outputPayloadFileName, + "--chain", flow.Emulator.Chain().String()}) + + err = Cmd.Execute() + require.NoError(t, err) + + // Verify exported payloads. + payloadsFromFile, err := util.ReadPayloadFile(zerolog.Nop(), outputPayloadFileName) + require.NoError(t, err) + require.Equal(t, len(keysValues), len(payloadsFromFile)) + + for _, payloadFromFile := range payloadsFromFile { + k, err := payloadFromFile.Key() + require.NoError(t, err) + + kv, exist := keysValues[k.String()] + require.True(t, exist) + require.Equal(t, kv.value, payloadFromFile.Value()) + } + }) + }) + + t.Run("some payloads", func(t *testing.T) { + withDirs(t, func(_, execdir, outdir string) { + const ( + checkpointDistance = math.MaxInt // A large number to prevent checkpoint creation. + checkpointsToKeep = 1 + ) + + outputPayloadFileName := filepath.Join(outdir, payloadFileName) + + size := 10 + + diskWal, err := wal.NewDiskWAL(zerolog.Nop(), nil, metrics.NewNoopCollector(), execdir, size, pathfinder.PathByteSize, wal.SegmentSize) + require.NoError(t, err) + f, err := complete.NewLedger(diskWal, size*10, metr, zerolog.Nop(), complete.DefaultPathFinderVersion) + require.NoError(t, err) + compactor, err := complete.NewCompactor(f, diskWal, zerolog.Nop(), uint(size), checkpointDistance, checkpointsToKeep, atomic.NewBool(false), &metrics.NoopCollector{}) + require.NoError(t, err) + <-compactor.Ready() + + var stateCommitment = f.InitialState() + + // Save generated data after updates + keysValues := make(map[string]keyPair) + + for i := 0; i < size; i++ { + keys, values := getSampleKeyValues(i) + + update, err := ledger.NewUpdate(stateCommitment, keys, values) + require.NoError(t, err) + + stateCommitment, _, err = f.Set(update) + require.NoError(t, err) + + for j, key := range keys { + keysValues[key.String()] = keyPair{ + key: key, + value: values[j], + } + } + } + + <-f.Done() + <-compactor.Done() + + tries, err := f.Tries() + require.NoError(t, err) + + err = wal.StoreCheckpointV6SingleThread(tries, execdir, "checkpoint.00000001", zerolog.Nop()) + require.NoError(t, err) + + const selectedAddressCount = 10 + selectedAddresses := make(map[string]struct{}) + selectedKeysValues := make(map[string]keyPair) + for k, kv := range keysValues { + owner := kv.key.KeyParts[0].Value + if len(owner) != runtimeCommon.AddressLength { + continue + } + + address, err := runtimeCommon.BytesToAddress(owner) + require.NoError(t, err) + + if len(selectedAddresses) < selectedAddressCount { + selectedAddresses[address.Hex()] = struct{}{} + } + + if _, exist := selectedAddresses[address.Hex()]; exist { + selectedKeysValues[k] = kv + } + } + + addresses := make([]string, 0, len(selectedAddresses)) + for address := range selectedAddresses { + addresses = append(addresses, address) + } + + // Export selected payloads + Cmd.SetArgs([]string{ + "--execution-state-dir", execdir, + "--output-dir", outdir, + "--state-commitment", hex.EncodeToString(stateCommitment[:]), + "--no-migration", + "--no-report", + "--output-payload-filename", outputPayloadFileName, + "--extract-payloads-by-address", strings.Join(addresses, ","), + "--chain", flow.Emulator.Chain().String()}) + + err = Cmd.Execute() + require.NoError(t, err) + + // Verify exported payloads. + payloadsFromFile, err := util.ReadPayloadFile(zerolog.Nop(), outputPayloadFileName) + require.NoError(t, err) + require.Equal(t, len(selectedKeysValues), len(payloadsFromFile)) + + for _, payloadFromFile := range payloadsFromFile { + k, err := payloadFromFile.Key() + require.NoError(t, err) + + kv, exist := selectedKeysValues[k.String()] + require.True(t, exist) + require.Equal(t, kv.value, payloadFromFile.Value()) + } + }) + }) +} + +// TestExtractStateFromPayloads tests state extraction with payload as input. +func TestExtractStateFromPayloads(t *testing.T) { + + const payloadFileName = "root.payload" + + t.Run("create checkpoint", func(t *testing.T) { + withDirs(t, func(_, execdir, outdir string) { + size := 10 + + inputPayloadFileName := filepath.Join(execdir, payloadFileName) + + // Generate some data + keysValues := make(map[string]keyPair) + var payloads []*ledger.Payload + + for i := 0; i < size; i++ { + keys, values := getSampleKeyValues(i) + + for j, key := range keys { + keysValues[key.String()] = keyPair{ + key: key, + value: values[j], + } + + payloads = append(payloads, ledger.NewPayload(key, values[j])) + } + } + + numOfPayloadWritten, err := util.CreatePayloadFile( + zerolog.Nop(), + inputPayloadFileName, + payloads, + nil, + ) + require.NoError(t, err) + require.Equal(t, len(payloads), numOfPayloadWritten) + + // Export checkpoint file + Cmd.SetArgs([]string{ + "--execution-state-dir", execdir, + "--output-dir", outdir, + "--no-migration", + "--no-report", + "--state-commitment", "", + "--input-payload-filename", inputPayloadFileName, + "--output-payload-filename", "", + "--extract-payloads-by-address", "", + "--chain", flow.Emulator.Chain().String()}) + + err = Cmd.Execute() + require.NoError(t, err) + + tries, err := wal.OpenAndReadCheckpointV6(outdir, "root.checkpoint", zerolog.Nop()) + require.NoError(t, err) + require.Equal(t, 1, len(tries)) + + // Verify exported checkpoint + payloadsFromFile := tries[0].AllPayloads() + require.NoError(t, err) + require.Equal(t, len(keysValues), len(payloadsFromFile)) + + for _, payloadFromFile := range payloadsFromFile { + k, err := payloadFromFile.Key() + require.NoError(t, err) + + kv, exist := keysValues[k.String()] + require.True(t, exist) + + require.Equal(t, kv.value, payloadFromFile.Value()) + } + }) + + }) + + t.Run("create payloads", func(t *testing.T) { + withDirs(t, func(_, execdir, outdir string) { + inputPayloadFileName := filepath.Join(execdir, payloadFileName) + outputPayloadFileName := filepath.Join(outdir, "selected.payload") + + size := 10 + + // Generate some data + keysValues := make(map[string]keyPair) + var payloads []*ledger.Payload + + for i := 0; i < size; i++ { + keys, values := getSampleKeyValues(i) + + for j, key := range keys { + keysValues[key.String()] = keyPair{ + key: key, + value: values[j], + } + + payloads = append(payloads, ledger.NewPayload(key, values[j])) + } + } + + numOfPayloadWritten, err := util.CreatePayloadFile( + zerolog.Nop(), + inputPayloadFileName, + payloads, + nil, + ) + require.NoError(t, err) + require.Equal(t, len(payloads), numOfPayloadWritten) + + // Export all payloads + Cmd.SetArgs([]string{ + "--execution-state-dir", execdir, + "--output-dir", outdir, + "--no-migration", + "--no-report", + "--state-commitment", "", + "--input-payload-filename", inputPayloadFileName, + "--output-payload-filename", outputPayloadFileName, + "--extract-payloads-by-address", "", + "--chain", flow.Emulator.Chain().String()}) + + err = Cmd.Execute() + require.NoError(t, err) + + // Verify exported payloads. + payloadsFromFile, err := util.ReadPayloadFile(zerolog.Nop(), outputPayloadFileName) + require.NoError(t, err) + require.Equal(t, len(keysValues), len(payloadsFromFile)) + + for _, payloadFromFile := range payloadsFromFile { + k, err := payloadFromFile.Key() + require.NoError(t, err) + + kv, exist := keysValues[k.String()] + require.True(t, exist) + + require.Equal(t, kv.value, payloadFromFile.Value()) + } + }) + }) +} + func getSampleKeyValues(i int) ([]ledger.Key, []ledger.Value) { switch i { case 0: @@ -226,7 +567,8 @@ func getSampleKeyValues(i int) ([]ledger.Key, []ledger.Value) { keys := make([]ledger.Key, 0) values := make([]ledger.Value, 0) for j := 0; j < 10; j++ { - address := make([]byte, 32) + // address := make([]byte, 32) + address := make([]byte, 8) _, err := rand.Read(address) if err != nil { panic(err) diff --git a/cmd/util/cmd/extract-payloads-by-address/cmd.go b/cmd/util/cmd/extract-payloads-by-address/cmd.go new file mode 100644 index 00000000000..3e384bf5d05 --- /dev/null +++ b/cmd/util/cmd/extract-payloads-by-address/cmd.go @@ -0,0 +1,264 @@ +package extractpayloads + +import ( + "bufio" + "bytes" + "encoding/binary" + "encoding/hex" + "fmt" + "io" + "os" + "strings" + + "github.com/fxamacker/cbor/v2" + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" + "github.com/spf13/cobra" + + "github.com/onflow/cadence/runtime/common" + + "github.com/onflow/flow-go/ledger" +) + +const ( + defaultBufioWriteSize = 1024 * 32 + defaultBufioReadSize = 1024 * 32 + + payloadEncodingVersion = 1 +) + +var ( + flagInputPayloadFileName string + flagOutputPayloadFileName string + flagAddresses string +) + +var Cmd = &cobra.Command{ + Use: "extract-payload-by-address", + Short: "Read payload file and generate payload file containing payloads with specified addresses", + Run: run, +} + +func init() { + Cmd.Flags().StringVar( + &flagInputPayloadFileName, + "input-filename", + "", + "Input payload file name") + _ = Cmd.MarkFlagRequired("input-filename") + + Cmd.Flags().StringVar( + &flagOutputPayloadFileName, + "output-filename", + "", + "Output payload file name") + _ = Cmd.MarkFlagRequired("output-filename") + + Cmd.Flags().StringVar( + &flagAddresses, + "addresses", + "", + "extract payloads of addresses (comma separated hex-encoded addresses) to file specified by output-payload-filename", + ) + _ = Cmd.MarkFlagRequired("addresses") +} + +func run(*cobra.Command, []string) { + + if _, err := os.Stat(flagInputPayloadFileName); os.IsNotExist(err) { + log.Fatal().Msgf("Input file %s doesn't exist", flagInputPayloadFileName) + } + + if _, err := os.Stat(flagOutputPayloadFileName); os.IsExist(err) { + log.Fatal().Msgf("Output file %s exists", flagOutputPayloadFileName) + } + + addresses, err := parseAddresses(strings.Split(flagAddresses, ",")) + if err != nil { + log.Fatal().Err(err) + } + + log.Info().Msgf( + "extracting payloads with address %v from %s to %s", + addresses, + flagInputPayloadFileName, + flagOutputPayloadFileName, + ) + + numOfPayloadWritten, err := extractPayloads(log.Logger, flagInputPayloadFileName, flagOutputPayloadFileName, addresses) + if err != nil { + log.Fatal().Err(err) + } + + err = overwritePayloadCountInFile(flagOutputPayloadFileName, numOfPayloadWritten) + if err != nil { + log.Fatal().Err(err) + } +} + +func overwritePayloadCountInFile(output string, numOfPayloadWritten int) error { + in, err := os.OpenFile(output, os.O_RDWR, 0644) + if err != nil { + return fmt.Errorf("failed to open %s to write payload count: %w", output, err) + } + defer in.Close() + + const cbor8BytesPositiveIntegerIndicator = 0x1b + + var data [9]byte + data[0] = cbor8BytesPositiveIntegerIndicator + binary.BigEndian.PutUint64(data[1:], uint64(numOfPayloadWritten)) + + n, err := in.WriteAt(data[:], 0) + if err != nil { + return fmt.Errorf("failed to overwrite number of payloads in %s: %w", output, err) + } + if n != len(data) { + return fmt.Errorf("failed to overwrite number of payloads in %s: wrote %d bytes, expect %d bytes", output, n, len(data)) + } + + return nil +} + +func extractPayloads(log zerolog.Logger, input, output string, addresses []common.Address) (int, error) { + in, err := os.Open(input) + if err != nil { + return 0, fmt.Errorf("failed to open %s: %w", input, err) + } + defer in.Close() + + reader := bufio.NewReaderSize(in, defaultBufioReadSize) + if err != nil { + return 0, fmt.Errorf("failed to create bufio reader for %s: %w", input, err) + } + + out, err := os.Create(output) + if err != nil { + return 0, fmt.Errorf("failed to open %s: %w", output, err) + } + defer out.Close() + + writer := bufio.NewWriterSize(out, defaultBufioWriteSize) + if err != nil { + return 0, fmt.Errorf("failed to create bufio writer for %s: %w", output, err) + } + defer writer.Flush() + + // Preserve 9-bytes header for number of payloads. + var head [9]byte + _, err = writer.Write(head[:]) + if err != nil { + return 0, fmt.Errorf("failed to write header for %s: %w", output, err) + } + + // Need to flush buffer before encoding payloads. + writer.Flush() + + enc := cbor.NewEncoder(writer) + + const logIntervalForPayloads = 1_000_000 + count := 0 + err = readPayloadFile(log, reader, func(rawPayload []byte) error { + + payload, err := ledger.DecodePayloadWithoutPrefix(rawPayload, false, payloadEncodingVersion) + if err != nil { + return fmt.Errorf("failed to decode payload 0x%x: %w", rawPayload, err) + } + + k, err := payload.Key() + if err != nil { + return err + } + + owner := k.KeyParts[0].Value + + include := false + for _, address := range addresses { + if bytes.Equal(owner, address[:]) { + include = true + break + } + } + + if include { + err = enc.Encode(rawPayload) + if err != nil { + return fmt.Errorf("failed to encode payload: %w", err) + } + + count++ + if count%logIntervalForPayloads == 0 { + log.Info().Msgf("wrote %d payloads", count) + } + } + + return nil + }) + if err != nil { + return 0, err + } + + log.Info().Msgf("wrote %d payloads", count) + return count, nil +} + +func parseAddresses(hexAddresses []string) ([]common.Address, error) { + if len(hexAddresses) == 0 { + return nil, fmt.Errorf("at least one address must be provided") + } + + addresses := make([]common.Address, len(hexAddresses)) + for i, hexAddr := range hexAddresses { + b, err := hex.DecodeString(strings.TrimSpace(hexAddr)) + if err != nil { + return nil, fmt.Errorf("address is not hex encoded %s: %w", strings.TrimSpace(hexAddr), err) + } + + addr, err := common.BytesToAddress(b) + if err != nil { + return nil, fmt.Errorf("cannot decode address %x", b) + } + + addresses[i] = addr + } + + return addresses, nil +} + +func readPayloadFile(log zerolog.Logger, r io.Reader, processPayload func([]byte) error) error { + dec := cbor.NewDecoder(r) + + var payloadCount int + err := dec.Decode(&payloadCount) + if err != nil { + return err + } + + log.Info().Msgf("Processing input file with %d payloads", payloadCount) + + const logIntervalForPayloads = 1_000_000 + count := 0 + for { + var rawPayload []byte + err = dec.Decode(&rawPayload) + if err == io.EOF { + break + } + if err != nil { + return err + } + + err = processPayload(rawPayload) + if err != nil { + return err + } + + count++ + if count%logIntervalForPayloads == 0 { + log.Info().Msgf("processed %d payloads", count) + } + } + + log.Info().Msgf("processed %d payloads", count) + return nil +} diff --git a/cmd/util/cmd/extract-payloads-by-address/extract_payloads_test.go b/cmd/util/cmd/extract-payloads-by-address/extract_payloads_test.go new file mode 100644 index 00000000000..443fed54518 --- /dev/null +++ b/cmd/util/cmd/extract-payloads-by-address/extract_payloads_test.go @@ -0,0 +1,241 @@ +package extractpayloads + +import ( + "bytes" + "crypto/rand" + "encoding/hex" + "path/filepath" + "strings" + "testing" + + "github.com/rs/zerolog" + "github.com/stretchr/testify/require" + + "github.com/onflow/cadence/runtime/common" + + "github.com/onflow/flow-go/cmd/util/ledger/util" + "github.com/onflow/flow-go/ledger" + "github.com/onflow/flow-go/utils/unittest" +) + +type keyPair struct { + key ledger.Key + value ledger.Value +} + +func TestExtractPayloads(t *testing.T) { + + t.Run("some payloads", func(t *testing.T) { + + unittest.RunWithTempDir(t, func(datadir string) { + + inputFile := filepath.Join(datadir, "input.payload") + outputFile := filepath.Join(datadir, "output.payload") + + size := 10 + + // Generate some data + keysValues := make(map[string]keyPair) + var payloads []*ledger.Payload + + for i := 0; i < size; i++ { + keys, values := getSampleKeyValues(i) + + for j, key := range keys { + keysValues[key.String()] = keyPair{ + key: key, + value: values[j], + } + + payloads = append(payloads, ledger.NewPayload(key, values[j])) + } + } + + numOfPayloadWritten, err := util.CreatePayloadFile(zerolog.Nop(), inputFile, payloads, nil) + require.NoError(t, err) + require.Equal(t, len(payloads), numOfPayloadWritten) + + const selectedAddressCount = 10 + selectedAddresses := make(map[string]struct{}) + selectedKeysValues := make(map[string]keyPair) + for k, kv := range keysValues { + owner := kv.key.KeyParts[0].Value + if len(owner) != common.AddressLength { + continue + } + + address, err := common.BytesToAddress(owner) + require.NoError(t, err) + + if len(selectedAddresses) < selectedAddressCount { + selectedAddresses[address.Hex()] = struct{}{} + } + + if _, exist := selectedAddresses[address.Hex()]; exist { + selectedKeysValues[k] = kv + } + } + + addresses := make([]string, 0, len(selectedAddresses)) + for address := range selectedAddresses { + addresses = append(addresses, address) + } + + // Export selected payloads + Cmd.SetArgs([]string{ + "--input-filename", inputFile, + "--output-filename", outputFile, + "--addresses", strings.Join(addresses, ","), + }) + + err = Cmd.Execute() + require.NoError(t, err) + + // Verify exported payloads. + payloadsFromFile, err := util.ReadPayloadFile(zerolog.Nop(), outputFile) + require.NoError(t, err) + require.Equal(t, len(selectedKeysValues), len(payloadsFromFile)) + + for _, payloadFromFile := range payloadsFromFile { + k, err := payloadFromFile.Key() + require.NoError(t, err) + + kv, exist := selectedKeysValues[k.String()] + require.True(t, exist) + require.Equal(t, kv.value, payloadFromFile.Value()) + } + }) + }) + + t.Run("no payloads", func(t *testing.T) { + + emptyAddress := common.Address{} + + unittest.RunWithTempDir(t, func(datadir string) { + + inputFile := filepath.Join(datadir, "input.payload") + outputFile := filepath.Join(datadir, "output.payload") + + size := 10 + + // Generate some data + keysValues := make(map[string]keyPair) + var payloads []*ledger.Payload + + for i := 0; i < size; i++ { + keys, values := getSampleKeyValues(i) + + for j, key := range keys { + if bytes.Equal(key.KeyParts[0].Value, emptyAddress[:]) { + continue + } + keysValues[key.String()] = keyPair{ + key: key, + value: values[j], + } + + payloads = append(payloads, ledger.NewPayload(key, values[j])) + } + } + + numOfPayloadWritten, err := util.CreatePayloadFile(zerolog.Nop(), inputFile, payloads, nil) + require.NoError(t, err) + require.Equal(t, len(payloads), numOfPayloadWritten) + + // Export selected payloads + Cmd.SetArgs([]string{ + "--input-filename", inputFile, + "--output-filename", outputFile, + "--addresses", hex.EncodeToString(emptyAddress[:]), + }) + + err = Cmd.Execute() + require.NoError(t, err) + + // Verify exported payloads. + payloadsFromFile, err := util.ReadPayloadFile(zerolog.Nop(), outputFile) + require.NoError(t, err) + require.Equal(t, 0, len(payloadsFromFile)) + }) + }) +} + +func getSampleKeyValues(i int) ([]ledger.Key, []ledger.Value) { + switch i { + case 0: + return []ledger.Key{getKey("", "uuid"), getKey("", "account_address_state")}, + []ledger.Value{[]byte{'1'}, []byte{'A'}} + case 1: + return []ledger.Key{getKey("ADDRESS", "public_key_count"), + getKey("ADDRESS", "public_key_0"), + getKey("ADDRESS", "exists"), + getKey("ADDRESS", "storage_used")}, + []ledger.Value{[]byte{1}, []byte("PUBLICKEYXYZ"), []byte{1}, []byte{100}} + case 2: + // TODO change the contract_names to CBOR encoding + return []ledger.Key{getKey("ADDRESS", "contract_names"), getKey("ADDRESS", "code.mycontract")}, + []ledger.Value{[]byte("mycontract"), []byte("CONTRACT Content")} + default: + keys := make([]ledger.Key, 0) + values := make([]ledger.Value, 0) + for j := 0; j < 10; j++ { + // address := make([]byte, 32) + address := make([]byte, 8) + _, err := rand.Read(address) + if err != nil { + panic(err) + } + keys = append(keys, getKey(string(address), "test")) + values = append(values, getRandomCadenceValue()) + } + return keys, values + } +} + +func getKey(owner, key string) ledger.Key { + return ledger.Key{KeyParts: []ledger.KeyPart{ + {Type: uint16(0), Value: []byte(owner)}, + {Type: uint16(2), Value: []byte(key)}, + }, + } +} + +func getRandomCadenceValue() ledger.Value { + + randomPart := make([]byte, 10) + _, err := rand.Read(randomPart) + if err != nil { + panic(err) + } + valueBytes := []byte{ + // magic prefix + 0x0, 0xca, 0xde, 0x0, 0x4, + // tag + 0xd8, 132, + // array, 5 items follow + 0x85, + + // tag + 0xd8, 193, + // UTF-8 string, length 4 + 0x64, + // t, e, s, t + 0x74, 0x65, 0x73, 0x74, + + // nil + 0xf6, + + // positive integer 1 + 0x1, + + // array, 0 items follow + 0x80, + + // UTF-8 string, length 10 + 0x6a, + 0x54, 0x65, 0x73, 0x74, 0x53, 0x74, 0x72, 0x75, 0x63, 0x74, + } + + valueBytes = append(valueBytes, randomPart...) + return ledger.Value(valueBytes) +} diff --git a/cmd/util/ledger/util/payload_file.go b/cmd/util/ledger/util/payload_file.go new file mode 100644 index 00000000000..6524cce8261 --- /dev/null +++ b/cmd/util/ledger/util/payload_file.go @@ -0,0 +1,201 @@ +package util + +import ( + "bufio" + "bytes" + "fmt" + "io" + "os" + + "github.com/fxamacker/cbor/v2" + "github.com/rs/zerolog" + + "github.com/onflow/cadence/runtime/common" + + "github.com/onflow/flow-go/ledger" +) + +const ( + defaultBufioWriteSize = 1024 * 32 + defaultBufioReadSize = 1024 * 32 + + payloadEncodingVersion = 1 +) + +func CreatePayloadFile( + logger zerolog.Logger, + payloadFile string, + payloads []*ledger.Payload, + addresses []common.Address, +) (int, error) { + + f, err := os.Create(payloadFile) + if err != nil { + return 0, fmt.Errorf("can't create %s: %w", payloadFile, err) + } + defer f.Close() + + writer := bufio.NewWriterSize(f, defaultBufioWriteSize) + if err != nil { + return 0, fmt.Errorf("can't create bufio writer for %s: %w", payloadFile, err) + } + defer writer.Flush() + + includeAllPayloads := len(addresses) == 0 + + if includeAllPayloads { + return writeAllPayloads(logger, writer, payloads) + } + + return writeSelectedPayloads(logger, writer, payloads, addresses) +} + +func writeAllPayloads(logger zerolog.Logger, w io.Writer, payloads []*ledger.Payload) (int, error) { + logger.Info().Msgf("writing %d payloads to file", len(payloads)) + + enc := cbor.NewEncoder(w) + + // Encode number of payloads + err := enc.Encode(len(payloads)) + if err != nil { + return 0, fmt.Errorf("failed to encode number of payloads %d in CBOR: %w", len(payloads), err) + } + + var payloadScratchBuffer [1024 * 2]byte + for _, p := range payloads { + + buf := ledger.EncodeAndAppendPayloadWithoutPrefix(payloadScratchBuffer[:0], p, payloadEncodingVersion) + + // Encode payload + err = enc.Encode(buf) + if err != nil { + return 0, err + } + } + + return len(payloads), nil +} + +func writeSelectedPayloads(logger zerolog.Logger, w io.Writer, payloads []*ledger.Payload, addresses []common.Address) (int, error) { + var includedPayloadCount int + + includedFlags := make([]bool, len(payloads)) + for i, p := range payloads { + include, err := includePayloadByAddresses(p, addresses) + if err != nil { + return 0, err + } + + includedFlags[i] = include + + if include { + includedPayloadCount++ + } + } + + logger.Info().Msgf("writing %d payloads to file", includedPayloadCount) + + enc := cbor.NewEncoder(w) + + // Encode number of payloads + err := enc.Encode(includedPayloadCount) + if err != nil { + return 0, fmt.Errorf("failed to encode number of payloads %d in CBOR: %w", includedPayloadCount, err) + } + + var payloadScratchBuffer [1024 * 2]byte + for i, included := range includedFlags { + if !included { + continue + } + + p := payloads[i] + + buf := ledger.EncodeAndAppendPayloadWithoutPrefix(payloadScratchBuffer[:0], p, payloadEncodingVersion) + + // Encode payload + err = enc.Encode(buf) + if err != nil { + return 0, err + } + } + + return includedPayloadCount, nil +} + +func includePayloadByAddresses(payload *ledger.Payload, addresses []common.Address) (bool, error) { + if len(addresses) == 0 { + // Include all payloads + return true, nil + } + + k, err := payload.Key() + if err != nil { + return false, fmt.Errorf("failed to get key from payload: %w", err) + } + + owner := k.KeyParts[0].Value + + for _, address := range addresses { + if bytes.Equal(owner, address[:]) { + return true, nil + } + } + + return false, nil +} + +func ReadPayloadFile(logger zerolog.Logger, payloadFile string) ([]*ledger.Payload, error) { + + if _, err := os.Stat(payloadFile); os.IsNotExist(err) { + return nil, fmt.Errorf("%s doesn't exist", payloadFile) + } + + f, err := os.Open(payloadFile) + if err != nil { + return nil, fmt.Errorf("failed to open %s: %w", payloadFile, err) + } + defer f.Close() + + r := bufio.NewReaderSize(f, defaultBufioReadSize) + if err != nil { + return nil, fmt.Errorf("failed to create bufio reader for %s: %w", payloadFile, err) + } + + dec := cbor.NewDecoder(r) + + // Decode number of payloads + var payloadCount int + err = dec.Decode(&payloadCount) + if err != nil { + return nil, fmt.Errorf("failed to decode number of payload in CBOR: %w", err) + } + + logger.Info().Msgf("reading %d payloads from file", payloadCount) + + payloads := make([]*ledger.Payload, 0, payloadCount) + + for { + var rawPayload []byte + err := dec.Decode(&rawPayload) + if err == io.EOF { + break + } + if err != nil { + return nil, fmt.Errorf("failed to decode payload in CBOR: %w", err) + } + + payload, err := ledger.DecodePayloadWithoutPrefix(rawPayload, false, payloadEncodingVersion) + if err != nil { + return nil, fmt.Errorf("failed to decode payload 0x%x: %w", rawPayload, err) + } + + payloads = append(payloads, payload) + } + + if payloadCount != len(payloads) { + return nil, fmt.Errorf("failed to decode %s: expect %d payloads, got %d payloads", payloadFile, payloadCount, len(payloads)) + } + + return payloads, nil +} diff --git a/cmd/util/ledger/util/payload_file_test.go b/cmd/util/ledger/util/payload_file_test.go new file mode 100644 index 00000000000..d37da30444f --- /dev/null +++ b/cmd/util/ledger/util/payload_file_test.go @@ -0,0 +1,272 @@ +package util_test + +import ( + "bytes" + "crypto/rand" + "path/filepath" + "testing" + + "github.com/rs/zerolog" + "github.com/stretchr/testify/require" + + "github.com/onflow/cadence/runtime/common" + + "github.com/onflow/flow-go/cmd/util/ledger/util" + "github.com/onflow/flow-go/ledger" + "github.com/onflow/flow-go/utils/unittest" +) + +type keyPair struct { + key ledger.Key + value ledger.Value +} + +func TestPayloadFile(t *testing.T) { + + const fileName = "root.payload" + + t.Run("without filter", func(t *testing.T) { + unittest.RunWithTempDir(t, func(datadir string) { + size := 10 + + payloadFileName := filepath.Join(datadir, fileName) + + // Generate some data + keysValues := make(map[string]keyPair) + var payloads []*ledger.Payload + + for i := 0; i < size; i++ { + keys, values := getSampleKeyValues(i) + + for j, key := range keys { + keysValues[key.String()] = keyPair{ + key: key, + value: values[j], + } + + payloads = append(payloads, ledger.NewPayload(key, values[j])) + } + } + + numOfPayloadWritten, err := util.CreatePayloadFile( + zerolog.Nop(), + payloadFileName, + payloads, + nil, + ) + require.NoError(t, err) + require.Equal(t, len(payloads), numOfPayloadWritten) + + payloadsFromFile, err := util.ReadPayloadFile(zerolog.Nop(), payloadFileName) + require.NoError(t, err) + require.Equal(t, len(payloads), len(payloadsFromFile)) + + for _, payloadFromFile := range payloadsFromFile { + k, err := payloadFromFile.Key() + require.NoError(t, err) + + kv, exist := keysValues[k.String()] + require.True(t, exist) + + require.Equal(t, kv.value, payloadFromFile.Value()) + } + }) + }) + + t.Run("with filter", func(t *testing.T) { + unittest.RunWithTempDir(t, func(datadir string) { + size := 10 + + payloadFileName := filepath.Join(datadir, fileName) + + // Generate some data + keysValues := make(map[string]keyPair) + var payloads []*ledger.Payload + + for i := 0; i < size; i++ { + keys, values := getSampleKeyValues(i) + + for j, key := range keys { + keysValues[key.String()] = keyPair{ + key: key, + value: values[j], + } + + payloads = append(payloads, ledger.NewPayload(key, values[j])) + } + } + + const selectedAddressCount = 10 + selectedAddresses := make(map[common.Address]struct{}) + selectedKeysValues := make(map[string]keyPair) + for k, kv := range keysValues { + owner := kv.key.KeyParts[0].Value + if len(owner) != common.AddressLength { + continue + } + + address, err := common.BytesToAddress(owner) + require.NoError(t, err) + + if len(selectedAddresses) < selectedAddressCount { + selectedAddresses[address] = struct{}{} + } + + if _, exist := selectedAddresses[address]; exist { + selectedKeysValues[k] = kv + } + } + + addresses := make([]common.Address, 0, len(selectedAddresses)) + for address := range selectedAddresses { + addresses = append(addresses, address) + } + + numOfPayloadWritten, err := util.CreatePayloadFile( + zerolog.Nop(), + payloadFileName, + payloads, + addresses, + ) + require.NoError(t, err) + require.Equal(t, len(selectedKeysValues), numOfPayloadWritten) + + payloadsFromFile, err := util.ReadPayloadFile(zerolog.Nop(), payloadFileName) + require.NoError(t, err) + require.Equal(t, len(selectedKeysValues), len(payloadsFromFile)) + + for _, payloadFromFile := range payloadsFromFile { + k, err := payloadFromFile.Key() + require.NoError(t, err) + + kv, exist := selectedKeysValues[k.String()] + require.True(t, exist) + + require.Equal(t, kv.value, payloadFromFile.Value()) + } + }) + }) + + t.Run("no payloads found with filter", func(t *testing.T) { + emptyAddress := common.Address{} + + unittest.RunWithTempDir(t, func(datadir string) { + size := 10 + + payloadFileName := filepath.Join(datadir, fileName) + + // Generate some data + keysValues := make(map[string]keyPair) + var payloads []*ledger.Payload + + for i := 0; i < size; i++ { + keys, values := getSampleKeyValues(i) + + for j, key := range keys { + if bytes.Equal(key.KeyParts[0].Value, emptyAddress[:]) { + continue + } + keysValues[key.String()] = keyPair{ + key: key, + value: values[j], + } + + payloads = append(payloads, ledger.NewPayload(key, values[j])) + } + } + + numOfPayloadWritten, err := util.CreatePayloadFile( + zerolog.Nop(), + payloadFileName, + payloads, + []common.Address{emptyAddress}, + ) + require.NoError(t, err) + require.Equal(t, 0, numOfPayloadWritten) + + payloadsFromFile, err := util.ReadPayloadFile(zerolog.Nop(), payloadFileName) + require.NoError(t, err) + require.Equal(t, 0, len(payloadsFromFile)) + }) + }) +} + +func getSampleKeyValues(i int) ([]ledger.Key, []ledger.Value) { + switch i { + case 0: + return []ledger.Key{getKey("", "uuid"), getKey("", "account_address_state")}, + []ledger.Value{[]byte{'1'}, []byte{'A'}} + case 1: + return []ledger.Key{getKey("ADDRESS", "public_key_count"), + getKey("ADDRESS", "public_key_0"), + getKey("ADDRESS", "exists"), + getKey("ADDRESS", "storage_used")}, + []ledger.Value{[]byte{1}, []byte("PUBLICKEYXYZ"), []byte{1}, []byte{100}} + case 2: + // TODO change the contract_names to CBOR encoding + return []ledger.Key{getKey("ADDRESS", "contract_names"), getKey("ADDRESS", "code.mycontract")}, + []ledger.Value{[]byte("mycontract"), []byte("CONTRACT Content")} + default: + keys := make([]ledger.Key, 0) + values := make([]ledger.Value, 0) + for j := 0; j < 10; j++ { + // address := make([]byte, 32) + address := make([]byte, 8) + _, err := rand.Read(address) + if err != nil { + panic(err) + } + keys = append(keys, getKey(string(address), "test")) + values = append(values, getRandomCadenceValue()) + } + return keys, values + } +} + +func getKey(owner, key string) ledger.Key { + return ledger.Key{KeyParts: []ledger.KeyPart{ + {Type: uint16(0), Value: []byte(owner)}, + {Type: uint16(2), Value: []byte(key)}, + }, + } +} + +func getRandomCadenceValue() ledger.Value { + + randomPart := make([]byte, 10) + _, err := rand.Read(randomPart) + if err != nil { + panic(err) + } + valueBytes := []byte{ + // magic prefix + 0x0, 0xca, 0xde, 0x0, 0x4, + // tag + 0xd8, 132, + // array, 5 items follow + 0x85, + + // tag + 0xd8, 193, + // UTF-8 string, length 4 + 0x64, + // t, e, s, t + 0x74, 0x65, 0x73, 0x74, + + // nil + 0xf6, + + // positive integer 1 + 0x1, + + // array, 0 items follow + 0x80, + + // UTF-8 string, length 10 + 0x6a, + 0x54, 0x65, 0x73, 0x74, 0x53, 0x74, 0x72, 0x75, 0x63, 0x74, + } + + valueBytes = append(valueBytes, randomPart...) + return ledger.Value(valueBytes) +}