From d3de29174cec542bb548ac8a239eb53b55cbfce7 Mon Sep 17 00:00:00 2001 From: Faye Amacker <33205765+fxamacker@users.noreply.github.com> Date: Tue, 18 Jul 2023 14:54:24 -0500 Subject: [PATCH 1/4] Add ReadOnly iterators and refactor other iterators This change: - Adds ReadOnly iterators that match current iterator API (except for the "ReadOnly" suffix added to some function names). - Refactors API of non-Readonly iterators because register inlining will require more parameters for MapIterator. For ReadOnly iterators, the caller is responsible for preventing changes to child containers during iteration because mutations of child containers are not guaranteed to persist. For non-ReadOnly iterators, two additional parameters are needed to update child container in parent map when child container is modified. --- array.go | 87 +++++++++++++++++++++++++++++---------------- array_bench_test.go | 4 +-- array_test.go | 48 ++++++++++++------------- cmd/stress/utils.go | 12 +++---- map.go | 75 +++++++++++++++++++++++--------------- map_test.go | 38 ++++++++++---------- utils_test.go | 4 +-- 7 files changed, 156 insertions(+), 112 deletions(-) diff --git a/array.go b/array.go index 40bb8686..b63b3e6e 100644 --- a/array.go +++ b/array.go @@ -3117,6 +3117,7 @@ type ArrayIterator struct { dataSlab *ArrayDataSlab index int remainingCount int + readOnly bool } func (i *ArrayIterator) Next() (Value, error) { @@ -3179,6 +3180,19 @@ func (a *Array) Iterator() (*ArrayIterator, error) { }, nil } +// ReadOnlyIterator returns readonly iterator for array elements. +// If elements of child containers are mutated, those changes +// are not guaranteed to persist. +func (a *Array) ReadOnlyIterator() (*ArrayIterator, error) { + iterator, err := a.Iterator() + if err != nil { + // Don't need to wrap error as external error because err is already categorized by Iterator(). + return nil, err + } + iterator.readOnly = true + return iterator, nil +} + func (a *Array) RangeIterator(startIndex uint64, endIndex uint64) (*ArrayIterator, error) { count := a.Count() @@ -3229,16 +3243,18 @@ func (a *Array) RangeIterator(startIndex uint64, endIndex uint64) (*ArrayIterato }, nil } -type ArrayIterationFunc func(element Value) (resume bool, err error) - -func (a *Array) Iterate(fn ArrayIterationFunc) error { - - iterator, err := a.Iterator() +func (a *Array) ReadOnlyRangeIterator(startIndex uint64, endIndex uint64) (*ArrayIterator, error) { + iterator, err := a.RangeIterator(startIndex, endIndex) if err != nil { - // Don't need to wrap error as external error because err is already categorized by Array.Iterator(). - return err + return nil, err } + iterator.readOnly = true + return iterator, nil +} + +type ArrayIterationFunc func(element Value) (resume bool, err error) +func iterate(iterator *ArrayIterator, fn ArrayIterationFunc) error { for { value, err := iterator.Next() if err != nil { @@ -3259,33 +3275,42 @@ func (a *Array) Iterate(fn ArrayIterationFunc) error { } } -func (a *Array) IterateRange(startIndex uint64, endIndex uint64, fn ArrayIterationFunc) error { +func (a *Array) Iterate(fn ArrayIterationFunc) error { + iterator, err := a.Iterator() + if err != nil { + // Don't need to wrap error as external error because err is already categorized by Array.Iterator(). + return err + } + return iterate(iterator, fn) +} + +func (a *Array) IterateReadOnly(fn ArrayIterationFunc) error { + iterator, err := a.ReadOnlyIterator() + if err != nil { + // Don't need to wrap error as external error because err is already categorized by Array.ReadOnlyIterator(). + return err + } + return iterate(iterator, fn) +} +func (a *Array) IterateRange(startIndex uint64, endIndex uint64, fn ArrayIterationFunc) error { iterator, err := a.RangeIterator(startIndex, endIndex) if err != nil { // Don't need to wrap error as external error because err is already categorized by Array.RangeIterator(). return err } + return iterate(iterator, fn) +} - for { - value, err := iterator.Next() - if err != nil { - // Don't need to wrap error as external error because err is already categorized by ArrayIterator.Next(). - return err - } - if value == nil { - return nil - } - resume, err := fn(value) - if err != nil { - // Wrap err as external error (if needed) because err is returned by ArrayIterationFunc callback. - return wrapErrorAsExternalErrorIfNeeded(err) - } - if !resume { - return nil - } +func (a *Array) IterateReadOnlyRange(startIndex uint64, endIndex uint64, fn ArrayIterationFunc) error { + iterator, err := a.ReadOnlyRangeIterator(startIndex, endIndex) + if err != nil { + // Don't need to wrap error as external error because err is already categorized by Array.ReadOnlyRangeIterator(). + return err } + return iterate(iterator, fn) } + func (a *Array) Count() uint64 { return uint64(a.root.Header().count) } @@ -3309,7 +3334,7 @@ func (a *Array) Type() TypeInfo { } func (a *Array) String() string { - iterator, err := a.Iterator() + iterator, err := a.ReadOnlyIterator() if err != nil { return err.Error() } @@ -3793,8 +3818,8 @@ func (i *ArrayLoadedValueIterator) Next() (Value, error) { return nil, nil } -// LoadedValueIterator returns iterator to iterate loaded array elements. -func (a *Array) LoadedValueIterator() (*ArrayLoadedValueIterator, error) { +// ReadOnlyLoadedValueIterator returns iterator to iterate loaded array elements. +func (a *Array) ReadOnlyLoadedValueIterator() (*ArrayLoadedValueIterator, error) { switch slab := a.root.(type) { case *ArrayDataSlab: @@ -3832,9 +3857,9 @@ func (a *Array) LoadedValueIterator() (*ArrayLoadedValueIterator, error) { } } -// IterateLoadedValues iterates loaded array values. -func (a *Array) IterateLoadedValues(fn ArrayIterationFunc) error { - iterator, err := a.LoadedValueIterator() +// IterateReadOnlyLoadedValues iterates loaded array values. +func (a *Array) IterateReadOnlyLoadedValues(fn ArrayIterationFunc) error { + iterator, err := a.ReadOnlyLoadedValueIterator() if err != nil { // Don't need to wrap error as external error because err is already categorized by Array.LoadedValueIterator(). return err diff --git a/array_bench_test.go b/array_bench_test.go index 572abff8..b8c06cd0 100644 --- a/array_bench_test.go +++ b/array_bench_test.go @@ -355,7 +355,7 @@ func benchmarkNewArrayFromAppend(b *testing.B, initialArraySize int) { for i := 0; i < b.N; i++ { copied, _ := NewArray(storage, array.Address(), array.Type()) - _ = array.Iterate(func(value Value) (bool, error) { + _ = array.IterateReadOnly(func(value Value) (bool, error) { _ = copied.Append(value) return true, nil }) @@ -379,7 +379,7 @@ func benchmarkNewArrayFromBatchData(b *testing.B, initialArraySize int) { b.StartTimer() for i := 0; i < b.N; i++ { - iter, err := array.Iterator() + iter, err := array.ReadOnlyIterator() require.NoError(b, err) copied, _ := NewArrayFromBatchData(storage, array.Address(), array.Type(), func() (Value, error) { diff --git a/array_test.go b/array_test.go index 4bd1b6fa..599f2365 100644 --- a/array_test.go +++ b/array_test.go @@ -100,7 +100,7 @@ func _verifyArray( // Verify array elements by iterator i := 0 - err = array.Iterate(func(v Value) (bool, error) { + err = array.IterateReadOnly(func(v Value) (bool, error) { valueEqual(t, expectedValues[i], v) i++ return true, nil @@ -679,7 +679,7 @@ func TestArrayIterate(t *testing.T) { require.NoError(t, err) i := uint64(0) - err = array.Iterate(func(v Value) (bool, error) { + err = array.IterateReadOnly(func(v Value) (bool, error) { i++ return true, nil }) @@ -706,7 +706,7 @@ func TestArrayIterate(t *testing.T) { } i := uint64(0) - err = array.Iterate(func(v Value) (bool, error) { + err = array.IterateReadOnly(func(v Value) (bool, error) { require.Equal(t, Uint64Value(i), v) i++ return true, nil @@ -743,7 +743,7 @@ func TestArrayIterate(t *testing.T) { } i := uint64(0) - err = array.Iterate(func(v Value) (bool, error) { + err = array.IterateReadOnly(func(v Value) (bool, error) { require.Equal(t, Uint64Value(i), v) i++ return true, nil @@ -776,7 +776,7 @@ func TestArrayIterate(t *testing.T) { } i := uint64(0) - err = array.Iterate(func(v Value) (bool, error) { + err = array.IterateReadOnly(func(v Value) (bool, error) { require.Equal(t, Uint64Value(i), v) i++ return true, nil @@ -812,7 +812,7 @@ func TestArrayIterate(t *testing.T) { i := uint64(0) j := uint64(1) - err = array.Iterate(func(v Value) (bool, error) { + err = array.IterateReadOnly(func(v Value) (bool, error) { require.Equal(t, Uint64Value(j), v) i++ j += 2 @@ -838,7 +838,7 @@ func TestArrayIterate(t *testing.T) { } i := 0 - err = array.Iterate(func(_ Value) (bool, error) { + err = array.IterateReadOnly(func(_ Value) (bool, error) { if i == count/2 { return false, nil } @@ -867,7 +867,7 @@ func TestArrayIterate(t *testing.T) { testErr := errors.New("test") i := 0 - err = array.Iterate(func(_ Value) (bool, error) { + err = array.IterateReadOnly(func(_ Value) (bool, error) { if i == count/2 { return false, testErr } @@ -893,7 +893,7 @@ func testArrayIterateRange(t *testing.T, array *Array, values []Value) { count := array.Count() // If startIndex > count, IterateRange returns SliceOutOfBoundsError - err = array.IterateRange(count+1, count+1, func(v Value) (bool, error) { + err = array.IterateReadOnlyRange(count+1, count+1, func(v Value) (bool, error) { i++ return true, nil }) @@ -906,7 +906,7 @@ func testArrayIterateRange(t *testing.T, array *Array, values []Value) { require.Equal(t, uint64(0), i) // If endIndex > count, IterateRange returns SliceOutOfBoundsError - err = array.IterateRange(0, count+1, func(v Value) (bool, error) { + err = array.IterateReadOnlyRange(0, count+1, func(v Value) (bool, error) { i++ return true, nil }) @@ -918,7 +918,7 @@ func testArrayIterateRange(t *testing.T, array *Array, values []Value) { // If startIndex > endIndex, IterateRange returns InvalidSliceIndexError if count > 0 { - err = array.IterateRange(1, 0, func(v Value) (bool, error) { + err = array.IterateReadOnlyRange(1, 0, func(v Value) (bool, error) { i++ return true, nil }) @@ -933,7 +933,7 @@ func testArrayIterateRange(t *testing.T, array *Array, values []Value) { for startIndex := uint64(0); startIndex <= count; startIndex++ { for endIndex := startIndex; endIndex <= count; endIndex++ { i = uint64(0) - err = array.IterateRange(startIndex, endIndex, func(v Value) (bool, error) { + err = array.IterateReadOnlyRange(startIndex, endIndex, func(v Value) (bool, error) { valueEqual(t, v, values[int(startIndex+i)]) i++ return true, nil @@ -1015,7 +1015,7 @@ func TestArrayIterateRange(t *testing.T) { startIndex := uint64(1) endIndex := uint64(5) count := endIndex - startIndex - err = array.IterateRange(startIndex, endIndex, func(_ Value) (bool, error) { + err = array.IterateReadOnlyRange(startIndex, endIndex, func(_ Value) (bool, error) { if i == count/2 { return false, nil } @@ -1044,7 +1044,7 @@ func TestArrayIterateRange(t *testing.T) { startIndex := uint64(1) endIndex := uint64(5) count := endIndex - startIndex - err = array.IterateRange(startIndex, endIndex, func(_ Value) (bool, error) { + err = array.IterateReadOnlyRange(startIndex, endIndex, func(_ Value) (bool, error) { if i == count/2 { return false, testErr } @@ -3059,7 +3059,7 @@ func TestEmptyArray(t *testing.T) { t.Run("iterate", func(t *testing.T) { i := uint64(0) - err := array.Iterate(func(v Value) (bool, error) { + err := array.IterateReadOnly(func(v Value) (bool, error) { i++ return true, nil }) @@ -3301,7 +3301,7 @@ func TestArrayFromBatchData(t *testing.T) { require.NoError(t, err) require.Equal(t, uint64(0), array.Count()) - iter, err := array.Iterator() + iter, err := array.ReadOnlyIterator() require.NoError(t, err) // Create a new array with new storage, new address, and original array's elements. @@ -3341,7 +3341,7 @@ func TestArrayFromBatchData(t *testing.T) { require.Equal(t, uint64(arraySize), array.Count()) - iter, err := array.Iterator() + iter, err := array.ReadOnlyIterator() require.NoError(t, err) // Create a new array with new storage, new address, and original array's elements. @@ -3385,7 +3385,7 @@ func TestArrayFromBatchData(t *testing.T) { require.Equal(t, uint64(arraySize), array.Count()) - iter, err := array.Iterator() + iter, err := array.ReadOnlyIterator() require.NoError(t, err) address := Address{2, 3, 4, 5, 6, 7, 8, 9} @@ -3435,7 +3435,7 @@ func TestArrayFromBatchData(t *testing.T) { require.Equal(t, uint64(36), array.Count()) - iter, err := array.Iterator() + iter, err := array.ReadOnlyIterator() require.NoError(t, err) storage := newTestPersistentStorage(t) @@ -3485,7 +3485,7 @@ func TestArrayFromBatchData(t *testing.T) { require.Equal(t, uint64(36), array.Count()) - iter, err := array.Iterator() + iter, err := array.ReadOnlyIterator() require.NoError(t, err) storage := newTestPersistentStorage(t) @@ -3531,7 +3531,7 @@ func TestArrayFromBatchData(t *testing.T) { require.Equal(t, uint64(arraySize), array.Count()) - iter, err := array.Iterator() + iter, err := array.ReadOnlyIterator() require.NoError(t, err) storage := newTestPersistentStorage(t) @@ -3586,7 +3586,7 @@ func TestArrayFromBatchData(t *testing.T) { err = array.Append(v) require.NoError(t, err) - iter, err := array.Iterator() + iter, err := array.ReadOnlyIterator() require.NoError(t, err) storage := newTestPersistentStorage(t) @@ -3942,7 +3942,7 @@ func TestArrayLoadedValueIterator(t *testing.T) { verifyArrayLoadedElements(t, array, values) i := 0 - err := array.IterateLoadedValues(func(v Value) (bool, error) { + err := array.IterateReadOnlyLoadedValues(func(v Value) (bool, error) { // At this point, iterator returned first element (v). // Remove all other nested composite elements (except first element) from storage. @@ -4627,7 +4627,7 @@ func createArrayWithSimpleAndChildArrayValues( func verifyArrayLoadedElements(t *testing.T, array *Array, expectedValues []Value) { i := 0 - err := array.IterateLoadedValues(func(v Value) (bool, error) { + err := array.IterateReadOnlyLoadedValues(func(v Value) (bool, error) { require.True(t, i < len(expectedValues)) valueEqual(t, expectedValues[i], v) i++ diff --git a/cmd/stress/utils.go b/cmd/stress/utils.go index c75296fe..96f72584 100644 --- a/cmd/stress/utils.go +++ b/cmd/stress/utils.go @@ -132,7 +132,7 @@ func copyValue(storage *atree.PersistentSlabStorage, address atree.Address, valu } func copyArray(storage *atree.PersistentSlabStorage, address atree.Address, array *atree.Array) (*atree.Array, error) { - iterator, err := array.Iterator() + iterator, err := array.ReadOnlyIterator() if err != nil { return nil, err } @@ -149,7 +149,7 @@ func copyArray(storage *atree.PersistentSlabStorage, address atree.Address, arra } func copyMap(storage *atree.PersistentSlabStorage, address atree.Address, m *atree.OrderedMap) (*atree.OrderedMap, error) { - iterator, err := m.Iterator() + iterator, err := m.ReadOnlyIterator() if err != nil { return nil, err } @@ -260,12 +260,12 @@ func arrayEqual(a atree.Value, b atree.Value) error { return fmt.Errorf("array %s count %d != array %s count %d", array1, array1.Count(), array2, array2.Count()) } - iterator1, err := array1.Iterator() + iterator1, err := array1.ReadOnlyIterator() if err != nil { return fmt.Errorf("failed to get array1 iterator: %w", err) } - iterator2, err := array2.Iterator() + iterator2, err := array2.ReadOnlyIterator() if err != nil { return fmt.Errorf("failed to get array2 iterator: %w", err) } @@ -309,12 +309,12 @@ func mapEqual(a atree.Value, b atree.Value) error { return fmt.Errorf("map %s count %d != map %s count %d", m1, m1.Count(), m2, m2.Count()) } - iterator1, err := m1.Iterator() + iterator1, err := m1.ReadOnlyIterator() if err != nil { return fmt.Errorf("failed to get m1 iterator: %w", err) } - iterator2, err := m2.Iterator() + iterator2, err := m2.ReadOnlyIterator() if err != nil { return fmt.Errorf("failed to get m2 iterator: %w", err) } diff --git a/map.go b/map.go index 9daa8e93..b47ce9bf 100644 --- a/map.go +++ b/map.go @@ -5076,7 +5076,7 @@ func (m *OrderedMap) Type() TypeInfo { } func (m *OrderedMap) String() string { - iterator, err := m.Iterator() + iterator, err := m.ReadOnlyIterator() if err != nil { return err.Error() } @@ -5135,19 +5135,19 @@ func (m *MapExtraData) decrementCount() { m.Count-- } -type MapElementIterator struct { +type mapElementIterator struct { storage SlabStorage elements elements index int - nestedIterator *MapElementIterator + nestedIterator *mapElementIterator } -func (i *MapElementIterator) Next() (key MapKey, value MapValue, err error) { +func (i *mapElementIterator) next() (key MapKey, value MapValue, err error) { if i.nestedIterator != nil { - key, value, err = i.nestedIterator.Next() + key, value, err = i.nestedIterator.next() if err != nil { - // Don't need to wrap error as external error because err is already categorized by MapElementIterator.Next(). + // Don't need to wrap error as external error because err is already categorized by mapElementIterator.next(). return nil, nil, err } if key != nil { @@ -5178,14 +5178,14 @@ func (i *MapElementIterator) Next() (key MapKey, value MapValue, err error) { return nil, nil, err } - i.nestedIterator = &MapElementIterator{ + i.nestedIterator = &mapElementIterator{ storage: i.storage, elements: elems, } i.index++ // Don't need to wrap error as external error because err is already categorized by MapElementIterator.Next(). - return i.nestedIterator.Next() + return i.nestedIterator.next() default: return nil, nil, NewSlabDataError(fmt.Errorf("unexpected element type %T during map iteration", e)) @@ -5197,8 +5197,10 @@ type MapElementIterationFunc func(Value) (resume bool, err error) type MapIterator struct { storage SlabStorage + comparator ValueComparator // TODO: use comparator and hip to update child element in parent map in register inlining. + hip HashInputProvider id SlabID - elemIterator *MapElementIterator + elemIterator *mapElementIterator } func (i *MapIterator) Next() (key Value, value Value, err error) { @@ -5215,7 +5217,7 @@ func (i *MapIterator) Next() (key Value, value Value, err error) { } var ks, vs Storable - ks, vs, err = i.elemIterator.Next() + ks, vs, err = i.elemIterator.next() if err != nil { // Don't need to wrap error as external error because err is already categorized by MapElementIterator.Next(). return nil, nil, err @@ -5256,7 +5258,7 @@ func (i *MapIterator) NextKey() (key Value, err error) { } var ks Storable - ks, _, err = i.elemIterator.Next() + ks, _, err = i.elemIterator.next() if err != nil { // Don't need to wrap error as external error because err is already categorized by MapElementIterator.Next(). return nil, err @@ -5291,7 +5293,7 @@ func (i *MapIterator) NextValue() (value Value, err error) { } var vs Storable - _, vs, err = i.elemIterator.Next() + _, vs, err = i.elemIterator.next() if err != nil { // Don't need to wrap error as external error because err is already categorized by MapElementIterator.Next(). return nil, err @@ -5329,7 +5331,7 @@ func (i *MapIterator) advance() error { i.id = dataSlab.next - i.elemIterator = &MapElementIterator{ + i.elemIterator = &mapElementIterator{ storage: i.storage, elements: dataSlab.elements, } @@ -5337,7 +5339,7 @@ func (i *MapIterator) advance() error { return nil } -func (m *OrderedMap) Iterator() (*MapIterator, error) { +func (m *OrderedMap) Iterator(comparator ValueComparator, hip HashInputProvider) (*MapIterator, error) { slab, err := firstMapDataSlab(m.Storage, m.root) if err != nil { // Don't need to wrap error as external error because err is already categorized by firstMapDataSlab(). @@ -5347,18 +5349,27 @@ func (m *OrderedMap) Iterator() (*MapIterator, error) { dataSlab := slab.(*MapDataSlab) return &MapIterator{ - storage: m.Storage, - id: dataSlab.next, - elemIterator: &MapElementIterator{ + storage: m.Storage, + comparator: comparator, + hip: hip, + id: dataSlab.next, + elemIterator: &mapElementIterator{ storage: m.Storage, elements: dataSlab.elements, }, }, nil } -func (m *OrderedMap) Iterate(fn MapEntryIterationFunc) error { +// ReadOnlyIterator returns readonly iterator for map elements. +// If elements of child containers are mutated, those changes +// are not guaranteed to persist. +func (m *OrderedMap) ReadOnlyIterator() (*MapIterator, error) { + return m.Iterator(nil, nil) +} + +func (m *OrderedMap) Iterate(comparator ValueComparator, hip HashInputProvider, fn MapEntryIterationFunc) error { - iterator, err := m.Iterator() + iterator, err := m.Iterator(comparator, hip) if err != nil { // Don't need to wrap error as external error because err is already categorized by OrderedMap.Iterator(). return err @@ -5385,9 +5396,13 @@ func (m *OrderedMap) Iterate(fn MapEntryIterationFunc) error { } } -func (m *OrderedMap) IterateKeys(fn MapElementIterationFunc) error { +func (m *OrderedMap) IterateReadOnly(fn MapEntryIterationFunc) error { + return m.Iterate(nil, nil, fn) +} + +func (m *OrderedMap) IterateReadOnlyKeys(fn MapElementIterationFunc) error { - iterator, err := m.Iterator() + iterator, err := m.ReadOnlyIterator() if err != nil { // Don't need to wrap error as external error because err is already categorized by OrderedMap.Iterator(). return err @@ -5414,9 +5429,9 @@ func (m *OrderedMap) IterateKeys(fn MapElementIterationFunc) error { } } -func (m *OrderedMap) IterateValues(fn MapElementIterationFunc) error { +func (m *OrderedMap) IterateValues(comparator ValueComparator, hip HashInputProvider, fn MapElementIterationFunc) error { - iterator, err := m.Iterator() + iterator, err := m.Iterator(comparator, hip) if err != nil { // Don't need to wrap error as external error because err is already categorized by OrderedMap.Iterator(). return err @@ -5443,6 +5458,10 @@ func (m *OrderedMap) IterateValues(fn MapElementIterationFunc) error { } } +func (m *OrderedMap) IterateReadOnlyValues(fn MapElementIterationFunc) error { + return m.IterateValues(nil, nil, fn) +} + type MapPopIterationFunc func(Storable, Storable) // PopIterate iterates and removes elements backward. @@ -6040,8 +6059,8 @@ func (i *MapLoadedValueIterator) Next() (Value, Value, error) { return nil, nil, nil } -// LoadedValueIterator returns iterator to iterate loaded map elements. -func (m *OrderedMap) LoadedValueIterator() (*MapLoadedValueIterator, error) { +// ReadOnlyLoadedValueIterator returns iterator to iterate loaded map elements. +func (m *OrderedMap) ReadOnlyLoadedValueIterator() (*MapLoadedValueIterator, error) { switch slab := m.root.(type) { case *MapDataSlab: @@ -6079,9 +6098,9 @@ func (m *OrderedMap) LoadedValueIterator() (*MapLoadedValueIterator, error) { } } -// IterateLoadedValues iterates loaded map values. -func (m *OrderedMap) IterateLoadedValues(fn MapEntryIterationFunc) error { - iterator, err := m.LoadedValueIterator() +// IterateReadOnlyLoadedValues iterates loaded map values. +func (m *OrderedMap) IterateReadOnlyLoadedValues(fn MapEntryIterationFunc) error { + iterator, err := m.ReadOnlyLoadedValueIterator() if err != nil { // Don't need to wrap error as external error because err is already categorized by OrderedMap.LoadedValueIterator(). return err diff --git a/map_test.go b/map_test.go index 2b0c5ec7..f05110ea 100644 --- a/map_test.go +++ b/map_test.go @@ -167,7 +167,7 @@ func _verifyMap( require.Equal(t, len(expectedKeyValues), len(sortedKeys)) i := 0 - err = m.Iterate(func(k, v Value) (bool, error) { + err = m.IterateReadOnly(func(k, v Value) (bool, error) { expectedKey := sortedKeys[i] expectedValue := expectedKeyValues[expectedKey] @@ -1123,7 +1123,7 @@ func TestMapIterate(t *testing.T) { // Iterate key value pairs i = uint64(0) - err = m.Iterate(func(k Value, v Value) (resume bool, err error) { + err = m.IterateReadOnly(func(k Value, v Value) (resume bool, err error) { valueEqual(t, sortedKeys[i], k) valueEqual(t, keyValues[k], v) i++ @@ -1135,7 +1135,7 @@ func TestMapIterate(t *testing.T) { // Iterate keys i = uint64(0) - err = m.IterateKeys(func(k Value) (resume bool, err error) { + err = m.IterateReadOnlyKeys(func(k Value) (resume bool, err error) { valueEqual(t, sortedKeys[i], k) i++ return true, nil @@ -1146,7 +1146,7 @@ func TestMapIterate(t *testing.T) { // Iterate values i = uint64(0) - err = m.IterateValues(func(v Value) (resume bool, err error) { + err = m.IterateReadOnlyValues(func(v Value) (resume bool, err error) { k := sortedKeys[i] valueEqual(t, keyValues[k], v) i++ @@ -1209,7 +1209,7 @@ func TestMapIterate(t *testing.T) { // Iterate key value pairs i := uint64(0) - err = m.Iterate(func(k Value, v Value) (resume bool, err error) { + err = m.IterateReadOnly(func(k Value, v Value) (resume bool, err error) { valueEqual(t, sortedKeys[i], k) valueEqual(t, keyValues[k], v) i++ @@ -1222,7 +1222,7 @@ func TestMapIterate(t *testing.T) { // Iterate keys i = uint64(0) - err = m.IterateKeys(func(k Value) (resume bool, err error) { + err = m.IterateReadOnlyKeys(func(k Value) (resume bool, err error) { valueEqual(t, sortedKeys[i], k) i++ return true, nil @@ -1234,7 +1234,7 @@ func TestMapIterate(t *testing.T) { // Iterate values i = uint64(0) - err = m.IterateValues(func(v Value) (resume bool, err error) { + err = m.IterateReadOnlyValues(func(v Value) (resume bool, err error) { valueEqual(t, keyValues[sortedKeys[i]], v) i++ return true, nil @@ -7895,7 +7895,7 @@ func TestEmptyMap(t *testing.T) { t.Run("iterate", func(t *testing.T) { i := 0 - err := m.Iterate(func(k Value, v Value) (bool, error) { + err := m.IterateReadOnly(func(k Value, v Value) (bool, error) { i++ return true, nil }) @@ -7933,7 +7933,7 @@ func TestMapFromBatchData(t *testing.T) { require.NoError(t, err) require.Equal(t, uint64(0), m.Count()) - iter, err := m.Iterator() + iter, err := m.ReadOnlyIterator() require.NoError(t, err) storage := newTestPersistentStorage(t) @@ -7980,7 +7980,7 @@ func TestMapFromBatchData(t *testing.T) { require.Equal(t, uint64(mapSize), m.Count()) - iter, err := m.Iterator() + iter, err := m.ReadOnlyIterator() require.NoError(t, err) var sortedKeys []Value @@ -8042,7 +8042,7 @@ func TestMapFromBatchData(t *testing.T) { require.Equal(t, uint64(mapSize), m.Count()) - iter, err := m.Iterator() + iter, err := m.ReadOnlyIterator() require.NoError(t, err) var sortedKeys []Value @@ -8107,7 +8107,7 @@ func TestMapFromBatchData(t *testing.T) { require.Equal(t, uint64(mapSize+1), m.Count()) - iter, err := m.Iterator() + iter, err := m.ReadOnlyIterator() require.NoError(t, err) var sortedKeys []Value @@ -8176,7 +8176,7 @@ func TestMapFromBatchData(t *testing.T) { require.Equal(t, uint64(mapSize+1), m.Count()) require.Equal(t, typeInfo, m.Type()) - iter, err := m.Iterator() + iter, err := m.ReadOnlyIterator() require.NoError(t, err) var sortedKeys []Value @@ -8239,7 +8239,7 @@ func TestMapFromBatchData(t *testing.T) { require.Equal(t, uint64(mapSize), m.Count()) - iter, err := m.Iterator() + iter, err := m.ReadOnlyIterator() require.NoError(t, err) storage := newTestPersistentStorage(t) @@ -8320,7 +8320,7 @@ func TestMapFromBatchData(t *testing.T) { require.Equal(t, uint64(mapSize), m.Count()) - iter, err := m.Iterator() + iter, err := m.ReadOnlyIterator() require.NoError(t, err) var sortedKeys []Value @@ -8404,7 +8404,7 @@ func TestMapFromBatchData(t *testing.T) { require.NoError(t, err) require.Nil(t, storable) - iter, err := m.Iterator() + iter, err := m.ReadOnlyIterator() require.NoError(t, err) var sortedKeys []Value @@ -9637,7 +9637,7 @@ func TestMapLoadedValueIterator(t *testing.T) { verifyMapLoadedElements(t, m, values) i := 0 - err := m.IterateLoadedValues(func(k Value, v Value) (bool, error) { + err := m.IterateReadOnlyLoadedValues(func(k Value, v Value) (bool, error) { // At this point, iterator returned first element (v). // Remove all other nested composite elements (except first element) from storage. @@ -10541,7 +10541,7 @@ func createMapWithSimpleAndChildArrayValues( func verifyMapLoadedElements(t *testing.T, m *OrderedMap, expectedValues [][2]Value) { i := 0 - err := m.IterateLoadedValues(func(k Value, v Value) (bool, error) { + err := m.IterateReadOnlyLoadedValues(func(k Value, v Value) (bool, error) { require.True(t, i < len(expectedValues)) valueEqual(t, expectedValues[i][0], k) valueEqual(t, expectedValues[i][1], v) @@ -12814,7 +12814,7 @@ func getInlinedChildMapsFromParentMap(t *testing.T, address Address, parentMap * children := make(map[Value]*mapInfo) - err := parentMap.IterateKeys(func(k Value) (bool, error) { + err := parentMap.IterateReadOnlyKeys(func(k Value) (bool, error) { if k == nil { return false, nil } diff --git a/utils_test.go b/utils_test.go index 6bcd6608..4584e71c 100644 --- a/utils_test.go +++ b/utils_test.go @@ -342,7 +342,7 @@ func valueEqual(t *testing.T, expected Value, actual Value) { func arrayEqual(t *testing.T, expected arrayValue, actual *Array) { require.Equal(t, uint64(len(expected)), actual.Count()) - iterator, err := actual.Iterator() + iterator, err := actual.ReadOnlyIterator() require.NoError(t, err) i := 0 @@ -363,7 +363,7 @@ func arrayEqual(t *testing.T, expected arrayValue, actual *Array) { func mapEqual(t *testing.T, expected mapValue, actual *OrderedMap) { require.Equal(t, uint64(len(expected)), actual.Count()) - iterator, err := actual.Iterator() + iterator, err := actual.ReadOnlyIterator() require.NoError(t, err) i := 0 From e88a73e60f126ba002263f2f4112a3bc968c7edb Mon Sep 17 00:00:00 2001 From: Faye Amacker <33205765+fxamacker@users.noreply.github.com> Date: Fri, 29 Sep 2023 08:13:05 -0500 Subject: [PATCH 2/4] Support value mutation from non-readonly iterators --- array.go | 55 ++++++++------- array_test.go | 125 ++++++++++++++++++++++++++++++++++ map.go | 99 ++++++++++++++++++--------- map_test.go | 183 ++++++++++++++++++++++++++++++++++++++++++++++---- 4 files changed, 396 insertions(+), 66 deletions(-) diff --git a/array.go b/array.go index b63b3e6e..fde54568 100644 --- a/array.go +++ b/array.go @@ -3112,12 +3112,13 @@ func (a *Array) Storable(_ SlabStorage, _ Address, maxInlineSize uint64) (Storab var emptyArrayIterator = &ArrayIterator{} type ArrayIterator struct { - storage SlabStorage - id SlabID - dataSlab *ArrayDataSlab - index int - remainingCount int - readOnly bool + array *Array + id SlabID + dataSlab *ArrayDataSlab + indexInArray int + indexInDataSlab int + remainingCount int + readOnly bool } func (i *ArrayIterator) Next() (Value, error) { @@ -3130,7 +3131,7 @@ func (i *ArrayIterator) Next() (Value, error) { return nil, nil } - slab, found, err := i.storage.Retrieve(i.id) + slab, found, err := i.array.Storage.Retrieve(i.id) if err != nil { // Wrap err as external error (if needed) because err is returned by SlabStorage interface. return nil, wrapErrorfAsExternalErrorIfNeeded(err, fmt.Sprintf("failed to retrieve slab %s", i.id)) @@ -3140,22 +3141,29 @@ func (i *ArrayIterator) Next() (Value, error) { } i.dataSlab = slab.(*ArrayDataSlab) - i.index = 0 + i.indexInDataSlab = 0 } var element Value var err error - if i.index < len(i.dataSlab.elements) { - element, err = i.dataSlab.elements[i.index].StoredValue(i.storage) + if i.indexInDataSlab < len(i.dataSlab.elements) { + element, err = i.dataSlab.elements[i.indexInDataSlab].StoredValue(i.array.Storage) if err != nil { // Wrap err as external error (if needed) because err is returned by Storable interface. return nil, wrapErrorfAsExternalErrorIfNeeded(err, "failed to get storable's stored value") } - i.index++ + if !i.readOnly { + // Set up notification callback in child value so + // when child value is modified parent a is notified. + i.array.setCallbackWithChild(uint64(i.indexInArray), element, maxInlineArrayElementSize) + } + + i.indexInDataSlab++ + i.indexInArray++ } - if i.index >= len(i.dataSlab.elements) { + if i.indexInDataSlab >= len(i.dataSlab.elements) { i.id = i.dataSlab.next i.dataSlab = nil } @@ -3173,7 +3181,7 @@ func (a *Array) Iterator() (*ArrayIterator, error) { } return &ArrayIterator{ - storage: a.Storage, + array: a, id: slab.SlabID(), dataSlab: slab, remainingCount: int(a.Count()), @@ -3235,11 +3243,12 @@ func (a *Array) RangeIterator(startIndex uint64, endIndex uint64) (*ArrayIterato } return &ArrayIterator{ - storage: a.Storage, - id: dataSlab.SlabID(), - dataSlab: dataSlab, - index: int(index), - remainingCount: int(numberOfElements), + array: a, + id: dataSlab.SlabID(), + dataSlab: dataSlab, + indexInArray: int(startIndex), + indexInDataSlab: int(index), + remainingCount: int(numberOfElements), }, nil } @@ -3254,7 +3263,7 @@ func (a *Array) ReadOnlyRangeIterator(startIndex uint64, endIndex uint64) (*Arra type ArrayIterationFunc func(element Value) (resume bool, err error) -func iterate(iterator *ArrayIterator, fn ArrayIterationFunc) error { +func iterateArray(iterator *ArrayIterator, fn ArrayIterationFunc) error { for { value, err := iterator.Next() if err != nil { @@ -3281,7 +3290,7 @@ func (a *Array) Iterate(fn ArrayIterationFunc) error { // Don't need to wrap error as external error because err is already categorized by Array.Iterator(). return err } - return iterate(iterator, fn) + return iterateArray(iterator, fn) } func (a *Array) IterateReadOnly(fn ArrayIterationFunc) error { @@ -3290,7 +3299,7 @@ func (a *Array) IterateReadOnly(fn ArrayIterationFunc) error { // Don't need to wrap error as external error because err is already categorized by Array.ReadOnlyIterator(). return err } - return iterate(iterator, fn) + return iterateArray(iterator, fn) } func (a *Array) IterateRange(startIndex uint64, endIndex uint64, fn ArrayIterationFunc) error { @@ -3299,7 +3308,7 @@ func (a *Array) IterateRange(startIndex uint64, endIndex uint64, fn ArrayIterati // Don't need to wrap error as external error because err is already categorized by Array.RangeIterator(). return err } - return iterate(iterator, fn) + return iterateArray(iterator, fn) } func (a *Array) IterateReadOnlyRange(startIndex uint64, endIndex uint64, fn ArrayIterationFunc) error { @@ -3308,7 +3317,7 @@ func (a *Array) IterateReadOnlyRange(startIndex uint64, endIndex uint64, fn Arra // Don't need to wrap error as external error because err is already categorized by Array.ReadOnlyRangeIterator(). return err } - return iterate(iterator, fn) + return iterateArray(iterator, fn) } func (a *Array) Count() uint64 { diff --git a/array_test.go b/array_test.go index 599f2365..c9fc2987 100644 --- a/array_test.go +++ b/array_test.go @@ -882,6 +882,67 @@ func TestArrayIterate(t *testing.T) { require.Equal(t, count/2, i) }) + + t.Run("mutation", func(t *testing.T) { + SetThreshold(256) + defer SetThreshold(1024) + + const arraySize = 15 + + typeInfo := testTypeInfo{42} + storage := newTestPersistentStorage(t) + address := Address{1, 2, 3, 4, 5, 6, 7, 8} + + array, err := NewArray(storage, address, typeInfo) + require.NoError(t, err) + + expectedValues := make([]Value, arraySize) + for i := uint64(0); i < arraySize; i++ { + childArray, err := NewArray(storage, address, typeInfo) + require.NoError(t, err) + + v := Uint64Value(i) + err = childArray.Append(v) + require.NoError(t, err) + + err = array.Append(childArray) + require.NoError(t, err) + + expectedValues[i] = arrayValue{v} + } + require.True(t, array.root.IsData()) + + sizeBeforeMutation := array.root.Header().size + + i := 0 + newElement := Uint64Value(0) + err = array.Iterate(func(v Value) (bool, error) { + childArray, ok := v.(*Array) + require.True(t, ok) + require.Equal(t, uint64(1), childArray.Count()) + require.True(t, childArray.Inlined()) + + err := childArray.Append(newElement) + require.NoError(t, err) + + expectedChildArrayValues, ok := expectedValues[i].(arrayValue) + require.True(t, ok) + + expectedChildArrayValues = append(expectedChildArrayValues, newElement) + expectedValues[i] = expectedChildArrayValues + + i++ + + require.Equal(t, array.root.Header().size, sizeBeforeMutation+uint32(i)*newElement.ByteSize()) + + return true, nil + }) + require.NoError(t, err) + require.Equal(t, arraySize, i) + require.True(t, array.root.IsData()) + + verifyArray(t, storage, typeInfo, address, array, expectedValues, false) + }) } func testArrayIterateRange(t *testing.T, array *Array, values []Value) { @@ -1058,6 +1119,70 @@ func TestArrayIterateRange(t *testing.T) { require.Equal(t, testErr, externalError.Unwrap()) require.Equal(t, count/2, i) }) + + t.Run("mutation", func(t *testing.T) { + SetThreshold(256) + defer SetThreshold(1024) + + const arraySize = 15 + + typeInfo := testTypeInfo{42} + storage := newTestPersistentStorage(t) + address := Address{1, 2, 3, 4, 5, 6, 7, 8} + + array, err := NewArray(storage, address, typeInfo) + require.NoError(t, err) + + expectedValues := make([]Value, arraySize) + for i := uint64(0); i < arraySize; i++ { + childArray, err := NewArray(storage, address, typeInfo) + require.NoError(t, err) + + v := Uint64Value(i) + err = childArray.Append(v) + require.NoError(t, err) + + err = array.Append(childArray) + require.NoError(t, err) + + expectedValues[i] = arrayValue{v} + } + require.True(t, array.root.IsData()) + + sizeBeforeMutation := array.root.Header().size + + i := 0 + startIndex := uint64(1) + endIndex := array.Count() - 2 + newElement := Uint64Value(0) + err = array.IterateRange(startIndex, endIndex, func(v Value) (bool, error) { + childArray, ok := v.(*Array) + require.True(t, ok) + require.Equal(t, uint64(1), childArray.Count()) + require.True(t, childArray.Inlined()) + + err := childArray.Append(newElement) + require.NoError(t, err) + + index := int(startIndex) + i + expectedChildArrayValues, ok := expectedValues[index].(arrayValue) + require.True(t, ok) + + expectedChildArrayValues = append(expectedChildArrayValues, newElement) + expectedValues[index] = expectedChildArrayValues + + i++ + + require.Equal(t, array.root.Header().size, sizeBeforeMutation+uint32(i)*newElement.ByteSize()) + + return true, nil + }) + require.NoError(t, err) + require.Equal(t, endIndex-startIndex, uint64(i)) + require.True(t, array.root.IsData()) + + verifyArray(t, storage, typeInfo, address, array, expectedValues, false) + }) } func TestArrayRootSlabID(t *testing.T) { diff --git a/map.go b/map.go index b47ce9bf..f337179a 100644 --- a/map.go +++ b/map.go @@ -5196,7 +5196,7 @@ type MapEntryIterationFunc func(Value, Value) (resume bool, err error) type MapElementIterationFunc func(Value) (resume bool, err error) type MapIterator struct { - storage SlabStorage + m *OrderedMap comparator ValueComparator // TODO: use comparator and hip to update child element in parent map in register inlining. hip HashInputProvider id SlabID @@ -5223,18 +5223,23 @@ func (i *MapIterator) Next() (key Value, value Value, err error) { return nil, nil, err } if ks != nil { - key, err = ks.StoredValue(i.storage) + key, err = ks.StoredValue(i.m.Storage) if err != nil { // Wrap err as external error (if needed) because err is returned by Storable interface. return nil, nil, wrapErrorfAsExternalErrorIfNeeded(err, "failed to get map key's stored value") } - value, err = vs.StoredValue(i.storage) + value, err = vs.StoredValue(i.m.Storage) if err != nil { // Wrap err as external error (if needed) because err is returned by Storable interface. return nil, nil, wrapErrorfAsExternalErrorIfNeeded(err, "failed to get map value's stored value") } + if i.comparator != nil && i.hip != nil { + maxInlineSize := maxInlineMapValueSize(uint64(ks.ByteSize())) + i.m.setCallbackWithChild(i.comparator, i.hip, key, value, maxInlineSize) + } + return key, value, nil } @@ -5264,7 +5269,7 @@ func (i *MapIterator) NextKey() (key Value, err error) { return nil, err } if ks != nil { - key, err = ks.StoredValue(i.storage) + key, err = ks.StoredValue(i.m.Storage) if err != nil { // Wrap err as external error (if needed) because err is returned by Storable interface. return nil, wrapErrorfAsExternalErrorIfNeeded(err, "failed to get map key's stored value") @@ -5292,19 +5297,30 @@ func (i *MapIterator) NextValue() (value Value, err error) { } } - var vs Storable - _, vs, err = i.elemIterator.next() + var ks, vs Storable + ks, vs, err = i.elemIterator.next() if err != nil { // Don't need to wrap error as external error because err is already categorized by MapElementIterator.Next(). return nil, err } if vs != nil { - value, err = vs.StoredValue(i.storage) + value, err = vs.StoredValue(i.m.Storage) if err != nil { // Wrap err as external error (if needed) because err is returned by Storable interface. return nil, wrapErrorfAsExternalErrorIfNeeded(err, "failed to get map value's stored value") } + if i.comparator != nil && i.hip != nil { + key, err := ks.StoredValue(i.m.Storage) + if err != nil { + // Wrap err as external error (if needed) because err is returned by Storable interface. + return nil, wrapErrorfAsExternalErrorIfNeeded(err, "failed to get map value's stored value") + } + + maxInlineSize := maxInlineMapValueSize(uint64(ks.ByteSize())) + i.m.setCallbackWithChild(i.comparator, i.hip, key, value, maxInlineSize) + } + return value, nil } @@ -5315,7 +5331,7 @@ func (i *MapIterator) NextValue() (value Value, err error) { } func (i *MapIterator) advance() error { - slab, found, err := i.storage.Retrieve(i.id) + slab, found, err := i.m.Storage.Retrieve(i.id) if err != nil { // Wrap err as external error (if needed) because err is returned by SlabStorage interface. return wrapErrorfAsExternalErrorIfNeeded(err, fmt.Sprintf("failed to retrieve slab %s", i.id)) @@ -5332,14 +5348,14 @@ func (i *MapIterator) advance() error { i.id = dataSlab.next i.elemIterator = &mapElementIterator{ - storage: i.storage, + storage: i.m.Storage, elements: dataSlab.elements, } return nil } -func (m *OrderedMap) Iterator(comparator ValueComparator, hip HashInputProvider) (*MapIterator, error) { +func (m *OrderedMap) iterator(comparator ValueComparator, hip HashInputProvider) (*MapIterator, error) { slab, err := firstMapDataSlab(m.Storage, m.root) if err != nil { // Don't need to wrap error as external error because err is already categorized by firstMapDataSlab(). @@ -5349,7 +5365,7 @@ func (m *OrderedMap) Iterator(comparator ValueComparator, hip HashInputProvider) dataSlab := slab.(*MapDataSlab) return &MapIterator{ - storage: m.Storage, + m: m, comparator: comparator, hip: hip, id: dataSlab.next, @@ -5360,21 +5376,22 @@ func (m *OrderedMap) Iterator(comparator ValueComparator, hip HashInputProvider) }, nil } +func (m *OrderedMap) Iterator(comparator ValueComparator, hip HashInputProvider) (*MapIterator, error) { + if comparator == nil || hip == nil { + return nil, NewUserError(fmt.Errorf("failed to create MapIterator: ValueComparator or HashInputProvider is nil")) + } + return m.iterator(comparator, hip) +} + // ReadOnlyIterator returns readonly iterator for map elements. // If elements of child containers are mutated, those changes // are not guaranteed to persist. func (m *OrderedMap) ReadOnlyIterator() (*MapIterator, error) { - return m.Iterator(nil, nil) + return m.iterator(nil, nil) } -func (m *OrderedMap) Iterate(comparator ValueComparator, hip HashInputProvider, fn MapEntryIterationFunc) error { - - iterator, err := m.Iterator(comparator, hip) - if err != nil { - // Don't need to wrap error as external error because err is already categorized by OrderedMap.Iterator(). - return err - } - +func iterateMap(iterator *MapIterator, fn MapEntryIterationFunc) error { + var err error var key, value Value for { key, value, err = iterator.Next() @@ -5396,8 +5413,22 @@ func (m *OrderedMap) Iterate(comparator ValueComparator, hip HashInputProvider, } } +func (m *OrderedMap) Iterate(comparator ValueComparator, hip HashInputProvider, fn MapEntryIterationFunc) error { + iterator, err := m.Iterator(comparator, hip) + if err != nil { + // Don't need to wrap error as external error because err is already categorized by OrderedMap.Iterator(). + return err + } + return iterateMap(iterator, fn) +} + func (m *OrderedMap) IterateReadOnly(fn MapEntryIterationFunc) error { - return m.Iterate(nil, nil, fn) + iterator, err := m.ReadOnlyIterator() + if err != nil { + // Don't need to wrap error as external error because err is already categorized by OrderedMap.ReadOnlyIterator(). + return err + } + return iterateMap(iterator, fn) } func (m *OrderedMap) IterateReadOnlyKeys(fn MapElementIterationFunc) error { @@ -5429,14 +5460,8 @@ func (m *OrderedMap) IterateReadOnlyKeys(fn MapElementIterationFunc) error { } } -func (m *OrderedMap) IterateValues(comparator ValueComparator, hip HashInputProvider, fn MapElementIterationFunc) error { - - iterator, err := m.Iterator(comparator, hip) - if err != nil { - // Don't need to wrap error as external error because err is already categorized by OrderedMap.Iterator(). - return err - } - +func iterateMapValues(iterator *MapIterator, fn MapElementIterationFunc) error { + var err error var value Value for { value, err = iterator.NextValue() @@ -5458,8 +5483,22 @@ func (m *OrderedMap) IterateValues(comparator ValueComparator, hip HashInputProv } } +func (m *OrderedMap) IterateValues(comparator ValueComparator, hip HashInputProvider, fn MapElementIterationFunc) error { + iterator, err := m.Iterator(comparator, hip) + if err != nil { + // Don't need to wrap error as external error because err is already categorized by OrderedMap.Iterator(). + return err + } + return iterateMapValues(iterator, fn) +} + func (m *OrderedMap) IterateReadOnlyValues(fn MapElementIterationFunc) error { - return m.IterateValues(nil, nil, fn) + iterator, err := m.ReadOnlyIterator() + if err != nil { + // Don't need to wrap error as external error because err is already categorized by OrderedMap.ReadOnlyIterator(). + return err + } + return iterateMapValues(iterator, fn) } type MapPopIterationFunc func(Storable, Storable) diff --git a/map_test.go b/map_test.go index f05110ea..8f7c5951 100644 --- a/map_test.go +++ b/map_test.go @@ -128,7 +128,7 @@ func verifyMap( typeInfo TypeInfo, address Address, m *OrderedMap, - keyValues map[Value]Value, + keyValues mapValue, sortedKeys []Value, hasNestedArrayMapElement bool, ) { @@ -1084,6 +1084,48 @@ func TestMapRemove(t *testing.T) { func TestMapIterate(t *testing.T) { + t.Run("empty", func(t *testing.T) { + + typeInfo := testTypeInfo{42} + address := Address{1, 2, 3, 4, 5, 6, 7, 8} + storage := newTestPersistentStorage(t) + + m, err := NewMap(storage, address, newBasicDigesterBuilder(), typeInfo) + require.NoError(t, err) + + // Iterate key value pairs + i := 0 + err = m.IterateReadOnly(func(k Value, v Value) (resume bool, err error) { + i++ + return true, nil + }) + + require.NoError(t, err) + require.Equal(t, 0, i) + + // Iterate keys + i = 0 + err = m.IterateReadOnlyKeys(func(k Value) (resume bool, err error) { + i++ + return true, nil + }) + + require.NoError(t, err) + require.Equal(t, 0, i) + + // Iterate values + i = 0 + err = m.IterateReadOnlyValues(func(v Value) (resume bool, err error) { + i++ + return true, nil + }) + + require.NoError(t, err) + require.Equal(t, 0, i) + + verifyMap(t, storage, typeInfo, address, m, mapValue{}, nil, false) + }) + t.Run("no collision", func(t *testing.T) { const ( mapSize = 2048 @@ -1200,13 +1242,9 @@ func TestMapIterate(t *testing.T) { } } - t.Log("created map of unique key value pairs") - // Sort keys by digest sort.Stable(keysByDigest{sortedKeys, digesterBuilder}) - t.Log("sorted keys by digests") - // Iterate key value pairs i := uint64(0) err = m.IterateReadOnly(func(k Value, v Value) (resume bool, err error) { @@ -1218,8 +1256,6 @@ func TestMapIterate(t *testing.T) { require.NoError(t, err) require.Equal(t, i, uint64(mapSize)) - t.Log("iterated key value pairs") - // Iterate keys i = uint64(0) err = m.IterateReadOnlyKeys(func(k Value) (resume bool, err error) { @@ -1230,8 +1266,6 @@ func TestMapIterate(t *testing.T) { require.NoError(t, err) require.Equal(t, i, uint64(mapSize)) - t.Log("iterated keys") - // Iterate values i = uint64(0) err = m.IterateReadOnlyValues(func(v Value) (resume bool, err error) { @@ -1242,10 +1276,134 @@ func TestMapIterate(t *testing.T) { require.NoError(t, err) require.Equal(t, i, uint64(mapSize)) - t.Log("iterated values") - verifyMap(t, storage, typeInfo, address, m, keyValues, sortedKeys, false) }) + + t.Run("mutation", func(t *testing.T) { + const ( + mapSize = 15 + valueStringSize = 16 + ) + + r := newRand(t) + + elementSize := digestSize + singleElementPrefixSize + Uint64Value(0).ByteSize() + NewStringValue(randStr(r, valueStringSize)).ByteSize() + + typeInfo := testTypeInfo{42} + address := Address{1, 2, 3, 4, 5, 6, 7, 8} + storage := newTestPersistentStorage(t) + digesterBuilder := newBasicDigesterBuilder() + + m, err := NewMap(storage, address, digesterBuilder, typeInfo) + require.NoError(t, err) + + keyValues := make(map[Value]Value, mapSize) + sortedKeys := make([]Value, 0, mapSize) + i := uint64(0) + for i := 0; i < mapSize; i++ { + ck := Uint64Value(0) + cv := NewStringValue(randStr(r, valueStringSize)) + + childMap, err := NewMap(storage, address, newBasicDigesterBuilder(), typeInfo) + require.NoError(t, err) + + existingStorable, err := childMap.Set(compare, hashInputProvider, ck, cv) + require.NoError(t, err) + require.Nil(t, existingStorable) + + k := Uint64Value(i) + sortedKeys = append(sortedKeys, k) + + existingStorable, err = m.Set(compare, hashInputProvider, k, childMap) + require.NoError(t, err) + require.Nil(t, existingStorable) + + require.Equal(t, uint64(1), childMap.Count()) + require.True(t, childMap.Inlined()) + + keyValues[k] = mapValue{ck: cv} + } + require.Equal(t, uint64(mapSize), m.Count()) + require.True(t, m.root.IsData()) + + verifyMap(t, storage, typeInfo, address, m, keyValues, nil, false) + + // Sort keys by digest + sort.Stable(keysByDigest{sortedKeys, digesterBuilder}) + + sizeBeforeMutation := m.root.Header().size + + // Iterate and mutate child map (inserting elements) + i = uint64(0) + err = m.Iterate(compare, hashInputProvider, func(k Value, v Value) (resume bool, err error) { + + childMap, ok := v.(*OrderedMap) + require.True(t, ok) + require.Equal(t, uint64(1), childMap.Count()) + require.True(t, childMap.Inlined()) + + newChildMapKey := Uint64Value(1) // Previous key is 0 + newChildMapValue := NewStringValue(randStr(r, valueStringSize)) + + existingStorable, err := childMap.Set(compare, hashInputProvider, newChildMapKey, newChildMapValue) + require.NoError(t, err) + require.Nil(t, existingStorable) + + expectedChildMapValues, ok := keyValues[k].(mapValue) + require.True(t, ok) + + expectedChildMapValues[newChildMapKey] = newChildMapValue + + valueEqual(t, sortedKeys[i], k) + valueEqual(t, keyValues[k], v) + i++ + + require.Equal(t, m.root.Header().size, sizeBeforeMutation+uint32(i)*elementSize) + + return true, nil + }) + + require.NoError(t, err) + require.Equal(t, uint64(mapSize), i) + + verifyMap(t, storage, typeInfo, address, m, keyValues, nil, false) + + sizeAfterInsertionMutation := m.root.Header().size + + // Iterate and mutate child map (removing elements) + i = uint64(0) + err = m.IterateValues(compare, hashInputProvider, func(v Value) (resume bool, err error) { + childMap, ok := v.(*OrderedMap) + require.True(t, ok) + require.Equal(t, uint64(2), childMap.Count()) + require.True(t, childMap.Inlined()) + + // Remove key 0 + ck := Uint64Value(0) + + existingKeyStorable, existingValueStorable, err := childMap.Remove(compare, hashInputProvider, ck) + require.NoError(t, err) + require.NotNil(t, existingKeyStorable) + require.NotNil(t, existingValueStorable) + + i++ + + require.Equal(t, m.root.Header().size, sizeAfterInsertionMutation-uint32(i)*elementSize) + return true, nil + }) + + require.NoError(t, err) + require.Equal(t, uint64(mapSize), i) + + for k := range keyValues { + expectedChildMapValues, ok := keyValues[k].(mapValue) + require.True(t, ok) + + delete(expectedChildMapValues, Uint64Value(0)) + } + + verifyMap(t, storage, typeInfo, address, m, keyValues, nil, false) + }) } func testMapDeterministicHashCollision(t *testing.T, r *rand.Rand, maxDigestLevel int) { @@ -7662,9 +7820,8 @@ func TestMapPopIterate(t *testing.T) { typeInfo := testTypeInfo{42} storage := newTestPersistentStorage(t) address := Address{1, 2, 3, 4, 5, 6, 7, 8} - digesterBuilder := newBasicDigesterBuilder() - m, err := NewMap(storage, address, digesterBuilder, typeInfo) + m, err := NewMap(storage, address, newBasicDigesterBuilder(), typeInfo) require.NoError(t, err) err = storage.Commit() From c07907da6b3b736d77d0bcd8cef920ad9e20ff08 Mon Sep 17 00:00:00 2001 From: Faye Amacker <33205765+fxamacker@users.noreply.github.com> Date: Tue, 3 Oct 2023 18:05:50 -0500 Subject: [PATCH 3/4] Add MapIterator.CanMutate() predicate function --- map.go | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/map.go b/map.go index 314aa04d..0c4e2732 100644 --- a/map.go +++ b/map.go @@ -5341,7 +5341,7 @@ type MapElementIterationFunc func(Value) (resume bool, err error) type MapIterator struct { m *OrderedMap - comparator ValueComparator // TODO: use comparator and hip to update child element in parent map in register inlining. + comparator ValueComparator hip HashInputProvider id SlabID elemIterator *mapElementIterator @@ -5379,7 +5379,7 @@ func (i *MapIterator) Next() (key Value, value Value, err error) { return nil, nil, wrapErrorfAsExternalErrorIfNeeded(err, "failed to get map value's stored value") } - if i.comparator != nil && i.hip != nil { + if i.CanMutate() { maxInlineSize := maxInlineMapValueSize(uint64(ks.ByteSize())) i.m.setCallbackWithChild(i.comparator, i.hip, key, value, maxInlineSize) } @@ -5454,7 +5454,7 @@ func (i *MapIterator) NextValue() (value Value, err error) { return nil, wrapErrorfAsExternalErrorIfNeeded(err, "failed to get map value's stored value") } - if i.comparator != nil && i.hip != nil { + if i.CanMutate() { key, err := ks.StoredValue(i.m.Storage) if err != nil { // Wrap err as external error (if needed) because err is returned by Storable interface. @@ -5520,11 +5520,19 @@ func (m *OrderedMap) iterator(comparator ValueComparator, hip HashInputProvider) }, nil } +func (i *MapIterator) CanMutate() bool { + return i.comparator != nil && i.hip != nil +} + func (m *OrderedMap) Iterator(comparator ValueComparator, hip HashInputProvider) (*MapIterator, error) { - if comparator == nil || hip == nil { + iterator, err := m.iterator(comparator, hip) + if err != nil { + return nil, err + } + if !iterator.CanMutate() { return nil, NewUserError(fmt.Errorf("failed to create MapIterator: ValueComparator or HashInputProvider is nil")) } - return m.iterator(comparator, hip) + return iterator, nil } // ReadOnlyIterator returns readonly iterator for map elements. From f6711898a763dc74205d6d541131c8eb9cba7483 Mon Sep 17 00:00:00 2001 From: Faye Amacker <33205765+fxamacker@users.noreply.github.com> Date: Tue, 3 Oct 2023 18:10:47 -0500 Subject: [PATCH 4/4] Add ArrayIterator.CanMutate() predicate function --- array.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/array.go b/array.go index c0e44238..52e9c22e 100644 --- a/array.go +++ b/array.go @@ -3359,6 +3359,10 @@ type ArrayIterator struct { readOnly bool } +func (i *ArrayIterator) CanMutate() bool { + return !i.readOnly +} + func (i *ArrayIterator) Next() (Value, error) { if i.remainingCount == 0 { return nil, nil @@ -3391,7 +3395,7 @@ func (i *ArrayIterator) Next() (Value, error) { return nil, wrapErrorfAsExternalErrorIfNeeded(err, "failed to get storable's stored value") } - if !i.readOnly { + if i.CanMutate() { // Set up notification callback in child value so // when child value is modified parent a is notified. i.array.setCallbackWithChild(uint64(i.indexInArray), element, maxInlineArrayElementSize)