diff --git a/cmd/util/ledger/migrations/contract_checking_migration_test.go b/cmd/util/ledger/migrations/contract_checking_migration_test.go new file mode 100644 index 00000000000..5f7cb84553a --- /dev/null +++ b/cmd/util/ledger/migrations/contract_checking_migration_test.go @@ -0,0 +1,203 @@ +package migrations + +import ( + "fmt" + "sort" + "testing" + + "github.com/onflow/cadence/runtime/common" + "github.com/onflow/cadence/runtime/interpreter" + coreContracts "github.com/onflow/flow-core-contracts/lib/go/contracts" + "github.com/onflow/flow-core-contracts/lib/go/templates" + "github.com/rs/zerolog" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/onflow/flow-go/cmd/util/ledger/util/registers" + "github.com/onflow/flow-go/fvm/environment" + "github.com/onflow/flow-go/fvm/systemcontracts" + "github.com/onflow/flow-go/model/flow" +) + +func oldExampleTokenCode(fungibleTokenAddress flow.Address) string { + return fmt.Sprintf( + ` + import FungibleToken from 0x%s + + pub contract ExampleToken: FungibleToken { + pub var totalSupply: UFix64 + + pub resource Vault: FungibleToken.Provider, FungibleToken.Receiver, FungibleToken.Balance { + pub var balance: UFix64 + + init(balance: UFix64) { + self.balance = balance + } + + pub fun withdraw(amount: UFix64): @FungibleToken.Vault { + self.balance = self.balance - amount + emit TokensWithdrawn(amount: amount, from: self.owner?.address) + return <-create Vault(balance: amount) + } + + pub fun deposit(from: @FungibleToken.Vault) { + let vault <- from as! @ExampleToken.Vault + self.balance = self.balance + vault.balance + emit TokensDeposited(amount: vault.balance, to: self.owner?.address) + vault.balance = 0.0 + destroy vault + } + + destroy() { + if self.balance > 0.0 { + ExampleToken.totalSupply = ExampleToken.totalSupply - self.balance + } + } + } + + pub fun createEmptyVault(): @Vault { + return <-create Vault(balance: 0.0) + } + + init() { + self.totalSupply = 0.0 + } + } + `, + fungibleTokenAddress.Hex(), + ) +} + +func TestContractCheckingMigrationProgramRecovery(t *testing.T) { + + t.Parallel() + + registersByAccount := registers.NewByAccount() + + // Set up contracts + + const chainID = flow.Testnet + + systemContracts := systemcontracts.SystemContractsForChain(chainID) + + contracts := map[flow.Address]map[string][]byte{} + + addContract := func(address flow.Address, name string, code []byte) { + addressContracts, ok := contracts[address] + if !ok { + addressContracts = map[string][]byte{} + contracts[address] = addressContracts + } + require.Empty(t, addressContracts[name]) + addressContracts[name] = code + } + + addSystemContract := func(systemContract systemcontracts.SystemContract, code []byte) { + addContract(systemContract.Address, systemContract.Name, code) + } + + env := templates.Environment{} + + addSystemContract( + systemContracts.ViewResolver, + coreContracts.ViewResolver(), + ) + env.ViewResolverAddress = systemContracts.ViewResolver.Address.Hex() + + addSystemContract( + systemContracts.Burner, + coreContracts.Burner(), + ) + env.BurnerAddress = systemContracts.Burner.Address.Hex() + + addSystemContract( + systemContracts.FungibleToken, + coreContracts.FungibleToken(env), + ) + + // Use an old version of the ExampleToken contract, + // and "deploy" it at some arbitrary, high (i.e. non-system) address + exampleTokenAddress, err := chainID.Chain().AddressAtIndex(1000) + require.NoError(t, err) + addContract( + exampleTokenAddress, + "ExampleToken", + []byte(oldExampleTokenCode(systemContracts.FungibleToken.Address)), + ) + + for address, addressContracts := range contracts { + + for contractName, code := range addressContracts { + + err := registersByAccount.Set( + string(address[:]), + flow.ContractKey(contractName), + code, + ) + require.NoError(t, err) + } + + contractNames := make([]string, 0, len(addressContracts)) + for contractName := range addressContracts { + contractNames = append(contractNames, contractName) + } + sort.Strings(contractNames) + + encodedContractNames, err := environment.EncodeContractNames(contractNames) + require.NoError(t, err) + + err = registersByAccount.Set( + string(address[:]), + flow.ContractNamesKey, + encodedContractNames, + ) + require.NoError(t, err) + } + + programs := map[common.Location]*interpreter.Program{} + + rwf := &testReportWriterFactory{} + + // Run contract checking migration + + log := zerolog.Nop() + checkingMigration := NewContractCheckingMigration( + log, + rwf, + chainID, + false, + nil, + programs, + ) + + err = checkingMigration(registersByAccount) + require.NoError(t, err) + + reporter := rwf.reportWriters[contractCheckingReporterName] + + assert.Equal(t, + []any{ + contractCheckingSuccess{ + AccountAddress: common.Address(systemContracts.ViewResolver.Address), + ContractName: systemcontracts.ContractNameViewResolver, + Code: string(coreContracts.ViewResolver()), + }, + contractCheckingSuccess{ + AccountAddress: common.Address(systemContracts.Burner.Address), + ContractName: systemcontracts.ContractNameBurner, + Code: string(coreContracts.Burner()), + }, + contractCheckingSuccess{ + AccountAddress: common.Address(systemContracts.FungibleToken.Address), + ContractName: systemcontracts.ContractNameFungibleToken, + Code: string(coreContracts.FungibleToken(env)), + }, + contractCheckingSuccess{ + AccountAddress: common.Address(exampleTokenAddress), + ContractName: "ExampleToken", + Code: oldExampleTokenCode(systemContracts.FungibleToken.Address), + }, + }, + reporter.entries, + ) +} diff --git a/cmd/util/ledger/migrations/migrator_runtime.go b/cmd/util/ledger/migrations/migrator_runtime.go index 3185bdbd2f5..ba72224d520 100644 --- a/cmd/util/ledger/migrations/migrator_runtime.go +++ b/cmd/util/ledger/migrations/migrator_runtime.go @@ -45,6 +45,7 @@ type InterpreterMigrationRuntimeConfig struct { } func (c InterpreterMigrationRuntimeConfig) NewRuntimeInterface( + chainID flow.ChainID, transactionState state.NestedTransactionPreparer, accounts environment.Accounts, ) ( @@ -90,6 +91,7 @@ func (c InterpreterMigrationRuntimeConfig) NewRuntimeInterface( } return util.NewMigrationRuntimeInterface( + chainID, getCodeFunc, getContractNames, getOrLoadProgram, @@ -141,6 +143,7 @@ func NewInterpreterMigrationRuntime( }) runtimeInterface, err := config.NewRuntimeInterface( + chainID, basicMigrationRuntime.TransactionState, basicMigrationRuntime.Accounts, ) diff --git a/cmd/util/ledger/util/migration_runtime_interface.go b/cmd/util/ledger/util/migration_runtime_interface.go index 7de4e871804..f91c4116096 100644 --- a/cmd/util/ledger/util/migration_runtime_interface.go +++ b/cmd/util/ledger/util/migration_runtime_interface.go @@ -9,6 +9,7 @@ import ( "github.com/onflow/cadence/runtime/common" "github.com/onflow/cadence/runtime/interpreter" + "github.com/onflow/flow-go/fvm/environment" "github.com/onflow/flow-go/fvm/storage/derived" "github.com/onflow/flow-go/fvm/storage/state" "github.com/onflow/flow-go/model/flow" @@ -36,6 +37,7 @@ type GerOrLoadProgramListenerFunc func( // It only allows parsing and checking of contracts. type MigrationRuntimeInterface struct { runtime.EmptyRuntimeInterface + chainID flow.ChainID GetContractCodeFunc GetContractCodeFunc GetContractNamesFunc GetContractNamesFunc GetOrLoadProgramFunc GetOrLoadProgramFunc @@ -45,12 +47,14 @@ type MigrationRuntimeInterface struct { var _ runtime.Interface = &MigrationRuntimeInterface{} func NewMigrationRuntimeInterface( + chainID flow.ChainID, getCodeFunc GetContractCodeFunc, getContractNamesFunc GetContractNamesFunc, getOrLoadProgramFunc GetOrLoadProgramFunc, getOrLoadProgramListenerFunc GerOrLoadProgramListenerFunc, ) *MigrationRuntimeInterface { return &MigrationRuntimeInterface{ + chainID: chainID, GetContractCodeFunc: getCodeFunc, GetContractNamesFunc: getContractNamesFunc, GetOrLoadProgramFunc: getOrLoadProgramFunc, @@ -168,8 +172,11 @@ func (m *MigrationRuntimeInterface) GetOrLoadProgram( return getOrLoadProgram(location, load) } -func (m *MigrationRuntimeInterface) RecoverProgram(_ *ast.Program, _ common.Location) (*ast.Program, error) { - return nil, nil +func (m *MigrationRuntimeInterface) RecoverProgram( + program *ast.Program, + location common.Location, +) (*ast.Program, error) { + return environment.RecoverProgram(nil, m.chainID, program, location) } type migrationTransactionPreparer struct { diff --git a/fvm/environment/facade_env.go b/fvm/environment/facade_env.go index be9a5d12f05..f15def32c3b 100644 --- a/fvm/environment/facade_env.go +++ b/fvm/environment/facade_env.go @@ -11,6 +11,7 @@ import ( "github.com/onflow/flow-go/fvm/storage/snapshot" "github.com/onflow/flow-go/fvm/storage/state" "github.com/onflow/flow-go/fvm/tracing" + "github.com/onflow/flow-go/model/flow" ) var _ Environment = &facadeEnvironment{} @@ -336,7 +337,17 @@ func (*facadeEnvironment) GetInterpreterSharedState() *interpreter.SharedState { return nil } -func (env *facadeEnvironment) RecoverProgram(_ *ast.Program, _ common.Location) (*ast.Program, error) { - // NO-OP - return nil, nil +func (env *facadeEnvironment) RecoverProgram(program *ast.Program, location common.Location) (*ast.Program, error) { + // Enabled on all networks but Mainnet, + // until https://github.com/onflow/flips/pull/283 got approved. + if env.chain.ChainID() == flow.Mainnet { + return nil, nil + } + + return RecoverProgram( + env, + env.chain.ChainID(), + program, + location, + ) } diff --git a/fvm/environment/program_recovery.go b/fvm/environment/program_recovery.go new file mode 100644 index 00000000000..9f2485c617d --- /dev/null +++ b/fvm/environment/program_recovery.go @@ -0,0 +1,229 @@ +package environment + +import ( + "fmt" + + "github.com/onflow/cadence/runtime/ast" + "github.com/onflow/cadence/runtime/common" + "github.com/onflow/cadence/runtime/parser" + "github.com/onflow/cadence/runtime/sema" + + "github.com/onflow/flow-go/fvm/systemcontracts" + "github.com/onflow/flow-go/model/flow" +) + +func RecoverProgram( + memoryGauge common.MemoryGauge, + chainID flow.ChainID, + program *ast.Program, + location common.Location, +) ( + *ast.Program, + error, +) { + addressLocation, ok := location.(common.AddressLocation) + if !ok { + return nil, nil + } + + sc := systemcontracts.SystemContractsForChain(chainID) + + fungibleTokenAddress := common.Address(sc.FungibleToken.Address) + + if !isFungibleTokenContract(program, fungibleTokenAddress) { + return nil, nil + } + + contractName := addressLocation.Name + + code := RecoveredFungibleTokenCode(fungibleTokenAddress, contractName) + + return parser.ParseProgram(memoryGauge, []byte(code), parser.Config{}) +} + +func RecoveredFungibleTokenCode(fungibleTokenAddress common.Address, contractName string) string { + return fmt.Sprintf( + //language=Cadence + ` + import FungibleToken from %s + + access(all) + contract %s: FungibleToken { + + access(all) + var totalSupply: UFix64 + + init() { + self.totalSupply = 0.0 + } + + access(all) + view fun getContractViews(resourceType: Type?): [Type] { + panic("getContractViews is not implemented") + } + + access(all) + fun resolveContractView(resourceType: Type?, viewType: Type): AnyStruct? { + panic("resolveContractView is not implemented") + } + + access(all) + resource Vault: FungibleToken.Vault { + + access(all) + var balance: UFix64 + + init(balance: UFix64) { + self.balance = balance + } + + access(FungibleToken.Withdraw) + fun withdraw(amount: UFix64): @{FungibleToken.Vault} { + panic("withdraw is not implemented") + } + + access(all) + view fun isAvailableToWithdraw(amount: UFix64): Bool { + panic("isAvailableToWithdraw is not implemented") + } + + access(all) + fun deposit(from: @{FungibleToken.Vault}) { + panic("deposit is not implemented") + } + + access(all) + fun createEmptyVault(): @{FungibleToken.Vault} { + panic("createEmptyVault is not implemented") + } + + access(all) + view fun getViews(): [Type] { + panic("getViews is not implemented") + } + + access(all) + fun resolveView(_ view: Type): AnyStruct? { + panic("resolveView is not implemented") + } + } + + access(all) + fun createEmptyVault(vaultType: Type): @{FungibleToken.Vault} { + panic("createEmptyVault is not implemented") + } + } + `, + fungibleTokenAddress.HexWithPrefix(), + contractName, + ) +} + +func importsAddressLocation(program *ast.Program, address common.Address, name string) bool { + importDeclarations := program.ImportDeclarations() + + // Check if the location is imported by any import declaration + for _, importDeclaration := range importDeclarations { + + // The import declaration imports from the same address + importedLocation, ok := importDeclaration.Location.(common.AddressLocation) + if !ok || importedLocation.Address != address { + continue + } + + // The import declaration imports all identifiers, so also the location + if len(importDeclaration.Identifiers) == 0 { + return true + } + + // The import declaration imports specific identifiers, so check if the location is imported + for _, identifier := range importDeclaration.Identifiers { + if identifier.Identifier == name { + return true + } + } + } + + return false +} + +func declaresConformanceTo(conformingDeclaration ast.ConformingDeclaration, name string) bool { + for _, conformance := range conformingDeclaration.ConformanceList() { + if conformance.Identifier.Identifier == name { + return true + } + } + + return false +} + +func isNominalType(ty ast.Type, name string) bool { + nominalType, ok := ty.(*ast.NominalType) + return ok && + len(nominalType.NestedIdentifiers) == 0 && + nominalType.Identifier.Identifier == name +} + +const fungibleTokenTypeIdentifier = "FungibleToken" +const fungibleTokenTypeTotalSupplyFieldName = "totalSupply" +const fungibleTokenVaultTypeIdentifier = "Vault" +const fungibleTokenVaultTypeBalanceFieldName = "balance" + +func isFungibleTokenContract(program *ast.Program, fungibleTokenAddress common.Address) bool { + + // Check if the contract imports the FungibleToken contract + if !importsAddressLocation(program, fungibleTokenAddress, fungibleTokenTypeIdentifier) { + return false + } + + contractDeclaration := program.SoleContractDeclaration() + if contractDeclaration == nil { + return false + } + + // Check if the contract implements the FungibleToken interface + if !declaresConformanceTo(contractDeclaration, fungibleTokenTypeIdentifier) { + return false + } + + // Check if the contract has a totalSupply field + totalSupplyFieldDeclaration := getField(contractDeclaration, fungibleTokenTypeTotalSupplyFieldName) + if totalSupplyFieldDeclaration == nil { + return false + } + + // Check if the totalSupply field is of type UFix64 + if !isNominalType(totalSupplyFieldDeclaration.TypeAnnotation.Type, sema.UFix64TypeName) { + return false + } + + // Check if the contract has a Vault resource + + vaultDeclaration := contractDeclaration.Members.CompositesByIdentifier()[fungibleTokenVaultTypeIdentifier] + if vaultDeclaration == nil { + return false + } + + // Check if the Vault resource has a balance field + balanceFieldDeclaration := getField(vaultDeclaration, fungibleTokenVaultTypeBalanceFieldName) + if balanceFieldDeclaration == nil { + return false + } + + // Check if the balance field is of type UFix64 + if !isNominalType(balanceFieldDeclaration.TypeAnnotation.Type, sema.UFix64TypeName) { + return false + } + + return true +} + +func getField(declaration *ast.CompositeDeclaration, name string) *ast.FieldDeclaration { + for _, fieldDeclaration := range declaration.Members.Fields() { + if fieldDeclaration.Identifier.Identifier == name { + return fieldDeclaration + } + } + + return nil +} diff --git a/fvm/evm/stdlib/checking.go b/fvm/evm/stdlib/checking.go index da6b4f2fdee..d4862cb56c6 100644 --- a/fvm/evm/stdlib/checking.go +++ b/fvm/evm/stdlib/checking.go @@ -4,7 +4,6 @@ import ( "fmt" "github.com/onflow/cadence/runtime" - "github.com/onflow/cadence/runtime/ast" "github.com/onflow/cadence/runtime/common" "github.com/onflow/cadence/runtime/interpreter" ) @@ -93,7 +92,3 @@ func (r *checkingInterface) GetOrLoadProgram( func (r *checkingInterface) GetAccountContractCode(location common.AddressLocation) (code []byte, err error) { return r.SystemContractCodes[location], nil } - -func (*checkingInterface) RecoverProgram(_ *ast.Program, _ common.Location) (*ast.Program, error) { - return nil, nil -}