From 0f37897672862587468381724cc129a08d2a0d47 Mon Sep 17 00:00:00 2001 From: Joseph Schorr Date: Wed, 11 Dec 2024 16:21:58 -0500 Subject: [PATCH] Implement a combined builder pattern for relationship SQL construction This moves the behavior out of Spanner datastore and into a common lib where possible --- internal/datastore/common/relationships.go | 55 +--- internal/datastore/common/schema.go | 122 +++++++++ internal/datastore/common/schema_options.go | 219 +++++++++++++++ internal/datastore/common/sql.go | 254 ++++++++++-------- internal/datastore/common/sql_test.go | 94 +++---- internal/datastore/crdb/crdb.go | 51 ++-- internal/datastore/crdb/reader.go | 2 +- .../datastore/dsfortesting/dsfortesting.go | 46 ++-- internal/datastore/mysql/datastore.go | 40 +-- internal/datastore/mysql/reader.go | 2 +- internal/datastore/postgres/common/pgx.go | 15 +- internal/datastore/postgres/postgres.go | 40 +-- internal/datastore/postgres/reader.go | 2 +- internal/datastore/spanner/reader.go | 51 ++-- internal/datastore/spanner/spanner.go | 38 +-- 15 files changed, 685 insertions(+), 346 deletions(-) create mode 100644 internal/datastore/common/schema.go create mode 100644 internal/datastore/common/schema_options.go diff --git a/internal/datastore/common/relationships.go b/internal/datastore/common/relationships.go index 7860b18650..3b6f8a5156 100644 --- a/internal/datastore/common/relationships.go +++ b/internal/datastore/common/relationships.go @@ -17,26 +17,6 @@ import ( const errUnableToQueryRels = "unable to query relationships: %w" -// StaticValueOrAddColumnForSelect adds a column to the list of columns to select if the value -// is not static, otherwise it sets the value to the static value. -func StaticValueOrAddColumnForSelect(colsToSelect []any, queryInfo QueryInfo, colName string, field *string) []any { - if queryInfo.Schema.ColumnOptimization == ColumnOptimizationOptionNone { - // If column optimization is disabled, always add the column to the list of columns to select. - colsToSelect = append(colsToSelect, field) - return colsToSelect - } - - // If the value is static, set the field to it and return. - if found, ok := queryInfo.FilteringValues[colName]; ok && found.SingleValue != nil { - *field = *found.SingleValue - return colsToSelect - } - - // Otherwise, add the column to the list of columns to select, as the value is not static. - colsToSelect = append(colsToSelect, field) - return colsToSelect -} - // Querier is an interface for querying the database. type Querier[R Rows] interface { QueryFunc(ctx context.Context, f func(context.Context, R) error, sql string, args ...any) error @@ -60,10 +40,14 @@ type closeRows interface { } // QueryRelationships queries relationships for the given query and transaction. -func QueryRelationships[R Rows, C ~map[string]any](ctx context.Context, queryInfo QueryInfo, sqlStatement string, args []any, span trace.Span, tx Querier[R], withIntegrity bool) (datastore.RelationshipIterator, error) { +func QueryRelationships[R Rows, C ~map[string]any](ctx context.Context, builder RelationshipsQueryBuilder, span trace.Span, tx Querier[R]) (datastore.RelationshipIterator, error) { defer span.End() - colsToSelect := make([]any, 0, 8) + sqlString, args, err := builder.SelectSQL() + if err != nil { + return nil, fmt.Errorf(errUnableToQueryRels, err) + } + var resourceObjectType string var resourceObjectID string var resourceRelation string @@ -78,26 +62,9 @@ func QueryRelationships[R Rows, C ~map[string]any](ctx context.Context, queryInf var integrityHash []byte var timestamp time.Time - colsToSelect = StaticValueOrAddColumnForSelect(colsToSelect, queryInfo, queryInfo.Schema.ColNamespace, &resourceObjectType) - colsToSelect = StaticValueOrAddColumnForSelect(colsToSelect, queryInfo, queryInfo.Schema.ColObjectID, &resourceObjectID) - colsToSelect = StaticValueOrAddColumnForSelect(colsToSelect, queryInfo, queryInfo.Schema.ColRelation, &resourceRelation) - colsToSelect = StaticValueOrAddColumnForSelect(colsToSelect, queryInfo, queryInfo.Schema.ColUsersetNamespace, &subjectObjectType) - colsToSelect = StaticValueOrAddColumnForSelect(colsToSelect, queryInfo, queryInfo.Schema.ColUsersetObjectID, &subjectObjectID) - colsToSelect = StaticValueOrAddColumnForSelect(colsToSelect, queryInfo, queryInfo.Schema.ColUsersetRelation, &subjectRelation) - - if !queryInfo.SkipCaveats || queryInfo.Schema.ColumnOptimization == ColumnOptimizationOptionNone { - colsToSelect = append(colsToSelect, &caveatName, &caveatCtx) - } - - colsToSelect = append(colsToSelect, &expiration) - - if withIntegrity { - colsToSelect = append(colsToSelect, &integrityKeyID, &integrityHash, ×tamp) - } - - if len(colsToSelect) == 0 { - var unused int - colsToSelect = append(colsToSelect, &unused) + colsToSelect, err := ColumnsToSelect(builder, &resourceObjectType, &resourceObjectID, &resourceRelation, &subjectObjectType, &subjectObjectID, &subjectRelation, &caveatName, &caveatCtx, &expiration, &integrityKeyID, &integrityHash, ×tamp) + if err != nil { + return nil, fmt.Errorf(errUnableToQueryRels, err) } return func(yield func(tuple.Relationship, error) bool) { @@ -117,7 +84,7 @@ func QueryRelationships[R Rows, C ~map[string]any](ctx context.Context, queryInf } var caveat *corev1.ContextualizedCaveat - if !queryInfo.SkipCaveats || queryInfo.Schema.ColumnOptimization == ColumnOptimizationOptionNone { + if !builder.SkipCaveats || builder.Schema.ColumnOptimization == ColumnOptimizationOptionNone { if caveatName.Valid { var err error caveat, err = ContextualizedCaveatFrom(caveatName.String, caveatCtx) @@ -171,7 +138,7 @@ func QueryRelationships[R Rows, C ~map[string]any](ctx context.Context, queryInf span.AddEvent("Rels loaded", trace.WithAttributes(attribute.Int("relCount", relCount))) return nil - }, sqlStatement, args...) + }, sqlString, args...) if err != nil { if !yield(tuple.Relationship{}, err) { return diff --git a/internal/datastore/common/schema.go b/internal/datastore/common/schema.go new file mode 100644 index 0000000000..8f31f929a6 --- /dev/null +++ b/internal/datastore/common/schema.go @@ -0,0 +1,122 @@ +package common + +import ( + sq "github.com/Masterminds/squirrel" + + "github.com/authzed/spicedb/pkg/spiceerrors" +) + +// SchemaInformation holds the schema information from the SQL datastore implementation. +// +//go:generate go run github.com/ecordell/optgen -output schema_options.go . SchemaInformation +type SchemaInformation struct { + RelationshipTableName string `debugmap:"visible"` + + ColNamespace string `debugmap:"visible"` + ColObjectID string `debugmap:"visible"` + ColRelation string `debugmap:"visible"` + ColUsersetNamespace string `debugmap:"visible"` + ColUsersetObjectID string `debugmap:"visible"` + ColUsersetRelation string `debugmap:"visible"` + ColCaveatName string `debugmap:"visible"` + ColCaveatContext string `debugmap:"visible"` + ColExpiration string `debugmap:"visible"` + + ColIntegrityKeyID string `debugmap:"visible"` + ColIntegrityHash string `debugmap:"visible"` + ColIntegrityTimestamp string `debugmap:"visible"` + + // PaginationFilterType is the type of pagination filter to use for this schema. + PaginationFilterType PaginationFilterType `debugmap:"visible"` + + // PlaceholderFormat is the format of placeholders to use for this schema. + PlaceholderFormat sq.PlaceholderFormat `debugmap:"visible"` + + // NowFunction is the function to use to get the current time in the datastore. + NowFunction string `debugmap:"visible"` + + // ColumnOptimization is the optimization to use for columns in the schema, if any. + ColumnOptimization ColumnOptimizationOption `debugmap:"visible"` + + // WithIntegrityColumns is a flag to indicate if the schema has integrity columns. + WithIntegrityColumns bool `debugmap:"visible"` +} + +func (si SchemaInformation) debugValidate() { + spiceerrors.DebugAssert(func() bool { + si.mustValidate() + return true + }, "SchemaInformation failed to validate") +} + +func (si SchemaInformation) mustValidate() { + if si.RelationshipTableName == "" { + panic("RelationshipTableName is required") + } + + if si.ColNamespace == "" { + panic("ColNamespace is required") + } + + if si.ColObjectID == "" { + panic("ColObjectID is required") + } + + if si.ColRelation == "" { + panic("ColRelation is required") + } + + if si.ColUsersetNamespace == "" { + panic("ColUsersetNamespace is required") + } + + if si.ColUsersetObjectID == "" { + panic("ColUsersetObjectID is required") + } + + if si.ColUsersetRelation == "" { + panic("ColUsersetRelation is required") + } + + if si.ColCaveatName == "" { + panic("ColCaveatName is required") + } + + if si.ColCaveatContext == "" { + panic("ColCaveatContext is required") + } + + if si.ColExpiration == "" { + panic("ColExpiration is required") + } + + if si.WithIntegrityColumns { + if si.ColIntegrityKeyID == "" { + panic("ColIntegrityKeyID is required") + } + + if si.ColIntegrityHash == "" { + panic("ColIntegrityHash is required") + } + + if si.ColIntegrityTimestamp == "" { + panic("ColIntegrityTimestamp is required") + } + } + + if si.NowFunction == "" { + panic("NowFunction is required") + } + + if si.ColumnOptimization == ColumnOptimizationOptionUnknown { + panic("ColumnOptimization is required") + } + + if si.PaginationFilterType == PaginationFilterTypeUnknown { + panic("PaginationFilterType is required") + } + + if si.PlaceholderFormat == nil { + panic("PlaceholderFormat is required") + } +} diff --git a/internal/datastore/common/schema_options.go b/internal/datastore/common/schema_options.go new file mode 100644 index 0000000000..3aed7f64e8 --- /dev/null +++ b/internal/datastore/common/schema_options.go @@ -0,0 +1,219 @@ +// Code generated by github.com/ecordell/optgen. DO NOT EDIT. +package common + +import ( + squirrel "github.com/Masterminds/squirrel" + defaults "github.com/creasty/defaults" + helpers "github.com/ecordell/optgen/helpers" +) + +type SchemaInformationOption func(s *SchemaInformation) + +// NewSchemaInformationWithOptions creates a new SchemaInformation with the passed in options set +func NewSchemaInformationWithOptions(opts ...SchemaInformationOption) *SchemaInformation { + s := &SchemaInformation{} + for _, o := range opts { + o(s) + } + return s +} + +// NewSchemaInformationWithOptionsAndDefaults creates a new SchemaInformation with the passed in options set starting from the defaults +func NewSchemaInformationWithOptionsAndDefaults(opts ...SchemaInformationOption) *SchemaInformation { + s := &SchemaInformation{} + defaults.MustSet(s) + for _, o := range opts { + o(s) + } + return s +} + +// ToOption returns a new SchemaInformationOption that sets the values from the passed in SchemaInformation +func (s *SchemaInformation) ToOption() SchemaInformationOption { + return func(to *SchemaInformation) { + to.RelationshipTableName = s.RelationshipTableName + to.ColNamespace = s.ColNamespace + to.ColObjectID = s.ColObjectID + to.ColRelation = s.ColRelation + to.ColUsersetNamespace = s.ColUsersetNamespace + to.ColUsersetObjectID = s.ColUsersetObjectID + to.ColUsersetRelation = s.ColUsersetRelation + to.ColCaveatName = s.ColCaveatName + to.ColCaveatContext = s.ColCaveatContext + to.ColExpiration = s.ColExpiration + to.ColIntegrityKeyID = s.ColIntegrityKeyID + to.ColIntegrityHash = s.ColIntegrityHash + to.ColIntegrityTimestamp = s.ColIntegrityTimestamp + to.PaginationFilterType = s.PaginationFilterType + to.PlaceholderFormat = s.PlaceholderFormat + to.NowFunction = s.NowFunction + to.ColumnOptimization = s.ColumnOptimization + to.WithIntegrityColumns = s.WithIntegrityColumns + } +} + +// DebugMap returns a map form of SchemaInformation for debugging +func (s SchemaInformation) DebugMap() map[string]any { + debugMap := map[string]any{} + debugMap["RelationshipTableName"] = helpers.DebugValue(s.RelationshipTableName, false) + debugMap["ColNamespace"] = helpers.DebugValue(s.ColNamespace, false) + debugMap["ColObjectID"] = helpers.DebugValue(s.ColObjectID, false) + debugMap["ColRelation"] = helpers.DebugValue(s.ColRelation, false) + debugMap["ColUsersetNamespace"] = helpers.DebugValue(s.ColUsersetNamespace, false) + debugMap["ColUsersetObjectID"] = helpers.DebugValue(s.ColUsersetObjectID, false) + debugMap["ColUsersetRelation"] = helpers.DebugValue(s.ColUsersetRelation, false) + debugMap["ColCaveatName"] = helpers.DebugValue(s.ColCaveatName, false) + debugMap["ColCaveatContext"] = helpers.DebugValue(s.ColCaveatContext, false) + debugMap["ColExpiration"] = helpers.DebugValue(s.ColExpiration, false) + debugMap["ColIntegrityKeyID"] = helpers.DebugValue(s.ColIntegrityKeyID, false) + debugMap["ColIntegrityHash"] = helpers.DebugValue(s.ColIntegrityHash, false) + debugMap["ColIntegrityTimestamp"] = helpers.DebugValue(s.ColIntegrityTimestamp, false) + debugMap["PaginationFilterType"] = helpers.DebugValue(s.PaginationFilterType, false) + debugMap["PlaceholderFormat"] = helpers.DebugValue(s.PlaceholderFormat, false) + debugMap["NowFunction"] = helpers.DebugValue(s.NowFunction, false) + debugMap["ColumnOptimization"] = helpers.DebugValue(s.ColumnOptimization, false) + debugMap["WithIntegrityColumns"] = helpers.DebugValue(s.WithIntegrityColumns, false) + return debugMap +} + +// SchemaInformationWithOptions configures an existing SchemaInformation with the passed in options set +func SchemaInformationWithOptions(s *SchemaInformation, opts ...SchemaInformationOption) *SchemaInformation { + for _, o := range opts { + o(s) + } + return s +} + +// WithOptions configures the receiver SchemaInformation with the passed in options set +func (s *SchemaInformation) WithOptions(opts ...SchemaInformationOption) *SchemaInformation { + for _, o := range opts { + o(s) + } + return s +} + +// WithRelationshipTableName returns an option that can set RelationshipTableName on a SchemaInformation +func WithRelationshipTableName(relationshipTableName string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.RelationshipTableName = relationshipTableName + } +} + +// WithColNamespace returns an option that can set ColNamespace on a SchemaInformation +func WithColNamespace(colNamespace string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ColNamespace = colNamespace + } +} + +// WithColObjectID returns an option that can set ColObjectID on a SchemaInformation +func WithColObjectID(colObjectID string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ColObjectID = colObjectID + } +} + +// WithColRelation returns an option that can set ColRelation on a SchemaInformation +func WithColRelation(colRelation string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ColRelation = colRelation + } +} + +// WithColUsersetNamespace returns an option that can set ColUsersetNamespace on a SchemaInformation +func WithColUsersetNamespace(colUsersetNamespace string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ColUsersetNamespace = colUsersetNamespace + } +} + +// WithColUsersetObjectID returns an option that can set ColUsersetObjectID on a SchemaInformation +func WithColUsersetObjectID(colUsersetObjectID string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ColUsersetObjectID = colUsersetObjectID + } +} + +// WithColUsersetRelation returns an option that can set ColUsersetRelation on a SchemaInformation +func WithColUsersetRelation(colUsersetRelation string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ColUsersetRelation = colUsersetRelation + } +} + +// WithColCaveatName returns an option that can set ColCaveatName on a SchemaInformation +func WithColCaveatName(colCaveatName string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ColCaveatName = colCaveatName + } +} + +// WithColCaveatContext returns an option that can set ColCaveatContext on a SchemaInformation +func WithColCaveatContext(colCaveatContext string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ColCaveatContext = colCaveatContext + } +} + +// WithColExpiration returns an option that can set ColExpiration on a SchemaInformation +func WithColExpiration(colExpiration string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ColExpiration = colExpiration + } +} + +// WithColIntegrityKeyID returns an option that can set ColIntegrityKeyID on a SchemaInformation +func WithColIntegrityKeyID(colIntegrityKeyID string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ColIntegrityKeyID = colIntegrityKeyID + } +} + +// WithColIntegrityHash returns an option that can set ColIntegrityHash on a SchemaInformation +func WithColIntegrityHash(colIntegrityHash string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ColIntegrityHash = colIntegrityHash + } +} + +// WithColIntegrityTimestamp returns an option that can set ColIntegrityTimestamp on a SchemaInformation +func WithColIntegrityTimestamp(colIntegrityTimestamp string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ColIntegrityTimestamp = colIntegrityTimestamp + } +} + +// WithPaginationFilterType returns an option that can set PaginationFilterType on a SchemaInformation +func WithPaginationFilterType(paginationFilterType PaginationFilterType) SchemaInformationOption { + return func(s *SchemaInformation) { + s.PaginationFilterType = paginationFilterType + } +} + +// WithPlaceholderFormat returns an option that can set PlaceholderFormat on a SchemaInformation +func WithPlaceholderFormat(placeholderFormat squirrel.PlaceholderFormat) SchemaInformationOption { + return func(s *SchemaInformation) { + s.PlaceholderFormat = placeholderFormat + } +} + +// WithNowFunction returns an option that can set NowFunction on a SchemaInformation +func WithNowFunction(nowFunction string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.NowFunction = nowFunction + } +} + +// WithColumnOptimization returns an option that can set ColumnOptimization on a SchemaInformation +func WithColumnOptimization(columnOptimization ColumnOptimizationOption) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ColumnOptimization = columnOptimization + } +} + +// WithWithIntegrityColumns returns an option that can set WithIntegrityColumns on a SchemaInformation +func WithWithIntegrityColumns(withIntegrityColumns bool) SchemaInformationOption { + return func(s *SchemaInformation) { + s.WithIntegrityColumns = withIntegrityColumns + } +} diff --git a/internal/datastore/common/sql.go b/internal/datastore/common/sql.go index 83fc2cbd8e..a0a92f65fd 100644 --- a/internal/datastore/common/sql.go +++ b/internal/datastore/common/sql.go @@ -2,8 +2,10 @@ package common import ( "context" + "maps" "math" "strings" + "time" sq "github.com/Masterminds/squirrel" v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" @@ -51,10 +53,12 @@ var ( type PaginationFilterType uint8 const ( + PaginationFilterTypeUnknown PaginationFilterType = iota + // TupleComparison uses a comparison with a compound key, // e.g. (namespace, object_id, relation) > ('ns', '123', 'viewer') // which is not compatible with all datastores. - TupleComparison PaginationFilterType = iota + TupleComparison // ExpandedLogicComparison comparison uses a nested tree of ANDs and ORs to properly // filter out already received relationships. Useful for databases that do not support @@ -66,80 +70,15 @@ const ( type ColumnOptimizationOption int const ( + ColumnOptimizationOptionUnknown ColumnOptimizationOption = iota + // ColumnOptimizationOptionNone is the default option, which does not optimize the static columns. - ColumnOptimizationOptionNone ColumnOptimizationOption = iota + ColumnOptimizationOptionNone // ColumnOptimizationOptionStaticValue is an option that optimizes the column for a static value. ColumnOptimizationOptionStaticValues ) -// SchemaInformation holds the schema information from the SQL datastore implementation. -type SchemaInformation struct { - RelationshipTableName string - ColNamespace string - ColObjectID string - ColRelation string - ColUsersetNamespace string - ColUsersetObjectID string - ColUsersetRelation string - ColCaveatName string - ColCaveatContext string - ColExpiration string - - // PaginationFilterType is the type of pagination filter to use for this schema. - PaginationFilterType PaginationFilterType - - // PlaceholderFormat is the format of placeholders to use for this schema. - PlaceholderFormat sq.PlaceholderFormat - - // NowFunction is the function to use to get the current time in the datastore. - NowFunction string - - // ColumnOptimization is the optimization to use for columns in the schema, if any. - ColumnOptimization ColumnOptimizationOption - - // ExtaFields are additional fields that are not part of the core schema, but are - // requested by the caller for this query. - ExtraFields []string -} - -// NewSchemaInformation creates a new SchemaInformation object for a query. -func NewSchemaInformation( - relationshipTableName, - colNamespace, - colObjectID, - colRelation, - colUsersetNamespace, - colUsersetObjectID, - colUsersetRelation, - colCaveatName string, - colCaveatContext string, - colExpiration string, - paginationFilterType PaginationFilterType, - placeholderFormat sq.PlaceholderFormat, - nowFunction string, - columnOptionizationOption ColumnOptimizationOption, - extraFields ...string, -) SchemaInformation { - return SchemaInformation{ - relationshipTableName, - colNamespace, - colObjectID, - colRelation, - colUsersetNamespace, - colUsersetObjectID, - colUsersetRelation, - colCaveatName, - colCaveatContext, - colExpiration, - paginationFilterType, - placeholderFormat, - nowFunction, - columnOptionizationOption, - extraFields, - } -} - type ColumnTracker struct { SingleValue *string } @@ -160,6 +99,8 @@ type SchemaQueryFilterer struct { // relationships. This method will automatically filter the columns retrieved from the database, only // selecting the columns that are not already specified with a single static value in the query. func NewSchemaQueryFiltererForRelationshipsSelect(schema SchemaInformation, filterMaximumIDCount uint16, extraFields ...string) SchemaQueryFilterer { + schema.debugValidate() + if filterMaximumIDCount == 0 { filterMaximumIDCount = 100 log.Warn().Msg("SchemaQueryFilterer: filterMaximumIDCount not set, defaulting to 100") @@ -186,6 +127,8 @@ func NewSchemaQueryFiltererForRelationshipsSelect(schema SchemaInformation, filt // relationships, with a custom starting query. Unlike NewSchemaQueryFiltererForRelationshipsSelect, // this method will not auto-filter the columns retrieved from the database. func NewSchemaQueryFiltererWithStartingQuery(schema SchemaInformation, startingQuery sq.SelectBuilder, filterMaximumIDCount uint16) SchemaQueryFilterer { + schema.debugValidate() + if filterMaximumIDCount == 0 { filterMaximumIDCount = 100 log.Warn().Msg("SchemaQueryFilterer: filterMaximumIDCount not set, defaulting to 100") @@ -665,13 +608,16 @@ func (sqf SchemaQueryFilterer) limit(limit uint64) SchemaQueryFilterer { return sqf } -// QueryExecutor is a tuple query runner shared by SQL implementations of the datastore. -type QueryExecutor struct { +// QueryRelationshipsExecutor is a relationships query runner shared by SQL implementations of the datastore. +type QueryRelationshipsExecutor struct { Executor ExecuteReadRelsQueryFunc } +// ExecuteReadRelsQueryFunc is a function that can be used to execute a single rendered SQL query. +type ExecuteReadRelsQueryFunc func(ctx context.Context, builder RelationshipsQueryBuilder) (datastore.RelationshipIterator, error) + // ExecuteQuery executes the query. -func (tqs QueryExecutor) ExecuteQuery( +func (exc QueryRelationshipsExecutor) ExecuteQuery( ctx context.Context, query SchemaQueryFilterer, opts ...options.QueryOptionsOption, @@ -682,8 +628,10 @@ func (tqs QueryExecutor) ExecuteQuery( queryOpts := options.NewQueryOptionsWithOptions(opts...) + // Add sort order. query = query.TupleOrder(queryOpts.Sort) + // Add cursor. if queryOpts.After != nil { if queryOpts.Sort == options.Unsorted { return nil, datastore.ErrCursorsWithoutSorting @@ -692,6 +640,7 @@ func (tqs QueryExecutor) ExecuteQuery( query = query.After(queryOpts.After, queryOpts.Sort) } + // Add limit. var limit uint64 // NOTE: we use a uint here because it lines up with the // assignments in this function, but we set it to MaxInt64 @@ -706,70 +655,149 @@ func (tqs QueryExecutor) ExecuteQuery( query = query.limit(limit) } - toExecute := query - - // Set the column names to select. - columnNamesToSelect := make([]string, 0, 8+len(query.extraFields)) + // Add FROM clause. + from := query.schema.RelationshipTableName + if query.fromSuffix != "" { + from += " " + query.fromSuffix + } - columnNamesToSelect = checkColumn(columnNamesToSelect, query.schema.ColumnOptimization, query.filteringColumnTracker, query.schema.ColNamespace) - columnNamesToSelect = checkColumn(columnNamesToSelect, query.schema.ColumnOptimization, query.filteringColumnTracker, query.schema.ColObjectID) - columnNamesToSelect = checkColumn(columnNamesToSelect, query.schema.ColumnOptimization, query.filteringColumnTracker, query.schema.ColRelation) - columnNamesToSelect = checkColumn(columnNamesToSelect, query.schema.ColumnOptimization, query.filteringColumnTracker, query.schema.ColUsersetNamespace) - columnNamesToSelect = checkColumn(columnNamesToSelect, query.schema.ColumnOptimization, query.filteringColumnTracker, query.schema.ColUsersetObjectID) - columnNamesToSelect = checkColumn(columnNamesToSelect, query.schema.ColumnOptimization, query.filteringColumnTracker, query.schema.ColUsersetRelation) + query.queryBuilder = query.queryBuilder.From(from) - if !queryOpts.SkipCaveats || query.schema.ColumnOptimization == ColumnOptimizationOptionNone { - columnNamesToSelect = append(columnNamesToSelect, query.schema.ColCaveatName, query.schema.ColCaveatContext) + builder := RelationshipsQueryBuilder{ + Schema: query.schema, + SkipCaveats: queryOpts.SkipCaveats, + filteringValues: query.filteringColumnTracker, + baseQueryBuilder: query, } - columnNamesToSelect = append(columnNamesToSelect, query.schema.ColExpiration) + return exc.Executor(ctx, builder) +} + +// RelationshipsQueryBuilder is a builder for producing the SQL and arguments necessary for reading +// relationships. +type RelationshipsQueryBuilder struct { + Schema SchemaInformation + SkipCaveats bool - selectingNoColumns := false - columnNamesToSelect = append(columnNamesToSelect, query.schema.ExtraFields...) - if len(columnNamesToSelect) == 0 { - columnNamesToSelect = append(columnNamesToSelect, "1") - selectingNoColumns = true + filteringValues map[string]ColumnTracker + baseQueryBuilder SchemaQueryFilterer +} + +// SelectSQL returns the SQL and arguments necessary for reading relationships. +func (b RelationshipsQueryBuilder) SelectSQL() (string, []any, error) { + // Set the column names to select. + columnCount := 9 + if b.Schema.WithIntegrityColumns { + columnCount += 3 } + columnNamesToSelect := make([]string, 0, columnCount) - toExecute.queryBuilder = toExecute.queryBuilder.Columns(columnNamesToSelect...) + columnNamesToSelect = b.checkColumn(columnNamesToSelect, b.Schema.ColNamespace) + columnNamesToSelect = b.checkColumn(columnNamesToSelect, b.Schema.ColObjectID) + columnNamesToSelect = b.checkColumn(columnNamesToSelect, b.Schema.ColRelation) + columnNamesToSelect = b.checkColumn(columnNamesToSelect, b.Schema.ColUsersetNamespace) + columnNamesToSelect = b.checkColumn(columnNamesToSelect, b.Schema.ColUsersetObjectID) + columnNamesToSelect = b.checkColumn(columnNamesToSelect, b.Schema.ColUsersetRelation) - from := query.schema.RelationshipTableName - if query.fromSuffix != "" { - from += " " + query.fromSuffix + if !b.SkipCaveats || b.Schema.ColumnOptimization == ColumnOptimizationOptionNone { + columnNamesToSelect = append(columnNamesToSelect, b.Schema.ColCaveatName, b.Schema.ColCaveatContext) } - toExecute.queryBuilder = toExecute.queryBuilder.From(from) + columnNamesToSelect = append(columnNamesToSelect, b.Schema.ColExpiration) - sql, args, err := toExecute.queryBuilder.ToSql() - if err != nil { - return nil, err + if b.Schema.WithIntegrityColumns { + columnNamesToSelect = append(columnNamesToSelect, b.Schema.ColIntegrityKeyID, b.Schema.ColIntegrityHash, b.Schema.ColIntegrityTimestamp) + } + + if len(columnNamesToSelect) == 0 { + columnNamesToSelect = append(columnNamesToSelect, "1") } - return tqs.Executor(ctx, QueryInfo{query.schema, query.filteringColumnTracker, queryOpts.SkipCaveats, selectingNoColumns}, sql, args) + sqlBuilder := b.baseQueryBuilder.queryBuilder + sqlBuilder = sqlBuilder.Columns(columnNamesToSelect...) + + return sqlBuilder.ToSql() } -func checkColumn(columns []string, option ColumnOptimizationOption, tracker map[string]ColumnTracker, colName string) []string { - if option == ColumnOptimizationOptionNone { +// FilteringValuesForTesting returns the filtering values. For test use only. +func (b RelationshipsQueryBuilder) FilteringValuesForTesting() map[string]ColumnTracker { + return maps.Clone(b.filteringValues) +} + +func (b RelationshipsQueryBuilder) checkColumn(columns []string, colName string) []string { + if b.Schema.ColumnOptimization == ColumnOptimizationOptionNone { return append(columns, colName) } - if r, ok := tracker[colName]; !ok || r.SingleValue == nil { + if r, ok := b.filteringValues[colName]; !ok || r.SingleValue == nil { return append(columns, colName) } return columns } -// QueryInfo holds the schema information and filtering values for a query. -type QueryInfo struct { - Schema SchemaInformation - FilteringValues map[string]ColumnTracker - SkipCaveats bool - SelectingNoColumns bool +func (b RelationshipsQueryBuilder) staticValueOrAddColumnForSelect(colsToSelect []any, colName string, field *string) []any { + if b.Schema.ColumnOptimization == ColumnOptimizationOptionNone { + // If column optimization is disabled, always add the column to the list of columns to select. + colsToSelect = append(colsToSelect, field) + return colsToSelect + } + + // If the value is static, set the field to it and return. + if found, ok := b.filteringValues[colName]; ok && found.SingleValue != nil { + *field = *found.SingleValue + return colsToSelect + } + + // Otherwise, add the column to the list of columns to select, as the value is not static. + colsToSelect = append(colsToSelect, field) + return colsToSelect } -// ExecuteReadRelsQueryFunc is a function that can be used to execute a single rendered SQL query. -type ExecuteReadRelsQueryFunc func(ctx context.Context, queryInfo QueryInfo, sql string, args []any) (datastore.RelationshipIterator, error) +// ColumnsToSelect returns the columns to select for a given query. The columns provided are +// the references to the slots in which the values for each relationship will be placed. +func ColumnsToSelect[CN any, CC any, EC any]( + b RelationshipsQueryBuilder, + resourceObjectType *string, + resourceObjectID *string, + resourceRelation *string, + subjectObjectType *string, + subjectObjectID *string, + subjectRelation *string, + caveatName *CN, + caveatCtx *CC, + expiration EC, + + integrityKeyID *string, + integrityHash *[]byte, + timestamp *time.Time, +) ([]any, error) { + columnCount := 9 + if b.Schema.WithIntegrityColumns { + columnCount += 3 + } + colsToSelect := make([]any, 0, columnCount) + + colsToSelect = b.staticValueOrAddColumnForSelect(colsToSelect, b.Schema.ColNamespace, resourceObjectType) + colsToSelect = b.staticValueOrAddColumnForSelect(colsToSelect, b.Schema.ColObjectID, resourceObjectID) + colsToSelect = b.staticValueOrAddColumnForSelect(colsToSelect, b.Schema.ColRelation, resourceRelation) + colsToSelect = b.staticValueOrAddColumnForSelect(colsToSelect, b.Schema.ColUsersetNamespace, subjectObjectType) + colsToSelect = b.staticValueOrAddColumnForSelect(colsToSelect, b.Schema.ColUsersetObjectID, subjectObjectID) + colsToSelect = b.staticValueOrAddColumnForSelect(colsToSelect, b.Schema.ColUsersetRelation, subjectRelation) + + if !b.SkipCaveats || b.Schema.ColumnOptimization == ColumnOptimizationOptionNone { + colsToSelect = append(colsToSelect, caveatName, caveatCtx) + } -// TxCleanupFunc is a function that should be executed when the caller of -// TransactionFactory is done with the transaction. -type TxCleanupFunc func(context.Context) + colsToSelect = append(colsToSelect, expiration) + + if b.Schema.WithIntegrityColumns { + colsToSelect = append(colsToSelect, integrityKeyID, integrityHash, timestamp) + } + + if len(colsToSelect) == 0 { + var unused int + colsToSelect = append(colsToSelect, &unused) + } + + return colsToSelect, nil +} diff --git a/internal/datastore/common/sql_test.go b/internal/datastore/common/sql_test.go index 35cd4cc833..6b2e2c40a4 100644 --- a/internal/datastore/common/sql_test.go +++ b/internal/datastore/common/sql_test.go @@ -558,23 +558,23 @@ func TestSchemaQueryFilterer(t *testing.T) { for _, test := range tests { test := test t.Run(test.name, func(t *testing.T) { - schema := NewSchemaInformation( - "relationtuples", - "ns", - "object_id", - "relation", - "subject_ns", - "subject_object_id", - "subject_relation", - "caveat", - "caveat_context", - "expiration", - TupleComparison, - sq.Question, - "NOW", - ColumnOptimizationOptionStaticValues, + schema := NewSchemaInformationWithOptions( + WithRelationshipTableName("relationtuples"), + WithColNamespace("ns"), + WithColObjectID("object_id"), + WithColRelation("relation"), + WithColUsersetNamespace("subject_ns"), + WithColUsersetObjectID("subject_object_id"), + WithColUsersetRelation("subject_relation"), + WithColCaveatName("caveat"), + WithColCaveatContext("caveat_context"), + WithColExpiration("expiration"), + WithPlaceholderFormat(sq.Question), + WithPaginationFilterType(TupleComparison), + WithColumnOptimization(ColumnOptimizationOptionStaticValues), + WithNowFunction("NOW"), ) - filterer := NewSchemaQueryFiltererForRelationshipsSelect(schema, 100) + filterer := NewSchemaQueryFiltererForRelationshipsSelect(*schema, 100) ran := test.run(filterer) foundStaticColumns := []string{} @@ -598,13 +598,12 @@ func TestSchemaQueryFilterer(t *testing.T) { func TestExecuteQuery(t *testing.T) { tcs := []struct { - name string - run func(filterer SchemaQueryFilterer) SchemaQueryFilterer - options []options.QueryOptionsOption - expectedSQL string - expectedArgs []any - expectedSelectingNoColumns bool - expectedSkipCaveats bool + name string + run func(filterer SchemaQueryFilterer) SchemaQueryFilterer + options []options.QueryOptionsOption + expectedSQL string + expectedArgs []any + expectedSkipCaveats bool }{ { name: "filter by static resource type", @@ -695,10 +694,9 @@ func TestExecuteQuery(t *testing.T) { options: []options.QueryOptionsOption{ options.WithSkipCaveats(true), }, - expectedSkipCaveats: true, - expectedSelectingNoColumns: false, - expectedSQL: "SELECT expiration FROM relationtuples WHERE (expiration IS NULL OR expiration > NOW()) AND ns = ? AND object_id = ? AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation = ?", - expectedArgs: []any{"sometype", "someobj", "somerel", "subns", "subid", "subrel"}, + expectedSkipCaveats: true, + expectedSQL: "SELECT expiration FROM relationtuples WHERE (expiration IS NULL OR expiration > NOW()) AND ns = ? AND object_id = ? AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation = ?", + expectedArgs: []any{"sometype", "someobj", "somerel", "subns", "subid", "subrel"}, }, { name: "filter by static everything (except one field) without caveats", @@ -820,33 +818,35 @@ func TestExecuteQuery(t *testing.T) { for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { - schema := NewSchemaInformation( - "relationtuples", - "ns", - "object_id", - "relation", - "subject_ns", - "subject_object_id", - "subject_relation", - "caveat", - "caveat_context", - "expiration", - TupleComparison, - sq.Question, - "NOW", - ColumnOptimizationOptionStaticValues, + schema := NewSchemaInformationWithOptions( + WithRelationshipTableName("relationtuples"), + WithColNamespace("ns"), + WithColObjectID("object_id"), + WithColRelation("relation"), + WithColUsersetNamespace("subject_ns"), + WithColUsersetObjectID("subject_object_id"), + WithColUsersetRelation("subject_relation"), + WithColCaveatName("caveat"), + WithColCaveatContext("caveat_context"), + WithColExpiration("expiration"), + WithPlaceholderFormat(sq.Question), + WithPaginationFilterType(TupleComparison), + WithColumnOptimization(ColumnOptimizationOptionStaticValues), + WithNowFunction("NOW"), ) - filterer := NewSchemaQueryFiltererForRelationshipsSelect(schema, 100) + filterer := NewSchemaQueryFiltererForRelationshipsSelect(*schema, 100) ran := tc.run(filterer) var wasRun bool - fake := QueryExecutor{ - Executor: func(ctx context.Context, queryInfo QueryInfo, sql string, args []any) (datastore.RelationshipIterator, error) { + fake := QueryRelationshipsExecutor{ + Executor: func(ctx context.Context, builder RelationshipsQueryBuilder) (datastore.RelationshipIterator, error) { + sql, args, err := builder.SelectSQL() + require.NoError(t, err) + wasRun = true require.Equal(t, tc.expectedSQL, sql) require.Equal(t, tc.expectedArgs, args) - require.Equal(t, tc.expectedSelectingNoColumns, queryInfo.SelectingNoColumns) - require.Equal(t, tc.expectedSkipCaveats, queryInfo.SkipCaveats) + require.Equal(t, tc.expectedSkipCaveats, builder.SkipCaveats) return nil, nil }, } diff --git a/internal/datastore/crdb/crdb.go b/internal/datastore/crdb/crdb.go index 9395800175..ced1386356 100644 --- a/internal/datastore/crdb/crdb.go +++ b/internal/datastore/crdb/crdb.go @@ -199,33 +199,30 @@ func newCRDBDatastore(ctx context.Context, url string, options ...Option) (datas return nil, fmt.Errorf("invalid head migration found for cockroach: %w", err) } - var extraFields []string relTableName := tableTuple if config.withIntegrity { relTableName = tableTupleWithIntegrity - extraFields = []string{ - colIntegrityKeyID, - colIntegrityHash, - colTimestamp, - } } - schema := common.NewSchemaInformation( - relTableName, - colNamespace, - colObjectID, - colRelation, - colUsersetNamespace, - colUsersetObjectID, - colUsersetRelation, - colCaveatContextName, - colCaveatContext, - colExpiration, - common.ExpandedLogicComparison, - sq.Dollar, - "NOW", - config.columnOptimizationOption, - extraFields..., + schema := common.NewSchemaInformationWithOptions( + common.WithRelationshipTableName(relTableName), + common.WithColNamespace(colNamespace), + common.WithColObjectID(colObjectID), + common.WithColRelation(colRelation), + common.WithColUsersetNamespace(colUsersetNamespace), + common.WithColUsersetObjectID(colUsersetObjectID), + common.WithColUsersetRelation(colUsersetRelation), + common.WithColCaveatName(colCaveatContextName), + common.WithColCaveatContext(colCaveatContext), + common.WithColExpiration(colExpiration), + common.WithColIntegrityKeyID(colIntegrityKeyID), + common.WithColIntegrityHash(colIntegrityHash), + common.WithColIntegrityTimestamp(colTimestamp), + common.WithPaginationFilterType(common.ExpandedLogicComparison), + common.WithPlaceholderFormat(sq.Dollar), + common.WithNowFunction("NOW"), + common.WithColumnOptimization(config.columnOptimizationOption), + common.WithWithIntegrityColumns(config.withIntegrity), ) ds := &crdbDatastore{ @@ -249,7 +246,7 @@ func newCRDBDatastore(ctx context.Context, url string, options ...Option) (datas filterMaximumIDCount: config.filterMaximumIDCount, supportsIntegrity: config.withIntegrity, gcWindow: config.gcWindow, - schema: schema, + schema: *schema, } ds.RemoteClockRevisions.SetNowFunc(ds.headRevisionInternal) @@ -349,8 +346,8 @@ type crdbDatastore struct { } func (cds *crdbDatastore) SnapshotReader(rev datastore.Revision) datastore.Reader { - executor := common.QueryExecutor{ - Executor: pgxcommon.NewPGXExecutorWithIntegrityOption(cds.readPool, cds.supportsIntegrity), + executor := common.QueryRelationshipsExecutor{ + Executor: pgxcommon.NewPGXQueryRelationshipsExecutor(cds.readPool), } withAsOfSystemTime := func(query sq.SelectBuilder, tableName string) sq.SelectBuilder { @@ -375,8 +372,8 @@ func (cds *crdbDatastore) ReadWriteTx( err := cds.writePool.BeginFunc(ctx, func(tx pgx.Tx) error { querier := pgxcommon.QuerierFuncsFor(tx) - executor := common.QueryExecutor{ - Executor: pgxcommon.NewPGXExecutorWithIntegrityOption(querier, cds.supportsIntegrity), + executor := common.QueryRelationshipsExecutor{ + Executor: pgxcommon.NewPGXQueryRelationshipsExecutor(querier), } // Write metadata onto the transaction. diff --git a/internal/datastore/crdb/reader.go b/internal/datastore/crdb/reader.go index ce9a950b05..e252a16c92 100644 --- a/internal/datastore/crdb/reader.go +++ b/internal/datastore/crdb/reader.go @@ -39,7 +39,7 @@ var ( type crdbReader struct { query pgxcommon.DBFuncQuerier - executor common.QueryExecutor + executor common.QueryRelationshipsExecutor keyer overlapKeyer overlapKeySet keySet fromWithAsOfSystemTime func(query sq.SelectBuilder, tableName string) sq.SelectBuilder diff --git a/internal/datastore/dsfortesting/dsfortesting.go b/internal/datastore/dsfortesting/dsfortesting.go index 5e83d9ce44..b6fc7f5d9c 100644 --- a/internal/datastore/dsfortesting/dsfortesting.go +++ b/internal/datastore/dsfortesting/dsfortesting.go @@ -49,24 +49,24 @@ func (vr validatingReader) QueryRelationships( filter datastore.RelationshipsFilter, options ...options.QueryOptionsOption, ) (datastore.RelationshipIterator, error) { - schema := common.NewSchemaInformation( - "relationtuples", - "ns", - "object_id", - "relation", - "subject_ns", - "subject_object_id", - "subject_relation", - "caveat", - "caveat_context", - "expiration", - common.TupleComparison, - sq.Question, - "NOW", - common.ColumnOptimizationOptionStaticValues, + schema := common.NewSchemaInformationWithOptions( + common.WithRelationshipTableName("relationtuples"), + common.WithColNamespace("ns"), + common.WithColObjectID("object_id"), + common.WithColRelation("relation"), + common.WithColUsersetNamespace("subject_ns"), + common.WithColUsersetObjectID("subject_object_id"), + common.WithColUsersetRelation("subject_relation"), + common.WithColCaveatName("caveat"), + common.WithColCaveatContext("caveat_context"), + common.WithColExpiration("expiration"), + common.WithPlaceholderFormat(sq.Question), + common.WithPaginationFilterType(common.TupleComparison), + common.WithColumnOptimization(common.ColumnOptimizationOptionStaticValues), + common.WithNowFunction("NOW"), ) - qBuilder, err := common.NewSchemaQueryFiltererForRelationshipsSelect(schema, 100). + qBuilder, err := common.NewSchemaQueryFiltererForRelationshipsSelect(*schema, 100). FilterWithRelationshipsFilter(filter) if err != nil { return nil, err @@ -74,21 +74,21 @@ func (vr validatingReader) QueryRelationships( // Run the filter through the common SQL ellison system and ensure that any // relationships return have values matching the static fields, if applicable. - var queryInfo *common.QueryInfo - executor := common.QueryExecutor{ - Executor: func(ctx context.Context, qi common.QueryInfo, sql string, args []any) (datastore.RelationshipIterator, error) { - queryInfo = &qi + var builder *common.RelationshipsQueryBuilder + executor := common.QueryRelationshipsExecutor{ + Executor: func(ctx context.Context, b common.RelationshipsQueryBuilder) (datastore.RelationshipIterator, error) { + builder = &b return nil, nil }, } _, _ = executor.ExecuteQuery(ctx, qBuilder, options...) - if queryInfo == nil { - return nil, fmt.Errorf("no query info returned") + if builder == nil { + return nil, fmt.Errorf("no builder returned") } checkStaticField := func(returnedValue string, fieldName string) error { - if found, ok := queryInfo.FilteringValues[fieldName]; ok && found.SingleValue != nil { + if found, ok := builder.FilteringValuesForTesting()[fieldName]; ok && found.SingleValue != nil { if returnedValue != *found.SingleValue { return fmt.Errorf("static field `%s` does not match expected value `%s`: `%s", fieldName, returnedValue, *found.SingleValue) } diff --git a/internal/datastore/mysql/datastore.go b/internal/datastore/mysql/datastore.go index e8c5b6c903..ea39afd629 100644 --- a/internal/datastore/mysql/datastore.go +++ b/internal/datastore/mysql/datastore.go @@ -244,21 +244,21 @@ func newMySQLDatastore(ctx context.Context, uri string, replicaIndex int, option -1*config.gcWindow.Seconds(), ) - schema := common.NewSchemaInformation( - driver.RelationTuple(), - colNamespace, - colObjectID, - colRelation, - colUsersetNamespace, - colUsersetObjectID, - colUsersetRelation, - colCaveatName, - colCaveatContext, - colExpiration, - common.ExpandedLogicComparison, - sq.Question, - "NOW", - config.columnOptimizationOption, + schema := common.NewSchemaInformationWithOptions( + common.WithRelationshipTableName(driver.RelationTuple()), + common.WithColNamespace(colNamespace), + common.WithColObjectID(colObjectID), + common.WithColRelation(colRelation), + common.WithColUsersetNamespace(colUsersetNamespace), + common.WithColUsersetObjectID(colUsersetObjectID), + common.WithColUsersetRelation(colUsersetRelation), + common.WithColCaveatName(colCaveatName), + common.WithColCaveatContext(colCaveatContext), + common.WithColExpiration(colExpiration), + common.WithPaginationFilterType(common.ExpandedLogicComparison), + common.WithPlaceholderFormat(sq.Question), + common.WithNowFunction("NOW"), + common.WithColumnOptimization(config.columnOptimizationOption), ) store := &Datastore{ @@ -282,7 +282,7 @@ func newMySQLDatastore(ctx context.Context, uri string, replicaIndex int, option readTxOptions: &sql.TxOptions{Isolation: sql.LevelSerializable, ReadOnly: true}, maxRetries: config.maxRetries, analyzeBeforeStats: config.analyzeBeforeStats, - schema: schema, + schema: *schema, CachedOptimizedRevisions: revisions.NewCachedOptimizedRevisions( maxRevisionStaleness, ), @@ -332,7 +332,7 @@ func (mds *Datastore) SnapshotReader(rev datastore.Revision) datastore.Reader { return tx, tx.Rollback, nil } - executor := common.QueryExecutor{ + executor := common.QueryRelationshipsExecutor{ Executor: newMySQLExecutor(mds.db), } @@ -375,7 +375,7 @@ func (mds *Datastore) ReadWriteTx( return tx, noCleanup, nil } - executor := common.QueryExecutor{ + executor := common.QueryRelationshipsExecutor{ Executor: newMySQLExecutor(tx), } @@ -468,9 +468,9 @@ func newMySQLExecutor(tx querier) common.ExecuteReadRelsQueryFunc { // // Prepared statements are also not used given they perform poorly on environments where connections have // short lifetime (e.g. to gracefully handle load-balancer connection drain) - return func(ctx context.Context, queryInfo common.QueryInfo, sqlQuery string, args []interface{}) (datastore.RelationshipIterator, error) { + return func(ctx context.Context, builder common.RelationshipsQueryBuilder) (datastore.RelationshipIterator, error) { span := trace.SpanFromContext(ctx) - return common.QueryRelationships[common.Rows, structpbWrapper](ctx, queryInfo, sqlQuery, args, span, asQueryableTx{tx}, false) + return common.QueryRelationships[common.Rows, structpbWrapper](ctx, builder, span, asQueryableTx{tx}) } } diff --git a/internal/datastore/mysql/reader.go b/internal/datastore/mysql/reader.go index 592b844575..c963fb808f 100644 --- a/internal/datastore/mysql/reader.go +++ b/internal/datastore/mysql/reader.go @@ -23,7 +23,7 @@ type mysqlReader struct { *QueryBuilder txSource txFactory - executor common.QueryExecutor + executor common.QueryRelationshipsExecutor aliveFilter queryFilterer filterMaximumIDCount uint16 schema common.SchemaInformation diff --git a/internal/datastore/postgres/common/pgx.go b/internal/datastore/postgres/common/pgx.go index d811ce5426..72ed14c3f9 100644 --- a/internal/datastore/postgres/common/pgx.go +++ b/internal/datastore/postgres/common/pgx.go @@ -21,18 +21,11 @@ import ( "github.com/authzed/spicedb/pkg/datastore" ) -// NewPGXExecutor creates an executor that uses the pgx library to make the specified queries. -func NewPGXExecutor(querier DBFuncQuerier) common.ExecuteReadRelsQueryFunc { - return func(ctx context.Context, queryInfo common.QueryInfo, sql string, args []any) (datastore.RelationshipIterator, error) { +// NewPGXQueryRelationshipsExecutor creates an executor that uses the pgx library to make the specified queries. +func NewPGXQueryRelationshipsExecutor(querier DBFuncQuerier) common.ExecuteReadRelsQueryFunc { + return func(ctx context.Context, builder common.RelationshipsQueryBuilder) (datastore.RelationshipIterator, error) { span := trace.SpanFromContext(ctx) - return common.QueryRelationships[pgx.Rows, map[string]any](ctx, queryInfo, sql, args, span, querier, false) - } -} - -func NewPGXExecutorWithIntegrityOption(querier DBFuncQuerier, withIntegrity bool) common.ExecuteReadRelsQueryFunc { - return func(ctx context.Context, queryInfo common.QueryInfo, sql string, args []any) (datastore.RelationshipIterator, error) { - span := trace.SpanFromContext(ctx) - return common.QueryRelationships[pgx.Rows, map[string]any](ctx, queryInfo, sql, args, span, querier, withIntegrity) + return common.QueryRelationships[pgx.Rows, map[string]any](ctx, builder, span, querier) } } diff --git a/internal/datastore/postgres/postgres.go b/internal/datastore/postgres/postgres.go index 5cc4df1e46..f69fd2d8cd 100644 --- a/internal/datastore/postgres/postgres.go +++ b/internal/datastore/postgres/postgres.go @@ -315,21 +315,21 @@ func newPostgresDatastore( maxRevisionStaleness := time.Duration(float64(config.revisionQuantization.Nanoseconds())* config.maxRevisionStalenessPercent) * time.Nanosecond - schema := common.NewSchemaInformation( - tableTuple, - colNamespace, - colObjectID, - colRelation, - colUsersetNamespace, - colUsersetObjectID, - colUsersetRelation, - colCaveatContextName, - colCaveatContext, - colExpiration, - common.TupleComparison, - sq.Dollar, - "NOW", - config.columnOptimizationOption, + schema := common.NewSchemaInformationWithOptions( + common.WithRelationshipTableName(tableTuple), + common.WithColNamespace(colNamespace), + common.WithColObjectID(colObjectID), + common.WithColRelation(colRelation), + common.WithColUsersetNamespace(colUsersetNamespace), + common.WithColUsersetObjectID(colUsersetObjectID), + common.WithColUsersetRelation(colUsersetRelation), + common.WithColCaveatName(colCaveatContextName), + common.WithColCaveatContext(colCaveatContext), + common.WithColExpiration(colExpiration), + common.WithPaginationFilterType(common.TupleComparison), + common.WithPlaceholderFormat(sq.Dollar), + common.WithNowFunction("NOW"), + common.WithColumnOptimization(config.columnOptimizationOption), ) datastore := &pgDatastore{ @@ -357,7 +357,7 @@ func newPostgresDatastore( isPrimary: isPrimary, inStrictReadMode: config.readStrictMode, filterMaximumIDCount: config.filterMaximumIDCount, - schema: schema, + schema: *schema, } if isPrimary && config.readStrictMode { @@ -433,8 +433,8 @@ func (pgd *pgDatastore) SnapshotReader(revRaw datastore.Revision) datastore.Read queryFuncs = strictReaderQueryFuncs{wrapped: queryFuncs, revision: rev} } - executor := common.QueryExecutor{ - Executor: pgxcommon.NewPGXExecutor(queryFuncs), + executor := common.QueryRelationshipsExecutor{ + Executor: pgxcommon.NewPGXQueryRelationshipsExecutor(queryFuncs), } return &pgReader{ @@ -476,8 +476,8 @@ func (pgd *pgDatastore) ReadWriteTx( } queryFuncs := pgxcommon.QuerierFuncsFor(pgd.readPool) - executor := common.QueryExecutor{ - Executor: pgxcommon.NewPGXExecutor(queryFuncs), + executor := common.QueryRelationshipsExecutor{ + Executor: pgxcommon.NewPGXQueryRelationshipsExecutor(queryFuncs), } rwt := &pgReadWriteTXN{ diff --git a/internal/datastore/postgres/reader.go b/internal/datastore/postgres/reader.go index 8ab3097e60..ed0ad792d1 100644 --- a/internal/datastore/postgres/reader.go +++ b/internal/datastore/postgres/reader.go @@ -17,7 +17,7 @@ import ( type pgReader struct { query pgxcommon.DBFuncQuerier - executor common.QueryExecutor + executor common.QueryRelationshipsExecutor aliveFilter queryFilterer filterMaximumIDCount uint16 schema common.SchemaInformation diff --git a/internal/datastore/spanner/reader.go b/internal/datastore/spanner/reader.go index fb23c0e06e..2dc8ee5022 100644 --- a/internal/datastore/spanner/reader.go +++ b/internal/datastore/spanner/reader.go @@ -31,7 +31,7 @@ type readTX interface { type txFactory func() readTX type spannerReader struct { - executor common.QueryExecutor + executor common.QueryRelationshipsExecutor txSource txFactory filterMaximumIDCount uint16 schema common.SchemaInformation @@ -173,10 +173,17 @@ func (sr spannerReader) ReverseQueryRelationships( var errStopIterator = fmt.Errorf("stop iteration") func queryExecutor(txSource txFactory) common.ExecuteReadRelsQueryFunc { - return func(ctx context.Context, queryInfo common.QueryInfo, sql string, args []any) (datastore.RelationshipIterator, error) { + return func(ctx context.Context, builder common.RelationshipsQueryBuilder) (datastore.RelationshipIterator, error) { return func(yield func(tuple.Relationship, error) bool) { span := trace.SpanFromContext(ctx) span.AddEvent("Query issued to database") + + sql, args, err := builder.SelectSQL() + if err != nil { + yield(tuple.Relationship{}, err) + return + } + iter := txSource().Query(ctx, statementFromSQL(sql, args)) defer iter.Stop() @@ -196,24 +203,28 @@ func queryExecutor(txSource txFactory) common.ExecuteReadRelsQueryFunc { var caveatCtx spanner.NullJSON var expirationOrNull spanner.NullTime - colsToSelect := make([]any, 0, 8) - - colsToSelect = common.StaticValueOrAddColumnForSelect(colsToSelect, queryInfo, queryInfo.Schema.ColNamespace, &resourceObjectType) - colsToSelect = common.StaticValueOrAddColumnForSelect(colsToSelect, queryInfo, queryInfo.Schema.ColObjectID, &resourceObjectID) - colsToSelect = common.StaticValueOrAddColumnForSelect(colsToSelect, queryInfo, queryInfo.Schema.ColRelation, &relation) - colsToSelect = common.StaticValueOrAddColumnForSelect(colsToSelect, queryInfo, queryInfo.Schema.ColUsersetNamespace, &subjectObjectType) - colsToSelect = common.StaticValueOrAddColumnForSelect(colsToSelect, queryInfo, queryInfo.Schema.ColUsersetObjectID, &subjectObjectID) - colsToSelect = common.StaticValueOrAddColumnForSelect(colsToSelect, queryInfo, queryInfo.Schema.ColUsersetRelation, &subjectRelation) - - if !queryInfo.SkipCaveats || queryInfo.Schema.ColumnOptimization == common.ColumnOptimizationOptionNone { - colsToSelect = append(colsToSelect, &caveatName, &caveatCtx) - } - - colsToSelect = append(colsToSelect, &expirationOrNull) - - if len(colsToSelect) == 0 { - var unused int64 - colsToSelect = append(colsToSelect, &unused) + // NOTE: these are unused in Spanner, but necessary for the ColumnsToSelect call. + var integrityKeyID string + var integrityHash []byte + var timestamp time.Time + + colsToSelect, err := common.ColumnsToSelect(builder, + &resourceObjectType, + &resourceObjectID, + &relation, + &subjectObjectType, + &subjectObjectID, + &subjectRelation, + &caveatName, + &caveatCtx, + &expirationOrNull, + &integrityKeyID, + &integrityHash, + ×tamp, + ) + if err != nil { + yield(tuple.Relationship{}, err) + return } if err := iter.Do(func(row *spanner.Row) error { diff --git a/internal/datastore/spanner/spanner.go b/internal/datastore/spanner/spanner.go index af61f45d26..fe31b92f86 100644 --- a/internal/datastore/spanner/spanner.go +++ b/internal/datastore/spanner/spanner.go @@ -177,6 +177,23 @@ func NewSpannerDatastore(ctx context.Context, database string, opts ...Option) ( return nil, fmt.Errorf("invalid head migration found for spanner: %w", err) } + schema := common.NewSchemaInformationWithOptions( + common.WithRelationshipTableName(tableRelationship), + common.WithColNamespace(colNamespace), + common.WithColObjectID(colObjectID), + common.WithColRelation(colRelation), + common.WithColUsersetNamespace(colUsersetNamespace), + common.WithColUsersetObjectID(colUsersetObjectID), + common.WithColUsersetRelation(colUsersetRelation), + common.WithColCaveatName(colCaveatName), + common.WithColCaveatContext(colCaveatContext), + common.WithColExpiration(colExpiration), + common.WithPaginationFilterType(common.ExpandedLogicComparison), + common.WithPlaceholderFormat(sq.AtP), + common.WithNowFunction("CURRENT_TIMESTAMP"), + common.WithColumnOptimization(config.columnOptimizationOption), + ) + ds := &spannerDatastore{ RemoteClockRevisions: revisions.NewRemoteClockRevisions( defaultChangeStreamRetention, @@ -197,22 +214,7 @@ func NewSpannerDatastore(ctx context.Context, database string, opts ...Option) ( cachedEstimatedBytesPerRelationshipLock: sync.RWMutex{}, tableSizesStatsTable: tableSizesStatsTable, filterMaximumIDCount: config.filterMaximumIDCount, - schema: common.NewSchemaInformation( - tableRelationship, - colNamespace, - colObjectID, - colRelation, - colUsersetNamespace, - colUsersetObjectID, - colUsersetRelation, - colCaveatName, - colCaveatContext, - colExpiration, - common.ExpandedLogicComparison, - sq.AtP, - "CURRENT_TIMESTAMP", - config.columnOptimizationOption, - ), + schema: *schema, } // Optimized revision and revision checking use a stale read for the // current timestamp. @@ -260,7 +262,7 @@ func (sd *spannerDatastore) SnapshotReader(revisionRaw datastore.Revision) datas txSource := func() readTX { return &traceableRTX{delegate: sd.client.Single().WithTimestampBound(spanner.ReadTimestamp(r.Time()))} } - executor := common.QueryExecutor{Executor: queryExecutor(txSource)} + executor := common.QueryRelationshipsExecutor{Executor: queryExecutor(txSource)} return spannerReader{executor, txSource, sd.filterMaximumIDCount, sd.schema} } @@ -308,7 +310,7 @@ func (sd *spannerDatastore) ReadWriteTx(ctx context.Context, fn datastore.TxUser } } - executor := common.QueryExecutor{Executor: queryExecutor(txSource)} + executor := common.QueryRelationshipsExecutor{Executor: queryExecutor(txSource)} rwt := spannerReadWriteTXN{ spannerReader{executor, txSource, sd.filterMaximumIDCount, sd.schema}, spannerRWT,