Skip to content

Commit

Permalink
Merge pull request onflow#3166 from onflow/bastian/fix-legacy-interse…
Browse files Browse the repository at this point in the history
…ction-type

Fix intersection type's legacy type getting converted to intersection type
  • Loading branch information
turbolent authored Mar 11, 2024
2 parents d073ef2 + 6ce5874 commit 4e712e7
Show file tree
Hide file tree
Showing 5 changed files with 242 additions and 26 deletions.
191 changes: 190 additions & 1 deletion migrations/entitlements/migration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3187,7 +3187,10 @@ func TestRehash(t *testing.T) {

storage, inter := newStorageAndInterpreter(t)

inter.SharedState.Config.CompositeTypeHandler = func(location common.Location, typeID interpreter.TypeID) *sema.CompositeType {
inter.SharedState.Config.CompositeTypeHandler = func(
location common.Location,
typeID interpreter.TypeID,
) *sema.CompositeType {

compositeType := &sema.CompositeType{
Location: fooAddressLocation,
Expand Down Expand Up @@ -3293,3 +3296,189 @@ func TestRehash(t *testing.T) {
)
})
}

func TestIntersectionTypeWithIntersectionLegacyType(t *testing.T) {

t.Parallel()

testAddress := common.Address{0x42}

const interface1QualifiedIdentifier = "SI1"
interfaceType1 := interpreter.NewInterfaceStaticType(
nil,
utils.TestLocation,
interface1QualifiedIdentifier,
utils.TestLocation.TypeID(nil, interface1QualifiedIdentifier),
)

const interface2QualifiedIdentifier = "SI2"
interfaceType2 := interpreter.NewInterfaceStaticType(
nil,
utils.TestLocation,
interface2QualifiedIdentifier,
utils.TestLocation.TypeID(nil, interface2QualifiedIdentifier),
)

ledger := NewTestLedger(nil, nil)

storageMapKey := interpreter.StringStorageMapKey("dict")

newStorageAndInterpreter := func(t *testing.T) (*runtime.Storage, *interpreter.Interpreter) {
storage := runtime.NewStorage(ledger, nil)
inter, err := interpreter.NewInterpreter(
nil,
utils.TestLocation,
&interpreter.Config{
Storage: storage,
// NOTE: disabled, because encoded and decoded values are expected to not match
AtreeValueValidationEnabled: false,
AtreeStorageValidationEnabled: true,
},
)
require.NoError(t, err)

return storage, inter
}

t.Run("prepare", func(t *testing.T) {

storage, inter := newStorageAndInterpreter(t)

expectedIntersection := interpreter.NewIntersectionStaticType(
nil,
[]*interpreter.InterfaceStaticType{
interfaceType1,
},
)
// NOTE: setting the legacy type to an intersection type
expectedIntersection.LegacyType = interpreter.NewIntersectionStaticType(
nil,
[]*interpreter.InterfaceStaticType{
interfaceType2,
},
)

storedValue := interpreter.NewTypeValue(
nil,
interpreter.NewReferenceStaticType(
nil,
interpreter.UnauthorizedAccess,
expectedIntersection,
),
)

storageMap := storage.GetStorageMap(
testAddress,
common.PathDomainStorage.Identifier(),
true,
)

storageMap.SetValue(inter,
storageMapKey,
storedValue,
)

err := storage.Commit(inter, false)
require.NoError(t, err)

err = storage.CheckHealth()
require.NoError(t, err)
})

t.Run("migrate", func(t *testing.T) {

storage, inter := newStorageAndInterpreter(t)

inter.SharedState.Config.InterfaceTypeHandler = func(
location common.Location,
typeID interpreter.TypeID,
) *sema.InterfaceType {

_, qualifiedIdentifier, err := common.DecodeTypeID(nil, string(typeID))
require.NoError(t, err)

return &sema.InterfaceType{
Location: TestLocation,
Identifier: qualifiedIdentifier,
CompositeKind: common.CompositeKindStructure,
Members: &sema.StringMemberOrderedMap{},
}
}

migration := migrations.NewStorageMigration(inter, storage)

reporter := newTestReporter()

migration.Migrate(
&migrations.AddressSliceIterator{
Addresses: []common.Address{
testAddress,
},
},
migration.NewValueMigrationsPathMigrator(
reporter,
NewEntitlementsMigration(inter),
),
)

err := migration.Commit()
require.NoError(t, err)

// Assert

err = storage.CheckHealth()
require.NoError(t, err)

assert.Empty(t, reporter.errors)

require.Equal(t,
map[struct {
interpreter.StorageKey
interpreter.StorageMapKey
}]struct{}{
{
StorageKey: interpreter.StorageKey{
Address: testAddress,
Key: common.PathDomainStorage.Identifier(),
},
StorageMapKey: storageMapKey,
}: {},
},
reporter.migrated,
)
})

t.Run("load", func(t *testing.T) {

storage, inter := newStorageAndInterpreter(t)

err := storage.CheckHealth()
require.NoError(t, err)

storageMap := storage.GetStorageMap(
testAddress,
common.PathDomainStorage.Identifier(),
false,
)

storedValue := storageMap.ReadValue(inter, storageMapKey)

require.IsType(t, interpreter.TypeValue{}, storedValue)

typeValue := storedValue.(interpreter.TypeValue)

expectedType := interpreter.NewReferenceStaticType(
nil,
interpreter.UnauthorizedAccess,
interpreter.NewIntersectionStaticType(
nil,
[]*interpreter.InterfaceStaticType{
// NOTE: this is the legacy type
interfaceType2,
},
),
)

require.Equal(t, expectedType, typeValue.Type)
})
}
32 changes: 16 additions & 16 deletions migrations/statictypes/statictype_migration.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,30 +247,31 @@ func (m *StaticTypeMigration) maybeConvertStaticType(staticType, parentType inte

legacyType := rewrittenIntersectionType.LegacyType

var mergedIntersections bool

var convertedLegacyType interpreter.StaticType
if legacyType != nil {
convertedLegacyType = m.maybeConvertStaticType(legacyType, rewrittenIntersectionType)
switch ty := convertedLegacyType.(type) {
switch convertedLegacyType.(type) {
case nil,
*interpreter.CompositeStaticType,
interpreter.PrimitiveStaticType:
// valid
break

case *interpreter.IntersectionStaticType:
// If the legacy type was converted to an intersection type,
// then merge it into the resulting intersection type

legacyType = nil
convertedLegacyType = nil

convertedInterfaceTypes = append(
convertedInterfaceTypes,
ty.Types...,
)
mergedIntersections = true
// also valid, temporarily:
//
// Given an intersection type T{Us}, where T is a legacy type, and Us are interface types,
// and given T is converted to intersection type V,
// then the resulting type is V{Us} (e.g. when V is {Ws}, {Ws}{Us}).
//
// The resulting type is expected to be ("temporarily") invalid.
// The entitlements migrations will handle such cases,
// i.e. rewrite the type to a valid type (V/{Ws}).
//
// It is important to not merge the intersection types, e.g. into {Us, Ws},
// to ensure that the entitlement migration does not infer entitlements for this type,
// which would incorrectly also add entitlements for the legacy type (which was restricted).
break

default:
panic(fmt.Errorf(
Expand All @@ -290,8 +291,7 @@ func (m *StaticTypeMigration) maybeConvertStaticType(staticType, parentType inte
// even if the interface types in the set have not changed.
if len(rewrittenIntersectionType.Types) >= 2 ||
convertedInterfaceType ||
convertedLegacyType != nil ||
mergedIntersections {
convertedLegacyType != nil {

result := interpreter.NewIntersectionStaticType(nil, convertedInterfaceTypes)

Expand Down
27 changes: 19 additions & 8 deletions migrations/statictypes/statictype_migration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,7 @@ func TestStaticTypeMigration(t *testing.T) {

})

t.Run("merge converted legacy type when intersection", func(t *testing.T) {
t.Run("legacy type gets converted intersection", func(t *testing.T) {

t.Parallel()

Expand Down Expand Up @@ -674,18 +674,29 @@ func TestStaticTypeMigration(t *testing.T) {
true,
)

// NOTE: the expected type {S2}{S1} is expected to be ("temporarily") invalid.
// The entitlements migrations will handle such cases, i.e. rewrite the type to a valid type ({S2}).
// This is important to ensure that the entitlement migration does not infer entitlements for {S1, S2}.

expectedIntersection := interpreter.NewIntersectionStaticType(
nil,
[]*interpreter.InterfaceStaticType{
interfaceType1,
},
)
expectedIntersection.LegacyType = interpreter.NewIntersectionStaticType(
nil,
[]*interpreter.InterfaceStaticType{
interfaceType2,
},
)

expected := interpreter.NewTypeValue(
nil,
interpreter.NewReferenceStaticType(
nil,
interpreter.UnauthorizedAccess,
interpreter.NewIntersectionStaticType(
nil,
[]*interpreter.InterfaceStaticType{
interfaceType1,
interfaceType2,
},
),
expectedIntersection,
),
)

Expand Down
2 changes: 2 additions & 0 deletions runtime/interpreter/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ type Config struct {
UUIDHandler UUIDHandlerFunc
// CompositeTypeHandler is used to load composite types
CompositeTypeHandler CompositeTypeHandlerFunc
// InterfaceTypeHandler is used to load interface types
InterfaceTypeHandler InterfaceTypeHandlerFunc
// CompositeValueFunctionsHandler is used to load composite value functions
CompositeValueFunctionsHandler CompositeValueFunctionsHandlerFunc
BaseActivationHandler func(location common.Location) *VariableActivation
Expand Down
16 changes: 15 additions & 1 deletion runtime/interpreter/interpreter.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,9 @@ type UUIDHandlerFunc func() (uint64, error)
// CompositeTypeHandlerFunc is a function that loads composite types.
type CompositeTypeHandlerFunc func(location common.Location, typeID TypeID) *sema.CompositeType

// InterfaceTypeHandlerFunc is a function that loads interface types.
type InterfaceTypeHandlerFunc func(location common.Location, typeID TypeID) *sema.InterfaceType

// CompositeValueFunctionsHandlerFunc is a function that loads composite value functions.
type CompositeValueFunctionsHandlerFunc func(
inter *Interpreter,
Expand Down Expand Up @@ -4767,7 +4770,18 @@ func (interpreter *Interpreter) GetInterfaceType(
typeID TypeID,
) (*sema.InterfaceType, error) {
if location == nil {
return nil, InterfaceMissingLocationError{QualifiedIdentifier: qualifiedIdentifier}
return nil, InterfaceMissingLocationError{
QualifiedIdentifier: qualifiedIdentifier,
}
}

config := interpreter.SharedState.Config
interfaceTypeHandler := config.InterfaceTypeHandler
if interfaceTypeHandler != nil {
interfaceType := interfaceTypeHandler(location, typeID)
if interfaceType != nil {
return interfaceType, nil
}
}

elaboration := interpreter.getElaboration(location)
Expand Down

0 comments on commit 4e712e7

Please sign in to comment.