From 6a29d88781bf6417aea2c9f0446daa39fcaa17cf Mon Sep 17 00:00:00 2001 From: Joseph Schorr Date: Thu, 5 Dec 2024 13:13:33 -0500 Subject: [PATCH] Add missing wiring of service label through With* methods in options --- pkg/cmd/server/defaults.go | 2 + pkg/cmd/server/defaults_test.go | 94 +++++++++++++++++++++++++++++++++ 2 files changed, 96 insertions(+) create mode 100644 pkg/cmd/server/defaults_test.go diff --git a/pkg/cmd/server/defaults.go b/pkg/cmd/server/defaults.go index a5d08291ef..a534bbf6f1 100644 --- a/pkg/cmd/server/defaults.go +++ b/pkg/cmd/server/defaults.go @@ -218,6 +218,7 @@ func (m MiddlewareOption) WithDatastoreMiddleware(middleware Middleware) Middlew EnableRequestLog: m.EnableRequestLog, EnableResponseLog: m.EnableResponseLog, DisableGRPCHistogram: m.DisableGRPCHistogram, + MiddlewareServiceLabel: m.MiddlewareServiceLabel, unaryDatastoreMiddleware: &unary, streamDatastoreMiddleware: &stream, } @@ -244,6 +245,7 @@ func (m MiddlewareOption) WithDatastore(ds datastore.Datastore) MiddlewareOption EnableRequestLog: m.EnableRequestLog, EnableResponseLog: m.EnableResponseLog, DisableGRPCHistogram: m.DisableGRPCHistogram, + MiddlewareServiceLabel: m.MiddlewareServiceLabel, unaryDatastoreMiddleware: &unary, streamDatastoreMiddleware: &stream, } diff --git a/pkg/cmd/server/defaults_test.go b/pkg/cmd/server/defaults_test.go new file mode 100644 index 0000000000..5b3ea59f60 --- /dev/null +++ b/pkg/cmd/server/defaults_test.go @@ -0,0 +1,94 @@ +package server + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/rs/zerolog" + "github.com/stretchr/testify/require" + + "github.com/authzed/spicedb/internal/datastore/memdb" + "github.com/authzed/spicedb/internal/dispatch" + "github.com/authzed/spicedb/internal/middleware/pertoken" +) + +func TestWithDatastore(t *testing.T) { + someLogger := zerolog.Nop() + someAuthFunc := func(ctx context.Context) (context.Context, error) { + return nil, fmt.Errorf("expected auth error") + } + var someDispatcher dispatch.Dispatcher + + opts := MiddlewareOption{ + someLogger, + someAuthFunc, + true, + someDispatcher, + true, + true, + false, + "service", + nil, + nil, + } + + someDS, err := memdb.NewMemdbDatastore(0, time.Hour, time.Hour) + require.NoError(t, err) + + withDS := opts.WithDatastore(someDS) + require.NotNil(t, withDS) + require.NotNil(t, withDS.unaryDatastoreMiddleware) + require.NotNil(t, withDS.streamDatastoreMiddleware) + + require.Equal(t, opts.Logger, withDS.Logger) + require.Equal(t, opts.DispatcherForMiddleware, withDS.DispatcherForMiddleware) + require.Equal(t, opts.EnableRequestLog, withDS.EnableRequestLog) + require.Equal(t, opts.EnableResponseLog, withDS.EnableResponseLog) + require.Equal(t, opts.DisableGRPCHistogram, withDS.DisableGRPCHistogram) + require.Equal(t, opts.MiddlewareServiceLabel, withDS.MiddlewareServiceLabel) + + _, authError := withDS.AuthFunc(context.Background()) + require.Error(t, authError) + require.ErrorContains(t, authError, "expected auth error") +} + +func TestWithDatastoreMiddleware(t *testing.T) { + someLogger := zerolog.Nop() + someAuthFunc := func(ctx context.Context) (context.Context, error) { + return nil, fmt.Errorf("expected auth error") + } + var someDispatcher dispatch.Dispatcher + + opts := MiddlewareOption{ + someLogger, + someAuthFunc, + true, + someDispatcher, + true, + true, + false, + "anotherservice", + nil, + nil, + } + + someMiddleware := pertoken.NewMiddleware(nil) + + withDS := opts.WithDatastoreMiddleware(someMiddleware) + require.NotNil(t, withDS) + require.NotNil(t, withDS.unaryDatastoreMiddleware) + require.NotNil(t, withDS.streamDatastoreMiddleware) + + require.Equal(t, opts.Logger, withDS.Logger) + require.Equal(t, opts.DispatcherForMiddleware, withDS.DispatcherForMiddleware) + require.Equal(t, opts.EnableRequestLog, withDS.EnableRequestLog) + require.Equal(t, opts.EnableResponseLog, withDS.EnableResponseLog) + require.Equal(t, opts.DisableGRPCHistogram, withDS.DisableGRPCHistogram) + require.Equal(t, opts.MiddlewareServiceLabel, withDS.MiddlewareServiceLabel) + + _, authError := withDS.AuthFunc(context.Background()) + require.Error(t, authError) + require.ErrorContains(t, authError, "expected auth error") +}