Skip to content

Commit

Permalink
Handle optionals and resource tracking
Browse files Browse the repository at this point in the history
  • Loading branch information
SupunS committed Oct 18, 2023
1 parent c659d79 commit a7dc05e
Show file tree
Hide file tree
Showing 9 changed files with 187 additions and 43 deletions.
3 changes: 2 additions & 1 deletion runtime/interpreter/interpreter_statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,8 @@ func (interpreter *Interpreter) VisitForStatement(statement *ast.ForStatement) S
panic(errors.NewUnreachableError())
}

iterator := iterable.Iterator(interpreter, locationRange)
forStmtTypes := interpreter.Program.Elaboration.ForStatementType(statement)
iterator := iterable.Iterator(interpreter, forStmtTypes.ValueVariableType, locationRange)

var index IntValue
if statement.Index != nil {
Expand Down
47 changes: 26 additions & 21 deletions runtime/interpreter/value.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ type ContractValue interface {
// IterableValue is a value which can be iterated over, e.g. with a for-loop
type IterableValue interface {
Value
Iterator(interpreter *Interpreter, locationRange LocationRange) ValueIterator
Iterator(interpreter *Interpreter, resultType sema.Type, locationRange LocationRange) ValueIterator
}

// ValueIterator is an iterator which returns values.
Expand Down Expand Up @@ -1580,7 +1580,7 @@ func (v *StringValue) ConformsToStaticType(
return true
}

func (v *StringValue) Iterator(_ *Interpreter, _ LocationRange) ValueIterator {
func (v *StringValue) Iterator(_ *Interpreter, _ sema.Type, _ LocationRange) ValueIterator {
return StringValueIterator{
graphemes: uniseg.NewGraphemes(v.Str),
}
Expand Down Expand Up @@ -1614,7 +1614,7 @@ type ArrayValueIterator struct {
atreeIterator *atree.ArrayIterator
}

func (v *ArrayValue) Iterator(_ *Interpreter, _ LocationRange) ValueIterator {
func (v *ArrayValue) Iterator(_ *Interpreter, _ sema.Type, _ LocationRange) ValueIterator {
arrayIterator, err := v.array.Iterator()
if err != nil {
panic(errors.NewExternalError(err))
Expand Down Expand Up @@ -20197,34 +20197,34 @@ func (*StorageReferenceValue) DeepRemove(_ *Interpreter) {

func (*StorageReferenceValue) isReference() {}

func (v *StorageReferenceValue) Iterator(interpreter *Interpreter, locationRange LocationRange) ValueIterator {
func (v *StorageReferenceValue) Iterator(
interpreter *Interpreter,
resultType sema.Type,
locationRange LocationRange,
) ValueIterator {
referencedValue := v.mustReferencedValue(interpreter, locationRange)
return referenceValueIterator(interpreter, referencedValue, v.BorrowedType, locationRange)
return referenceValueIterator(interpreter, referencedValue, resultType, locationRange)
}

func referenceValueIterator(
interpreter *Interpreter,
referencedValue Value,
borrowedType sema.Type,
resultType sema.Type,
locationRange LocationRange,
) ValueIterator {
referencedIterable, ok := referencedValue.(IterableValue)
if !ok {
panic(errors.NewUnreachableError())
}

referencedValueIterator := referencedIterable.Iterator(interpreter, locationRange)
referencedValueIterator := referencedIterable.Iterator(interpreter, resultType, locationRange)

referencedType, ok := borrowedType.(sema.ValueIndexableType)
if !ok {
panic(errors.NewUnreachableError())
}

elementType := referencedType.ElementType(false)
_, isResultReference := sema.GetReferenceType(resultType)

return ReferenceValueIterator{
iterator: referencedValueIterator,
elementType: elementType,
iterator: referencedValueIterator,
resultType: resultType,
isResultReference: isResultReference,
}
}

Expand Down Expand Up @@ -20574,16 +20574,21 @@ func (*EphemeralReferenceValue) DeepRemove(_ *Interpreter) {

func (*EphemeralReferenceValue) isReference() {}

func (v *EphemeralReferenceValue) Iterator(interpreter *Interpreter, locationRange LocationRange) ValueIterator {
func (v *EphemeralReferenceValue) Iterator(
interpreter *Interpreter,
resultType sema.Type,
locationRange LocationRange,
) ValueIterator {
referencedValue := v.MustReferencedValue(interpreter, locationRange)
return referenceValueIterator(interpreter, referencedValue, v.BorrowedType, locationRange)
return referenceValueIterator(interpreter, referencedValue, resultType, locationRange)
}

// ReferenceValueIterator

type ReferenceValueIterator struct {
iterator ValueIterator
elementType sema.Type
iterator ValueIterator
resultType sema.Type
isResultReference bool
}

var _ ValueIterator = ReferenceValueIterator{}
Expand All @@ -20596,8 +20601,8 @@ func (i ReferenceValueIterator) Next(interpreter *Interpreter) Value {
}

// For non-primitive values, return a reference.
if i.elementType.ContainFieldsOrElements() {
return NewEphemeralReferenceValue(interpreter, UnauthorizedAccess, element, i.elementType)
if i.isResultReference {
return interpreter.getReferenceValue(element, i.resultType)
}

return element
Expand Down
2 changes: 1 addition & 1 deletion runtime/sema/check_assignment.go
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ func (checker *Checker) visitIndexExpressionAssignment(
elementType = checker.visitIndexExpression(indexExpression, true)

indexExprTypes := checker.Elaboration.IndexExpressionTypes(indexExpression)
indexedRefType, isReference := referenceType(indexExprTypes.IndexedType)
indexedRefType, isReference := GetReferenceType(indexExprTypes.IndexedType)

if isReference &&
!mutableEntitledAccess.PermitsAccess(indexedRefType.Authorization) &&
Expand Down
12 changes: 10 additions & 2 deletions runtime/sema/check_for.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,14 @@ func (checker *Checker) VisitForStatement(statement *ast.ForStatement) (_ struct
checker.recordVariableDeclarationOccurrence(identifier, variable)
}

var indexType Type

if statement.Index != nil {
index := statement.Index.Identifier
indexType = IntType
indexVariable, err := checker.valueActivations.declare(variableDeclaration{
identifier: index,
ty: IntType,
ty: indexType,
kind: common.DeclarationKindConstant,
pos: statement.Index.Pos,
isConstant: true,
Expand All @@ -84,6 +87,11 @@ func (checker *Checker) VisitForStatement(statement *ast.ForStatement) (_ struct
}
}

checker.Elaboration.SetForStatementType(statement, ForStatementTypes{
IndexVariableType: indexType,
ValueVariableType: loopVariableType,
})

// The body of the loop will maybe be evaluated.
// That means that resource invalidations and
// returns are not definite, but only potential.
Expand Down Expand Up @@ -133,7 +141,7 @@ func (checker *Checker) loopVariableType(valueType Type, hasPosition ast.HasPosi
// Case (a): Element type is a container type.
// Then the loop-var must also be a reference type.
if referencedIterableElementType.ContainFieldsOrElements() {
return NewReferenceType(checker.memoryGauge, UnauthorizedAccess, referencedIterableElementType)
return checker.getReferenceType(referencedIterableElementType, false, UnauthorizedAccess)
}

// Case (b): Element type is a primitive type.
Expand Down
4 changes: 2 additions & 2 deletions runtime/sema/check_member_expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,14 @@ func shouldReturnReference(parentType, memberType Type, isAssignment bool) bool
return false
}

if _, isReference := referenceType(parentType); !isReference {
if _, isReference := GetReferenceType(parentType); !isReference {
return false
}

return memberType.ContainFieldsOrElements()
}

func referenceType(typ Type) (*ReferenceType, bool) {
func GetReferenceType(typ Type) (*ReferenceType, bool) {
unwrappedType := UnwrapOptionalType(typ)
refType, isReference := unwrappedType.(*ReferenceType)
return refType, isReference
Expand Down
2 changes: 1 addition & 1 deletion runtime/sema/check_variable_declaration.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ func (checker *Checker) recordReference(targetVariable *Variable, expr ast.Expre
return
}

if _, isReference := referenceType(targetVariable.Type); !isReference {
if _, isReference := GetReferenceType(targetVariable.Type); !isReference {
return
}

Expand Down
20 changes: 20 additions & 0 deletions runtime/sema/elaboration.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,19 @@ type ExpressionTypes struct {
ExpectedType Type
}

type ForStatementTypes struct {
IndexVariableType Type
ValueVariableType Type
}

type Elaboration struct {
interfaceTypesAndDeclarationsBiMap *bimap.BiMap[*InterfaceType, *ast.InterfaceDeclaration]
entitlementTypesAndDeclarationsBiMap *bimap.BiMap[*EntitlementType, *ast.EntitlementDeclaration]
entitlementMapTypesAndDeclarationsBiMap *bimap.BiMap[*EntitlementMapType, *ast.EntitlementMappingDeclaration]

fixedPointExpressionTypes map[*ast.FixedPointExpression]Type
swapStatementTypes map[*ast.SwapStatement]SwapStatementTypes
forStatementTypes map[*ast.ForStatement]ForStatementTypes
assignmentStatementTypes map[*ast.AssignmentStatement]AssignmentStatementTypes
compositeDeclarationTypes map[ast.CompositeLikeDeclaration]*CompositeType
compositeTypeDeclarations map[*CompositeType]ast.CompositeLikeDeclaration
Expand Down Expand Up @@ -1032,3 +1038,17 @@ func (e *Elaboration) GetSemanticAccess(access ast.Access) (semaAccess Access, p
semaAccess, present = e.semanticAccesses[access]
return
}

func (e *Elaboration) SetForStatementType(statement *ast.ForStatement, types ForStatementTypes) {
if e.forStatementTypes == nil {
e.forStatementTypes = map[*ast.ForStatement]ForStatementTypes{}
}
e.forStatementTypes[statement] = types
}

func (e *Elaboration) ForStatementType(statement *ast.ForStatement) (types ForStatementTypes) {
if e.forStatementTypes == nil {
return
}
return e.forStatementTypes[statement]
}
19 changes: 19 additions & 0 deletions runtime/tests/checker/for_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -431,4 +431,23 @@ func TestCheckReferencesInForLoop(t *testing.T) {
errors := RequireCheckerErrors(t, err, 1)
assert.IsType(t, &sema.TypeMismatchError{}, errors[0])
})

t.Run("Optional array", func(t *testing.T) {
t.Parallel()

_, err := ParseAndCheck(t, `
struct Foo{}
fun main() {
var array: [Foo?] = [Foo(), Foo()]
var arrayRef = &array as &[Foo?]
for element in arrayRef {
let e: &Foo? = element // Should be an optional reference
}
}
`)

require.NoError(t, err)
})
}
Loading

0 comments on commit a7dc05e

Please sign in to comment.