From 00dd3fd233336a97160885c6c4c224c95dc40d95 Mon Sep 17 00:00:00 2001 From: Supun Setunga Date: Mon, 16 Oct 2023 17:29:27 -0700 Subject: [PATCH 1/7] Support iterating references to iterables --- runtime/interpreter/interpreter_statement.go | 2 +- runtime/interpreter/value.go | 45 +++++++- runtime/sema/check_for.go | 101 ++++++++++++------ runtime/tests/checker/for_test.go | 103 +++++++++++++++++++ runtime/tests/interpreter/for_test.go | 92 +++++++++++++++++ 5 files changed, 309 insertions(+), 34 deletions(-) diff --git a/runtime/interpreter/interpreter_statement.go b/runtime/interpreter/interpreter_statement.go index ca1ea8cedd..7273cb307e 100644 --- a/runtime/interpreter/interpreter_statement.go +++ b/runtime/interpreter/interpreter_statement.go @@ -338,7 +338,7 @@ func (interpreter *Interpreter) VisitForStatement(statement *ast.ForStatement) S panic(errors.NewUnreachableError()) } - iterator := iterable.Iterator(interpreter) + iterator := iterable.Iterator(interpreter, locationRange) var index IntValue if statement.Index != nil { diff --git a/runtime/interpreter/value.go b/runtime/interpreter/value.go index bb08c02e3f..db901c20eb 100644 --- a/runtime/interpreter/value.go +++ b/runtime/interpreter/value.go @@ -234,7 +234,7 @@ type ContractValue interface { // IterableValue is a value which can be iterated over, e.g. with a for-loop type IterableValue interface { Value - Iterator(interpreter *Interpreter) ValueIterator + Iterator(interpreter *Interpreter, locationRange LocationRange) ValueIterator } // ValueIterator is an iterator which returns values. @@ -1580,7 +1580,7 @@ func (v *StringValue) ConformsToStaticType( return true } -func (v *StringValue) Iterator(_ *Interpreter) ValueIterator { +func (v *StringValue) Iterator(_ *Interpreter, _ LocationRange) ValueIterator { return StringValueIterator{ graphemes: uniseg.NewGraphemes(v.Str), } @@ -1614,7 +1614,7 @@ type ArrayValueIterator struct { atreeIterator *atree.ArrayIterator } -func (v *ArrayValue) Iterator(_ *Interpreter) ValueIterator { +func (v *ArrayValue) Iterator(_ *Interpreter, _ LocationRange) ValueIterator { arrayIterator, err := v.array.Iterator() if err != nil { panic(errors.NewExternalError(err)) @@ -20212,6 +20212,7 @@ var _ TypeIndexableValue = &EphemeralReferenceValue{} var _ MemberAccessibleValue = &EphemeralReferenceValue{} var _ AuthorizedValue = &EphemeralReferenceValue{} var _ ReferenceValue = &EphemeralReferenceValue{} +var _ IterableValue = &EphemeralReferenceValue{} func NewUnmeteredEphemeralReferenceValue( authorization Authorization, @@ -20541,6 +20542,44 @@ func (*EphemeralReferenceValue) DeepRemove(_ *Interpreter) { func (*EphemeralReferenceValue) isReference() {} +func (v *EphemeralReferenceValue) Iterator(interpreter *Interpreter, locationRange LocationRange) ValueIterator { + referencedValue := v.MustReferencedValue(interpreter, locationRange) + referenceValueIterator := referencedValue.(IterableValue).Iterator(interpreter, locationRange) + + referencedType, ok := v.BorrowedType.(sema.ValueIndexableType) + if !ok { + panic(errors.NewUnreachableError()) + } + + elementType := referencedType.ElementType(false) + + return ReferenceValueIterator{ + iterator: referenceValueIterator, + elementType: elementType, + } +} + +type ReferenceValueIterator struct { + iterator ValueIterator + elementType sema.Type +} + +var _ ValueIterator = ReferenceValueIterator{} + +func (i ReferenceValueIterator) Next(interpreter *Interpreter) Value { + element := i.iterator.Next(interpreter) + + if element == nil { + return nil + } + + if i.elementType.ContainFieldsOrElements() { + return NewEphemeralReferenceValue(interpreter, UnauthorizedAccess, element, i.elementType) + } + + return element +} + // AddressValue type AddressValue common.Address diff --git a/runtime/sema/check_for.go b/runtime/sema/check_for.go index 6bb6b88083..68af706468 100644 --- a/runtime/sema/check_for.go +++ b/runtime/sema/check_for.go @@ -43,41 +43,17 @@ func (checker *Checker) VisitForStatement(statement *ast.ForStatement) (_ struct valueType := checker.VisitExpression(valueExpression, expectedType) - var elementType Type = InvalidType - - if !valueType.IsInvalidType() { - - // Only get the element type if the array is not a resource array. - // Otherwise, in addition to the `UnsupportedResourceForLoopError`, - // the loop variable will be declared with the resource-typed element type, - // leading to an additional `ResourceLossError`. - - if valueType.IsResourceType() { - checker.report( - &UnsupportedResourceForLoopError{ - Range: ast.NewRangeFromPositioned(checker.memoryGauge, valueExpression), - }, - ) - } else if arrayType, ok := valueType.(ArrayType); ok { - elementType = arrayType.ElementType(false) - } else if valueType == StringType { - elementType = CharacterType - } else { - checker.report( - &TypeMismatchWithDescriptionError{ - ExpectedTypeDescription: "array", - ActualType: valueType, - Range: ast.NewRangeFromPositioned(checker.memoryGauge, valueExpression), - }, - ) - } - } + // Only get the element type if the array is not a resource array. + // Otherwise, in addition to the `UnsupportedResourceForLoopError`, + // the loop variable will be declared with the resource-typed element type, + // leading to an additional `ResourceLossError`. + loopVariableType := checker.loopVariableType(valueType, valueExpression) identifier := statement.Identifier.Identifier variable, err := checker.valueActivations.declare(variableDeclaration{ identifier: identifier, - ty: elementType, + ty: loopVariableType, kind: common.DeclarationKindConstant, pos: statement.Identifier.Pos, isConstant: true, @@ -123,3 +99,68 @@ func (checker *Checker) VisitForStatement(statement *ast.ForStatement) (_ struct return } + +func (checker *Checker) loopVariableType(valueType Type, hasPosition ast.HasPosition) Type { + if valueType.IsInvalidType() { + return InvalidType + } + + // Resources cannot be looped. + if valueType.IsResourceType() { + checker.report( + &UnsupportedResourceForLoopError{ + Range: ast.NewRangeFromPositioned(checker.memoryGauge, hasPosition), + }, + ) + return InvalidType + } + + // If it's a reference, check whether the referenced type is iterable. + // If yes, then determine the loop-var type depending on the + // element-type of the referenced type. + // If that element type is: + // a) A container type, then the loop-var is also a reference-type. + // b) A primitive type, then the loop-var is the concrete type itself. + + if referenceType, ok := valueType.(*ReferenceType); ok { + referencedType := referenceType.Type + referencedIterableElementType := checker.iterableElementType(referencedType, hasPosition) + + if referencedIterableElementType.IsInvalidType() { + return referencedIterableElementType + } + + // Case (a): Element type is a container type. + // Then the loop-var must also be a reference type. + if referencedIterableElementType.ContainFieldsOrElements() { + return NewReferenceType(checker.memoryGauge, UnauthorizedAccess, referencedIterableElementType) + } + + // Case (b): Element type is a primitive type. + // Then the loop-var must be the concrete type. + return referencedIterableElementType + } + + // If it's not a reference, then simply get the element type. + return checker.iterableElementType(valueType, hasPosition) +} + +func (checker *Checker) iterableElementType(valueType Type, hasPosition ast.HasPosition) Type { + if arrayType, ok := valueType.(ArrayType); ok { + return arrayType.ElementType(false) + } + + if valueType == StringType { + return CharacterType + } + + checker.report( + &TypeMismatchWithDescriptionError{ + ExpectedTypeDescription: "array", + ActualType: valueType, + Range: ast.NewRangeFromPositioned(checker.memoryGauge, hasPosition), + }, + ) + + return InvalidType +} diff --git a/runtime/tests/checker/for_test.go b/runtime/tests/checker/for_test.go index 9ed443c605..71c1d69537 100644 --- a/runtime/tests/checker/for_test.go +++ b/runtime/tests/checker/for_test.go @@ -22,6 +22,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/onflow/cadence/runtime/sema" ) @@ -273,3 +274,105 @@ func TestCheckInvalidForShadowing(t *testing.T) { assert.IsType(t, &sema.RedeclarationError{}, errs[0]) } + +func TestCheckReferencesInForLoop(t *testing.T) { + + t.Parallel() + + t.Run("Primitive array", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + fun main() { + var array = ["Hello", "World", "Foo", "Bar"] + var arrayRef = &array as &[String] + + for element in arrayRef { + let e: String = element + } + } + `) + + require.NoError(t, err) + }) + + t.Run("Struct array", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + struct Foo{} + + fun main() { + var array = [Foo(), Foo()] + var arrayRef = &array as &[Foo] + + for element in arrayRef { + let e: &Foo = element + } + } + `) + + require.NoError(t, err) + }) + + t.Run("Resource array", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + resource Foo{} + + fun main() { + var array <- [ <- create Foo(), <- create Foo()] + var arrayRef = &array as &[Foo] + + for element in arrayRef { + let e: &Foo = element + } + + destroy array + } + `) + + require.NoError(t, err) + }) + + t.Run("Dictionary", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + struct Foo{} + + fun main() { + var foo = {"foo": Foo()} + var fooRef = &foo as &{String: Foo} + + for element in fooRef { + let e: &Foo = element + } + } + `) + + errors := RequireCheckerErrors(t, err, 1) + assert.IsType(t, &sema.TypeMismatchWithDescriptionError{}, errors[0]) + }) + + t.Run("Non iterable", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + struct Foo{} + + fun main() { + var foo = Foo() + var fooRef = &foo as &Foo + + for element in fooRef { + let e: &Foo = element + } + } + `) + + errors := RequireCheckerErrors(t, err, 1) + assert.IsType(t, &sema.TypeMismatchWithDescriptionError{}, errors[0]) + }) +} diff --git a/runtime/tests/interpreter/for_test.go b/runtime/tests/interpreter/for_test.go index dc58e368ec..cfea4e4547 100644 --- a/runtime/tests/interpreter/for_test.go +++ b/runtime/tests/interpreter/for_test.go @@ -294,3 +294,95 @@ func TestInterpretForStatementCapturing(t *testing.T) { ArrayElements(inter, arrayValue), ) } + +func TestInterpretReferencesInForLoop(t *testing.T) { + + t.Parallel() + + t.Run("Primitive array", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + fun main() { + var array = ["Hello", "World", "Foo", "Bar"] + var arrayRef = &array as &[String] + + for element in arrayRef { + let e: String = element + } + } + `) + + _, err := inter.Invoke("main") + require.NoError(t, err) + }) + + t.Run("Struct array", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + struct Foo{} + + fun main() { + var array = [Foo(), Foo()] + var arrayRef = &array as &[Foo] + + for element in arrayRef { + let e: &Foo = element + } + } + `) + + _, err := inter.Invoke("main") + require.NoError(t, err) + }) + + t.Run("Resource array", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + resource Foo{} + + fun main() { + var array <- [ <- create Foo(), <- create Foo()] + var arrayRef = &array as &[Foo] + + for element in arrayRef { + let e: &Foo = element + } + + destroy array + } + `) + + _, err := inter.Invoke("main") + require.NoError(t, err) + }) + + t.Run("Moved resource array", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + resource Foo{} + + fun main() { + var array <- [ <- create Foo(), <- create Foo()] + var arrayRef = returnSameRef(&array as &[Foo]) + var movedArray <- array + + for element in arrayRef { + let e: &Foo = element + } + + destroy movedArray + } + + fun returnSameRef(_ ref: &[Foo]): &[Foo] { + return ref + } + `) + + _, err := inter.Invoke("main") + require.ErrorAs(t, err, &interpreter.InvalidatedResourceReferenceError{}) + }) +} From 9162f3e2aa2d969c4870d04e94f71c8f40dca14e Mon Sep 17 00:00:00 2001 From: Supun Setunga Date: Tue, 17 Oct 2023 10:16:25 -0700 Subject: [PATCH 2/7] Support looping storage references --- runtime/interpreter/value.go | 49 +++++++++---- runtime/tests/interpreter/for_test.go | 102 +++++++++++++++++++++++++- 2 files changed, 137 insertions(+), 14 deletions(-) diff --git a/runtime/interpreter/value.go b/runtime/interpreter/value.go index db901c20eb..64620efbe4 100644 --- a/runtime/interpreter/value.go +++ b/runtime/interpreter/value.go @@ -19844,6 +19844,7 @@ var _ TypeIndexableValue = &StorageReferenceValue{} var _ MemberAccessibleValue = &StorageReferenceValue{} var _ AuthorizedValue = &StorageReferenceValue{} var _ ReferenceValue = &StorageReferenceValue{} +var _ IterableValue = &StorageReferenceValue{} func NewUnmeteredStorageReferenceValue( authorization Authorization, @@ -20196,6 +20197,37 @@ func (*StorageReferenceValue) DeepRemove(_ *Interpreter) { func (*StorageReferenceValue) isReference() {} +func (v *StorageReferenceValue) Iterator(interpreter *Interpreter, locationRange LocationRange) ValueIterator { + referencedValue := v.mustReferencedValue(interpreter, locationRange) + return referenceValueIterator(interpreter, referencedValue, v.BorrowedType, locationRange) +} + +func referenceValueIterator( + interpreter *Interpreter, + referencedValue Value, + borrowedType sema.Type, + locationRange LocationRange, +) ValueIterator { + referencedIterable, ok := referencedValue.(IterableValue) + if !ok { + panic(errors.NewUnreachableError()) + } + + referencedValueIterator := referencedIterable.Iterator(interpreter, locationRange) + + referencedType, ok := borrowedType.(sema.ValueIndexableType) + if !ok { + panic(errors.NewUnreachableError()) + } + + elementType := referencedType.ElementType(false) + + return ReferenceValueIterator{ + iterator: referencedValueIterator, + elementType: elementType, + } +} + // EphemeralReferenceValue type EphemeralReferenceValue struct { @@ -20544,21 +20576,11 @@ func (*EphemeralReferenceValue) isReference() {} func (v *EphemeralReferenceValue) Iterator(interpreter *Interpreter, locationRange LocationRange) ValueIterator { referencedValue := v.MustReferencedValue(interpreter, locationRange) - referenceValueIterator := referencedValue.(IterableValue).Iterator(interpreter, locationRange) - - referencedType, ok := v.BorrowedType.(sema.ValueIndexableType) - if !ok { - panic(errors.NewUnreachableError()) - } - - elementType := referencedType.ElementType(false) - - return ReferenceValueIterator{ - iterator: referenceValueIterator, - elementType: elementType, - } + return referenceValueIterator(interpreter, referencedValue, v.BorrowedType, locationRange) } +// ReferenceValueIterator + type ReferenceValueIterator struct { iterator ValueIterator elementType sema.Type @@ -20573,6 +20595,7 @@ func (i ReferenceValueIterator) Next(interpreter *Interpreter) Value { return nil } + // For non-primitive values, return a reference. if i.elementType.ContainFieldsOrElements() { return NewEphemeralReferenceValue(interpreter, UnauthorizedAccess, element, i.elementType) } diff --git a/runtime/tests/interpreter/for_test.go b/runtime/tests/interpreter/for_test.go index cfea4e4547..f1440f564a 100644 --- a/runtime/tests/interpreter/for_test.go +++ b/runtime/tests/interpreter/for_test.go @@ -24,6 +24,7 @@ import ( "github.com/stretchr/testify/require" "github.com/onflow/cadence/runtime/common" + "github.com/onflow/cadence/runtime/sema" . "github.com/onflow/cadence/runtime/tests/utils" "github.com/onflow/cadence/runtime/interpreter" @@ -295,7 +296,7 @@ func TestInterpretForStatementCapturing(t *testing.T) { ) } -func TestInterpretReferencesInForLoop(t *testing.T) { +func TestInterpretEphemeralReferencesInForLoop(t *testing.T) { t.Parallel() @@ -386,3 +387,102 @@ func TestInterpretReferencesInForLoop(t *testing.T) { require.ErrorAs(t, err, &interpreter.InvalidatedResourceReferenceError{}) }) } + +func TestInterpretStorageReferencesInForLoop(t *testing.T) { + + t.Parallel() + + t.Run("Primitive array", func(t *testing.T) { + t.Parallel() + + address := interpreter.NewUnmeteredAddressValueFromBytes([]byte{42}) + + inter, _ := testAccount(t, address, true, nil, ` + fun test() { + var array = ["Hello", "World", "Foo", "Bar"] + account.storage.save(array, to: /storage/array) + + let arrayRef = account.storage.borrow<&[String]>(from: /storage/array)! + + for element in arrayRef { + let e: String = element // Must be the concrete string + } + }`, sema.Config{}) + + _, err := inter.Invoke("test") + require.NoError(t, err) + }) + + t.Run("Struct array", func(t *testing.T) { + t.Parallel() + + address := interpreter.NewUnmeteredAddressValueFromBytes([]byte{42}) + + inter, _ := testAccount(t, address, true, nil, ` + struct Foo{} + + fun test() { + var array = [Foo(), Foo()] + account.storage.save(array, to: /storage/array) + + let arrayRef = account.storage.borrow<&[Foo]>(from: /storage/array)! + + for element in arrayRef { + let e: &Foo = element // Must be a reference + } + }`, sema.Config{}) + + _, err := inter.Invoke("test") + require.NoError(t, err) + }) + + t.Run("Resource array", func(t *testing.T) { + t.Parallel() + + address := interpreter.NewUnmeteredAddressValueFromBytes([]byte{42}) + + inter, _ := testAccount(t, address, true, nil, ` + resource Foo{} + + fun test() { + var array <- [ <- create Foo(), <- create Foo()] + account.storage.save(<- array, to: /storage/array) + + let arrayRef = account.storage.borrow<&[Foo]>(from: /storage/array)! + + for element in arrayRef { + let e: &Foo = element // Must be a reference + } + }`, sema.Config{}) + + _, err := inter.Invoke("test") + require.NoError(t, err) + }) + + t.Run("Moved resource array", func(t *testing.T) { + t.Parallel() + + address := interpreter.NewUnmeteredAddressValueFromBytes([]byte{42}) + + inter, _ := testAccount(t, address, true, nil, ` + resource Foo{} + + fun test() { + var array <- [ <- create Foo(), <- create Foo()] + account.storage.save(<- array, to: /storage/array) + + let arrayRef = account.storage.borrow<&[Foo]>(from: /storage/array)! + + let movedArray <- account.storage.load<@[Foo]>(from: /storage/array)! + + for element in arrayRef { + let e: &Foo = element // Must be a reference + } + + destroy movedArray + }`, sema.Config{}) + + _, err := inter.Invoke("test") + require.ErrorAs(t, err, &interpreter.DereferenceError{}) + }) +} From 4942b145e3c86a7b1da0369284beea7ce16b7302 Mon Sep 17 00:00:00 2001 From: Supun Setunga Date: Tue, 17 Oct 2023 11:21:32 -0700 Subject: [PATCH 3/7] Add test for invalid type --- runtime/tests/checker/for_test.go | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/runtime/tests/checker/for_test.go b/runtime/tests/checker/for_test.go index 71c1d69537..69ae3a5598 100644 --- a/runtime/tests/checker/for_test.go +++ b/runtime/tests/checker/for_test.go @@ -375,4 +375,21 @@ func TestCheckReferencesInForLoop(t *testing.T) { errors := RequireCheckerErrors(t, err, 1) assert.IsType(t, &sema.TypeMismatchWithDescriptionError{}, errors[0]) }) + + t.Run("Non existing type", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + fun main() { + var foo = Foo() + var fooRef = &foo as &Foo + + for element in fooRef {} + } + `) + + errors := RequireCheckerErrors(t, err, 2) + assert.IsType(t, &sema.NotDeclaredError{}, errors[0]) + assert.IsType(t, &sema.NotDeclaredError{}, errors[1]) + }) } From c659d79c67e50a58a2b1f981488c2be00858d77f Mon Sep 17 00:00:00 2001 From: Supun Setunga Date: Wed, 18 Oct 2023 08:52:19 -0700 Subject: [PATCH 4/7] Add test for auth references in for-loop --- runtime/tests/checker/for_test.go | 39 +++++++++++++++++++++++++++ runtime/tests/interpreter/for_test.go | 20 ++++++++++++++ 2 files changed, 59 insertions(+) diff --git a/runtime/tests/checker/for_test.go b/runtime/tests/checker/for_test.go index 69ae3a5598..d52c33e1d0 100644 --- a/runtime/tests/checker/for_test.go +++ b/runtime/tests/checker/for_test.go @@ -392,4 +392,43 @@ func TestCheckReferencesInForLoop(t *testing.T) { assert.IsType(t, &sema.NotDeclaredError{}, errors[0]) assert.IsType(t, &sema.NotDeclaredError{}, errors[1]) }) + + t.Run("Auth ref", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + struct Foo{} + + fun main() { + var array = [Foo(), Foo()] + var arrayRef = &array as auth(Mutate) &[Foo] + + for element in arrayRef { + let e: &Foo = element // should be non-auth + } + } + `) + + require.NoError(t, err) + }) + + t.Run("Auth ref invalid", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + struct Foo{} + + fun main() { + var array = [Foo(), Foo()] + var arrayRef = &array as auth(Mutate) &[Foo] + + for element in arrayRef { + let e: auth(Mutate) &Foo = element // should be non-auth + } + } + `) + + errors := RequireCheckerErrors(t, err, 1) + assert.IsType(t, &sema.TypeMismatchError{}, errors[0]) + }) } diff --git a/runtime/tests/interpreter/for_test.go b/runtime/tests/interpreter/for_test.go index f1440f564a..a30116a234 100644 --- a/runtime/tests/interpreter/for_test.go +++ b/runtime/tests/interpreter/for_test.go @@ -386,6 +386,26 @@ func TestInterpretEphemeralReferencesInForLoop(t *testing.T) { _, err := inter.Invoke("main") require.ErrorAs(t, err, &interpreter.InvalidatedResourceReferenceError{}) }) + + t.Run("Auth ref", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + struct Foo{} + + fun main() { + var array = [Foo(), Foo()] + var arrayRef = &array as auth(Mutate) &[Foo] + + for element in arrayRef { + let e: &Foo = element // Should be non-auth + } + } + `) + + _, err := inter.Invoke("main") + require.NoError(t, err) + }) } func TestInterpretStorageReferencesInForLoop(t *testing.T) { From a7dc05ec25ecf9e80578ef1b18076f1c72b0ffbf Mon Sep 17 00:00:00 2001 From: Supun Setunga Date: Wed, 18 Oct 2023 12:10:27 -0700 Subject: [PATCH 5/7] Handle optionals and resource tracking --- runtime/interpreter/interpreter_statement.go | 3 +- runtime/interpreter/value.go | 47 +++---- runtime/sema/check_assignment.go | 2 +- runtime/sema/check_for.go | 12 +- runtime/sema/check_member_expression.go | 4 +- runtime/sema/check_variable_declaration.go | 2 +- runtime/sema/elaboration.go | 20 +++ runtime/tests/checker/for_test.go | 19 +++ runtime/tests/interpreter/for_test.go | 121 ++++++++++++++++--- 9 files changed, 187 insertions(+), 43 deletions(-) diff --git a/runtime/interpreter/interpreter_statement.go b/runtime/interpreter/interpreter_statement.go index 7273cb307e..01f84b176c 100644 --- a/runtime/interpreter/interpreter_statement.go +++ b/runtime/interpreter/interpreter_statement.go @@ -338,7 +338,8 @@ func (interpreter *Interpreter) VisitForStatement(statement *ast.ForStatement) S panic(errors.NewUnreachableError()) } - iterator := iterable.Iterator(interpreter, locationRange) + forStmtTypes := interpreter.Program.Elaboration.ForStatementType(statement) + iterator := iterable.Iterator(interpreter, forStmtTypes.ValueVariableType, locationRange) var index IntValue if statement.Index != nil { diff --git a/runtime/interpreter/value.go b/runtime/interpreter/value.go index 64620efbe4..829352ce29 100644 --- a/runtime/interpreter/value.go +++ b/runtime/interpreter/value.go @@ -234,7 +234,7 @@ type ContractValue interface { // IterableValue is a value which can be iterated over, e.g. with a for-loop type IterableValue interface { Value - Iterator(interpreter *Interpreter, locationRange LocationRange) ValueIterator + Iterator(interpreter *Interpreter, resultType sema.Type, locationRange LocationRange) ValueIterator } // ValueIterator is an iterator which returns values. @@ -1580,7 +1580,7 @@ func (v *StringValue) ConformsToStaticType( return true } -func (v *StringValue) Iterator(_ *Interpreter, _ LocationRange) ValueIterator { +func (v *StringValue) Iterator(_ *Interpreter, _ sema.Type, _ LocationRange) ValueIterator { return StringValueIterator{ graphemes: uniseg.NewGraphemes(v.Str), } @@ -1614,7 +1614,7 @@ type ArrayValueIterator struct { atreeIterator *atree.ArrayIterator } -func (v *ArrayValue) Iterator(_ *Interpreter, _ LocationRange) ValueIterator { +func (v *ArrayValue) Iterator(_ *Interpreter, _ sema.Type, _ LocationRange) ValueIterator { arrayIterator, err := v.array.Iterator() if err != nil { panic(errors.NewExternalError(err)) @@ -20197,15 +20197,19 @@ func (*StorageReferenceValue) DeepRemove(_ *Interpreter) { func (*StorageReferenceValue) isReference() {} -func (v *StorageReferenceValue) Iterator(interpreter *Interpreter, locationRange LocationRange) ValueIterator { +func (v *StorageReferenceValue) Iterator( + interpreter *Interpreter, + resultType sema.Type, + locationRange LocationRange, +) ValueIterator { referencedValue := v.mustReferencedValue(interpreter, locationRange) - return referenceValueIterator(interpreter, referencedValue, v.BorrowedType, locationRange) + return referenceValueIterator(interpreter, referencedValue, resultType, locationRange) } func referenceValueIterator( interpreter *Interpreter, referencedValue Value, - borrowedType sema.Type, + resultType sema.Type, locationRange LocationRange, ) ValueIterator { referencedIterable, ok := referencedValue.(IterableValue) @@ -20213,18 +20217,14 @@ func referenceValueIterator( panic(errors.NewUnreachableError()) } - referencedValueIterator := referencedIterable.Iterator(interpreter, locationRange) + referencedValueIterator := referencedIterable.Iterator(interpreter, resultType, locationRange) - referencedType, ok := borrowedType.(sema.ValueIndexableType) - if !ok { - panic(errors.NewUnreachableError()) - } - - elementType := referencedType.ElementType(false) + _, isResultReference := sema.GetReferenceType(resultType) return ReferenceValueIterator{ - iterator: referencedValueIterator, - elementType: elementType, + iterator: referencedValueIterator, + resultType: resultType, + isResultReference: isResultReference, } } @@ -20574,16 +20574,21 @@ func (*EphemeralReferenceValue) DeepRemove(_ *Interpreter) { func (*EphemeralReferenceValue) isReference() {} -func (v *EphemeralReferenceValue) Iterator(interpreter *Interpreter, locationRange LocationRange) ValueIterator { +func (v *EphemeralReferenceValue) Iterator( + interpreter *Interpreter, + resultType sema.Type, + locationRange LocationRange, +) ValueIterator { referencedValue := v.MustReferencedValue(interpreter, locationRange) - return referenceValueIterator(interpreter, referencedValue, v.BorrowedType, locationRange) + return referenceValueIterator(interpreter, referencedValue, resultType, locationRange) } // ReferenceValueIterator type ReferenceValueIterator struct { - iterator ValueIterator - elementType sema.Type + iterator ValueIterator + resultType sema.Type + isResultReference bool } var _ ValueIterator = ReferenceValueIterator{} @@ -20596,8 +20601,8 @@ func (i ReferenceValueIterator) Next(interpreter *Interpreter) Value { } // For non-primitive values, return a reference. - if i.elementType.ContainFieldsOrElements() { - return NewEphemeralReferenceValue(interpreter, UnauthorizedAccess, element, i.elementType) + if i.isResultReference { + return interpreter.getReferenceValue(element, i.resultType) } return element diff --git a/runtime/sema/check_assignment.go b/runtime/sema/check_assignment.go index 0b0a20fba1..8745f1fea7 100644 --- a/runtime/sema/check_assignment.go +++ b/runtime/sema/check_assignment.go @@ -329,7 +329,7 @@ func (checker *Checker) visitIndexExpressionAssignment( elementType = checker.visitIndexExpression(indexExpression, true) indexExprTypes := checker.Elaboration.IndexExpressionTypes(indexExpression) - indexedRefType, isReference := referenceType(indexExprTypes.IndexedType) + indexedRefType, isReference := GetReferenceType(indexExprTypes.IndexedType) if isReference && !mutableEntitledAccess.PermitsAccess(indexedRefType.Authorization) && diff --git a/runtime/sema/check_for.go b/runtime/sema/check_for.go index 68af706468..da3fe156fb 100644 --- a/runtime/sema/check_for.go +++ b/runtime/sema/check_for.go @@ -66,11 +66,14 @@ func (checker *Checker) VisitForStatement(statement *ast.ForStatement) (_ struct checker.recordVariableDeclarationOccurrence(identifier, variable) } + var indexType Type + if statement.Index != nil { index := statement.Index.Identifier + indexType = IntType indexVariable, err := checker.valueActivations.declare(variableDeclaration{ identifier: index, - ty: IntType, + ty: indexType, kind: common.DeclarationKindConstant, pos: statement.Index.Pos, isConstant: true, @@ -84,6 +87,11 @@ func (checker *Checker) VisitForStatement(statement *ast.ForStatement) (_ struct } } + checker.Elaboration.SetForStatementType(statement, ForStatementTypes{ + IndexVariableType: indexType, + ValueVariableType: loopVariableType, + }) + // The body of the loop will maybe be evaluated. // That means that resource invalidations and // returns are not definite, but only potential. @@ -133,7 +141,7 @@ func (checker *Checker) loopVariableType(valueType Type, hasPosition ast.HasPosi // Case (a): Element type is a container type. // Then the loop-var must also be a reference type. if referencedIterableElementType.ContainFieldsOrElements() { - return NewReferenceType(checker.memoryGauge, UnauthorizedAccess, referencedIterableElementType) + return checker.getReferenceType(referencedIterableElementType, false, UnauthorizedAccess) } // Case (b): Element type is a primitive type. diff --git a/runtime/sema/check_member_expression.go b/runtime/sema/check_member_expression.go index e673a579a1..2cb3ab83ce 100644 --- a/runtime/sema/check_member_expression.go +++ b/runtime/sema/check_member_expression.go @@ -116,14 +116,14 @@ func shouldReturnReference(parentType, memberType Type, isAssignment bool) bool return false } - if _, isReference := referenceType(parentType); !isReference { + if _, isReference := GetReferenceType(parentType); !isReference { return false } return memberType.ContainFieldsOrElements() } -func referenceType(typ Type) (*ReferenceType, bool) { +func GetReferenceType(typ Type) (*ReferenceType, bool) { unwrappedType := UnwrapOptionalType(typ) refType, isReference := unwrappedType.(*ReferenceType) return refType, isReference diff --git a/runtime/sema/check_variable_declaration.go b/runtime/sema/check_variable_declaration.go index 971abbbe5d..e5722ebf75 100644 --- a/runtime/sema/check_variable_declaration.go +++ b/runtime/sema/check_variable_declaration.go @@ -264,7 +264,7 @@ func (checker *Checker) recordReference(targetVariable *Variable, expr ast.Expre return } - if _, isReference := referenceType(targetVariable.Type); !isReference { + if _, isReference := GetReferenceType(targetVariable.Type); !isReference { return } diff --git a/runtime/sema/elaboration.go b/runtime/sema/elaboration.go index a64add82da..8211be53b2 100644 --- a/runtime/sema/elaboration.go +++ b/runtime/sema/elaboration.go @@ -111,6 +111,11 @@ type ExpressionTypes struct { ExpectedType Type } +type ForStatementTypes struct { + IndexVariableType Type + ValueVariableType Type +} + type Elaboration struct { interfaceTypesAndDeclarationsBiMap *bimap.BiMap[*InterfaceType, *ast.InterfaceDeclaration] entitlementTypesAndDeclarationsBiMap *bimap.BiMap[*EntitlementType, *ast.EntitlementDeclaration] @@ -118,6 +123,7 @@ type Elaboration struct { fixedPointExpressionTypes map[*ast.FixedPointExpression]Type swapStatementTypes map[*ast.SwapStatement]SwapStatementTypes + forStatementTypes map[*ast.ForStatement]ForStatementTypes assignmentStatementTypes map[*ast.AssignmentStatement]AssignmentStatementTypes compositeDeclarationTypes map[ast.CompositeLikeDeclaration]*CompositeType compositeTypeDeclarations map[*CompositeType]ast.CompositeLikeDeclaration @@ -1032,3 +1038,17 @@ func (e *Elaboration) GetSemanticAccess(access ast.Access) (semaAccess Access, p semaAccess, present = e.semanticAccesses[access] return } + +func (e *Elaboration) SetForStatementType(statement *ast.ForStatement, types ForStatementTypes) { + if e.forStatementTypes == nil { + e.forStatementTypes = map[*ast.ForStatement]ForStatementTypes{} + } + e.forStatementTypes[statement] = types +} + +func (e *Elaboration) ForStatementType(statement *ast.ForStatement) (types ForStatementTypes) { + if e.forStatementTypes == nil { + return + } + return e.forStatementTypes[statement] +} diff --git a/runtime/tests/checker/for_test.go b/runtime/tests/checker/for_test.go index d52c33e1d0..edb5148290 100644 --- a/runtime/tests/checker/for_test.go +++ b/runtime/tests/checker/for_test.go @@ -431,4 +431,23 @@ func TestCheckReferencesInForLoop(t *testing.T) { errors := RequireCheckerErrors(t, err, 1) assert.IsType(t, &sema.TypeMismatchError{}, errors[0]) }) + + t.Run("Optional array", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + struct Foo{} + + fun main() { + var array: [Foo?] = [Foo(), Foo()] + var arrayRef = &array as &[Foo?] + + for element in arrayRef { + let e: &Foo? = element // Should be an optional reference + } + } + `) + + require.NoError(t, err) + }) } diff --git a/runtime/tests/interpreter/for_test.go b/runtime/tests/interpreter/for_test.go index a30116a234..7c4738c0fd 100644 --- a/runtime/tests/interpreter/for_test.go +++ b/runtime/tests/interpreter/for_test.go @@ -305,8 +305,8 @@ func TestInterpretEphemeralReferencesInForLoop(t *testing.T) { inter := parseCheckAndInterpret(t, ` fun main() { - var array = ["Hello", "World", "Foo", "Bar"] - var arrayRef = &array as &[String] + let array = ["Hello", "World", "Foo", "Bar"] + let arrayRef = &array as &[String] for element in arrayRef { let e: String = element @@ -325,8 +325,8 @@ func TestInterpretEphemeralReferencesInForLoop(t *testing.T) { struct Foo{} fun main() { - var array = [Foo(), Foo()] - var arrayRef = &array as &[Foo] + let array = [Foo(), Foo()] + let arrayRef = &array as &[Foo] for element in arrayRef { let e: &Foo = element @@ -345,8 +345,8 @@ func TestInterpretEphemeralReferencesInForLoop(t *testing.T) { resource Foo{} fun main() { - var array <- [ <- create Foo(), <- create Foo()] - var arrayRef = &array as &[Foo] + let array <- [ <- create Foo(), <- create Foo()] + let arrayRef = &array as &[Foo] for element in arrayRef { let e: &Foo = element @@ -367,9 +367,9 @@ func TestInterpretEphemeralReferencesInForLoop(t *testing.T) { resource Foo{} fun main() { - var array <- [ <- create Foo(), <- create Foo()] - var arrayRef = returnSameRef(&array as &[Foo]) - var movedArray <- array + let array <- [ <- create Foo(), <- create Foo()] + let arrayRef = returnSameRef(&array as &[Foo]) + let movedArray <- array for element in arrayRef { let e: &Foo = element @@ -394,8 +394,8 @@ func TestInterpretEphemeralReferencesInForLoop(t *testing.T) { struct Foo{} fun main() { - var array = [Foo(), Foo()] - var arrayRef = &array as auth(Mutate) &[Foo] + let array = [Foo(), Foo()] + let arrayRef = &array as auth(Mutate) &[Foo] for element in arrayRef { let e: &Foo = element // Should be non-auth @@ -406,6 +406,97 @@ func TestInterpretEphemeralReferencesInForLoop(t *testing.T) { _, err := inter.Invoke("main") require.NoError(t, err) }) + + t.Run("Optional array", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + struct Foo{} + + fun main() { + let array: [Foo?] = [Foo(), Foo()] + let arrayRef = &array as &[Foo?] + + for element in arrayRef { + let e: &Foo? = element // Should be an optional reference + } + } + `) + + _, err := inter.Invoke("main") + require.NoError(t, err) + }) + + t.Run("Nil array", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + struct Foo{} + + fun main() { + let array: [Foo?] = [nil, nil] + let arrayRef = &array as &[Foo?] + + for element in arrayRef { + let e: &Foo? = element // Should be an optional reference + } + } + `) + + _, err := inter.Invoke("main") + require.NoError(t, err) + }) + + t.Run("Reference array", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + struct Foo{} + + fun main() { + let elementRef = &Foo() as &Foo + let array: [&Foo] = [elementRef, elementRef] + let arrayRef = &array as &[&Foo] + + for element in arrayRef { + let e: &Foo = element + } + } + `) + + _, err := inter.Invoke("main") + require.NoError(t, err) + }) + + t.Run("Moved resource element", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + resource Foo{ + fun sayHello() {} + } + + fun main() { + let array <- [ <- create Foo()] + let arrayRef = &array as auth(Mutate) &[Foo] + + for element in arrayRef { + // Move the actual element + let oldElement <- arrayRef.remove(at: 0) + + // Use the element reference + element.sayHello() + + destroy oldElement + } + + destroy array + } + `) + + _, err := inter.Invoke("main") + require.ErrorAs(t, err, &interpreter.InvalidatedResourceReferenceError{}) + }) } func TestInterpretStorageReferencesInForLoop(t *testing.T) { @@ -419,7 +510,7 @@ func TestInterpretStorageReferencesInForLoop(t *testing.T) { inter, _ := testAccount(t, address, true, nil, ` fun test() { - var array = ["Hello", "World", "Foo", "Bar"] + var let = ["Hello", "World", "Foo", "Bar"] account.storage.save(array, to: /storage/array) let arrayRef = account.storage.borrow<&[String]>(from: /storage/array)! @@ -442,7 +533,7 @@ func TestInterpretStorageReferencesInForLoop(t *testing.T) { struct Foo{} fun test() { - var array = [Foo(), Foo()] + var let = [Foo(), Foo()] account.storage.save(array, to: /storage/array) let arrayRef = account.storage.borrow<&[Foo]>(from: /storage/array)! @@ -465,7 +556,7 @@ func TestInterpretStorageReferencesInForLoop(t *testing.T) { resource Foo{} fun test() { - var array <- [ <- create Foo(), <- create Foo()] + var let <- [ <- create Foo(), <- create Foo()] account.storage.save(<- array, to: /storage/array) let arrayRef = account.storage.borrow<&[Foo]>(from: /storage/array)! @@ -488,7 +579,7 @@ func TestInterpretStorageReferencesInForLoop(t *testing.T) { resource Foo{} fun test() { - var array <- [ <- create Foo(), <- create Foo()] + var let <- [ <- create Foo(), <- create Foo()] account.storage.save(<- array, to: /storage/array) let arrayRef = account.storage.borrow<&[Foo]>(from: /storage/array)! From 3ec974939a960226b821b6b394d31430e69b80ec Mon Sep 17 00:00:00 2001 From: Supun Setunga Date: Thu, 19 Oct 2023 14:41:54 -0700 Subject: [PATCH 6/7] Prevent mutation while iterating --- runtime/interpreter/interpreter_statement.go | 25 ++-- runtime/interpreter/value.go | 136 +++++++++++++------ runtime/sema/check_assignment.go | 2 +- runtime/sema/check_member_expression.go | 4 +- runtime/sema/check_variable_declaration.go | 2 +- runtime/tests/interpreter/for_test.go | 39 +++++- 6 files changed, 144 insertions(+), 64 deletions(-) diff --git a/runtime/interpreter/interpreter_statement.go b/runtime/interpreter/interpreter_statement.go index 01f84b176c..f0d726bec6 100644 --- a/runtime/interpreter/interpreter_statement.go +++ b/runtime/interpreter/interpreter_statement.go @@ -313,7 +313,7 @@ func (interpreter *Interpreter) VisitWhileStatement(statement *ast.WhileStatemen var intOne = NewUnmeteredIntValueFromInt64(1) -func (interpreter *Interpreter) VisitForStatement(statement *ast.ForStatement) StatementResult { +func (interpreter *Interpreter) VisitForStatement(statement *ast.ForStatement) (result StatementResult) { interpreter.activations.PushNewWithCurrent() defer interpreter.activations.Pop() @@ -339,28 +339,35 @@ func (interpreter *Interpreter) VisitForStatement(statement *ast.ForStatement) S } forStmtTypes := interpreter.Program.Elaboration.ForStatementType(statement) - iterator := iterable.Iterator(interpreter, forStmtTypes.ValueVariableType, locationRange) var index IntValue if statement.Index != nil { index = NewIntValueFromInt64(interpreter, 0) } - for { - value := iterator.Next(interpreter) - if value == nil { - return nil - } - + executeBody := func(value Value) (resume bool) { statementResult, done := interpreter.visitForStatementBody(statement, index, value) if done { - return statementResult + result = statementResult } + resume = !done + if statement.Index != nil { index = index.Plus(interpreter, intOne, locationRange).(IntValue) } + + return } + + iterable.ForEach( + interpreter, + forStmtTypes.ValueVariableType, + executeBody, + locationRange, + ) + + return } func (interpreter *Interpreter) visitForStatementBody( diff --git a/runtime/interpreter/value.go b/runtime/interpreter/value.go index 829352ce29..01126bbaa7 100644 --- a/runtime/interpreter/value.go +++ b/runtime/interpreter/value.go @@ -234,7 +234,13 @@ type ContractValue interface { // IterableValue is a value which can be iterated over, e.g. with a for-loop type IterableValue interface { Value - Iterator(interpreter *Interpreter, resultType sema.Type, locationRange LocationRange) ValueIterator + Iterator(interpreter *Interpreter) ValueIterator + ForEach( + interpreter *Interpreter, + elementType sema.Type, + function func(value Value) (resume bool), + locationRange LocationRange, + ) } // ValueIterator is an iterator which returns values. @@ -1580,12 +1586,31 @@ func (v *StringValue) ConformsToStaticType( return true } -func (v *StringValue) Iterator(_ *Interpreter, _ sema.Type, _ LocationRange) ValueIterator { +func (v *StringValue) Iterator(_ *Interpreter) ValueIterator { return StringValueIterator{ graphemes: uniseg.NewGraphemes(v.Str), } } +func (v *StringValue) ForEach( + interpreter *Interpreter, + _ sema.Type, + function func(value Value) (resume bool), + _ LocationRange, +) { + iterator := v.Iterator(interpreter) + for { + value := iterator.Next(interpreter) + if value == nil { + return + } + + if !function(value) { + return + } + } +} + type StringValueIterator struct { graphemes *uniseg.Graphemes } @@ -1614,7 +1639,7 @@ type ArrayValueIterator struct { atreeIterator *atree.ArrayIterator } -func (v *ArrayValue) Iterator(_ *Interpreter, _ sema.Type, _ LocationRange) ValueIterator { +func (v *ArrayValue) Iterator(_ *Interpreter) ValueIterator { arrayIterator, err := v.array.Iterator() if err != nil { panic(errors.NewExternalError(err)) @@ -3244,6 +3269,15 @@ func (v *ArrayValue) Map( ) } +func (v *ArrayValue) ForEach( + interpreter *Interpreter, + _ sema.Type, + function func(value Value) (resume bool), + _ LocationRange, +) { + v.Iterate(interpreter, function) +} + // NumberValue type NumberValue interface { ComparableValue @@ -20197,35 +20231,60 @@ func (*StorageReferenceValue) DeepRemove(_ *Interpreter) { func (*StorageReferenceValue) isReference() {} -func (v *StorageReferenceValue) Iterator( +func (v *StorageReferenceValue) Iterator(_ *Interpreter) ValueIterator { + // Not used for now + panic(errors.NewUnreachableError()) +} + +func (v *StorageReferenceValue) ForEach( interpreter *Interpreter, - resultType sema.Type, + elementType sema.Type, + function func(value Value) (resume bool), locationRange LocationRange, -) ValueIterator { +) { referencedValue := v.mustReferencedValue(interpreter, locationRange) - return referenceValueIterator(interpreter, referencedValue, resultType, locationRange) + forEachReference( + interpreter, + referencedValue, + elementType, + function, + locationRange, + ) } -func referenceValueIterator( +func forEachReference( interpreter *Interpreter, referencedValue Value, - resultType sema.Type, + elementType sema.Type, + function func(value Value) (resume bool), locationRange LocationRange, -) ValueIterator { +) { referencedIterable, ok := referencedValue.(IterableValue) if !ok { panic(errors.NewUnreachableError()) } - referencedValueIterator := referencedIterable.Iterator(interpreter, resultType, locationRange) + referenceType, isResultReference := sema.MaybeReferenceType(elementType) + + updatedFunction := func(value Value) (resume bool) { + if isResultReference { + value = interpreter.getReferenceValue(value, elementType) + } - _, isResultReference := sema.GetReferenceType(resultType) + return function(value) + } - return ReferenceValueIterator{ - iterator: referencedValueIterator, - resultType: resultType, - isResultReference: isResultReference, + referencedElementType := elementType + if isResultReference { + referencedElementType = referenceType.Type } + + referencedIterable.ForEach( + interpreter, + referencedElementType, + updatedFunction, + locationRange, + ) } // EphemeralReferenceValue @@ -20574,38 +20633,25 @@ func (*EphemeralReferenceValue) DeepRemove(_ *Interpreter) { func (*EphemeralReferenceValue) isReference() {} -func (v *EphemeralReferenceValue) Iterator( +func (v *EphemeralReferenceValue) Iterator(_ *Interpreter) ValueIterator { + // Not used for now + panic(errors.NewUnreachableError()) +} + +func (v *EphemeralReferenceValue) ForEach( interpreter *Interpreter, - resultType sema.Type, + elementType sema.Type, + function func(value Value) (resume bool), locationRange LocationRange, -) ValueIterator { +) { referencedValue := v.MustReferencedValue(interpreter, locationRange) - return referenceValueIterator(interpreter, referencedValue, resultType, locationRange) -} - -// ReferenceValueIterator - -type ReferenceValueIterator struct { - iterator ValueIterator - resultType sema.Type - isResultReference bool -} - -var _ ValueIterator = ReferenceValueIterator{} - -func (i ReferenceValueIterator) Next(interpreter *Interpreter) Value { - element := i.iterator.Next(interpreter) - - if element == nil { - return nil - } - - // For non-primitive values, return a reference. - if i.isResultReference { - return interpreter.getReferenceValue(element, i.resultType) - } - - return element + forEachReference( + interpreter, + referencedValue, + elementType, + function, + locationRange, + ) } // AddressValue diff --git a/runtime/sema/check_assignment.go b/runtime/sema/check_assignment.go index 8745f1fea7..9c3c141978 100644 --- a/runtime/sema/check_assignment.go +++ b/runtime/sema/check_assignment.go @@ -329,7 +329,7 @@ func (checker *Checker) visitIndexExpressionAssignment( elementType = checker.visitIndexExpression(indexExpression, true) indexExprTypes := checker.Elaboration.IndexExpressionTypes(indexExpression) - indexedRefType, isReference := GetReferenceType(indexExprTypes.IndexedType) + indexedRefType, isReference := MaybeReferenceType(indexExprTypes.IndexedType) if isReference && !mutableEntitledAccess.PermitsAccess(indexedRefType.Authorization) && diff --git a/runtime/sema/check_member_expression.go b/runtime/sema/check_member_expression.go index 2cb3ab83ce..b4fbb5ca8f 100644 --- a/runtime/sema/check_member_expression.go +++ b/runtime/sema/check_member_expression.go @@ -116,14 +116,14 @@ func shouldReturnReference(parentType, memberType Type, isAssignment bool) bool return false } - if _, isReference := GetReferenceType(parentType); !isReference { + if _, isReference := MaybeReferenceType(parentType); !isReference { return false } return memberType.ContainFieldsOrElements() } -func GetReferenceType(typ Type) (*ReferenceType, bool) { +func MaybeReferenceType(typ Type) (*ReferenceType, bool) { unwrappedType := UnwrapOptionalType(typ) refType, isReference := unwrappedType.(*ReferenceType) return refType, isReference diff --git a/runtime/sema/check_variable_declaration.go b/runtime/sema/check_variable_declaration.go index e5722ebf75..c7eba52dda 100644 --- a/runtime/sema/check_variable_declaration.go +++ b/runtime/sema/check_variable_declaration.go @@ -264,7 +264,7 @@ func (checker *Checker) recordReference(targetVariable *Variable, expr ast.Expre return } - if _, isReference := GetReferenceType(targetVariable.Type); !isReference { + if _, isReference := MaybeReferenceType(targetVariable.Type); !isReference { return } diff --git a/runtime/tests/interpreter/for_test.go b/runtime/tests/interpreter/for_test.go index 7c4738c0fd..4689ea3dc0 100644 --- a/runtime/tests/interpreter/for_test.go +++ b/runtime/tests/interpreter/for_test.go @@ -468,7 +468,7 @@ func TestInterpretEphemeralReferencesInForLoop(t *testing.T) { require.NoError(t, err) }) - t.Run("Moved resource element", func(t *testing.T) { + t.Run("Mutating reference to resource array", func(t *testing.T) { t.Parallel() inter := parseCheckAndInterpret(t, ` @@ -482,6 +482,7 @@ func TestInterpretEphemeralReferencesInForLoop(t *testing.T) { for element in arrayRef { // Move the actual element + // This mutation should fail. let oldElement <- arrayRef.remove(at: 0) // Use the element reference @@ -495,7 +496,33 @@ func TestInterpretEphemeralReferencesInForLoop(t *testing.T) { `) _, err := inter.Invoke("main") - require.ErrorAs(t, err, &interpreter.InvalidatedResourceReferenceError{}) + require.ErrorAs(t, err, &interpreter.ContainerMutatedDuringIterationError{}) + }) + + t.Run("Mutating reference to struct array", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + struct Foo{ + fun sayHello() {} + } + + fun main() { + let array = [Foo()] + let arrayRef = &array as auth(Mutate) &[Foo] + + for element in arrayRef { + // Move the actual element + let oldElement = arrayRef.remove(at: 0) + + // Use the element reference + element.sayHello() + } + } + `) + + _, err := inter.Invoke("main") + require.NoError(t, err) }) } @@ -510,7 +537,7 @@ func TestInterpretStorageReferencesInForLoop(t *testing.T) { inter, _ := testAccount(t, address, true, nil, ` fun test() { - var let = ["Hello", "World", "Foo", "Bar"] + let array = ["Hello", "World", "Foo", "Bar"] account.storage.save(array, to: /storage/array) let arrayRef = account.storage.borrow<&[String]>(from: /storage/array)! @@ -533,7 +560,7 @@ func TestInterpretStorageReferencesInForLoop(t *testing.T) { struct Foo{} fun test() { - var let = [Foo(), Foo()] + let array = [Foo(), Foo()] account.storage.save(array, to: /storage/array) let arrayRef = account.storage.borrow<&[Foo]>(from: /storage/array)! @@ -556,7 +583,7 @@ func TestInterpretStorageReferencesInForLoop(t *testing.T) { resource Foo{} fun test() { - var let <- [ <- create Foo(), <- create Foo()] + let array <- [ <- create Foo(), <- create Foo()] account.storage.save(<- array, to: /storage/array) let arrayRef = account.storage.borrow<&[Foo]>(from: /storage/array)! @@ -579,7 +606,7 @@ func TestInterpretStorageReferencesInForLoop(t *testing.T) { resource Foo{} fun test() { - var let <- [ <- create Foo(), <- create Foo()] + let array <- [ <- create Foo(), <- create Foo()] account.storage.save(<- array, to: /storage/array) let arrayRef = account.storage.borrow<&[Foo]>(from: /storage/array)! From e8fbb185a9bc666962d15f4f3a3ab7f254120398 Mon Sep 17 00:00:00 2001 From: Supun Setunga Date: Fri, 20 Oct 2023 11:03:53 -0700 Subject: [PATCH 7/7] Add more tests for string looping --- runtime/tests/interpreter/for_test.go | 162 +++++++++++++++++++++----- 1 file changed, 136 insertions(+), 26 deletions(-) diff --git a/runtime/tests/interpreter/for_test.go b/runtime/tests/interpreter/for_test.go index 4689ea3dc0..ab14bd90d8 100644 --- a/runtime/tests/interpreter/for_test.go +++ b/runtime/tests/interpreter/for_test.go @@ -226,35 +226,105 @@ func TestInterpretForString(t *testing.T) { t.Parallel() - inter := parseCheckAndInterpret(t, ` - fun test(): [Character] { - let characters: [Character] = [] - let hello = "👪❤️" - for c in hello { - characters.append(c) - } - return characters - } - `) + t.Run("basic", func(t *testing.T) { - value, err := inter.Invoke("test") - require.NoError(t, err) + inter := parseCheckAndInterpret(t, ` + fun test(): [Character] { + let characters: [Character] = [] + let hello = "👪❤️" + for c in hello { + characters.append(c) + } + return characters + } + `) - RequireValuesEqual( - t, - inter, - interpreter.NewArrayValue( + value, err := inter.Invoke("test") + require.NoError(t, err) + + RequireValuesEqual( + t, inter, - interpreter.EmptyLocationRange, - &interpreter.VariableSizedStaticType{ - Type: interpreter.PrimitiveStaticTypeCharacter, - }, - common.ZeroAddress, - interpreter.NewUnmeteredCharacterValue("👪"), - interpreter.NewUnmeteredCharacterValue("❤️"), - ), - value, - ) + interpreter.NewArrayValue( + inter, + interpreter.EmptyLocationRange, + &interpreter.VariableSizedStaticType{ + Type: interpreter.PrimitiveStaticTypeCharacter, + }, + common.ZeroAddress, + interpreter.NewUnmeteredCharacterValue("👪"), + interpreter.NewUnmeteredCharacterValue("❤️"), + ), + value, + ) + }) + + t.Run("return", func(t *testing.T) { + + inter := parseCheckAndInterpret(t, ` + fun test(): [Character] { + let characters: [Character] = [] + let hello = "abc" + for c in hello { + characters.append(c) + return characters + } + return characters + } + `) + + value, err := inter.Invoke("test") + require.NoError(t, err) + + RequireValuesEqual( + t, + inter, + interpreter.NewArrayValue( + inter, + interpreter.EmptyLocationRange, + &interpreter.VariableSizedStaticType{ + Type: interpreter.PrimitiveStaticTypeCharacter, + }, + common.ZeroAddress, + interpreter.NewUnmeteredCharacterValue("a"), + ), + value, + ) + }) + + t.Run("break", func(t *testing.T) { + + inter := parseCheckAndInterpret(t, ` + fun test(): [Character] { + let characters: [Character] = [] + let hello = "abc" + for c in hello { + characters.append(c) + break + } + return characters + } + `) + + value, err := inter.Invoke("test") + require.NoError(t, err) + + RequireValuesEqual( + t, + inter, + interpreter.NewArrayValue( + inter, + interpreter.EmptyLocationRange, + &interpreter.VariableSizedStaticType{ + Type: interpreter.PrimitiveStaticTypeCharacter, + }, + common.ZeroAddress, + interpreter.NewUnmeteredCharacterValue("a"), + ), + value, + ) + }) + } func TestInterpretForStatementCapturing(t *testing.T) { @@ -524,6 +594,46 @@ func TestInterpretEphemeralReferencesInForLoop(t *testing.T) { _, err := inter.Invoke("main") require.NoError(t, err) }) + + t.Run("String ref", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + fun main(): [Character] { + let s = "Hello" + let sRef = &s as &String + let characters: [Character] = [] + + for char in sRef { + characters.append(char) + } + + return characters + } + `) + + value, err := inter.Invoke("main") + require.NoError(t, err) + + RequireValuesEqual( + t, + inter, + interpreter.NewArrayValue( + inter, + interpreter.EmptyLocationRange, + &interpreter.VariableSizedStaticType{ + Type: interpreter.PrimitiveStaticTypeCharacter, + }, + common.ZeroAddress, + interpreter.NewUnmeteredCharacterValue("H"), + interpreter.NewUnmeteredCharacterValue("e"), + interpreter.NewUnmeteredCharacterValue("l"), + interpreter.NewUnmeteredCharacterValue("l"), + interpreter.NewUnmeteredCharacterValue("o"), + ), + value, + ) + }) } func TestInterpretStorageReferencesInForLoop(t *testing.T) {