diff --git a/internal/datastore/mysql/readwrite.go b/internal/datastore/mysql/readwrite.go index 5f790b5f13..36c8726f2a 100644 --- a/internal/datastore/mysql/readwrite.go +++ b/internal/datastore/mysql/readwrite.go @@ -151,8 +151,7 @@ func (rwt *mysqlReadWriteTXN) WriteRelationships(ctx context.Context, mutations } // Ensure the tuples are the same. - // TODO(jschorr): Use a faster method then string comparison. - if tuple.MustString(mut.Tuple) == tuple.MustString(foundTpl) { + if tuple.Equal(mut.Tuple, foundTpl) { delete(createAndTouchMutationsByTuple, tplString) continue } diff --git a/internal/graph/check.go b/internal/graph/check.go index 09def5029f..2904723f9a 100644 --- a/internal/graph/check.go +++ b/internal/graph/check.go @@ -197,15 +197,6 @@ func (cc *ConcurrentChecker) checkInternal(ctx context.Context, req ValidatedChe return combineResultWithFoundResources(cc.checkUsersetRewrite(ctx, crc, relation.UsersetRewrite), membershipSet) } -func onrEqual(lhs, rhs *core.ObjectAndRelation) bool { - // Properties are sorted by highest to lowest cardinality to optimize for short-circuiting. - return lhs.ObjectId == rhs.ObjectId && lhs.Relation == rhs.Relation && lhs.Namespace == rhs.Namespace -} - -func onrEqualOrWildcard(tpl, target *core.ObjectAndRelation) bool { - return onrEqual(tpl, target) || (tpl.ObjectId == tuple.PublicWildcard && tpl.Namespace == target.Namespace) -} - type directDispatch struct { resourceType *core.RelationReference resourceIds []string @@ -298,7 +289,7 @@ func (cc *ConcurrentChecker) checkDirect(ctx context.Context, crc currentRequest // If the subject of the relationship matches the target subject, then we've found // a result. - if !onrEqualOrWildcard(tpl.Subject, crc.parentReq.Subject) { + if !tuple.OnrEqualOrWildcard(tpl.Subject, crc.parentReq.Subject) { tplString, err := tuple.String(tpl) if err != nil { return checkResultError(err, emptyMetadata) diff --git a/pkg/tuple/onr.go b/pkg/tuple/onr.go index bb8f5473de..a0685c9104 100644 --- a/pkg/tuple/onr.go +++ b/pkg/tuple/onr.go @@ -108,3 +108,12 @@ func StringsONRs(onrs []*core.ObjectAndRelation) []string { sort.Strings(onrstrings) return onrstrings } + +func OnrEqual(lhs, rhs *core.ObjectAndRelation) bool { + // Properties are sorted by highest to lowest cardinality to optimize for short-circuiting. + return lhs.ObjectId == rhs.ObjectId && lhs.Relation == rhs.Relation && lhs.Namespace == rhs.Namespace +} + +func OnrEqualOrWildcard(tpl, target *core.ObjectAndRelation) bool { + return OnrEqual(tpl, target) || (tpl.ObjectId == PublicWildcard && tpl.Namespace == target.Namespace) +} diff --git a/pkg/tuple/tuple.go b/pkg/tuple/tuple.go index 6ca65f7f56..7b97cbd72b 100644 --- a/pkg/tuple/tuple.go +++ b/pkg/tuple/tuple.go @@ -120,6 +120,14 @@ func StringWithoutCaveat(tpl *core.RelationTuple) string { return fmt.Sprintf("%s@%s", StringONR(tpl.ResourceAndRelation), StringONR(tpl.Subject)) } +func MustStringCaveat(caveat *core.ContextualizedCaveat) string { + caveatString, err := StringCaveat(caveat) + if err != nil { + panic(err) + } + return caveatString +} + // StringCaveat converts a contextualized caveat to a string. If the caveat is nil or empty, returns empty string. func StringCaveat(caveat *core.ContextualizedCaveat) (string, error) { if caveat == nil || caveat.CaveatName == "" { @@ -263,6 +271,11 @@ func Delete(tpl *core.RelationTuple) *core.RelationTupleUpdate { } } +func Equal(lhs, rhs *core.RelationTuple) bool { + // TODO(jschorr): Use a faster method then string comparison for caveats. + return OnrEqual(lhs.ResourceAndRelation, rhs.ResourceAndRelation) && OnrEqual(lhs.Subject, rhs.Subject) && MustStringCaveat(lhs.Caveat) == MustStringCaveat(rhs.Caveat) +} + // MustToRelationship converts a RelationTuple into a Relationship. Will panic if // the RelationTuple does not validate. func MustToRelationship(tpl *core.RelationTuple) *v1.Relationship { diff --git a/pkg/tuple/tuple_test.go b/pkg/tuple/tuple_test.go index 23774f7993..32bac7046f 100644 --- a/pkg/tuple/tuple_test.go +++ b/pkg/tuple/tuple_test.go @@ -546,3 +546,175 @@ func TestCopyRelationTupleToRelationship(t *testing.T) { }) } } + +func TestEqual(t *testing.T) { + equalTestCases := []*core.RelationTuple{ + makeTuple( + ObjectAndRelation("testns", "testobj", "testrel"), + ObjectAndRelation("user", "testusr", "..."), + ), + MustWithCaveat( + makeTuple( + ObjectAndRelation("testns", "testobj", "testrel"), + ObjectAndRelation("user", "testusr", "..."), + ), + "somecaveat", + map[string]any{ + "context": map[string]any{ + "deeply": map[string]any{ + "nested": true, + }, + }, + }, + ), + } + + for _, tc := range equalTestCases { + t.Run(MustString(tc), func(t *testing.T) { + require := require.New(t) + require.True(Equal(tc, tc.CloneVT())) + }) + } + + notEqualTestCases := []struct { + name string + lhs *core.RelationTuple + rhs *core.RelationTuple + }{ + { + name: "Mismatch Resource Type", + lhs: makeTuple( + ObjectAndRelation("testns1", "testobj", "testrel"), + ObjectAndRelation("user", "testusr", "..."), + ), + rhs: makeTuple( + ObjectAndRelation("testns2", "testobj", "testrel"), + ObjectAndRelation("user", "testusr", "..."), + ), + }, + { + name: "Mismatch Resource ID", + lhs: makeTuple( + ObjectAndRelation("testns", "testobj1", "testrel"), + ObjectAndRelation("user", "testusr", "..."), + ), + rhs: makeTuple( + ObjectAndRelation("testns", "testobj2", "testrel"), + ObjectAndRelation("user", "testusr", "..."), + ), + }, + { + name: "Mismatch Resource Relationship", + lhs: makeTuple( + ObjectAndRelation("testns", "testobj", "testrel1"), + ObjectAndRelation("user", "testusr", "..."), + ), + rhs: makeTuple( + ObjectAndRelation("testns", "testobj", "testrel2"), + ObjectAndRelation("user", "testusr", "..."), + ), + }, + { + name: "Mismatch Subject Type", + lhs: makeTuple( + ObjectAndRelation("testns", "testobj", "testrel"), + ObjectAndRelation("user1", "testusr", "..."), + ), + rhs: makeTuple( + ObjectAndRelation("testns", "testobj", "testrel"), + ObjectAndRelation("user2", "testusr", "..."), + ), + }, + { + name: "Mismatch Subject ID", + lhs: makeTuple( + ObjectAndRelation("testns", "testobj", "testrel"), + ObjectAndRelation("user", "testusr1", "..."), + ), + rhs: makeTuple( + ObjectAndRelation("testns", "testobj", "testrel"), + ObjectAndRelation("user", "testusr2", "..."), + ), + }, + { + name: "Mismatch Subject Relationship", + lhs: makeTuple( + ObjectAndRelation("testns", "testobj", "testrel"), + ObjectAndRelation("user", "testusr", "testrel1"), + ), + rhs: makeTuple( + ObjectAndRelation("testns", "testobj", "testrel"), + ObjectAndRelation("user", "testusr", "testrel2"), + ), + }, + { + name: "Mismatch Caveat Name", + lhs: MustWithCaveat( + makeTuple( + ObjectAndRelation("testns", "testobj", "testrel"), + ObjectAndRelation("user", "testusr", "..."), + ), + "somecaveat1", + map[string]any{ + "context": map[string]any{ + "deeply": map[string]any{ + "nested": true, + }, + }, + }, + ), + rhs: MustWithCaveat( + makeTuple( + ObjectAndRelation("testns", "testobj", "testrel"), + ObjectAndRelation("user", "testusr", "..."), + ), + "somecaveat2", + map[string]any{ + "context": map[string]any{ + "deeply": map[string]any{ + "nested": true, + }, + }, + }, + ), + }, + { + name: "Mismatch Caveat Content", + lhs: MustWithCaveat( + makeTuple( + ObjectAndRelation("testns", "testobj", "testrel"), + ObjectAndRelation("user", "testusr", "..."), + ), + "somecaveat", + map[string]any{ + "context": map[string]any{ + "deeply": map[string]any{ + "nested": "1", + }, + }, + }, + ), + rhs: MustWithCaveat( + makeTuple( + ObjectAndRelation("testns", "testobj", "testrel"), + ObjectAndRelation("user", "testusr", "..."), + ), + "somecaveat", + map[string]any{ + "context": map[string]any{ + "deeply": map[string]any{ + "nested": "2", + }, + }, + }, + ), + }, + } + + for _, tc := range notEqualTestCases { + t.Run(tc.name, func(t *testing.T) { + require := require.New(t) + require.False(Equal(tc.lhs, tc.rhs)) + }) + } +}