From 3238e578c9b612e670f37a275610ab6f5deb52cf Mon Sep 17 00:00:00 2001 From: Daniel Sainati Date: Tue, 17 Oct 2023 11:56:57 -0400 Subject: [PATCH 1/6] properly substitute types in nested mapped references --- runtime/sema/check_member_expression.go | 4 +- runtime/tests/checker/entitlements_test.go | 257 +++++++++++++++++++++ 2 files changed, 259 insertions(+), 2 deletions(-) diff --git a/runtime/sema/check_member_expression.go b/runtime/sema/check_member_expression.go index 4d1d044a0d..e43d651ae8 100644 --- a/runtime/sema/check_member_expression.go +++ b/runtime/sema/check_member_expression.go @@ -330,10 +330,10 @@ func (checker *Checker) visitMember(expression *ast.MemberExpression) (accessedT resultingType = NewSimpleFunctionType( ty.Purity, ty.Parameters, - NewTypeAnnotation(substituteConcreteAuthorization(ty.ReturnTypeAnnotation.Type)), + ty.ReturnTypeAnnotation.Map(checker.memoryGauge, make(map[*TypeParameter]*TypeParameter), substituteConcreteAuthorization), ) default: - resultingType = substituteConcreteAuthorization(resultingType) + resultingType = resultingType.Map(checker.memoryGauge, make(map[*TypeParameter]*TypeParameter), substituteConcreteAuthorization) } } diff --git a/runtime/tests/checker/entitlements_test.go b/runtime/tests/checker/entitlements_test.go index 727c0f613d..3ae52fead5 100644 --- a/runtime/tests/checker/entitlements_test.go +++ b/runtime/tests/checker/entitlements_test.go @@ -7440,3 +7440,260 @@ func TestInterpretMappingEscalation(t *testing.T) { }) } + +func TestCheckEntitlementMappingComplexFields(t *testing.T) { + + t.Parallel() + + t.Run("array mapped field", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + entitlement Inner1 + entitlement Inner2 + entitlement Outer1 + entitlement Outer2 + + entitlement mapping MyMap { + Outer1 -> Inner1 + Outer2 -> Inner2 + } + struct InnerObj { + access(Inner1) fun first(): Int{ return 9999 } + access(Inner2) fun second(): Int{ return 8888 } + } + + struct Carrier{ + access(MyMap) let arr: [auth(MyMap) &InnerObj] + init() { + self.arr = [&InnerObj()] + } + } + + fun foo() { + let x: auth(Inner1, Inner2) &InnerObj = Carrier().arr[0] + x.first() + x.second() + } + `) + + require.NoError(t, err) + }) + + t.Run("array mapped field via reference", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + entitlement Inner1 + entitlement Inner2 + entitlement Outer1 + entitlement Outer2 + + entitlement mapping MyMap { + Outer1 -> Inner1 + Outer2 -> Inner2 + } + struct InnerObj { + access(Inner1) fun first(): Int{ return 9999 } + access(Inner2) fun second(): Int{ return 8888 } + } + + struct Carrier{ + access(MyMap) let arr: [auth(MyMap) &InnerObj] + init() { + self.arr = [&InnerObj()] + } + } + + fun foo() { + let x = (&Carrier() as auth(Outer1) &Carrier).arr[0] + x.first() // ok + x.second() // fails + } + `) + + errors := RequireCheckerErrors(t, err, 1) + require.IsType(t, &sema.InvalidAccessError{}, errors[0]) + }) + + t.Run("array mapped function", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + entitlement Inner1 + entitlement Inner2 + entitlement Outer1 + entitlement Outer2 + + entitlement mapping MyMap { + Outer1 -> Inner1 + Outer2 -> Inner2 + } + struct InnerObj { + access(Inner1) fun first(): Int{ return 9999 } + access(Inner2) fun second(): Int{ return 8888 } + } + + struct Carrier{ + access(MyMap) fun getArr(): [auth(MyMap) &InnerObj] { + return [&InnerObj()] + } + } + + + `) + + errors := RequireCheckerErrors(t, err, 1) + require.IsType(t, &sema.InvalidMappedEntitlementMemberError{}, errors[0]) + }) + + t.Run("dictionary mapped field", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + entitlement Inner1 + entitlement Inner2 + entitlement Outer1 + entitlement Outer2 + + entitlement mapping MyMap { + Outer1 -> Inner1 + Outer2 -> Inner2 + } + struct InnerObj { + access(Inner1) fun first(): Int{ return 9999 } + access(Inner2) fun second(): Int{ return 8888 } + } + + struct Carrier{ + access(MyMap) let dict: {String: auth(MyMap) &InnerObj} + init() { + self.dict = {"": &InnerObj()} + } + } + + fun foo() { + let x: auth(Inner1, Inner2) &InnerObj = Carrier().dict[""]! + x.first() + x.second() + } + `) + + require.NoError(t, err) + }) + + t.Run("dictionary mapped field via reference", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + entitlement Inner1 + entitlement Inner2 + entitlement Outer1 + entitlement Outer2 + + entitlement mapping MyMap { + Outer1 -> Inner1 + Outer2 -> Inner2 + } + struct InnerObj { + access(Inner1) fun first(): Int{ return 9999 } + access(Inner2) fun second(): Int{ return 8888 } + } + + struct Carrier{ + access(MyMap) let dict: {String: auth(MyMap) &InnerObj} + init() { + self.dict = {"": &InnerObj()} + } + } + + fun foo() { + let x = (&Carrier() as auth(Outer1) &Carrier).dict[""]! + x.first() // ok + x.second() // fails + } + `) + + errors := RequireCheckerErrors(t, err, 1) + require.IsType(t, &sema.InvalidAccessError{}, errors[0]) + }) + + t.Run("array mapped function", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + entitlement Inner1 + entitlement Inner2 + entitlement Outer1 + entitlement Outer2 + + entitlement mapping MyMap { + Outer1 -> Inner1 + Outer2 -> Inner2 + } + struct InnerObj { + access(Inner1) fun first(): Int{ return 9999 } + access(Inner2) fun second(): Int{ return 8888 } + } + + struct Carrier{ + access(MyMap) fun getDict(): {String: auth(MyMap) &InnerObj} { + return {"": &InnerObj()} + } + } + + + `) + + errors := RequireCheckerErrors(t, err, 1) + require.IsType(t, &sema.InvalidMappedEntitlementMemberError{}, errors[0]) + }) + + t.Run("lambda mapped array field", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + entitlement Inner1 + entitlement Inner2 + entitlement Outer1 + entitlement Outer2 + + entitlement mapping MyMap { + Outer1 -> Inner1 + Outer2 -> Inner2 + } + struct InnerObj { + access(Inner1) fun first(): Int{ return 9999 } + access(Inner2) fun second(): Int{ return 8888 } + } + + struct Carrier{ + access(MyMap) let fnArr: [fun(auth(MyMap) &InnerObj): auth(MyMap) &InnerObj] + init() { + let innerObj = &InnerObj() as auth(Inner1, Inner2) &InnerObj + self.fnArr = [fun(_ x: &InnerObj): auth(Inner1, Inner2) &InnerObj { + return innerObj + }] + } + + } + + fun foo() { + let x = (&Carrier() as auth(Outer1) &Carrier).fnArr[0] + x(&InnerObj()).first() // ok + + x(&InnerObj() as auth(Inner1) &InnerObj).first() // ok + + x(&InnerObj() as auth(Inner2) &InnerObj).first() // mismatch + + x(&InnerObj()).second() // fails + } + + `) + + errors := RequireCheckerErrors(t, err, 2) + require.IsType(t, &sema.TypeMismatchError{}, errors[0]) + require.IsType(t, &sema.InvalidAccessError{}, errors[1]) + }) + +} From 81fce7eac7a04c69261dd9d8f0cafff8b44f6a85 Mon Sep 17 00:00:00 2001 From: Daniel Sainati Date: Tue, 17 Oct 2023 12:23:28 -0400 Subject: [PATCH 2/6] add more tests --- runtime/sema/check_member_expression.go | 6 - runtime/tests/checker/entitlements_test.go | 84 ++++++++ .../tests/interpreter/entitlements_test.go | 192 ++++++++++++++++++ 3 files changed, 276 insertions(+), 6 deletions(-) diff --git a/runtime/sema/check_member_expression.go b/runtime/sema/check_member_expression.go index e43d651ae8..68580b9205 100644 --- a/runtime/sema/check_member_expression.go +++ b/runtime/sema/check_member_expression.go @@ -312,12 +312,6 @@ func (checker *Checker) visitMember(expression *ast.MemberExpression) (accessedT switch ty := resultingType.(type) { case *ReferenceType: return NewReferenceType(checker.memoryGauge, resultingAuthorization, ty.Type) - case *OptionalType: - switch innerTy := ty.Type.(type) { - case *ReferenceType: - return NewOptionalType(checker.memoryGauge, - NewReferenceType(checker.memoryGauge, resultingAuthorization, innerTy.Type)) - } } return resultingType } diff --git a/runtime/tests/checker/entitlements_test.go b/runtime/tests/checker/entitlements_test.go index 3ae52fead5..9e04558f18 100644 --- a/runtime/tests/checker/entitlements_test.go +++ b/runtime/tests/checker/entitlements_test.go @@ -7547,6 +7547,90 @@ func TestCheckEntitlementMappingComplexFields(t *testing.T) { require.IsType(t, &sema.InvalidMappedEntitlementMemberError{}, errors[0]) }) + t.Run("array mapped field escape", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + entitlement Inner1 + entitlement Inner2 + entitlement Outer1 + entitlement Outer2 + + entitlement mapping MyMap { + Outer1 -> Inner1 + Outer2 -> Inner2 + } + struct InnerObj { + access(Inner1) fun first(): Int{ return 9999 } + access(Inner2) fun second(): Int{ return 8888 } + } + + struct Carrier{ + access(MyMap) let arr: [auth(MyMap) &InnerObj] + init() { + self.arr = [&InnerObj()] + } + } + + struct TranslatorStruct { + access(self) var carrier: &Carrier; + access(MyMap) fun translate(): auth(MyMap) &InnerObj { + return self.carrier.arr[0] // type mismatch + } + init(_ carrier: &Carrier) { + self.carrier = carrier + } + } + `) + + errors := RequireCheckerErrors(t, err, 1) + require.IsType(t, &sema.TypeMismatchError{}, errors[0]) + }) + + t.Run("lambda escape", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + entitlement Inner1 + entitlement Inner2 + entitlement Outer1 + entitlement Outer2 + + entitlement mapping MyMap { + Outer1 -> Inner1 + Outer2 -> Inner2 + } + struct InnerObj { + access(Inner1) fun first(): Int{ return 9999 } + access(Inner2) fun second(): Int{ return 8888 } + } + + struct Carrier{ + access(MyMap) let arr: [auth(MyMap) &InnerObj] + init() { + self.arr = [&InnerObj()] + } + } + + struct FuncGenerator { + access(MyMap) fun generate(): auth(MyMap) &Int? { + fun innerFunc(_ param: auth(MyMap) &InnerObj): Int { + return 123; + } + var f = innerFunc; // will fail if we're called via a reference + return nil; + } + } + + fun foo() { + (&FuncGenerator() as auth(Outer1) &FuncGenerator).generate() + } + `) + + errors := RequireCheckerErrors(t, err, 1) + require.IsType(t, &sema.TypeMismatchError{}, errors[0]) + }) + t.Run("dictionary mapped field", func(t *testing.T) { t.Parallel() diff --git a/runtime/tests/interpreter/entitlements_test.go b/runtime/tests/interpreter/entitlements_test.go index 42e7c57f60..2befbf6d92 100644 --- a/runtime/tests/interpreter/entitlements_test.go +++ b/runtime/tests/interpreter/entitlements_test.go @@ -3280,3 +3280,195 @@ func TestInterpretMappingInclude(t *testing.T) { ) }) } + +func TestInterpretEntitlementMappingComplexFields(t *testing.T) { + t.Parallel() + + t.Run("array field", func(t *testing.T) { + + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + entitlement Inner1 + entitlement Inner2 + entitlement Outer1 + entitlement Outer2 + + entitlement mapping MyMap { + Outer1 -> Inner1 + Outer2 -> Inner2 + } + struct InnerObj { + access(Inner1) fun first(): Int{ return 9999 } + access(Inner2) fun second(): Int{ return 8888 } + } + + struct Carrier{ + access(MyMap) let arr: [auth(MyMap) &InnerObj] + init() { + self.arr = [&InnerObj()] + } + } + + fun test(): Int { + let x = Carrier().arr[0] + return x.first() + x.second() + } + `) + + value, err := inter.Invoke("test") + require.NoError(t, err) + + AssertValuesEqual( + t, + inter, + interpreter.NewUnmeteredIntValueFromInt64(9999+8888), + value, + ) + }) + + t.Run("dictionary field", func(t *testing.T) { + + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + entitlement Inner1 + entitlement Inner2 + entitlement Outer1 + entitlement Outer2 + + entitlement mapping MyMap { + Outer1 -> Inner1 + Outer2 -> Inner2 + } + struct InnerObj { + access(Inner1) fun first(): Int{ return 9999 } + access(Inner2) fun second(): Int{ return 8888 } + } + + struct Carrier{ + access(MyMap) let dict: {String: auth(MyMap) &InnerObj} + init() { + self.dict = {"": &InnerObj()} + } + } + + fun test(): Int { + let x = Carrier().dict[""]! + return x.first() + x.second() + } + `) + + value, err := inter.Invoke("test") + require.NoError(t, err) + + AssertValuesEqual( + t, + inter, + interpreter.NewUnmeteredIntValueFromInt64(9999+8888), + value, + ) + }) + + t.Run("lambda array field", func(t *testing.T) { + + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + entitlement Inner1 + entitlement Inner2 + entitlement Outer1 + entitlement Outer2 + + entitlement mapping MyMap { + Outer1 -> Inner1 + Outer2 -> Inner2 + } + struct InnerObj { + access(Inner1) fun first(): Int{ return 9999 } + access(Inner2) fun second(): Int{ return 8888 } + } + + struct Carrier{ + access(MyMap) let fnArr: [fun(auth(MyMap) &InnerObj): auth(MyMap) &InnerObj] + init() { + let innerObj = &InnerObj() as auth(Inner1, Inner2) &InnerObj + self.fnArr = [fun(_ x: &InnerObj): auth(Inner1, Inner2) &InnerObj { + return innerObj + }] + } + + } + + fun test(): Int { + let carrier = Carrier() + let ref1 = &carrier as auth(Outer1) &Carrier + let ref2 = &carrier as auth(Outer2) &Carrier + return ref1.fnArr[0](&InnerObj() as auth(Inner1) &InnerObj).first() + + ref2.fnArr[0](&InnerObj() as auth(Inner2) &InnerObj).second() + } + `) + + value, err := inter.Invoke("test") + require.NoError(t, err) + + AssertValuesEqual( + t, + inter, + interpreter.NewUnmeteredIntValueFromInt64(9999+8888), + value, + ) + }) + + t.Run("lambda escape", func(t *testing.T) { + + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + entitlement Inner1 + entitlement Inner2 + entitlement Outer1 + entitlement Outer2 + + entitlement mapping MyMap { + Outer1 -> Inner1 + Outer2 -> Inner2 + } + struct InnerObj { + access(Inner1) fun first(): Int{ return 9999 } + access(Inner2) fun second(): Int{ return 8888 } + } + + struct Carrier{ + access(MyMap) let arr: [auth(MyMap) &InnerObj] + init() { + self.arr = [&InnerObj()] + } + } + + struct FuncGenerator { + access(MyMap) fun generate(): auth(MyMap) &Int? { + fun innerFunc(_ param: auth(MyMap) &InnerObj): Int { + return 123; + } + var f = innerFunc; // will fail if we're called via a reference + return nil; + } + } + + fun test() { + (&FuncGenerator() as auth(Outer1) &FuncGenerator).generate() + } + `) + + value, err := inter.Invoke("test") + require.NoError(t, err) + + AssertValuesEqual( + t, + inter, + interpreter.NewUnmeteredIntValueFromInt64(9999+8888), + value, + ) + }) +} From 3e0f26a16c45ffdf3f1b615471d58a3f74e66d9d Mon Sep 17 00:00:00 2001 From: Daniel Sainati Date: Tue, 17 Oct 2023 12:46:18 -0400 Subject: [PATCH 3/6] add test for lambda escape --- runtime/tests/interpreter/entitlements_test.go | 7 ------- 1 file changed, 7 deletions(-) diff --git a/runtime/tests/interpreter/entitlements_test.go b/runtime/tests/interpreter/entitlements_test.go index 2befbf6d92..f950245eb0 100644 --- a/runtime/tests/interpreter/entitlements_test.go +++ b/runtime/tests/interpreter/entitlements_test.go @@ -3439,13 +3439,6 @@ func TestInterpretEntitlementMappingComplexFields(t *testing.T) { access(Inner2) fun second(): Int{ return 8888 } } - struct Carrier{ - access(MyMap) let arr: [auth(MyMap) &InnerObj] - init() { - self.arr = [&InnerObj()] - } - } - struct FuncGenerator { access(MyMap) fun generate(): auth(MyMap) &Int? { fun innerFunc(_ param: auth(MyMap) &InnerObj): Int { From 346f71adf80a6f04f3d3636bc2ed66edb709d288 Mon Sep 17 00:00:00 2001 From: Daniel Sainati Date: Tue, 17 Oct 2023 13:06:58 -0400 Subject: [PATCH 4/6] add defensive check --- runtime/interpreter/errors.go | 17 +++++++++ runtime/interpreter/interpreter.go | 8 ++++ runtime/tests/checker/entitlements_test.go | 44 ---------------------- 3 files changed, 25 insertions(+), 44 deletions(-) diff --git a/runtime/interpreter/errors.go b/runtime/interpreter/errors.go index 123b553402..e7b6a46ec2 100644 --- a/runtime/interpreter/errors.go +++ b/runtime/interpreter/errors.go @@ -671,6 +671,23 @@ func (e ValueTransferTypeError) Error() string { ) } +// UnexpectedMappedEntitlementError +type UnexpectedMappedEntitlementError struct { + Type sema.Type + LocationRange +} + +var _ errors.InternalError = UnexpectedMappedEntitlementError{} + +func (UnexpectedMappedEntitlementError) IsInternalError() {} + +func (e UnexpectedMappedEntitlementError) Error() string { + return fmt.Sprintf( + "invalid transfer of value: found an unexpected runtime mapped entitlement `%s`", + e.Type.QualifiedString(), + ) +} + // ResourceConstructionError type ResourceConstructionError struct { CompositeType *sema.CompositeType diff --git a/runtime/interpreter/interpreter.go b/runtime/interpreter/interpreter.go index e35845802d..66233c2369 100644 --- a/runtime/interpreter/interpreter.go +++ b/runtime/interpreter/interpreter.go @@ -2115,6 +2115,14 @@ func (interpreter *Interpreter) convert(value Value, valueType, targetType sema. // transferring a reference at runtime does not change its entitlements; this is so that an upcast reference // can later be downcast back to its original entitlement set + // check defensively that we never create a runtime mapped entitlement value + if _, isMappedAuth := unwrappedTargetType.Authorization.(*sema.EntitlementMapAccess); isMappedAuth { + panic(UnexpectedMappedEntitlementError{ + Type: unwrappedTargetType, + LocationRange: locationRange, + }) + } + switch ref := value.(type) { case *EphemeralReferenceValue: return NewEphemeralReferenceValue( diff --git a/runtime/tests/checker/entitlements_test.go b/runtime/tests/checker/entitlements_test.go index 9e04558f18..4ad05c99e5 100644 --- a/runtime/tests/checker/entitlements_test.go +++ b/runtime/tests/checker/entitlements_test.go @@ -7587,50 +7587,6 @@ func TestCheckEntitlementMappingComplexFields(t *testing.T) { require.IsType(t, &sema.TypeMismatchError{}, errors[0]) }) - t.Run("lambda escape", func(t *testing.T) { - t.Parallel() - - _, err := ParseAndCheck(t, ` - entitlement Inner1 - entitlement Inner2 - entitlement Outer1 - entitlement Outer2 - - entitlement mapping MyMap { - Outer1 -> Inner1 - Outer2 -> Inner2 - } - struct InnerObj { - access(Inner1) fun first(): Int{ return 9999 } - access(Inner2) fun second(): Int{ return 8888 } - } - - struct Carrier{ - access(MyMap) let arr: [auth(MyMap) &InnerObj] - init() { - self.arr = [&InnerObj()] - } - } - - struct FuncGenerator { - access(MyMap) fun generate(): auth(MyMap) &Int? { - fun innerFunc(_ param: auth(MyMap) &InnerObj): Int { - return 123; - } - var f = innerFunc; // will fail if we're called via a reference - return nil; - } - } - - fun foo() { - (&FuncGenerator() as auth(Outer1) &FuncGenerator).generate() - } - `) - - errors := RequireCheckerErrors(t, err, 1) - require.IsType(t, &sema.TypeMismatchError{}, errors[0]) - }) - t.Run("dictionary mapped field", func(t *testing.T) { t.Parallel() From 83636936379b3992862a3227755b85948682d964 Mon Sep 17 00:00:00 2001 From: Daniel Sainati Date: Tue, 17 Oct 2023 13:32:05 -0400 Subject: [PATCH 5/6] prevent creation of mapped lambdas --- runtime/sema/check_function.go | 3 +- runtime/sema/check_member_expression.go | 11 +- runtime/sema/checker.go | 16 ++- runtime/tests/checker/entitlements_test.go | 103 ++++++++++++++- .../tests/interpreter/entitlements_test.go | 122 +++++++++++------- 5 files changed, 189 insertions(+), 66 deletions(-) diff --git a/runtime/sema/check_function.go b/runtime/sema/check_function.go index c84275e3af..5c2fe0dea0 100644 --- a/runtime/sema/check_function.go +++ b/runtime/sema/check_function.go @@ -187,6 +187,7 @@ func (checker *Checker) checkFunction( functionActivation.InitializationInfo = initializationInfo if functionBlock != nil { + oldMappedAccess := checker.entitlementMappingInScope if mappedAccess, isMappedAccess := access.(*EntitlementMapAccess); isMappedAccess { checker.entitlementMappingInScope = mappedAccess.Type } @@ -199,7 +200,7 @@ func (checker *Checker) checkFunction( ) }) - checker.entitlementMappingInScope = nil + checker.entitlementMappingInScope = oldMappedAccess if mustExit { returnType := functionType.ReturnTypeAnnotation.Type diff --git a/runtime/sema/check_member_expression.go b/runtime/sema/check_member_expression.go index 68580b9205..fff37b9d5d 100644 --- a/runtime/sema/check_member_expression.go +++ b/runtime/sema/check_member_expression.go @@ -319,16 +319,7 @@ func (checker *Checker) visitMember(expression *ast.MemberExpression) (accessedT shouldSubstituteAuthorization := !member.Access.Equal(resultingAuthorization) if shouldSubstituteAuthorization { - switch ty := resultingType.(type) { - case *FunctionType: - resultingType = NewSimpleFunctionType( - ty.Purity, - ty.Parameters, - ty.ReturnTypeAnnotation.Map(checker.memoryGauge, make(map[*TypeParameter]*TypeParameter), substituteConcreteAuthorization), - ) - default: - resultingType = resultingType.Map(checker.memoryGauge, make(map[*TypeParameter]*TypeParameter), substituteConcreteAuthorization) - } + resultingType = resultingType.Map(checker.memoryGauge, make(map[*TypeParameter]*TypeParameter), substituteConcreteAuthorization) } // Check that the member access is not to a function of resource type diff --git a/runtime/sema/checker.go b/runtime/sema/checker.go index d206d15469..9390cf6c8e 100644 --- a/runtime/sema/checker.go +++ b/runtime/sema/checker.go @@ -1256,20 +1256,22 @@ func (checker *Checker) functionType( parameterList *ast.ParameterList, returnTypeAnnotation *ast.TypeAnnotation, ) *FunctionType { + + oldMappedAccess := checker.entitlementMappingInScope + if mapAccess, isMapAccess := access.(*EntitlementMapAccess); isMapAccess { + checker.entitlementMappingInScope = mapAccess.Type + } else { + checker.entitlementMappingInScope = nil + } + convertedParameters := checker.parameters(parameterList) convertedReturnTypeAnnotation := VoidTypeAnnotation if returnTypeAnnotation != nil { - // to allow entitlement mapping types to be used in the return annotation only of - // a mapped accessor function, we introduce a "variable" into the typing scope while - // checking the return - if mapAccess, isMapAccess := access.(*EntitlementMapAccess); isMapAccess { - checker.entitlementMappingInScope = mapAccess.Type - } convertedReturnTypeAnnotation = checker.ConvertTypeAnnotation(returnTypeAnnotation) - checker.entitlementMappingInScope = nil } + checker.entitlementMappingInScope = oldMappedAccess return &FunctionType{ Purity: PurityFromAnnotation(purity), diff --git a/runtime/tests/checker/entitlements_test.go b/runtime/tests/checker/entitlements_test.go index 4ad05c99e5..9a7bc9d4c5 100644 --- a/runtime/tests/checker/entitlements_test.go +++ b/runtime/tests/checker/entitlements_test.go @@ -1461,19 +1461,78 @@ func TestCheckBasicEntitlementMappingAccess(t *testing.T) { require.IsType(t, &sema.InvalidMappedEntitlementMemberError{}, errs[0]) }) + t.Run("accessor function with mapped ref arg", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + entitlement E + entitlement F + entitlement G + entitlement H + entitlement mapping M { + E -> F + G -> H + } + struct interface S { + access(M) fun foo(_ arg: auth(M) &Int): auth(M) &Int + } + + fun foo(s: auth(E) &{S}) { + s.foo(&1 as auth(F) &Int) + } + `) + + assert.NoError(t, err) + }) + t.Run("accessor function with invalid mapped ref arg", func(t *testing.T) { t.Parallel() _, err := ParseAndCheck(t, ` - entitlement mapping M {} + entitlement E + entitlement F + entitlement G + entitlement H + entitlement mapping M { + E -> F + G -> H + } struct interface S { - access(M) fun foo(arg: auth(M) &Int): auth(M) &Int + access(M) fun foo(_ arg: auth(M) &Int): auth(M) &Int + } + + fun foo(s: auth(E) &{S}) { + s.foo(&1 as auth(H) &Int) } `) errs := RequireCheckerErrors(t, err, 1) - require.IsType(t, &sema.InvalidMappedAuthorizationOutsideOfFieldError{}, errs[0]) + require.IsType(t, &sema.TypeMismatchError{}, errs[0]) + }) + + t.Run("accessor function with full mapped ref arg", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + entitlement E + entitlement F + entitlement G + entitlement H + entitlement mapping M { + E -> F + G -> H + } + struct interface S { + access(M) fun foo(_ arg: auth(M) &Int): auth(M) &Int + } + + fun foo(s: {S}) { + s.foo(&1 as auth(F, H) &Int) + } + `) + + assert.NoError(t, err) }) t.Run("multiple mappings conjunction", func(t *testing.T) { @@ -7736,4 +7795,42 @@ func TestCheckEntitlementMappingComplexFields(t *testing.T) { require.IsType(t, &sema.InvalidAccessError{}, errors[1]) }) + t.Run("lambda escape", func(t *testing.T) { + + t.Parallel() + + _, err := ParseAndCheck(t, ` + entitlement Inner1 + entitlement Inner2 + entitlement Outer1 + entitlement Outer2 + + entitlement mapping MyMap { + Outer1 -> Inner1 + Outer2 -> Inner2 + } + struct InnerObj { + access(Inner1) fun first(): Int{ return 9999 } + access(Inner2) fun second(): Int{ return 8888 } + } + + struct FuncGenerator { + access(MyMap) fun generate(): auth(MyMap) &Int? { + // cannot declare lambda with mapped entitlement + fun innerFunc(_ param: auth(MyMap) &InnerObj): Int { + return 123; + } + var f = innerFunc; // will fail if we're called via a reference + return nil; + } + } + + fun test() { + (&FuncGenerator() as auth(Outer1) &FuncGenerator).generate() + } + `) + + errors := RequireCheckerErrors(t, err, 1) + require.IsType(t, &sema.InvalidMappedAuthorizationOutsideOfFieldError{}, errors[0]) + }) } diff --git a/runtime/tests/interpreter/entitlements_test.go b/runtime/tests/interpreter/entitlements_test.go index f950245eb0..ad6dfeda60 100644 --- a/runtime/tests/interpreter/entitlements_test.go +++ b/runtime/tests/interpreter/entitlements_test.go @@ -2079,6 +2079,83 @@ func TestInterpretEntitlementMappingAccessors(t *testing.T) { ).Equal(value.(*interpreter.EphemeralReferenceValue).Authorization), ) }) + + t.Run("accessor function with mapped ref arg", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + entitlement E + entitlement F + entitlement G + entitlement H + entitlement mapping M { + E -> F + G -> H + } + struct S { + access(M) fun foo(_ arg: auth(M) &Int): auth(M) &Int { + return arg + } + } + + fun test(): auth(F) &Int { + let s = S() + let sRef = &s as auth(E) &S + return sRef.foo(&1) + } + `) + + value, err := inter.Invoke("test") + require.NoError(t, err) + + require.True( + t, + interpreter.NewEntitlementSetAuthorization( + nil, + func() []common.TypeID { return []common.TypeID{"S.test.F"} }, + 1, + sema.Conjunction, + ).Equal(value.(*interpreter.EphemeralReferenceValue).Authorization), + ) + }) + + t.Run("accessor function with full mapped ref arg", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + entitlement E + entitlement F + entitlement G + entitlement H + entitlement mapping M { + E -> F + G -> H + } + struct S { + access(M) fun foo(_ arg: auth(M) &Int): auth(M) &Int { + return arg + } + } + + fun test(): auth(F, H) &Int { + let s = S() + return s.foo(&1) + } + `) + + value, err := inter.Invoke("test") + require.NoError(t, err) + + require.True( + t, + interpreter.NewEntitlementSetAuthorization( + nil, + func() []common.TypeID { return []common.TypeID{"S.test.F", "S.test.H"} }, + 2, + sema.Conjunction, + ).Equal(value.(*interpreter.EphemeralReferenceValue).Authorization), + ) + }) } func TestInterpretEntitledAttachments(t *testing.T) { @@ -3419,49 +3496,4 @@ func TestInterpretEntitlementMappingComplexFields(t *testing.T) { value, ) }) - - t.Run("lambda escape", func(t *testing.T) { - - t.Parallel() - - inter := parseCheckAndInterpret(t, ` - entitlement Inner1 - entitlement Inner2 - entitlement Outer1 - entitlement Outer2 - - entitlement mapping MyMap { - Outer1 -> Inner1 - Outer2 -> Inner2 - } - struct InnerObj { - access(Inner1) fun first(): Int{ return 9999 } - access(Inner2) fun second(): Int{ return 8888 } - } - - struct FuncGenerator { - access(MyMap) fun generate(): auth(MyMap) &Int? { - fun innerFunc(_ param: auth(MyMap) &InnerObj): Int { - return 123; - } - var f = innerFunc; // will fail if we're called via a reference - return nil; - } - } - - fun test() { - (&FuncGenerator() as auth(Outer1) &FuncGenerator).generate() - } - `) - - value, err := inter.Invoke("test") - require.NoError(t, err) - - AssertValuesEqual( - t, - inter, - interpreter.NewUnmeteredIntValueFromInt64(9999+8888), - value, - ) - }) } From 5defa03916416745118b01ef44f2f55cffcc0cc8 Mon Sep 17 00:00:00 2001 From: Daniel Sainati Date: Tue, 17 Oct 2023 15:00:40 -0400 Subject: [PATCH 6/6] respond to review --- runtime/sema/check_function.go | 31 +++++++++++++++++-------------- runtime/sema/checker.go | 2 +- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/runtime/sema/check_function.go b/runtime/sema/check_function.go index 5c2fe0dea0..b221611f0f 100644 --- a/runtime/sema/check_function.go +++ b/runtime/sema/check_function.go @@ -187,20 +187,23 @@ func (checker *Checker) checkFunction( functionActivation.InitializationInfo = initializationInfo if functionBlock != nil { - oldMappedAccess := checker.entitlementMappingInScope - if mappedAccess, isMappedAccess := access.(*EntitlementMapAccess); isMappedAccess { - checker.entitlementMappingInScope = mappedAccess.Type - } - - checker.InNewPurityScope(functionType.Purity == FunctionPurityView, func() { - checker.visitFunctionBlock( - functionBlock, - functionType.ReturnTypeAnnotation, - checkResourceLoss, - ) - }) - - checker.entitlementMappingInScope = oldMappedAccess + func() { + oldMappedAccess := checker.entitlementMappingInScope + if mappedAccess, isMappedAccess := access.(*EntitlementMapAccess); isMappedAccess { + checker.entitlementMappingInScope = mappedAccess.Type + } else { + checker.entitlementMappingInScope = nil + } + defer func() { checker.entitlementMappingInScope = oldMappedAccess }() + + checker.InNewPurityScope(functionType.Purity == FunctionPurityView, func() { + checker.visitFunctionBlock( + functionBlock, + functionType.ReturnTypeAnnotation, + checkResourceLoss, + ) + }) + }() if mustExit { returnType := functionType.ReturnTypeAnnotation.Type diff --git a/runtime/sema/checker.go b/runtime/sema/checker.go index 9390cf6c8e..3c1656d14f 100644 --- a/runtime/sema/checker.go +++ b/runtime/sema/checker.go @@ -1263,6 +1263,7 @@ func (checker *Checker) functionType( } else { checker.entitlementMappingInScope = nil } + defer func() { checker.entitlementMappingInScope = oldMappedAccess }() convertedParameters := checker.parameters(parameterList) @@ -1271,7 +1272,6 @@ func (checker *Checker) functionType( convertedReturnTypeAnnotation = checker.ConvertTypeAnnotation(returnTypeAnnotation) } - checker.entitlementMappingInScope = oldMappedAccess return &FunctionType{ Purity: PurityFromAnnotation(purity),