diff --git a/internal/datastore/proxy/schemacaching/standardcache.go b/internal/datastore/proxy/schemacaching/standardcache.go index 9273cb98db..e6132aebbd 100644 --- a/internal/datastore/proxy/schemacaching/standardcache.go +++ b/internal/datastore/proxy/schemacaching/standardcache.go @@ -125,7 +125,7 @@ func listAndCache[T schemaDefinition]( continue } - remainingToLoad.Remove(name) + remainingToLoad.Delete(name) loaded := loadedRaw.(*cacheEntry) foundDefs = append(foundDefs, datastore.RevisionedDefinition[T]{ Definition: loaded.definition.(T), diff --git a/internal/datastore/proxy/schemacaching/watchingcache.go b/internal/datastore/proxy/schemacaching/watchingcache.go index dbb604f84d..89407b1b74 100644 --- a/internal/datastore/proxy/schemacaching/watchingcache.go +++ b/internal/datastore/proxy/schemacaching/watchingcache.go @@ -567,7 +567,7 @@ func (swc *schemaWatchCache[T]) readDefinitionsWithNames(ctx context.Context, na } swc.definitionsReadCachedCounter.WithLabelValues(swc.kind).Inc() - remainingNames.Remove(name) + remainingNames.Delete(name) if !found.wasNotFound { foundDefs = append(foundDefs, found.revisionedDefinition) } diff --git a/internal/dispatch/graph/check_test.go b/internal/dispatch/graph/check_test.go index a25fe7469e..98e4f9b8ca 100644 --- a/internal/dispatch/graph/check_test.go +++ b/internal/dispatch/graph/check_test.go @@ -292,7 +292,7 @@ func TestCheckMetadata(t *testing.T) { } func addFrame(trace *v1.CheckDebugTrace, foundFrames *mapz.Set[string]) { - foundFrames.Add(fmt.Sprintf("%s:%s#%s", trace.Request.ResourceRelation.Namespace, strings.Join(trace.Request.ResourceIds, ","), trace.Request.ResourceRelation.Relation)) + foundFrames.Insert(fmt.Sprintf("%s:%s#%s", trace.Request.ResourceRelation.Namespace, strings.Join(trace.Request.ResourceIds, ","), trace.Request.ResourceRelation.Relation)) for _, subTrace := range trace.SubProblems { addFrame(subTrace, foundFrames) } diff --git a/internal/dispatch/graph/lookupresources_test.go b/internal/dispatch/graph/lookupresources_test.go index e678d3c826..fbc73287b4 100644 --- a/internal/dispatch/graph/lookupresources_test.go +++ b/internal/dispatch/graph/lookupresources_test.go @@ -212,7 +212,7 @@ func TestSimpleLookupResourcesWithCursor(t *testing.T) { require.Equal(1, len(stream.Results())) - found.Add(stream.Results()[0].ResolvedResource.ResourceId) + found.Insert(stream.Results()[0].ResolvedResource.ResourceId) require.Equal(tc.expectedFirst, found.AsSlice()) cursor := stream.Results()[0].AfterResponseCursor @@ -233,7 +233,7 @@ func TestSimpleLookupResourcesWithCursor(t *testing.T) { require.NoError(err) for _, result := range stream.Results() { - found.Add(result.ResolvedResource.ResourceId) + found.Insert(result.ResolvedResource.ResourceId) } foundResults := found.AsSlice() @@ -585,7 +585,7 @@ func TestLookupResourcesOverSchemaWithCursors(t *testing.T) { } for _, result := range stream.Results() { - foundResourceIDs.Add(result.ResolvedResource.ResourceId) + foundResourceIDs.Insert(result.ResolvedResource.ResourceId) currentCursor = result.AfterResponseCursor } diff --git a/internal/dispatch/graph/reachableresources_test.go b/internal/dispatch/graph/reachableresources_test.go index 712ca66feb..e1b1af6054 100644 --- a/internal/dispatch/graph/reachableresources_test.go +++ b/internal/dispatch/graph/reachableresources_test.go @@ -827,7 +827,7 @@ func TestReachableResourcesCursors(t *testing.T) { count := 0 for _, result := range stream2.Results() { count++ - foundResources.Add(result.Resource.ResourceId) + foundResources.Insert(result.Resource.ResourceId) } require.LessOrEqual(t, count, 310) @@ -1262,7 +1262,7 @@ func TestReachableResourcesOverSchema(t *testing.T) { } for _, result := range stream.Results() { - foundResourceIDs.Add(result.Resource.ResourceId) + foundResourceIDs.Insert(result.Resource.ResourceId) currentCursor = result.AfterResponseCursor } @@ -1433,12 +1433,12 @@ func TestReachableResourcesWithCachingInParallelTest(t *testing.T) { for i := 0; i < 410; i++ { if i < 250 { - expectedResources.Add(fmt.Sprintf("res%03d", i)) + expectedResources.Insert(fmt.Sprintf("res%03d", i)) testRels = append(testRels, tuple.MustParse(fmt.Sprintf("resource:res%03d#viewer@user:tom", i))) } if i > 200 { - expectedResources.Add(fmt.Sprintf("res%03d", i)) + expectedResources.Insert(fmt.Sprintf("res%03d", i)) testRels = append(testRels, tuple.MustParse(fmt.Sprintf("resource:res%03d#editor@user:tom", i))) } } @@ -1488,7 +1488,7 @@ func TestReachableResourcesWithCachingInParallelTest(t *testing.T) { foundResources := mapz.NewSet[string]() for _, result := range stream.Results() { - foundResources.Add(result.Resource.ResourceId) + foundResources.Insert(result.Resource.ResourceId) } expectedResourcesSlice := expectedResources.AsSlice() diff --git a/internal/graph/resourcesubjectsmap.go b/internal/graph/resourcesubjectsmap.go index 552d93344d..62a3076098 100644 --- a/internal/graph/resourcesubjectsmap.go +++ b/internal/graph/resourcesubjectsmap.go @@ -196,9 +196,9 @@ func (rsm dispatchableResourcesSubjectMap) mapFoundResource(foundResource *v1.Re } for _, info := range infos { - forSubjectIDs.Add(info.subjectID) + forSubjectIDs.Insert(info.subjectID) if !info.isCaveated { - nonCaveatedSubjectIDs.Add(info.subjectID) + nonCaveatedSubjectIDs.Insert(info.subjectID) } } } diff --git a/internal/namespace/util.go b/internal/namespace/util.go index 96820648e8..4e085d6a6e 100644 --- a/internal/namespace/util.go +++ b/internal/namespace/util.go @@ -55,7 +55,7 @@ type TypeAndRelationToCheck struct { func CheckNamespaceAndRelations(ctx context.Context, checks []TypeAndRelationToCheck, ds datastore.Reader) error { nsNames := mapz.NewSet[string]() for _, toCheck := range checks { - nsNames.Add(toCheck.NamespaceName) + nsNames.Insert(toCheck.NamespaceName) } if nsNames.IsEmpty() { @@ -150,12 +150,12 @@ func ReadNamespaceAndTypes( func ListReferencedNamespaces(nsdefs []*core.NamespaceDefinition) []string { referencedNamespaceNamesSet := mapz.NewSet[string]() for _, nsdef := range nsdefs { - referencedNamespaceNamesSet.Add(nsdef.Name) + referencedNamespaceNamesSet.Insert(nsdef.Name) for _, relation := range nsdef.Relation { if relation.GetTypeInformation() != nil { for _, allowedRel := range relation.GetTypeInformation().AllowedDirectRelations { - referencedNamespaceNamesSet.Add(allowedRel.GetNamespace()) + referencedNamespaceNamesSet.Insert(allowedRel.GetNamespace()) } } } diff --git a/internal/relationships/validation.go b/internal/relationships/validation.go index aca86d075f..143d8285f3 100644 --- a/internal/relationships/validation.go +++ b/internal/relationships/validation.go @@ -87,10 +87,10 @@ func loadNamespacesAndCaveats(ctx context.Context, rels []*core.RelationTuple, r referencedNamespaceNames := mapz.NewSet[string]() referencedCaveatNamesWithContext := mapz.NewSet[string]() for _, rel := range rels { - referencedNamespaceNames.Add(rel.ResourceAndRelation.Namespace) - referencedNamespaceNames.Add(rel.Subject.Namespace) + referencedNamespaceNames.Insert(rel.ResourceAndRelation.Namespace) + referencedNamespaceNames.Insert(rel.Subject.Namespace) if hasNonEmptyCaveatContext(rel) { - referencedCaveatNamesWithContext.Add(rel.Caveat.CaveatName) + referencedCaveatNamesWithContext.Insert(rel.Caveat.CaveatName) } } diff --git a/internal/services/integrationtesting/consistency_test.go b/internal/services/integrationtesting/consistency_test.go index 0ba43a2bd3..059dc92d49 100644 --- a/internal/services/integrationtesting/consistency_test.go +++ b/internal/services/integrationtesting/consistency_test.go @@ -253,7 +253,7 @@ func validateRelationshipReads(t *testing.T, vctx validationContext) { foundRelationshipsSet := mapz.NewSet[string]() for _, rel := range foundRelationships { - foundRelationshipsSet.Add(tuple.MustString(rel)) + foundRelationshipsSet.Insert(tuple.MustString(rel)) } require.True(t, foundRelationshipsSet.Has(tuple.MustString(relationship)), "missing expected relationship %s in read results: %s", tuple.MustString(relationship), foundRelationshipsSet.AsSlice()) diff --git a/internal/services/shared/schema.go b/internal/services/shared/schema.go index c9ee8059f7..42ed4d353e 100644 --- a/internal/services/shared/schema.go +++ b/internal/services/shared/schema.go @@ -34,7 +34,7 @@ func ValidateSchemaChanges(ctx context.Context, compiled *compiler.CompiledSchem return nil, err } - newCaveatDefNames.Add(caveatDef.Name) + newCaveatDefNames.Insert(caveatDef.Name) } // 2) Validate the namespaces defined. @@ -58,7 +58,7 @@ func ValidateSchemaChanges(ctx context.Context, compiled *compiler.CompiledSchem return nil, err } - newObjectDefNames.Add(nsdef.Name) + newObjectDefNames.Insert(nsdef.Name) } return &ValidatedSchemaChanges{ @@ -119,7 +119,7 @@ func ApplySchemaChangesOverExisting( for _, existingCaveat := range existingCaveats { existingCaveatDefMap[existingCaveat.Name] = existingCaveat - existingCaveatDefNames.Add(existingCaveat.Name) + existingCaveatDefNames.Insert(existingCaveat.Name) } // For each caveat definition, perform a diff and ensure the changes will not result in type errors. @@ -142,7 +142,7 @@ func ApplySchemaChangesOverExisting( existingObjectDefNames := mapz.NewSet[string]() for _, existingDef := range existingObjectDefs { existingObjectDefMap[existingDef.Name] = existingDef - existingObjectDefNames.Add(existingDef.Name) + existingObjectDefNames.Insert(existingDef.Name) } // For each definition, perform a diff and ensure the changes will not result in any diff --git a/internal/services/v1/debug_test.go b/internal/services/v1/debug_test.go index 9404e6e26f..b925a33609 100644 --- a/internal/services/v1/debug_test.go +++ b/internal/services/v1/debug_test.go @@ -49,7 +49,7 @@ func expectDebugFrames(permissionNames ...string) rda { for _, sp := range debugInfo.Check.GetSubProblems().Traces { for _, permissionName := range permissionNames { if sp.Permission == permissionName { - found.Add(permissionName) + found.Insert(permissionName) } } } diff --git a/internal/testutil/subjects.go b/internal/testutil/subjects.go index a62c506bb5..2ed3544dd1 100644 --- a/internal/testutil/subjects.go +++ b/internal/testutil/subjects.go @@ -324,7 +324,7 @@ func combinatorialValues(names []string) []map[string]bool { // collectReferencedNames collects all referenced caveat names into the given set. func collectReferencedNames(expr *core.CaveatExpression, nameSet *mapz.Set[string]) { if expr.GetCaveat() != nil { - nameSet.Add(expr.GetCaveat().CaveatName) + nameSet.Insert(expr.GetCaveat().CaveatName) return } diff --git a/pkg/cmd/server/middleware.go b/pkg/cmd/server/middleware.go index c6042dc9d4..eac34f83f5 100644 --- a/pkg/cmd/server/middleware.go +++ b/pkg/cmd/server/middleware.go @@ -97,7 +97,7 @@ const ( func (mc *MiddlewareChain[T]) Names() *mapz.Set[string] { names := mapz.NewSet[string]() for _, mw := range mc.chain { - names.Add(mw.Name) + names.Insert(mw.Name) } return names } diff --git a/pkg/datastore/test/watch.go b/pkg/datastore/test/watch.go index b59b549638..f4f6babbe9 100644 --- a/pkg/datastore/test/watch.go +++ b/pkg/datastore/test/watch.go @@ -580,17 +580,17 @@ func verifyMixedUpdates( foundChanges := mapz.NewSet[string]() for _, changedDef := range change.ChangedDefinitions { - foundChanges.Add("changed:" + changedDef.GetName()) + foundChanges.Insert("changed:" + changedDef.GetName()) } for _, deleted := range change.DeletedNamespaces { - foundChanges.Add("deleted-ns:" + deleted) + foundChanges.Insert("deleted-ns:" + deleted) } for _, deleted := range change.DeletedCaveats { - foundChanges.Add("deleted-caveat:" + deleted) + foundChanges.Insert("deleted-caveat:" + deleted) } for _, update := range change.RelationshipChanges { - foundChanges.Add("rel:" + fmt.Sprintf("OPERATION_%s(%s)", update.Operation, tuple.StringWithoutCaveat(update.Tuple))) + foundChanges.Insert("rel:" + fmt.Sprintf("OPERATION_%s(%s)", update.Operation, tuple.StringWithoutCaveat(update.Tuple))) } found := foundChanges.AsSlice() diff --git a/pkg/diff/namespace/diff.go b/pkg/diff/namespace/diff.go index 78d5a0cc40..c2f6e82fbd 100644 --- a/pkg/diff/namespace/diff.go +++ b/pkg/diff/namespace/diff.go @@ -2,7 +2,6 @@ package namespace import ( "github.com/google/go-cmp/cmp" - "github.com/scylladb/go-set/strset" "golang.org/x/exp/slices" "google.golang.org/protobuf/testing/protocmp" @@ -132,16 +131,16 @@ func DiffNamespaces(existing *core.NamespaceDefinition, updated *core.NamespaceD // Collect up relations and check. existingRels := map[string]*core.Relation{} - existingRelNames := strset.New() + existingRelNames := mapz.NewSet[string]() existingPerms := map[string]*core.Relation{} - existingPermNames := strset.New() + existingPermNames := mapz.NewSet[string]() updatedRels := map[string]*core.Relation{} - updatedRelNames := strset.New() + updatedRelNames := mapz.NewSet[string]() updatedPerms := map[string]*core.Relation{} - updatedPermNames := strset.New() + updatedPermNames := mapz.NewSet[string]() for _, relation := range existing.Relation { _, ok := existingRels[relation.Name] @@ -173,35 +172,39 @@ func DiffNamespaces(existing *core.NamespaceDefinition, updated *core.NamespaceD } } - for _, removed := range strset.Difference(existingRelNames, updatedRelNames).List() { + _ = existingRelNames.Subtract(updatedRelNames).ForEach(func(removed string) error { deltas = append(deltas, Delta{ Type: RemovedRelation, RelationName: removed, }) - } + return nil + }) - for _, added := range strset.Difference(updatedRelNames, existingRelNames).List() { + _ = updatedRelNames.Subtract(existingRelNames).ForEach(func(added string) error { deltas = append(deltas, Delta{ Type: AddedRelation, RelationName: added, }) - } + return nil + }) - for _, removed := range strset.Difference(existingPermNames, updatedPermNames).List() { + _ = existingPermNames.Subtract(updatedPermNames).ForEach(func(removed string) error { deltas = append(deltas, Delta{ Type: RemovedPermission, RelationName: removed, }) - } + return nil + }) - for _, added := range strset.Difference(updatedPermNames, existingPermNames).List() { + _ = updatedPermNames.Subtract(existingPermNames).ForEach(func(added string) error { deltas = append(deltas, Delta{ Type: AddedPermission, RelationName: added, }) - } + return nil + }) - for _, shared := range strset.Intersection(existingPermNames, updatedPermNames).List() { + _ = existingPermNames.Intersect(updatedPermNames).ForEach(func(shared string) error { existingPerm := existingPerms[shared] updatedPerm := updatedPerms[shared] @@ -222,9 +225,10 @@ func DiffNamespaces(existing *core.NamespaceDefinition, updated *core.NamespaceD RelationName: shared, }) } - } + return nil + }) - for _, shared := range strset.Intersection(existingRelNames, updatedRelNames).List() { + _ = existingRelNames.Intersect(updatedRelNames).ForEach(func(shared string) error { existingRel := existingRels[shared] updatedRel := updatedRels[shared] @@ -273,22 +277,26 @@ func DiffNamespaces(existing *core.NamespaceDefinition, updated *core.NamespaceD updatedAllowedRels.Add(source) } - for _, removed := range existingAllowedRels.Subtract(updatedAllowedRels).AsSlice() { + _ = existingAllowedRels.Subtract(updatedAllowedRels).ForEach(func(removed string) error { deltas = append(deltas, Delta{ Type: RelationAllowedTypeRemoved, RelationName: shared, AllowedType: allowedRelsBySource[removed], }) - } + return nil + }) - for _, added := range updatedAllowedRels.Subtract(existingAllowedRels).AsSlice() { + _ = updatedAllowedRels.Subtract(existingAllowedRels).ForEach(func(added string) error { deltas = append(deltas, Delta{ Type: RelationAllowedTypeAdded, RelationName: shared, AllowedType: allowedRelsBySource[added], }) - } - } + return nil + }) + + return nil + }) return &Diff{ existing: existing, diff --git a/pkg/genutil/mapz/countingmap.go b/pkg/genutil/mapz/countingmap.go index f8ca2b4268..7cebacdd70 100644 --- a/pkg/genutil/mapz/countingmap.go +++ b/pkg/genutil/mapz/countingmap.go @@ -43,7 +43,7 @@ func (cmm *CountingMultiMap[T, Q]) Remove(key T, value Q) { return } - values.Remove(value) + values.Delete(value) if values.IsEmpty() { delete(cmm.valuesByKey, key) } diff --git a/pkg/genutil/mapz/set.go b/pkg/genutil/mapz/set.go index dca178d94a..21bd52be99 100644 --- a/pkg/genutil/mapz/set.go +++ b/pkg/genutil/mapz/set.go @@ -1,7 +1,11 @@ package mapz import ( + "maps" + "github.com/rs/zerolog" + + expmaps "golang.org/x/exp/maps" ) // Set implements a very basic generic set. @@ -37,15 +41,14 @@ func (s *Set[T]) Add(value T) bool { return true } -// Remove removes the value from the set, returning whether -// the element was present when the call was made. -func (s *Set[T]) Remove(value T) bool { - if !s.Has(value) { - return false - } +// Insert adds the given value to the set. +func (s *Set[T]) Insert(value T) { + s.values[value] = struct{}{} +} +// Delete removes the value from the set, returning nothing. +func (s *Set[T]) Delete(value T) { delete(s.values, value) - return true } // Extend adds all the values to the set. @@ -60,7 +63,7 @@ func (s *Set[T]) Extend(values []T) { func (s *Set[T]) IntersectionDifference(other *Set[T]) *Set[T] { for value := range s.values { if !other.Has(value) { - s.Remove(value) + delete(s.values, value) } } return s @@ -69,21 +72,22 @@ func (s *Set[T]) IntersectionDifference(other *Set[T]) *Set[T] { // RemoveAll removes all values from this set found in the other set. func (s *Set[T]) RemoveAll(other *Set[T]) { for value := range other.values { - s.Remove(value) + delete(s.values, value) } } // Subtract subtracts the other set from this set, returning a new set. func (s *Set[T]) Subtract(other *Set[T]) *Set[T] { - newSet := NewSet[T]() - newSet.Extend(s.AsSlice()) - newSet.RemoveAll(other) - return newSet + cpy := s.Copy() + cpy.RemoveAll(other) + return cpy } // Copy returns a copy of this set. func (s *Set[T]) Copy() *Set[T] { - return NewSet(s.AsSlice()...) + return &Set[T]{ + values: maps.Clone(s.values), + } } // Intersect removes any values from this set that @@ -92,7 +96,7 @@ func (s *Set[T]) Intersect(other *Set[T]) *Set[T] { cpy := s.Copy() for value := range cpy.values { if !other.Has(value) { - cpy.Remove(value) + delete(cpy.values, value) } } return cpy @@ -100,17 +104,7 @@ func (s *Set[T]) Intersect(other *Set[T]) *Set[T] { // Equal returns true if both sets have the same elements func (s *Set[T]) Equal(other *Set[T]) bool { - for value := range s.values { - if !other.Has(value) { - return false - } - } - for value := range other.values { - if !s.Has(value) { - return false - } - } - return true + return maps.Equal(s.values, other.values) } // IsEmpty returns true if the set is empty. @@ -124,11 +118,7 @@ func (s *Set[T]) AsSlice() []T { return nil } - slice := make([]T, 0, len(s.values)) - for value := range s.values { - slice = append(slice, value) - } - return slice + return expmaps.Keys(s.values) } // Len returns the length of the set. diff --git a/pkg/genutil/mapz/set_test.go b/pkg/genutil/mapz/set_test.go index fb8289b30a..7bdbef37a7 100644 --- a/pkg/genutil/mapz/set_test.go +++ b/pkg/genutil/mapz/set_test.go @@ -32,9 +32,9 @@ func TestSetOperations(t *testing.T) { sort.Strings(slice) require.Equal(t, slice, []string{"hello", "heyo", "hi"}) - // Remove some items. - require.True(t, set.Remove("hi")) - require.False(t, set.Remove("hi")) + // Delete some items. + set.Delete("hi") + set.Delete("hi") require.False(t, set.Has("hi")) require.True(t, set.Has("hello")) @@ -95,6 +95,12 @@ func TestSetIntersect(t *testing.T) { slice := set.AsSlice() sort.Strings(slice) require.Equal(t, []string{"1", "2", "3", "4"}, slice) + + // Perform in reverse. + updated = NewSet[string]("1", "2", "3", "5").Intersect(set) + updatedSlice = updated.AsSlice() + sort.Strings(updatedSlice) + require.Equal(t, []string{"1", "2", "3"}, updatedSlice) } func TestSetSubtract(t *testing.T) { @@ -162,3 +168,106 @@ func TestSetIntersectionDifference(t *testing.T) { }) } } + +func BenchmarkAdd(b *testing.B) { + set := NewSet[int]() + for i := 0; i < b.N; i++ { + set.Add(i) + } +} + +func BenchmarkInsert(b *testing.B) { + set := NewSet[int]() + for i := 0; i < b.N; i++ { + set.Insert(i) + } +} + +func BenchmarkCopy(b *testing.B) { + set := NewSet[int]() + for i := 0; i < b.N; i++ { + set.Add(i) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + set.Copy() + } +} + +func BenchmarkHas(b *testing.B) { + set := NewSet[int]() + for i := 0; i < b.N; i++ { + set.Add(i) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + set.Has(i) + } +} + +func BenchmarkDelete(b *testing.B) { + set := NewSet[int]() + for i := 0; i < b.N; i++ { + set.Add(i) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + set.Delete(i) + } +} + +func BenchmarkIntersect(b *testing.B) { + set := NewSet[int]() + for i := 0; i < b.N; i++ { + set.Add(i) + } + other := NewSet[int]() + for i := 0; i < b.N; i++ { + other.Add(i) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + set.Intersect(other) + } +} + +func BenchmarkSubtract(b *testing.B) { + set := NewSet[int]() + for i := 0; i < b.N; i++ { + set.Add(i) + } + other := NewSet[int]() + for i := 0; i < b.N; i++ { + other.Add(i) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + set.Subtract(other) + } +} + +func BenchmarkAsSlice(b *testing.B) { + set := NewSet[int]() + for i := 0; i < b.N; i++ { + set.Add(i) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + set.AsSlice() + } +} + +func BenchmarkEqual(b *testing.B) { + set := NewSet[int]() + for i := 0; i < b.N; i++ { + set.Add(i) + } + other := NewSet[int]() + for i := 0; i < b.N; i++ { + other.Add(i) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + set.Equal(other) + } +} diff --git a/pkg/tuple/onrset.go b/pkg/tuple/onrset.go index bc647a77a5..0496d1103d 100644 --- a/pkg/tuple/onrset.go +++ b/pkg/tuple/onrset.go @@ -1,6 +1,10 @@ package tuple import ( + "maps" + + expmaps "golang.org/x/exp/maps" + core "github.com/authzed/spicedb/pkg/proto/core/v1" ) @@ -83,9 +87,8 @@ func (ons *ONRSet) Subtract(otherSet *ONRSet) *ONRSet { // With returns a copy of this ONR set with the given element added. func (ons *ONRSet) With(onr *core.ObjectAndRelation) *ONRSet { - updated := NewONRSet() - for _, current := range ons.onrs { - updated.Add(current) + updated := &ONRSet{ + onrs: maps.Clone(ons.onrs), } updated.Add(onr) return updated @@ -93,9 +96,8 @@ func (ons *ONRSet) With(onr *core.ObjectAndRelation) *ONRSet { // Union returns a copy of this ONR set with the other set's elements added in. func (ons *ONRSet) Union(otherSet *ONRSet) *ONRSet { - updated := NewONRSet() - for _, current := range ons.onrs { - updated.Add(current) + updated := &ONRSet{ + onrs: maps.Clone(ons.onrs), } for _, current := range otherSet.onrs { updated.Add(current) @@ -105,9 +107,5 @@ func (ons *ONRSet) Union(otherSet *ONRSet) *ONRSet { // AsSlice returns the ONRs found in the set as a slice. func (ons *ONRSet) AsSlice() []*core.ObjectAndRelation { - slice := make([]*core.ObjectAndRelation, 0, len(ons.onrs)) - for _, onr := range ons.onrs { - slice = append(slice, onr) - } - return slice + return expmaps.Values(ons.onrs) } diff --git a/pkg/tuple/onrset_test.go b/pkg/tuple/onrset_test.go new file mode 100644 index 0000000000..43417d7e63 --- /dev/null +++ b/pkg/tuple/onrset_test.go @@ -0,0 +1,143 @@ +package tuple + +import ( + "testing" + + "github.com/stretchr/testify/require" + + core "github.com/authzed/spicedb/pkg/proto/core/v1" +) + +func TestONRSet(t *testing.T) { + set := NewONRSet() + require.True(t, set.IsEmpty()) + require.Equal(t, uint32(0), set.Length()) + + require.True(t, set.Add(ParseONR("resource:1#viewer"))) + require.False(t, set.IsEmpty()) + require.Equal(t, uint32(1), set.Length()) + + require.True(t, set.Add(ParseONR("resource:2#viewer"))) + require.True(t, set.Add(ParseONR("resource:3#viewer"))) + require.Equal(t, uint32(3), set.Length()) + + require.False(t, set.Add(ParseONR("resource:1#viewer"))) + require.True(t, set.Add(ParseONR("resource:1#editor"))) + + require.True(t, set.Has(ParseONR("resource:1#viewer"))) + require.True(t, set.Has(ParseONR("resource:1#editor"))) + require.False(t, set.Has(ParseONR("resource:1#owner"))) + require.False(t, set.Has(ParseONR("resource:1#admin"))) + require.False(t, set.Has(ParseONR("resource:1#reader"))) + + require.True(t, set.Has(ParseONR("resource:2#viewer"))) +} + +func TestONRSetUpdate(t *testing.T) { + set := NewONRSet() + set.Update([]*core.ObjectAndRelation{ + ParseONR("resource:1#viewer"), + ParseONR("resource:2#viewer"), + ParseONR("resource:3#viewer"), + }) + require.Equal(t, uint32(3), set.Length()) + + set.Update([]*core.ObjectAndRelation{ + ParseONR("resource:1#viewer"), + ParseONR("resource:1#editor"), + ParseONR("resource:1#owner"), + ParseONR("resource:1#admin"), + ParseONR("resource:1#reader"), + }) + require.Equal(t, uint32(7), set.Length()) +} + +func TestONRSetIntersect(t *testing.T) { + set1 := NewONRSet() + set1.Update([]*core.ObjectAndRelation{ + ParseONR("resource:1#viewer"), + ParseONR("resource:2#viewer"), + ParseONR("resource:3#viewer"), + }) + + set2 := NewONRSet() + set2.Update([]*core.ObjectAndRelation{ + ParseONR("resource:1#viewer"), + ParseONR("resource:1#editor"), + ParseONR("resource:1#owner"), + ParseONR("resource:1#admin"), + ParseONR("resource:2#viewer"), + ParseONR("resource:1#reader"), + }) + + require.Equal(t, uint32(2), set1.Intersect(set2).Length()) + require.Equal(t, uint32(2), set2.Intersect(set1).Length()) +} + +func TestONRSetSubtract(t *testing.T) { + set1 := NewONRSet() + set1.Update([]*core.ObjectAndRelation{ + ParseONR("resource:1#viewer"), + ParseONR("resource:2#viewer"), + ParseONR("resource:3#viewer"), + }) + + set2 := NewONRSet() + set2.Update([]*core.ObjectAndRelation{ + ParseONR("resource:1#viewer"), + ParseONR("resource:1#editor"), + ParseONR("resource:1#owner"), + ParseONR("resource:1#admin"), + ParseONR("resource:2#viewer"), + ParseONR("resource:1#reader"), + }) + + require.Equal(t, uint32(1), set1.Subtract(set2).Length()) + require.Equal(t, uint32(4), set2.Subtract(set1).Length()) +} + +func TestONRSetUnion(t *testing.T) { + set1 := NewONRSet() + set1.Update([]*core.ObjectAndRelation{ + ParseONR("resource:1#viewer"), + ParseONR("resource:2#viewer"), + ParseONR("resource:3#viewer"), + }) + + set2 := NewONRSet() + set2.Update([]*core.ObjectAndRelation{ + ParseONR("resource:1#viewer"), + ParseONR("resource:1#editor"), + ParseONR("resource:1#owner"), + ParseONR("resource:1#admin"), + ParseONR("resource:2#viewer"), + ParseONR("resource:1#reader"), + }) + + require.Equal(t, uint32(7), set1.Union(set2).Length()) + require.Equal(t, uint32(7), set2.Union(set1).Length()) +} + +func TestONRSetWith(t *testing.T) { + set1 := NewONRSet() + set1.Update([]*core.ObjectAndRelation{ + ParseONR("resource:1#viewer"), + ParseONR("resource:2#viewer"), + ParseONR("resource:3#viewer"), + }) + + added := set1.With(ParseONR("resource:1#editor")) + require.Equal(t, uint32(3), set1.Length()) + require.Equal(t, uint32(4), added.Length()) +} + +func TestONRSetAsSlice(t *testing.T) { + set := NewONRSet() + set.Update([]*core.ObjectAndRelation{ + ParseONR("resource:1#viewer"), + ParseONR("resource:2#viewer"), + ParseONR("resource:3#viewer"), + }) + + require.Equal(t, 3, len(set.AsSlice())) +}