diff --git a/app/upgrade.go b/app/upgrade.go index ac396758..aaf84e0f 100644 --- a/app/upgrade.go +++ b/app/upgrade.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" + "cosmossdk.io/collections" sdkerrors "cosmossdk.io/errors" upgradetypes "cosmossdk.io/x/upgrade/types" "github.com/cosmos/cosmos-sdk/types/module" @@ -35,14 +36,6 @@ func (app *InitiaApp) RegisterUpgradeHandlers(cfg module.Configurator) { // 2. update vm data with new seperator and add checksums of each module - type KV struct { - key []byte - value []byte - } - kvs := make([]KV, 0) - - rmKeys := make([][]byte, 0) - // Previous: // ModuleSeparator = byte(0) // ResourceSeparator = byte(1) @@ -56,7 +49,7 @@ func (app *InitiaApp) RegisterUpgradeHandlers(cfg module.Configurator) { // TableEntrySeparator = byte(3) // TableInfoSeparator = byte(4) - err = app.MoveKeeper.VMStore.Walk(ctx, nil, func(key, value []byte) (stop bool, err error) { + err = app.MoveKeeper.VMStore.Walk(ctx, new(collections.Range[[]byte]).Descending(), func(key, value []byte) (stop bool, err error) { key = bytes.Clone(key) value = bytes.Clone(value) @@ -73,34 +66,22 @@ func (app *InitiaApp) RegisterUpgradeHandlers(cfg module.Configurator) { } else if separator >= movetypes.TableInfoSeparator { return true, errors.New("unknown prefix") } else { - rmKeys = append(rmKeys, bytes.Clone(key)) + err = app.MoveKeeper.VMStore.Remove(ctx, bytes.Clone(key)) + if err != nil { + return true, err + } } - key[cursor] = key[cursor] + 1 - kvs = append(kvs, KV{ - key: key, - value: value, - }) + err = app.MoveKeeper.VMStore.Set(ctx, key, value) + if err != nil { + return true, err + } return false, nil }) if err != nil { return nil, err } - for _, key := range rmKeys { - err = app.MoveKeeper.VMStore.Remove(ctx, key) - if err != nil { - return nil, err - } - } - - for _, kv := range kvs { - err = app.MoveKeeper.VMStore.Set(ctx, kv.key, kv.value) - if err != nil { - return nil, err - } - } - // 3. update new modules codesBz, err := vmprecompile.ReadStdlib("object_code_deployment.mv", "coin.mv", "cosmos.mv", "dex.mv", "json.mv", "bech32.mv", "hash.mv", "collection.mv") diff --git a/x/move/keeper/keeper.go b/x/move/keeper/keeper.go index 92b10b21..1ebafc6f 100644 --- a/x/move/keeper/keeper.go +++ b/x/move/keeper/keeper.go @@ -543,7 +543,16 @@ func (k Keeper) SetTableEntry( // IterateVMStore iterate VMStore store for genesis export func (k Keeper) IterateVMStore(ctx context.Context, cb func(*types.Module, *types.Checksum, *types.Resource, *types.TableInfo, *types.TableEntry)) error { - err := k.VMStore.Walk(ctx, nil, func(key, value []byte) (stop bool, err error) { + return k.walkVMStore(ctx, cb, nil) +} + +// ReverseIterateVMStore iterate VMStore store for genesis export +func (k Keeper) ReverseIterateVMStore(ctx context.Context, cb func(*types.Module, *types.Checksum, *types.Resource, *types.TableInfo, *types.TableEntry)) error { + return k.walkVMStore(ctx, cb, new(collections.Range[[]byte]).Descending()) +} + +func (k Keeper) walkVMStore(ctx context.Context, cb func(*types.Module, *types.Checksum, *types.Resource, *types.TableInfo, *types.TableEntry), ranger collections.Ranger[[]byte]) error { + err := k.VMStore.Walk(ctx, ranger, func(key, value []byte) (stop bool, err error) { cursor := types.AddressBytesLength addrBytes := key[:cursor] separator := key[cursor] diff --git a/x/move/keeper/keeper_test.go b/x/move/keeper/keeper_test.go index 4662c90b..c6e812ff 100644 --- a/x/move/keeper/keeper_test.go +++ b/x/move/keeper/keeper_test.go @@ -181,16 +181,20 @@ func TestIterateVMStore(t *testing.T) { input.MoveKeeper.SetChecksum(ctx, vmtypes.StdAddress, "BasicCoin", basicCoinChecksum[:]) input.MoveKeeper.SetResource(ctx, vmtypes.TestAddress, structTag, data) - input.MoveKeeper.SetTableInfo(ctx, types.TableInfo{ - Address: vmtypes.TestAddress.String(), - KeyType: "u64", - ValueType: "u64", - }) + input.MoveKeeper.SetTableEntry(ctx, types.TableEntry{ Address: vmtypes.TestAddress.String(), KeyBytes: []byte{1, 2, 3}, ValueBytes: []byte{4, 5, 6}, }) + input.MoveKeeper.SetTableInfo(ctx, types.TableInfo{ + Address: vmtypes.TestAddress.String(), + KeyType: "u64", + ValueType: "u64", + }) + + counter := 0 + input.MoveKeeper.IterateVMStore(ctx, func(module *types.Module, checksum *types.Checksum, resource *types.Resource, tableInfo *types.TableInfo, tableEntry *types.TableEntry) { if module != nil && module.ModuleName == "BasicCoin" { require.Equal(t, types.Module{ @@ -199,6 +203,8 @@ func TestIterateVMStore(t *testing.T) { RawBytes: basicCoinModule, UpgradePolicy: types.UpgradePolicy_COMPATIBLE, }, *module) + require.Equal(t, 0, counter) + counter++ } if checksum != nil && checksum.ModuleName == "BasicCoin" { @@ -207,6 +213,8 @@ func TestIterateVMStore(t *testing.T) { ModuleName: "BasicCoin", Checksum: basicCoinChecksum[:], }, *checksum) + require.Equal(t, 1, counter) + counter++ } if resource != nil && resource.Address == "0x2" { @@ -215,6 +223,18 @@ func TestIterateVMStore(t *testing.T) { StructTag: structTagStr, RawBytes: data, }, *resource) + require.Equal(t, 2, counter) + counter++ + } + + if tableEntry != nil && tableEntry.Address == "0x2" { + require.Equal(t, types.TableEntry{ + Address: vmtypes.TestAddress.String(), + KeyBytes: []byte{1, 2, 3}, + ValueBytes: []byte{4, 5, 6}, + }, *tableEntry) + require.Equal(t, 3, counter) + counter++ } if tableInfo != nil && tableInfo.Address == "0x2" { @@ -223,6 +243,72 @@ func TestIterateVMStore(t *testing.T) { KeyType: "u64", ValueType: "u64", }, *tableInfo) + require.Equal(t, 4, counter) + counter++ + } + }) +} + +func TestReverseIterateVMStore(t *testing.T) { + ctx, input := createDefaultTestInput(t) + + input.MoveKeeper.SetModule(ctx, vmtypes.StdAddress, "BasicCoin", basicCoinModule) + + structTagStr := "0x1::BasicCoin::Coin<0x1::BasicCoin::Initia>" + structTag, err := vmapi.ParseStructTag(structTagStr) + require.NoError(t, err) + + data, err := vmtypes.SerializeUint64(100) + require.NoError(t, err) + + basicCoinChecksum := types.ModuleBzToChecksum(basicCoinModule) + input.MoveKeeper.SetChecksum(ctx, vmtypes.StdAddress, "BasicCoin", basicCoinChecksum[:]) + + input.MoveKeeper.SetResource(ctx, vmtypes.TestAddress, structTag, data) + + input.MoveKeeper.SetTableEntry(ctx, types.TableEntry{ + Address: vmtypes.TestAddress.String(), + KeyBytes: []byte{1, 2, 3}, + ValueBytes: []byte{4, 5, 6}, + }) + + input.MoveKeeper.SetTableInfo(ctx, types.TableInfo{ + Address: vmtypes.TestAddress.String(), + KeyType: "u64", + ValueType: "u64", + }) + + counter := 0 + input.MoveKeeper.ReverseIterateVMStore(ctx, func(module *types.Module, checksum *types.Checksum, resource *types.Resource, tableInfo *types.TableInfo, tableEntry *types.TableEntry) { + if module != nil && module.ModuleName == "BasicCoin" { + require.Equal(t, types.Module{ + Address: "0x1", + ModuleName: "BasicCoin", + RawBytes: basicCoinModule, + UpgradePolicy: types.UpgradePolicy_COMPATIBLE, + }, *module) + require.Equal(t, 4, counter) + counter++ + } + + if checksum != nil && checksum.ModuleName == "BasicCoin" { + require.Equal(t, types.Checksum{ + Address: "0x1", + ModuleName: "BasicCoin", + Checksum: basicCoinChecksum[:], + }, *checksum) + require.Equal(t, 3, counter) + counter++ + } + + if resource != nil && resource.Address == "0x2" { + require.Equal(t, types.Resource{ + Address: "0x2", + StructTag: structTagStr, + RawBytes: data, + }, *resource) + require.Equal(t, 2, counter) + counter++ } if tableEntry != nil && tableEntry.Address == "0x2" { @@ -231,6 +317,18 @@ func TestIterateVMStore(t *testing.T) { KeyBytes: []byte{1, 2, 3}, ValueBytes: []byte{4, 5, 6}, }, *tableEntry) + require.Equal(t, 1, counter) + counter++ + } + + if tableInfo != nil && tableInfo.Address == "0x2" { + require.Equal(t, types.TableInfo{ + Address: vmtypes.TestAddress.String(), + KeyType: "u64", + ValueType: "u64", + }, *tableInfo) + require.Equal(t, 0, counter) + counter++ } }) }