diff --git a/internal/datastore/postgres/postgres_shared_test.go b/internal/datastore/postgres/postgres_shared_test.go index e556bfd1c1..de9bf6076d 100644 --- a/internal/datastore/postgres/postgres_shared_test.go +++ b/internal/datastore/postgres/postgres_shared_test.go @@ -1372,16 +1372,8 @@ func GCQueriesServedByExpectedIndexes(t *testing.T, _ testdatastore.RunningEngin revision, err := ds.HeadRevision(ctx) require.NoError(err) - for { - wds, ok := ds.(datastore.UnwrappableDatastore) - if !ok { - break - } - ds = wds.Unwrap() - } - - casted, ok := ds.(common.GarbageCollector) - require.True(ok) + casted := datastore.UnwrapAs[common.GarbageCollector](ds) + require.NotNil(casted) _, err = casted.DeleteBeforeTx(context.Background(), revision) require.NoError(err) diff --git a/internal/datastore/proxy/schemacaching/caching.go b/internal/datastore/proxy/schemacaching/caching.go index df28d4d647..9259e19cfe 100644 --- a/internal/datastore/proxy/schemacaching/caching.go +++ b/internal/datastore/proxy/schemacaching/caching.go @@ -42,7 +42,7 @@ func NewCachingDatastoreProxy(delegate datastore.Datastore, c cache.Cache, gcWin } if cachingMode == JustInTimeCaching { - log.Info().Type("datastore-type", delegate).Msg("datastore driver explicitly asked to skip schema watch") + log.Info().Msg("schema watch explicitly disabled") return &definitionCachingProxy{ Datastore: delegate, c: c, @@ -51,27 +51,14 @@ func NewCachingDatastoreProxy(delegate datastore.Datastore, c cache.Cache, gcWin // Try to instantiate a schema cache that reads updates from the datastore's schema watch stream. If not possible, // fallback to the just-in-time caching proxy. - if watchable, ok := delegate.(datastore.SchemaWatchableDatastore); ok { + if watchable := datastore.UnwrapAs[datastore.SchemaWatchableDatastore](delegate); watchable != nil { + log.Info().Type("datastore-type", watchable).Msg("enabled schema caching") return createWatchingCacheProxy(watchable, c, gcWindow) } - unwrapped, ok := delegate.(datastore.UnwrappableDatastore) - if !ok { - log.Warn().Type("datastore-type", delegate).Msg("datastore driver does not support unwrapping; falling back to just-in-time caching") - return &definitionCachingProxy{ - Datastore: delegate, - c: c, - } + log.Info().Type("datastore-type", delegate).Msg("schema watch was enabled but datastore does not support it; falling back to just-in-time caching") + return &definitionCachingProxy{ + Datastore: delegate, + c: c, } - - watchable, ok := unwrapped.Unwrap().(datastore.SchemaWatchableDatastore) - if !ok { - log.Info().Type("datastore-type", delegate).Msg("datastore driver does not schema watch; falling back to just-in-time caching") - return &definitionCachingProxy{ - Datastore: delegate, - c: c, - } - } - - return createWatchingCacheProxy(watchable, c, gcWindow) } diff --git a/pkg/cmd/datastore.go b/pkg/cmd/datastore.go index 65fb4b92e7..c8fef084fc 100644 --- a/pkg/cmd/datastore.go +++ b/pkg/cmd/datastore.go @@ -64,16 +64,8 @@ func NewGCDatastoreCommand(programName string, cfg *datastore.Config) *cobra.Com return fmt.Errorf("failed to create datastore: %w", err) } - for { - wds, ok := ds.(dspkg.UnwrappableDatastore) - if !ok { - break - } - ds = wds.Unwrap() - } - - gc, ok := ds.(common.GarbageCollector) - if !ok { + gc := dspkg.UnwrapAs[common.GarbageCollector](ds) + if gc == nil { return fmt.Errorf("datastore of type %T does not support garbage collection", ds) } @@ -109,16 +101,8 @@ func NewRepairDatastoreCommand(programName string, cfg *datastore.Config) *cobra return fmt.Errorf("failed to create datastore: %w", err) } - for { - wds, ok := ds.(dspkg.UnwrappableDatastore) - if !ok { - break - } - ds = wds.Unwrap() - } - - repairable, ok := ds.(dspkg.RepairableDatastore) - if !ok { + repairable := dspkg.UnwrapAs[dspkg.RepairableDatastore](ds) + if repairable == nil { return fmt.Errorf("datastore of type %T does not support the repair operation", ds) } diff --git a/pkg/cmd/server/server.go b/pkg/cmd/server/server.go index 5118f34089..4727ab85d4 100644 --- a/pkg/cmd/server/server.go +++ b/pkg/cmd/server/server.go @@ -591,14 +591,11 @@ func (c *completedServerConfig) DispatchNetDialContext(ctx context.Context, s st func (c *completedServerConfig) Run(ctx context.Context) error { log.Ctx(ctx).Info().Type("datastore", c.ds).Msg("running server") - if unwrappableDS, ok := c.ds.(datastore.UnwrappableDatastore); ok { - log.Ctx(ctx).Info().Msg("checking for startable datastore") - if startableDS, ok := unwrappableDS.Unwrap().(datastore.StartableDatastore); ok { - log.Ctx(ctx).Info().Msg("Start-ing datastore") - err := startableDS.Start(ctx) - if err != nil { - return err - } + if startable := datastore.UnwrapAs[datastore.StartableDatastore](c.ds); startable != nil { + log.Ctx(ctx).Info().Msg("Start-ing datastore") + err := startable.Start(ctx) + if err != nil { + return err } } diff --git a/pkg/datastore/datastore.go b/pkg/datastore/datastore.go index 3bc9210ccf..4a946bdfd5 100644 --- a/pkg/datastore/datastore.go +++ b/pkg/datastore/datastore.go @@ -287,7 +287,7 @@ type ReadyState struct { // BulkWriteRelationshipSource is an interface for transferring relationships // to a backing datastore with a zero-copy methodology. type BulkWriteRelationshipSource interface { - // Returns a pointer to a relation tuple if one is available, or nil if + // Next Returns a pointer to a relation tuple if one is available, or nil if // there are no more or there was an error. // // Note: sources may re-use the same memory address for every tuple, data @@ -439,6 +439,30 @@ type UnwrappableDatastore interface { Unwrap() Datastore } +// UnwrapAs recursively attempts to unwrap the datastore into the specified type +// In none of the layers of the datastore implement the specified type, nil is returned. +func UnwrapAs[T any](datastore Datastore) T { + var ds T + uwds := datastore + + for { + var ok bool + ds, ok = uwds.(T) + if ok { + break + } + + wds, ok := uwds.(UnwrappableDatastore) + if !ok { + break + } + + uwds = wds.Unwrap() + } + + return ds +} + // Feature represents a capability that a datastore can support, plus an // optional message explaining the feature is available (or not). type Feature struct { @@ -502,10 +526,10 @@ type Revision interface { // Equal returns whether the revisions should be considered equal. Equal(Revision) bool - // Equal returns whether the receiver is provably greater than the right hand side. + // GreaterThan returns whether the receiver is probably greater than the right hand side. GreaterThan(Revision) bool - // Equal returns whether the receiver is provably less than the right hand side. + // LessThan returns whether the receiver is probably less than the right hand side. LessThan(Revision) bool } diff --git a/pkg/datastore/datastore_test.go b/pkg/datastore/datastore_test.go index b73b718cbf..b71d7a1058 100644 --- a/pkg/datastore/datastore_test.go +++ b/pkg/datastore/datastore_test.go @@ -1,8 +1,11 @@ package datastore import ( + "context" "testing" + "github.com/authzed/spicedb/pkg/datastore/options" + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" "github.com/stretchr/testify/require" ) @@ -104,3 +107,78 @@ func TestRelationshipsFilterFromPublicFilter(t *testing.T) { }) } } + +func TestUnwrapAs(t *testing.T) { + result := UnwrapAs[error](nil) + require.Nil(t, result) + + ds := fakeDatastore{delegate: fakeDatastore{fakeDatastoreError{}}} + result = UnwrapAs[error](ds) + require.NotNil(t, result) + require.IsType(t, fakeDatastoreError{}, result) + + errorable := fakeDatastoreError{} + result = UnwrapAs[error](errorable) + require.NotNil(t, result) + require.IsType(t, fakeDatastoreError{}, result) +} + +type fakeDatastoreError struct { + fakeDatastore +} + +func (e fakeDatastoreError) Error() string { + return "" +} + +type fakeDatastore struct { + delegate Datastore +} + +func (f fakeDatastore) Unwrap() Datastore { + return f.delegate +} + +func (f fakeDatastore) SnapshotReader(_ Revision) Reader { + return nil +} + +func (f fakeDatastore) ReadWriteTx(_ context.Context, _ TxUserFunc, _ ...options.RWTOptionsOption) (Revision, error) { + return nil, nil +} + +func (f fakeDatastore) OptimizedRevision(_ context.Context) (Revision, error) { + return nil, nil +} + +func (f fakeDatastore) HeadRevision(_ context.Context) (Revision, error) { + return nil, nil +} + +func (f fakeDatastore) CheckRevision(_ context.Context, _ Revision) error { + return nil +} + +func (f fakeDatastore) RevisionFromString(_ string) (Revision, error) { + return nil, nil +} + +func (f fakeDatastore) Watch(_ context.Context, _ Revision) (<-chan *RevisionChanges, <-chan error) { + return nil, nil +} + +func (f fakeDatastore) ReadyState(_ context.Context) (ReadyState, error) { + return ReadyState{}, nil +} + +func (f fakeDatastore) Features(_ context.Context) (*Features, error) { + return nil, nil +} + +func (f fakeDatastore) Statistics(_ context.Context) (Stats, error) { + return Stats{}, nil +} + +func (f fakeDatastore) Close() error { + return nil +}