Skip to content

Commit

Permalink
Prevent mutation while iterating
Browse files Browse the repository at this point in the history
  • Loading branch information
SupunS committed Oct 19, 2023
1 parent a7dc05e commit 3ec9749
Show file tree
Hide file tree
Showing 6 changed files with 144 additions and 64 deletions.
25 changes: 16 additions & 9 deletions runtime/interpreter/interpreter_statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ func (interpreter *Interpreter) VisitWhileStatement(statement *ast.WhileStatemen

var intOne = NewUnmeteredIntValueFromInt64(1)

func (interpreter *Interpreter) VisitForStatement(statement *ast.ForStatement) StatementResult {
func (interpreter *Interpreter) VisitForStatement(statement *ast.ForStatement) (result StatementResult) {

interpreter.activations.PushNewWithCurrent()
defer interpreter.activations.Pop()
Expand All @@ -339,28 +339,35 @@ func (interpreter *Interpreter) VisitForStatement(statement *ast.ForStatement) S
}

forStmtTypes := interpreter.Program.Elaboration.ForStatementType(statement)
iterator := iterable.Iterator(interpreter, forStmtTypes.ValueVariableType, locationRange)

var index IntValue
if statement.Index != nil {
index = NewIntValueFromInt64(interpreter, 0)
}

for {
value := iterator.Next(interpreter)
if value == nil {
return nil
}

executeBody := func(value Value) (resume bool) {
statementResult, done := interpreter.visitForStatementBody(statement, index, value)
if done {
return statementResult
result = statementResult
}

resume = !done

if statement.Index != nil {
index = index.Plus(interpreter, intOne, locationRange).(IntValue)
}

return
}

iterable.ForEach(
interpreter,
forStmtTypes.ValueVariableType,
executeBody,
locationRange,
)

return
}

func (interpreter *Interpreter) visitForStatementBody(
Expand Down
136 changes: 91 additions & 45 deletions runtime/interpreter/value.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,13 @@ type ContractValue interface {
// IterableValue is a value which can be iterated over, e.g. with a for-loop
type IterableValue interface {
Value
Iterator(interpreter *Interpreter, resultType sema.Type, locationRange LocationRange) ValueIterator
Iterator(interpreter *Interpreter) ValueIterator
ForEach(
interpreter *Interpreter,
elementType sema.Type,
function func(value Value) (resume bool),
locationRange LocationRange,
)
}

// ValueIterator is an iterator which returns values.
Expand Down Expand Up @@ -1580,12 +1586,31 @@ func (v *StringValue) ConformsToStaticType(
return true
}

func (v *StringValue) Iterator(_ *Interpreter, _ sema.Type, _ LocationRange) ValueIterator {
func (v *StringValue) Iterator(_ *Interpreter) ValueIterator {
return StringValueIterator{
graphemes: uniseg.NewGraphemes(v.Str),
}
}

func (v *StringValue) ForEach(
interpreter *Interpreter,
_ sema.Type,
function func(value Value) (resume bool),
_ LocationRange,
) {
iterator := v.Iterator(interpreter)
for {
value := iterator.Next(interpreter)
if value == nil {
return
}

if !function(value) {
return
}
}
}

type StringValueIterator struct {
graphemes *uniseg.Graphemes
}
Expand Down Expand Up @@ -1614,7 +1639,7 @@ type ArrayValueIterator struct {
atreeIterator *atree.ArrayIterator
}

func (v *ArrayValue) Iterator(_ *Interpreter, _ sema.Type, _ LocationRange) ValueIterator {
func (v *ArrayValue) Iterator(_ *Interpreter) ValueIterator {
arrayIterator, err := v.array.Iterator()
if err != nil {
panic(errors.NewExternalError(err))
Expand Down Expand Up @@ -3244,6 +3269,15 @@ func (v *ArrayValue) Map(
)
}

func (v *ArrayValue) ForEach(
interpreter *Interpreter,
_ sema.Type,
function func(value Value) (resume bool),
_ LocationRange,
) {
v.Iterate(interpreter, function)
}

// NumberValue
type NumberValue interface {
ComparableValue
Expand Down Expand Up @@ -20197,35 +20231,60 @@ func (*StorageReferenceValue) DeepRemove(_ *Interpreter) {

func (*StorageReferenceValue) isReference() {}

func (v *StorageReferenceValue) Iterator(
func (v *StorageReferenceValue) Iterator(_ *Interpreter) ValueIterator {
// Not used for now
panic(errors.NewUnreachableError())
}

func (v *StorageReferenceValue) ForEach(
interpreter *Interpreter,
resultType sema.Type,
elementType sema.Type,
function func(value Value) (resume bool),
locationRange LocationRange,
) ValueIterator {
) {
referencedValue := v.mustReferencedValue(interpreter, locationRange)
return referenceValueIterator(interpreter, referencedValue, resultType, locationRange)
forEachReference(
interpreter,
referencedValue,
elementType,
function,
locationRange,
)
}

func referenceValueIterator(
func forEachReference(
interpreter *Interpreter,
referencedValue Value,
resultType sema.Type,
elementType sema.Type,
function func(value Value) (resume bool),
locationRange LocationRange,
) ValueIterator {
) {
referencedIterable, ok := referencedValue.(IterableValue)
if !ok {
panic(errors.NewUnreachableError())
}

referencedValueIterator := referencedIterable.Iterator(interpreter, resultType, locationRange)
referenceType, isResultReference := sema.MaybeReferenceType(elementType)

updatedFunction := func(value Value) (resume bool) {
if isResultReference {
value = interpreter.getReferenceValue(value, elementType)
}

_, isResultReference := sema.GetReferenceType(resultType)
return function(value)
}

return ReferenceValueIterator{
iterator: referencedValueIterator,
resultType: resultType,
isResultReference: isResultReference,
referencedElementType := elementType
if isResultReference {
referencedElementType = referenceType.Type
}

referencedIterable.ForEach(
interpreter,
referencedElementType,
updatedFunction,
locationRange,
)
}

// EphemeralReferenceValue
Expand Down Expand Up @@ -20574,38 +20633,25 @@ func (*EphemeralReferenceValue) DeepRemove(_ *Interpreter) {

func (*EphemeralReferenceValue) isReference() {}

func (v *EphemeralReferenceValue) Iterator(
func (v *EphemeralReferenceValue) Iterator(_ *Interpreter) ValueIterator {
// Not used for now
panic(errors.NewUnreachableError())
}

func (v *EphemeralReferenceValue) ForEach(
interpreter *Interpreter,
resultType sema.Type,
elementType sema.Type,
function func(value Value) (resume bool),
locationRange LocationRange,
) ValueIterator {
) {
referencedValue := v.MustReferencedValue(interpreter, locationRange)
return referenceValueIterator(interpreter, referencedValue, resultType, locationRange)
}

// ReferenceValueIterator

type ReferenceValueIterator struct {
iterator ValueIterator
resultType sema.Type
isResultReference bool
}

var _ ValueIterator = ReferenceValueIterator{}

func (i ReferenceValueIterator) Next(interpreter *Interpreter) Value {
element := i.iterator.Next(interpreter)

if element == nil {
return nil
}

// For non-primitive values, return a reference.
if i.isResultReference {
return interpreter.getReferenceValue(element, i.resultType)
}

return element
forEachReference(
interpreter,
referencedValue,
elementType,
function,
locationRange,
)
}

// AddressValue
Expand Down
2 changes: 1 addition & 1 deletion runtime/sema/check_assignment.go
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ func (checker *Checker) visitIndexExpressionAssignment(
elementType = checker.visitIndexExpression(indexExpression, true)

indexExprTypes := checker.Elaboration.IndexExpressionTypes(indexExpression)
indexedRefType, isReference := GetReferenceType(indexExprTypes.IndexedType)
indexedRefType, isReference := MaybeReferenceType(indexExprTypes.IndexedType)

if isReference &&
!mutableEntitledAccess.PermitsAccess(indexedRefType.Authorization) &&
Expand Down
4 changes: 2 additions & 2 deletions runtime/sema/check_member_expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,14 @@ func shouldReturnReference(parentType, memberType Type, isAssignment bool) bool
return false
}

if _, isReference := GetReferenceType(parentType); !isReference {
if _, isReference := MaybeReferenceType(parentType); !isReference {
return false
}

return memberType.ContainFieldsOrElements()
}

func GetReferenceType(typ Type) (*ReferenceType, bool) {
func MaybeReferenceType(typ Type) (*ReferenceType, bool) {
unwrappedType := UnwrapOptionalType(typ)
refType, isReference := unwrappedType.(*ReferenceType)
return refType, isReference
Expand Down
2 changes: 1 addition & 1 deletion runtime/sema/check_variable_declaration.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ func (checker *Checker) recordReference(targetVariable *Variable, expr ast.Expre
return
}

if _, isReference := GetReferenceType(targetVariable.Type); !isReference {
if _, isReference := MaybeReferenceType(targetVariable.Type); !isReference {
return
}

Expand Down
39 changes: 33 additions & 6 deletions runtime/tests/interpreter/for_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ func TestInterpretEphemeralReferencesInForLoop(t *testing.T) {
require.NoError(t, err)
})

t.Run("Moved resource element", func(t *testing.T) {
t.Run("Mutating reference to resource array", func(t *testing.T) {
t.Parallel()

inter := parseCheckAndInterpret(t, `
Expand All @@ -482,6 +482,7 @@ func TestInterpretEphemeralReferencesInForLoop(t *testing.T) {
for element in arrayRef {
// Move the actual element
// This mutation should fail.
let oldElement <- arrayRef.remove(at: 0)
// Use the element reference
Expand All @@ -495,7 +496,33 @@ func TestInterpretEphemeralReferencesInForLoop(t *testing.T) {
`)

_, err := inter.Invoke("main")
require.ErrorAs(t, err, &interpreter.InvalidatedResourceReferenceError{})
require.ErrorAs(t, err, &interpreter.ContainerMutatedDuringIterationError{})
})

t.Run("Mutating reference to struct array", func(t *testing.T) {
t.Parallel()

inter := parseCheckAndInterpret(t, `
struct Foo{
fun sayHello() {}
}
fun main() {
let array = [Foo()]
let arrayRef = &array as auth(Mutate) &[Foo]
for element in arrayRef {
// Move the actual element
let oldElement = arrayRef.remove(at: 0)
// Use the element reference
element.sayHello()
}
}
`)

_, err := inter.Invoke("main")
require.NoError(t, err)
})
}

Expand All @@ -510,7 +537,7 @@ func TestInterpretStorageReferencesInForLoop(t *testing.T) {

inter, _ := testAccount(t, address, true, nil, `
fun test() {
var let = ["Hello", "World", "Foo", "Bar"]
let array = ["Hello", "World", "Foo", "Bar"]
account.storage.save(array, to: /storage/array)
let arrayRef = account.storage.borrow<&[String]>(from: /storage/array)!
Expand All @@ -533,7 +560,7 @@ func TestInterpretStorageReferencesInForLoop(t *testing.T) {
struct Foo{}
fun test() {
var let = [Foo(), Foo()]
let array = [Foo(), Foo()]
account.storage.save(array, to: /storage/array)
let arrayRef = account.storage.borrow<&[Foo]>(from: /storage/array)!
Expand All @@ -556,7 +583,7 @@ func TestInterpretStorageReferencesInForLoop(t *testing.T) {
resource Foo{}
fun test() {
var let <- [ <- create Foo(), <- create Foo()]
let array <- [ <- create Foo(), <- create Foo()]
account.storage.save(<- array, to: /storage/array)
let arrayRef = account.storage.borrow<&[Foo]>(from: /storage/array)!
Expand All @@ -579,7 +606,7 @@ func TestInterpretStorageReferencesInForLoop(t *testing.T) {
resource Foo{}
fun test() {
var let <- [ <- create Foo(), <- create Foo()]
let array <- [ <- create Foo(), <- create Foo()]
account.storage.save(<- array, to: /storage/array)
let arrayRef = account.storage.borrow<&[Foo]>(from: /storage/array)!
Expand Down

0 comments on commit 3ec9749

Please sign in to comment.