diff --git a/runtime/interpreter/interpreter_statement.go b/runtime/interpreter/interpreter_statement.go index 7273cb307e..01f84b176c 100644 --- a/runtime/interpreter/interpreter_statement.go +++ b/runtime/interpreter/interpreter_statement.go @@ -338,7 +338,8 @@ func (interpreter *Interpreter) VisitForStatement(statement *ast.ForStatement) S panic(errors.NewUnreachableError()) } - iterator := iterable.Iterator(interpreter, locationRange) + forStmtTypes := interpreter.Program.Elaboration.ForStatementType(statement) + iterator := iterable.Iterator(interpreter, forStmtTypes.ValueVariableType, locationRange) var index IntValue if statement.Index != nil { diff --git a/runtime/interpreter/value.go b/runtime/interpreter/value.go index 64620efbe4..829352ce29 100644 --- a/runtime/interpreter/value.go +++ b/runtime/interpreter/value.go @@ -234,7 +234,7 @@ type ContractValue interface { // IterableValue is a value which can be iterated over, e.g. with a for-loop type IterableValue interface { Value - Iterator(interpreter *Interpreter, locationRange LocationRange) ValueIterator + Iterator(interpreter *Interpreter, resultType sema.Type, locationRange LocationRange) ValueIterator } // ValueIterator is an iterator which returns values. @@ -1580,7 +1580,7 @@ func (v *StringValue) ConformsToStaticType( return true } -func (v *StringValue) Iterator(_ *Interpreter, _ LocationRange) ValueIterator { +func (v *StringValue) Iterator(_ *Interpreter, _ sema.Type, _ LocationRange) ValueIterator { return StringValueIterator{ graphemes: uniseg.NewGraphemes(v.Str), } @@ -1614,7 +1614,7 @@ type ArrayValueIterator struct { atreeIterator *atree.ArrayIterator } -func (v *ArrayValue) Iterator(_ *Interpreter, _ LocationRange) ValueIterator { +func (v *ArrayValue) Iterator(_ *Interpreter, _ sema.Type, _ LocationRange) ValueIterator { arrayIterator, err := v.array.Iterator() if err != nil { panic(errors.NewExternalError(err)) @@ -20197,15 +20197,19 @@ func (*StorageReferenceValue) DeepRemove(_ *Interpreter) { func (*StorageReferenceValue) isReference() {} -func (v *StorageReferenceValue) Iterator(interpreter *Interpreter, locationRange LocationRange) ValueIterator { +func (v *StorageReferenceValue) Iterator( + interpreter *Interpreter, + resultType sema.Type, + locationRange LocationRange, +) ValueIterator { referencedValue := v.mustReferencedValue(interpreter, locationRange) - return referenceValueIterator(interpreter, referencedValue, v.BorrowedType, locationRange) + return referenceValueIterator(interpreter, referencedValue, resultType, locationRange) } func referenceValueIterator( interpreter *Interpreter, referencedValue Value, - borrowedType sema.Type, + resultType sema.Type, locationRange LocationRange, ) ValueIterator { referencedIterable, ok := referencedValue.(IterableValue) @@ -20213,18 +20217,14 @@ func referenceValueIterator( panic(errors.NewUnreachableError()) } - referencedValueIterator := referencedIterable.Iterator(interpreter, locationRange) + referencedValueIterator := referencedIterable.Iterator(interpreter, resultType, locationRange) - referencedType, ok := borrowedType.(sema.ValueIndexableType) - if !ok { - panic(errors.NewUnreachableError()) - } - - elementType := referencedType.ElementType(false) + _, isResultReference := sema.GetReferenceType(resultType) return ReferenceValueIterator{ - iterator: referencedValueIterator, - elementType: elementType, + iterator: referencedValueIterator, + resultType: resultType, + isResultReference: isResultReference, } } @@ -20574,16 +20574,21 @@ func (*EphemeralReferenceValue) DeepRemove(_ *Interpreter) { func (*EphemeralReferenceValue) isReference() {} -func (v *EphemeralReferenceValue) Iterator(interpreter *Interpreter, locationRange LocationRange) ValueIterator { +func (v *EphemeralReferenceValue) Iterator( + interpreter *Interpreter, + resultType sema.Type, + locationRange LocationRange, +) ValueIterator { referencedValue := v.MustReferencedValue(interpreter, locationRange) - return referenceValueIterator(interpreter, referencedValue, v.BorrowedType, locationRange) + return referenceValueIterator(interpreter, referencedValue, resultType, locationRange) } // ReferenceValueIterator type ReferenceValueIterator struct { - iterator ValueIterator - elementType sema.Type + iterator ValueIterator + resultType sema.Type + isResultReference bool } var _ ValueIterator = ReferenceValueIterator{} @@ -20596,8 +20601,8 @@ func (i ReferenceValueIterator) Next(interpreter *Interpreter) Value { } // For non-primitive values, return a reference. - if i.elementType.ContainFieldsOrElements() { - return NewEphemeralReferenceValue(interpreter, UnauthorizedAccess, element, i.elementType) + if i.isResultReference { + return interpreter.getReferenceValue(element, i.resultType) } return element diff --git a/runtime/sema/check_assignment.go b/runtime/sema/check_assignment.go index 0b0a20fba1..8745f1fea7 100644 --- a/runtime/sema/check_assignment.go +++ b/runtime/sema/check_assignment.go @@ -329,7 +329,7 @@ func (checker *Checker) visitIndexExpressionAssignment( elementType = checker.visitIndexExpression(indexExpression, true) indexExprTypes := checker.Elaboration.IndexExpressionTypes(indexExpression) - indexedRefType, isReference := referenceType(indexExprTypes.IndexedType) + indexedRefType, isReference := GetReferenceType(indexExprTypes.IndexedType) if isReference && !mutableEntitledAccess.PermitsAccess(indexedRefType.Authorization) && diff --git a/runtime/sema/check_for.go b/runtime/sema/check_for.go index 68af706468..da3fe156fb 100644 --- a/runtime/sema/check_for.go +++ b/runtime/sema/check_for.go @@ -66,11 +66,14 @@ func (checker *Checker) VisitForStatement(statement *ast.ForStatement) (_ struct checker.recordVariableDeclarationOccurrence(identifier, variable) } + var indexType Type + if statement.Index != nil { index := statement.Index.Identifier + indexType = IntType indexVariable, err := checker.valueActivations.declare(variableDeclaration{ identifier: index, - ty: IntType, + ty: indexType, kind: common.DeclarationKindConstant, pos: statement.Index.Pos, isConstant: true, @@ -84,6 +87,11 @@ func (checker *Checker) VisitForStatement(statement *ast.ForStatement) (_ struct } } + checker.Elaboration.SetForStatementType(statement, ForStatementTypes{ + IndexVariableType: indexType, + ValueVariableType: loopVariableType, + }) + // The body of the loop will maybe be evaluated. // That means that resource invalidations and // returns are not definite, but only potential. @@ -133,7 +141,7 @@ func (checker *Checker) loopVariableType(valueType Type, hasPosition ast.HasPosi // Case (a): Element type is a container type. // Then the loop-var must also be a reference type. if referencedIterableElementType.ContainFieldsOrElements() { - return NewReferenceType(checker.memoryGauge, UnauthorizedAccess, referencedIterableElementType) + return checker.getReferenceType(referencedIterableElementType, false, UnauthorizedAccess) } // Case (b): Element type is a primitive type. diff --git a/runtime/sema/check_member_expression.go b/runtime/sema/check_member_expression.go index e673a579a1..2cb3ab83ce 100644 --- a/runtime/sema/check_member_expression.go +++ b/runtime/sema/check_member_expression.go @@ -116,14 +116,14 @@ func shouldReturnReference(parentType, memberType Type, isAssignment bool) bool return false } - if _, isReference := referenceType(parentType); !isReference { + if _, isReference := GetReferenceType(parentType); !isReference { return false } return memberType.ContainFieldsOrElements() } -func referenceType(typ Type) (*ReferenceType, bool) { +func GetReferenceType(typ Type) (*ReferenceType, bool) { unwrappedType := UnwrapOptionalType(typ) refType, isReference := unwrappedType.(*ReferenceType) return refType, isReference diff --git a/runtime/sema/check_variable_declaration.go b/runtime/sema/check_variable_declaration.go index 971abbbe5d..e5722ebf75 100644 --- a/runtime/sema/check_variable_declaration.go +++ b/runtime/sema/check_variable_declaration.go @@ -264,7 +264,7 @@ func (checker *Checker) recordReference(targetVariable *Variable, expr ast.Expre return } - if _, isReference := referenceType(targetVariable.Type); !isReference { + if _, isReference := GetReferenceType(targetVariable.Type); !isReference { return } diff --git a/runtime/sema/elaboration.go b/runtime/sema/elaboration.go index a64add82da..8211be53b2 100644 --- a/runtime/sema/elaboration.go +++ b/runtime/sema/elaboration.go @@ -111,6 +111,11 @@ type ExpressionTypes struct { ExpectedType Type } +type ForStatementTypes struct { + IndexVariableType Type + ValueVariableType Type +} + type Elaboration struct { interfaceTypesAndDeclarationsBiMap *bimap.BiMap[*InterfaceType, *ast.InterfaceDeclaration] entitlementTypesAndDeclarationsBiMap *bimap.BiMap[*EntitlementType, *ast.EntitlementDeclaration] @@ -118,6 +123,7 @@ type Elaboration struct { fixedPointExpressionTypes map[*ast.FixedPointExpression]Type swapStatementTypes map[*ast.SwapStatement]SwapStatementTypes + forStatementTypes map[*ast.ForStatement]ForStatementTypes assignmentStatementTypes map[*ast.AssignmentStatement]AssignmentStatementTypes compositeDeclarationTypes map[ast.CompositeLikeDeclaration]*CompositeType compositeTypeDeclarations map[*CompositeType]ast.CompositeLikeDeclaration @@ -1032,3 +1038,17 @@ func (e *Elaboration) GetSemanticAccess(access ast.Access) (semaAccess Access, p semaAccess, present = e.semanticAccesses[access] return } + +func (e *Elaboration) SetForStatementType(statement *ast.ForStatement, types ForStatementTypes) { + if e.forStatementTypes == nil { + e.forStatementTypes = map[*ast.ForStatement]ForStatementTypes{} + } + e.forStatementTypes[statement] = types +} + +func (e *Elaboration) ForStatementType(statement *ast.ForStatement) (types ForStatementTypes) { + if e.forStatementTypes == nil { + return + } + return e.forStatementTypes[statement] +} diff --git a/runtime/tests/checker/for_test.go b/runtime/tests/checker/for_test.go index d52c33e1d0..edb5148290 100644 --- a/runtime/tests/checker/for_test.go +++ b/runtime/tests/checker/for_test.go @@ -431,4 +431,23 @@ func TestCheckReferencesInForLoop(t *testing.T) { errors := RequireCheckerErrors(t, err, 1) assert.IsType(t, &sema.TypeMismatchError{}, errors[0]) }) + + t.Run("Optional array", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + struct Foo{} + + fun main() { + var array: [Foo?] = [Foo(), Foo()] + var arrayRef = &array as &[Foo?] + + for element in arrayRef { + let e: &Foo? = element // Should be an optional reference + } + } + `) + + require.NoError(t, err) + }) } diff --git a/runtime/tests/interpreter/for_test.go b/runtime/tests/interpreter/for_test.go index a30116a234..7c4738c0fd 100644 --- a/runtime/tests/interpreter/for_test.go +++ b/runtime/tests/interpreter/for_test.go @@ -305,8 +305,8 @@ func TestInterpretEphemeralReferencesInForLoop(t *testing.T) { inter := parseCheckAndInterpret(t, ` fun main() { - var array = ["Hello", "World", "Foo", "Bar"] - var arrayRef = &array as &[String] + let array = ["Hello", "World", "Foo", "Bar"] + let arrayRef = &array as &[String] for element in arrayRef { let e: String = element @@ -325,8 +325,8 @@ func TestInterpretEphemeralReferencesInForLoop(t *testing.T) { struct Foo{} fun main() { - var array = [Foo(), Foo()] - var arrayRef = &array as &[Foo] + let array = [Foo(), Foo()] + let arrayRef = &array as &[Foo] for element in arrayRef { let e: &Foo = element @@ -345,8 +345,8 @@ func TestInterpretEphemeralReferencesInForLoop(t *testing.T) { resource Foo{} fun main() { - var array <- [ <- create Foo(), <- create Foo()] - var arrayRef = &array as &[Foo] + let array <- [ <- create Foo(), <- create Foo()] + let arrayRef = &array as &[Foo] for element in arrayRef { let e: &Foo = element @@ -367,9 +367,9 @@ func TestInterpretEphemeralReferencesInForLoop(t *testing.T) { resource Foo{} fun main() { - var array <- [ <- create Foo(), <- create Foo()] - var arrayRef = returnSameRef(&array as &[Foo]) - var movedArray <- array + let array <- [ <- create Foo(), <- create Foo()] + let arrayRef = returnSameRef(&array as &[Foo]) + let movedArray <- array for element in arrayRef { let e: &Foo = element @@ -394,8 +394,8 @@ func TestInterpretEphemeralReferencesInForLoop(t *testing.T) { struct Foo{} fun main() { - var array = [Foo(), Foo()] - var arrayRef = &array as auth(Mutate) &[Foo] + let array = [Foo(), Foo()] + let arrayRef = &array as auth(Mutate) &[Foo] for element in arrayRef { let e: &Foo = element // Should be non-auth @@ -406,6 +406,97 @@ func TestInterpretEphemeralReferencesInForLoop(t *testing.T) { _, err := inter.Invoke("main") require.NoError(t, err) }) + + t.Run("Optional array", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + struct Foo{} + + fun main() { + let array: [Foo?] = [Foo(), Foo()] + let arrayRef = &array as &[Foo?] + + for element in arrayRef { + let e: &Foo? = element // Should be an optional reference + } + } + `) + + _, err := inter.Invoke("main") + require.NoError(t, err) + }) + + t.Run("Nil array", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + struct Foo{} + + fun main() { + let array: [Foo?] = [nil, nil] + let arrayRef = &array as &[Foo?] + + for element in arrayRef { + let e: &Foo? = element // Should be an optional reference + } + } + `) + + _, err := inter.Invoke("main") + require.NoError(t, err) + }) + + t.Run("Reference array", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + struct Foo{} + + fun main() { + let elementRef = &Foo() as &Foo + let array: [&Foo] = [elementRef, elementRef] + let arrayRef = &array as &[&Foo] + + for element in arrayRef { + let e: &Foo = element + } + } + `) + + _, err := inter.Invoke("main") + require.NoError(t, err) + }) + + t.Run("Moved resource element", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + resource Foo{ + fun sayHello() {} + } + + fun main() { + let array <- [ <- create Foo()] + let arrayRef = &array as auth(Mutate) &[Foo] + + for element in arrayRef { + // Move the actual element + let oldElement <- arrayRef.remove(at: 0) + + // Use the element reference + element.sayHello() + + destroy oldElement + } + + destroy array + } + `) + + _, err := inter.Invoke("main") + require.ErrorAs(t, err, &interpreter.InvalidatedResourceReferenceError{}) + }) } func TestInterpretStorageReferencesInForLoop(t *testing.T) { @@ -419,7 +510,7 @@ func TestInterpretStorageReferencesInForLoop(t *testing.T) { inter, _ := testAccount(t, address, true, nil, ` fun test() { - var array = ["Hello", "World", "Foo", "Bar"] + var let = ["Hello", "World", "Foo", "Bar"] account.storage.save(array, to: /storage/array) let arrayRef = account.storage.borrow<&[String]>(from: /storage/array)! @@ -442,7 +533,7 @@ func TestInterpretStorageReferencesInForLoop(t *testing.T) { struct Foo{} fun test() { - var array = [Foo(), Foo()] + var let = [Foo(), Foo()] account.storage.save(array, to: /storage/array) let arrayRef = account.storage.borrow<&[Foo]>(from: /storage/array)! @@ -465,7 +556,7 @@ func TestInterpretStorageReferencesInForLoop(t *testing.T) { resource Foo{} fun test() { - var array <- [ <- create Foo(), <- create Foo()] + var let <- [ <- create Foo(), <- create Foo()] account.storage.save(<- array, to: /storage/array) let arrayRef = account.storage.borrow<&[Foo]>(from: /storage/array)! @@ -488,7 +579,7 @@ func TestInterpretStorageReferencesInForLoop(t *testing.T) { resource Foo{} fun test() { - var array <- [ <- create Foo(), <- create Foo()] + var let <- [ <- create Foo(), <- create Foo()] account.storage.save(<- array, to: /storage/array) let arrayRef = account.storage.borrow<&[Foo]>(from: /storage/array)!