From 063c693fa0fbf56dc0d3070f335c0dbc226530a7 Mon Sep 17 00:00:00 2001 From: Joseph Schorr Date: Thu, 12 Dec 2024 11:50:36 -0500 Subject: [PATCH] Fix bulk export of relationships with caveats The structs change made the export not properly copy the caveats --- internal/services/v1/experimental.go | 42 ++++++-------- internal/services/v1/experimental_test.go | 39 +++++-------- internal/services/v1/permissions.go | 27 ++++++--- internal/services/v1/permissions_test.go | 39 +++++-------- internal/services/v1/relationships_test.go | 67 ++++++++++++++++++++++ 5 files changed, 133 insertions(+), 81 deletions(-) diff --git a/internal/services/v1/experimental.go b/internal/services/v1/experimental.go index fe892a9b5e..de68d3a678 100644 --- a/internal/services/v1/experimental.go +++ b/internal/services/v1/experimental.go @@ -12,6 +12,7 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/protobuf/types/known/timestamppb" "github.com/ccoveille/go-safecast" "github.com/jzelinskie/stringz" @@ -154,12 +155,6 @@ func (a *bulkLoadAdapter) Next(_ context.Context) (*tuple.Relationship, error) { a.currentBatch = batch.Relationships a.numSent = 0 - for _, rel := range batch.Relationships { - if rel.OptionalExpiresAt != nil { - return nil, fmt.Errorf("expiration time is not currently supported") - } - } - a.awaitingNamespaces, a.awaitingCaveats = extractBatchNewReferencedNamespacesAndCaveats( a.currentBatch, a.referencedNamespaceMap, @@ -172,6 +167,13 @@ func (a *bulkLoadAdapter) Next(_ context.Context) (*tuple.Relationship, error) { return nil, nil } + a.current.RelationshipReference.Resource.ObjectType = a.currentBatch[a.numSent].Resource.ObjectType + a.current.RelationshipReference.Resource.ObjectID = a.currentBatch[a.numSent].Resource.ObjectId + a.current.RelationshipReference.Resource.Relation = a.currentBatch[a.numSent].Relation + a.current.Subject.ObjectType = a.currentBatch[a.numSent].Subject.Object.ObjectType + a.current.Subject.ObjectID = a.currentBatch[a.numSent].Subject.Object.ObjectId + a.current.Subject.Relation = stringz.DefaultEmpty(a.currentBatch[a.numSent].Subject.OptionalRelation, tuple.Ellipsis) + if a.currentBatch[a.numSent].OptionalCaveat != nil { a.caveat.CaveatName = a.currentBatch[a.numSent].OptionalCaveat.CaveatName a.caveat.Context = a.currentBatch[a.numSent].OptionalCaveat.Context @@ -180,22 +182,6 @@ func (a *bulkLoadAdapter) Next(_ context.Context) (*tuple.Relationship, error) { a.current.OptionalCaveat = nil } - if a.caveat.CaveatName != "" { - a.current.OptionalCaveat = &a.caveat - } else { - a.current.OptionalCaveat = nil - } - - a.current.OptionalIntegrity = nil - a.current.OptionalExpiration = nil - - a.current.RelationshipReference.Resource.ObjectType = a.currentBatch[a.numSent].Resource.ObjectType - a.current.RelationshipReference.Resource.ObjectID = a.currentBatch[a.numSent].Resource.ObjectId - a.current.RelationshipReference.Resource.Relation = a.currentBatch[a.numSent].Relation - a.current.Subject.ObjectType = a.currentBatch[a.numSent].Subject.Object.ObjectType - a.current.Subject.ObjectID = a.currentBatch[a.numSent].Subject.Object.ObjectId - a.current.Subject.Relation = stringz.DefaultEmpty(a.currentBatch[a.numSent].Subject.OptionalRelation, tuple.Ellipsis) - if a.currentBatch[a.numSent].OptionalExpiresAt != nil { t := a.currentBatch[a.numSent].OptionalExpiresAt.AsTime() a.current.OptionalExpiration = &t @@ -203,6 +189,8 @@ func (a *bulkLoadAdapter) Next(_ context.Context) (*tuple.Relationship, error) { a.current.OptionalExpiration = nil } + a.current.OptionalIntegrity = nil + if err := relationships.ValidateOneRelationship( a.referencedNamespaceMap, a.referencedCaveatMap, @@ -432,9 +420,15 @@ func BulkExport(ctx context.Context, ds datastore.ReadOnlyDatastore, batchSize u if rel.OptionalCaveat != nil { caveatArray[offset].CaveatName = rel.OptionalCaveat.CaveatName caveatArray[offset].Context = rel.OptionalCaveat.Context + v1Rel.OptionalCaveat = &caveatArray[offset] + } else { + v1Rel.OptionalCaveat = nil + } + + if rel.OptionalExpiration != nil { + v1Rel.OptionalExpiresAt = timestamppb.New(*rel.OptionalExpiration) } else { - caveatArray[offset].CaveatName = "" - caveatArray[offset].Context = nil + v1Rel.OptionalExpiresAt = nil } } diff --git a/internal/services/v1/experimental_test.go b/internal/services/v1/experimental_test.go index ecb123cf04..595878c3ea 100644 --- a/internal/services/v1/experimental_test.go +++ b/internal/services/v1/experimental_test.go @@ -74,7 +74,7 @@ func TestBulkImportRelationships(t *testing.T) { for i := uint64(0); i < batchSize; i++ { if withCaveats { - batch = append(batch, relWithCaveat( + batch = append(batch, mustRelWithCaveatAndContext( tf.DocumentNS.Name, strconv.Itoa(batchNum)+"_"+strconv.FormatUint(i, 10), "caveated_viewer", @@ -82,6 +82,7 @@ func TestBulkImportRelationships(t *testing.T) { strconv.FormatUint(i, 10), "", "test", + map[string]any{"secret": strconv.FormatUint(i, 10)}, )) } else { batch = append(batch, rel( @@ -177,6 +178,8 @@ func TestBulkExportRelationships(t *testing.T) { {tf.FolderNS.Name, "owner"}, {tf.DocumentNS.Name, "editor"}, {tf.FolderNS.Name, "editor"}, + {tf.DocumentNS.Name, "caveated_viewer"}, + {tf.DocumentNS.Name, "expiring_viewer"}, } totalToWrite := 1_000 @@ -184,16 +187,9 @@ func TestBulkExportRelationships(t *testing.T) { batch := make([]*v1.Relationship, totalToWrite) for i := range batch { nsAndRel := nsAndRels[i%len(nsAndRels)] - rel := rel( - nsAndRel.namespace, - strconv.Itoa(i), - nsAndRel.relation, - tf.UserNS.Name, - strconv.Itoa(i), - "", - ) - batch[i] = rel - expectedRels.Add(tuple.MustV1RelString(rel)) + v1rel := relationshipForBulkTesting(nsAndRel, i) + batch[i] = v1rel + expectedRels.Add(tuple.MustV1RelString(v1rel)) } ctx := context.Background() @@ -280,7 +276,7 @@ func TestBulkExportRelationshipsWithFilter(t *testing.T) { &v1.RelationshipFilter{ ResourceType: tf.DocumentNS.Name, }, - 500, + 625, }, { "filter by resource ID", @@ -302,7 +298,7 @@ func TestBulkExportRelationshipsWithFilter(t *testing.T) { ResourceType: tf.DocumentNS.Name, OptionalResourceIdPrefix: "1", }, - 55, + 69, }, { "filter by invalid resource type", @@ -335,31 +331,26 @@ func TestBulkExportRelationshipsWithFilter(t *testing.T) { {tf.FolderNS.Name, "owner"}, {tf.DocumentNS.Name, "editor"}, {tf.FolderNS.Name, "editor"}, + {tf.DocumentNS.Name, "caveated_viewer"}, + {tf.DocumentNS.Name, "expiring_viewer"}, } expectedRels := set.NewStringSetWithSize(1000) batch := make([]*v1.Relationship, 1000) for i := range batch { nsAndRel := nsAndRels[i%len(nsAndRels)] - rel := rel( - nsAndRel.namespace, - strconv.Itoa(i), - nsAndRel.relation, - tf.UserNS.Name, - strconv.Itoa(i), - "", - ) - batch[i] = rel + v1rel := relationshipForBulkTesting(nsAndRel, i) + batch[i] = v1rel if tc.filter != nil { filter, err := datastore.RelationshipsFilterFromPublicFilter(tc.filter) require.NoError(err) - if !filter.Test(tuple.FromV1Relationship(rel)) { + if !filter.Test(tuple.FromV1Relationship(v1rel)) { continue } } - expectedRels.Add(tuple.MustV1RelString(rel)) + expectedRels.Add(tuple.MustV1RelString(v1rel)) } require.Equal(tc.expectedCount, expectedRels.Size()) diff --git a/internal/services/v1/permissions.go b/internal/services/v1/permissions.go index 8ab94a5580..8105cd62b8 100644 --- a/internal/services/v1/permissions.go +++ b/internal/services/v1/permissions.go @@ -17,6 +17,7 @@ import ( "google.golang.org/grpc/status" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/structpb" + "google.golang.org/protobuf/types/known/timestamppb" cexpr "github.com/authzed/spicedb/internal/caveats" dispatchpkg "github.com/authzed/spicedb/internal/dispatch" @@ -768,6 +769,13 @@ func (a *loadBulkAdapter) Next(_ context.Context) (*tuple.Relationship, error) { return nil, nil } + a.current.RelationshipReference.Resource.ObjectType = a.currentBatch[a.numSent].Resource.ObjectType + a.current.RelationshipReference.Resource.ObjectID = a.currentBatch[a.numSent].Resource.ObjectId + a.current.RelationshipReference.Resource.Relation = a.currentBatch[a.numSent].Relation + a.current.Subject.ObjectType = a.currentBatch[a.numSent].Subject.Object.ObjectType + a.current.Subject.ObjectID = a.currentBatch[a.numSent].Subject.Object.ObjectId + a.current.Subject.Relation = stringz.DefaultEmpty(a.currentBatch[a.numSent].Subject.OptionalRelation, tuple.Ellipsis) + if a.currentBatch[a.numSent].OptionalCaveat != nil { a.caveat.CaveatName = a.currentBatch[a.numSent].OptionalCaveat.CaveatName a.caveat.Context = a.currentBatch[a.numSent].OptionalCaveat.Context @@ -776,15 +784,6 @@ func (a *loadBulkAdapter) Next(_ context.Context) (*tuple.Relationship, error) { a.current.OptionalCaveat = nil } - a.current.OptionalIntegrity = nil - - a.current.RelationshipReference.Resource.ObjectType = a.currentBatch[a.numSent].Resource.ObjectType - a.current.RelationshipReference.Resource.ObjectID = a.currentBatch[a.numSent].Resource.ObjectId - a.current.RelationshipReference.Resource.Relation = a.currentBatch[a.numSent].Relation - a.current.Subject.ObjectType = a.currentBatch[a.numSent].Subject.Object.ObjectType - a.current.Subject.ObjectID = a.currentBatch[a.numSent].Subject.Object.ObjectId - a.current.Subject.Relation = stringz.DefaultEmpty(a.currentBatch[a.numSent].Subject.OptionalRelation, tuple.Ellipsis) - if a.currentBatch[a.numSent].OptionalExpiresAt != nil { t := a.currentBatch[a.numSent].OptionalExpiresAt.AsTime() a.current.OptionalExpiration = &t @@ -792,6 +791,8 @@ func (a *loadBulkAdapter) Next(_ context.Context) (*tuple.Relationship, error) { a.current.OptionalExpiration = nil } + a.current.OptionalIntegrity = nil + if err := relationships.ValidateOneRelationship( a.referencedNamespaceMap, a.referencedCaveatMap, @@ -996,9 +997,17 @@ func ExportBulk(ctx context.Context, ds datastore.Datastore, batchSize uint64, r if rel.OptionalCaveat != nil { caveatArray[offset].CaveatName = rel.OptionalCaveat.CaveatName caveatArray[offset].Context = rel.OptionalCaveat.Context + v1Rel.OptionalCaveat = &caveatArray[offset] } else { caveatArray[offset].CaveatName = "" caveatArray[offset].Context = nil + v1Rel.OptionalCaveat = nil + } + + if rel.OptionalExpiration != nil { + v1Rel.OptionalExpiresAt = timestamppb.New(*rel.OptionalExpiration) + } else { + v1Rel.OptionalExpiresAt = nil } } diff --git a/internal/services/v1/permissions_test.go b/internal/services/v1/permissions_test.go index 2815209dae..5a72b2f6a9 100644 --- a/internal/services/v1/permissions_test.go +++ b/internal/services/v1/permissions_test.go @@ -2087,7 +2087,7 @@ func TestImportBulkRelationships(t *testing.T) { for i := uint64(0); i < batchSize; i++ { if withTrait == "caveated_viewer" { - batch = append(batch, relWithCaveat( + batch = append(batch, mustRelWithCaveatAndContext( tf.DocumentNS.Name, strconv.Itoa(batchNum)+"_"+strconv.FormatUint(i, 10), "caveated_viewer", @@ -2095,6 +2095,7 @@ func TestImportBulkRelationships(t *testing.T) { strconv.FormatUint(i, 10), "", "test", + map[string]any{"secret": strconv.FormatUint(i, 10)}, )) } else if withTrait == "expiring_viewer" { batch = append(batch, relWithExpiration( @@ -2197,6 +2198,8 @@ func TestExportBulkRelationships(t *testing.T) { {tf.FolderNS.Name, "owner"}, {tf.DocumentNS.Name, "editor"}, {tf.FolderNS.Name, "editor"}, + {tf.DocumentNS.Name, "caveated_viewer"}, + {tf.DocumentNS.Name, "expiring_viewer"}, } totalToWrite := 1_000 @@ -2204,16 +2207,9 @@ func TestExportBulkRelationships(t *testing.T) { batch := make([]*v1.Relationship, totalToWrite) for i := range batch { nsAndRel := nsAndRels[i%len(nsAndRels)] - rel := rel( - nsAndRel.namespace, - strconv.Itoa(i), - nsAndRel.relation, - tf.UserNS.Name, - strconv.Itoa(i), - "", - ) - batch[i] = rel - expectedRels.Add(tuple.MustV1RelString(rel)) + v1rel := relationshipForBulkTesting(nsAndRel, i) + batch[i] = v1rel + expectedRels.Add(tuple.MustV1RelString(v1rel)) } ctx := context.Background() @@ -2300,7 +2296,7 @@ func TestExportBulkRelationshipsWithFilter(t *testing.T) { &v1.RelationshipFilter{ ResourceType: tf.DocumentNS.Name, }, - 500, + 625, }, { "filter by resource ID", @@ -2322,7 +2318,7 @@ func TestExportBulkRelationshipsWithFilter(t *testing.T) { ResourceType: tf.DocumentNS.Name, OptionalResourceIdPrefix: "1", }, - 55, + 69, }, { "filter by invalid resource type", @@ -2354,31 +2350,26 @@ func TestExportBulkRelationshipsWithFilter(t *testing.T) { {tf.FolderNS.Name, "owner"}, {tf.DocumentNS.Name, "editor"}, {tf.FolderNS.Name, "editor"}, + {tf.DocumentNS.Name, "caveated_viewer"}, + {tf.DocumentNS.Name, "expiring_viewer"}, } expectedRels := set.NewStringSetWithSize(1000) batch := make([]*v1.Relationship, 1000) for i := range batch { nsAndRel := nsAndRels[i%len(nsAndRels)] - rel := rel( - nsAndRel.namespace, - strconv.Itoa(i), - nsAndRel.relation, - tf.UserNS.Name, - strconv.Itoa(i), - "", - ) - batch[i] = rel + v1rel := relationshipForBulkTesting(nsAndRel, i) + batch[i] = v1rel if tc.filter != nil { filter, err := datastore.RelationshipsFilterFromPublicFilter(tc.filter) require.NoError(err) - if !filter.Test(tuple.FromV1Relationship(rel)) { + if !filter.Test(tuple.FromV1Relationship(v1rel)) { continue } } - expectedRels.Add(tuple.MustV1RelString(rel)) + expectedRels.Add(tuple.MustV1RelString(v1rel)) } require.Equal(tc.expectedCount, expectedRels.Size()) diff --git a/internal/services/v1/relationships_test.go b/internal/services/v1/relationships_test.go index fe9443dead..928e60f977 100644 --- a/internal/services/v1/relationships_test.go +++ b/internal/services/v1/relationships_test.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "maps" + "strconv" "strings" "testing" "time" @@ -672,6 +673,72 @@ func relWithCaveat(resType, resID, relation, subType, subID, subRel, caveatName } } +func mustRelWithCaveatAndContext(resType, resID, relation, subType, subID, subRel, caveatName string, context map[string]any) *v1.Relationship { + sctx, err := structpb.NewStruct(context) + if err != nil { + panic(err) + } + + return &v1.Relationship{ + Resource: &v1.ObjectReference{ + ObjectType: resType, + ObjectId: resID, + }, + Relation: relation, + Subject: &v1.SubjectReference{ + Object: &v1.ObjectReference{ + ObjectType: subType, + ObjectId: subID, + }, + OptionalRelation: subRel, + }, + OptionalCaveat: &v1.ContextualizedCaveat{ + CaveatName: caveatName, + Context: sctx, + }, + } +} + +func relationshipForBulkTesting(nsAndRel struct { + namespace string + relation string +}, i int, +) *v1.Relationship { + if nsAndRel.relation == "caveated_viewer" { + return mustRelWithCaveatAndContext( + nsAndRel.namespace, + strconv.Itoa(i), + nsAndRel.relation, + tf.UserNS.Name, + strconv.Itoa(i), + "", + "test", + map[string]any{"secret": strconv.Itoa(i)}, + ) + } + + if nsAndRel.relation == "expiring_viewer" { + return relWithExpiration( + nsAndRel.namespace, + strconv.Itoa(i), + nsAndRel.relation, + tf.UserNS.Name, + strconv.Itoa(i), + "", + time.Now().Add(time.Hour), + ) + } + + return rel( + nsAndRel.namespace, + strconv.Itoa(i), + nsAndRel.relation, + tf.UserNS.Name, + strconv.Itoa(i), + "", + ) +} + func TestInvalidWriteRelationship(t *testing.T) { testCases := []struct { name string