diff --git a/runtime/sema/check_function.go b/runtime/sema/check_function.go index c84275e3af..5c2fe0dea0 100644 --- a/runtime/sema/check_function.go +++ b/runtime/sema/check_function.go @@ -187,6 +187,7 @@ func (checker *Checker) checkFunction( functionActivation.InitializationInfo = initializationInfo if functionBlock != nil { + oldMappedAccess := checker.entitlementMappingInScope if mappedAccess, isMappedAccess := access.(*EntitlementMapAccess); isMappedAccess { checker.entitlementMappingInScope = mappedAccess.Type } @@ -199,7 +200,7 @@ func (checker *Checker) checkFunction( ) }) - checker.entitlementMappingInScope = nil + checker.entitlementMappingInScope = oldMappedAccess if mustExit { returnType := functionType.ReturnTypeAnnotation.Type diff --git a/runtime/sema/check_member_expression.go b/runtime/sema/check_member_expression.go index 68580b9205..fff37b9d5d 100644 --- a/runtime/sema/check_member_expression.go +++ b/runtime/sema/check_member_expression.go @@ -319,16 +319,7 @@ func (checker *Checker) visitMember(expression *ast.MemberExpression) (accessedT shouldSubstituteAuthorization := !member.Access.Equal(resultingAuthorization) if shouldSubstituteAuthorization { - switch ty := resultingType.(type) { - case *FunctionType: - resultingType = NewSimpleFunctionType( - ty.Purity, - ty.Parameters, - ty.ReturnTypeAnnotation.Map(checker.memoryGauge, make(map[*TypeParameter]*TypeParameter), substituteConcreteAuthorization), - ) - default: - resultingType = resultingType.Map(checker.memoryGauge, make(map[*TypeParameter]*TypeParameter), substituteConcreteAuthorization) - } + resultingType = resultingType.Map(checker.memoryGauge, make(map[*TypeParameter]*TypeParameter), substituteConcreteAuthorization) } // Check that the member access is not to a function of resource type diff --git a/runtime/sema/checker.go b/runtime/sema/checker.go index d206d15469..9390cf6c8e 100644 --- a/runtime/sema/checker.go +++ b/runtime/sema/checker.go @@ -1256,20 +1256,22 @@ func (checker *Checker) functionType( parameterList *ast.ParameterList, returnTypeAnnotation *ast.TypeAnnotation, ) *FunctionType { + + oldMappedAccess := checker.entitlementMappingInScope + if mapAccess, isMapAccess := access.(*EntitlementMapAccess); isMapAccess { + checker.entitlementMappingInScope = mapAccess.Type + } else { + checker.entitlementMappingInScope = nil + } + convertedParameters := checker.parameters(parameterList) convertedReturnTypeAnnotation := VoidTypeAnnotation if returnTypeAnnotation != nil { - // to allow entitlement mapping types to be used in the return annotation only of - // a mapped accessor function, we introduce a "variable" into the typing scope while - // checking the return - if mapAccess, isMapAccess := access.(*EntitlementMapAccess); isMapAccess { - checker.entitlementMappingInScope = mapAccess.Type - } convertedReturnTypeAnnotation = checker.ConvertTypeAnnotation(returnTypeAnnotation) - checker.entitlementMappingInScope = nil } + checker.entitlementMappingInScope = oldMappedAccess return &FunctionType{ Purity: PurityFromAnnotation(purity), diff --git a/runtime/tests/checker/entitlements_test.go b/runtime/tests/checker/entitlements_test.go index 4ad05c99e5..9a7bc9d4c5 100644 --- a/runtime/tests/checker/entitlements_test.go +++ b/runtime/tests/checker/entitlements_test.go @@ -1461,19 +1461,78 @@ func TestCheckBasicEntitlementMappingAccess(t *testing.T) { require.IsType(t, &sema.InvalidMappedEntitlementMemberError{}, errs[0]) }) + t.Run("accessor function with mapped ref arg", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + entitlement E + entitlement F + entitlement G + entitlement H + entitlement mapping M { + E -> F + G -> H + } + struct interface S { + access(M) fun foo(_ arg: auth(M) &Int): auth(M) &Int + } + + fun foo(s: auth(E) &{S}) { + s.foo(&1 as auth(F) &Int) + } + `) + + assert.NoError(t, err) + }) + t.Run("accessor function with invalid mapped ref arg", func(t *testing.T) { t.Parallel() _, err := ParseAndCheck(t, ` - entitlement mapping M {} + entitlement E + entitlement F + entitlement G + entitlement H + entitlement mapping M { + E -> F + G -> H + } struct interface S { - access(M) fun foo(arg: auth(M) &Int): auth(M) &Int + access(M) fun foo(_ arg: auth(M) &Int): auth(M) &Int + } + + fun foo(s: auth(E) &{S}) { + s.foo(&1 as auth(H) &Int) } `) errs := RequireCheckerErrors(t, err, 1) - require.IsType(t, &sema.InvalidMappedAuthorizationOutsideOfFieldError{}, errs[0]) + require.IsType(t, &sema.TypeMismatchError{}, errs[0]) + }) + + t.Run("accessor function with full mapped ref arg", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + entitlement E + entitlement F + entitlement G + entitlement H + entitlement mapping M { + E -> F + G -> H + } + struct interface S { + access(M) fun foo(_ arg: auth(M) &Int): auth(M) &Int + } + + fun foo(s: {S}) { + s.foo(&1 as auth(F, H) &Int) + } + `) + + assert.NoError(t, err) }) t.Run("multiple mappings conjunction", func(t *testing.T) { @@ -7736,4 +7795,42 @@ func TestCheckEntitlementMappingComplexFields(t *testing.T) { require.IsType(t, &sema.InvalidAccessError{}, errors[1]) }) + t.Run("lambda escape", func(t *testing.T) { + + t.Parallel() + + _, err := ParseAndCheck(t, ` + entitlement Inner1 + entitlement Inner2 + entitlement Outer1 + entitlement Outer2 + + entitlement mapping MyMap { + Outer1 -> Inner1 + Outer2 -> Inner2 + } + struct InnerObj { + access(Inner1) fun first(): Int{ return 9999 } + access(Inner2) fun second(): Int{ return 8888 } + } + + struct FuncGenerator { + access(MyMap) fun generate(): auth(MyMap) &Int? { + // cannot declare lambda with mapped entitlement + fun innerFunc(_ param: auth(MyMap) &InnerObj): Int { + return 123; + } + var f = innerFunc; // will fail if we're called via a reference + return nil; + } + } + + fun test() { + (&FuncGenerator() as auth(Outer1) &FuncGenerator).generate() + } + `) + + errors := RequireCheckerErrors(t, err, 1) + require.IsType(t, &sema.InvalidMappedAuthorizationOutsideOfFieldError{}, errors[0]) + }) } diff --git a/runtime/tests/interpreter/entitlements_test.go b/runtime/tests/interpreter/entitlements_test.go index f950245eb0..ad6dfeda60 100644 --- a/runtime/tests/interpreter/entitlements_test.go +++ b/runtime/tests/interpreter/entitlements_test.go @@ -2079,6 +2079,83 @@ func TestInterpretEntitlementMappingAccessors(t *testing.T) { ).Equal(value.(*interpreter.EphemeralReferenceValue).Authorization), ) }) + + t.Run("accessor function with mapped ref arg", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + entitlement E + entitlement F + entitlement G + entitlement H + entitlement mapping M { + E -> F + G -> H + } + struct S { + access(M) fun foo(_ arg: auth(M) &Int): auth(M) &Int { + return arg + } + } + + fun test(): auth(F) &Int { + let s = S() + let sRef = &s as auth(E) &S + return sRef.foo(&1) + } + `) + + value, err := inter.Invoke("test") + require.NoError(t, err) + + require.True( + t, + interpreter.NewEntitlementSetAuthorization( + nil, + func() []common.TypeID { return []common.TypeID{"S.test.F"} }, + 1, + sema.Conjunction, + ).Equal(value.(*interpreter.EphemeralReferenceValue).Authorization), + ) + }) + + t.Run("accessor function with full mapped ref arg", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + entitlement E + entitlement F + entitlement G + entitlement H + entitlement mapping M { + E -> F + G -> H + } + struct S { + access(M) fun foo(_ arg: auth(M) &Int): auth(M) &Int { + return arg + } + } + + fun test(): auth(F, H) &Int { + let s = S() + return s.foo(&1) + } + `) + + value, err := inter.Invoke("test") + require.NoError(t, err) + + require.True( + t, + interpreter.NewEntitlementSetAuthorization( + nil, + func() []common.TypeID { return []common.TypeID{"S.test.F", "S.test.H"} }, + 2, + sema.Conjunction, + ).Equal(value.(*interpreter.EphemeralReferenceValue).Authorization), + ) + }) } func TestInterpretEntitledAttachments(t *testing.T) { @@ -3419,49 +3496,4 @@ func TestInterpretEntitlementMappingComplexFields(t *testing.T) { value, ) }) - - t.Run("lambda escape", func(t *testing.T) { - - t.Parallel() - - inter := parseCheckAndInterpret(t, ` - entitlement Inner1 - entitlement Inner2 - entitlement Outer1 - entitlement Outer2 - - entitlement mapping MyMap { - Outer1 -> Inner1 - Outer2 -> Inner2 - } - struct InnerObj { - access(Inner1) fun first(): Int{ return 9999 } - access(Inner2) fun second(): Int{ return 8888 } - } - - struct FuncGenerator { - access(MyMap) fun generate(): auth(MyMap) &Int? { - fun innerFunc(_ param: auth(MyMap) &InnerObj): Int { - return 123; - } - var f = innerFunc; // will fail if we're called via a reference - return nil; - } - } - - fun test() { - (&FuncGenerator() as auth(Outer1) &FuncGenerator).generate() - } - `) - - value, err := inter.Invoke("test") - require.NoError(t, err) - - AssertValuesEqual( - t, - inter, - interpreter.NewUnmeteredIntValueFromInt64(9999+8888), - value, - ) - }) }