diff --git a/pkg/cmd/server/defaults.go b/pkg/cmd/server/defaults.go index 1212fcd84e..69f345cc7d 100644 --- a/pkg/cmd/server/defaults.go +++ b/pkg/cmd/server/defaults.go @@ -11,8 +11,10 @@ import ( "github.com/fatih/color" "github.com/go-logr/zerologr" grpcprom "github.com/grpc-ecosystem/go-grpc-middleware/providers/prometheus" + "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors" grpcauth "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/auth" grpclog "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/logging" + "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/selector" "github.com/jzelinskie/cobrautil/v2" "github.com/jzelinskie/cobrautil/v2/cobraotel" "github.com/jzelinskie/cobrautil/v2/cobrazerolog" @@ -113,25 +115,29 @@ func MetricsHandler(telemetryRegistry *prometheus.Registry, c *Config) http.Hand return mux } -var defaultGRPCLogOptions = []grpclog.Option{ - // the server has a deadline set, so we consider it a normal condition - // this makes sure we don't log them as errors - grpclog.WithLevels(func(code codes.Code) grpclog.Level { - if code == codes.DeadlineExceeded { - return grpclog.LevelInfo - } - return grpclog.DefaultServerCodeToLevel(code) - }), - grpclog.WithDurationField(func(duration time.Duration) grpclog.Fields { - return grpclog.Fields{"grpc.time_ms", duration.Milliseconds()} - }), - grpclog.WithFieldsFromContext(func(ctx context.Context) grpclog.Fields { - if span := trace.SpanContextFromContext(ctx); span.IsSampled() { - return grpclog.Fields{"traceID", span.TraceID().String()} - } - return nil - }), -} +// the server has a deadline set, so we consider it a normal condition +// this makes sure we don't log them as errors +var defaultCodeToLevel = grpclog.WithLevels(func(code codes.Code) grpclog.Level { + if code == codes.DeadlineExceeded { + return grpclog.LevelInfo + } + return grpclog.DefaultServerCodeToLevel(code) +}) + +var durationFieldOption = grpclog.WithDurationField(func(duration time.Duration) grpclog.Fields { + return grpclog.Fields{"grpc.time_ms", duration.Milliseconds()} +}) + +var traceIDFieldOption = grpclog.WithFieldsFromContext(func(ctx context.Context) grpclog.Fields { + if span := trace.SpanContextFromContext(ctx); span.IsSampled() { + return grpclog.Fields{"traceID", span.TraceID().String()} + } + return nil +}) + +var alwaysDebugOption = grpclog.WithLevels(func(code codes.Code) grpclog.Level { + return grpclog.LevelDebug +}) const ( DefaultMiddlewareRequestID = "requestid" @@ -168,6 +174,20 @@ func init() { GRPCMetricsUnaryInterceptor, GRPCMetricsStreamingInterceptor = createServerMetrics() } +const healthCheckRoute = "/grpc.health.v1.Health/Check" + +func matchesRoute(route string) func(_ context.Context, c interceptors.CallMeta) bool { + return func(_ context.Context, c interceptors.CallMeta) bool { + return c.FullMethod() == route + } +} + +func doesNotMatchRoute(route string) func(_ context.Context, c interceptors.CallMeta) bool { + return func(_ context.Context, c interceptors.CallMeta) bool { + return c.FullMethod() != route + } +} + // DefaultUnaryMiddleware generates the default middleware chain used for the public SpiceDB Unary gRPC methods func DefaultUnaryMiddleware(opts MiddlewareOption) (*MiddlewareChain[grpc.UnaryServerInterceptor], error) { chain, err := NewMiddlewareChain([]ReferenceableMiddleware[grpc.UnaryServerInterceptor]{ @@ -186,10 +206,18 @@ func DefaultUnaryMiddleware(opts MiddlewareOption) (*MiddlewareChain[grpc.UnaryS WithInterceptor(otelgrpc.UnaryServerInterceptor()). // nolint: staticcheck Done(), + NewUnaryMiddleware(). + WithName(DefaultMiddlewareGRPCLog + "-debug"). + WithInterceptor(selector.UnaryServerInterceptor( + grpclog.UnaryServerInterceptor(InterceptorLogger(opts.logger), determineEventsToLog(opts), alwaysDebugOption, durationFieldOption, traceIDFieldOption), + selector.MatchFunc(matchesRoute(healthCheckRoute)))). + Done(), + NewUnaryMiddleware(). WithName(DefaultMiddlewareGRPCLog). - WithInterceptor(grpclog.UnaryServerInterceptor(InterceptorLogger(opts.logger), determineEventsToLog(opts)...)). - EnsureAlreadyExecuted(DefaultMiddlewareOTelGRPC). // dependency so that OTel traceID is injected in logs + WithInterceptor(selector.UnaryServerInterceptor( + grpclog.UnaryServerInterceptor(InterceptorLogger(opts.logger), determineEventsToLog(opts), defaultCodeToLevel, durationFieldOption, traceIDFieldOption), + selector.MatchFunc(doesNotMatchRoute(healthCheckRoute)))). Done(), NewUnaryMiddleware(). @@ -253,10 +281,18 @@ func DefaultStreamingMiddleware(opts MiddlewareOption) (*MiddlewareChain[grpc.St WithInterceptor(otelgrpc.StreamServerInterceptor()). // nolint: staticcheck Done(), + NewStreamMiddleware(). + WithName(DefaultMiddlewareGRPCLog + "-debug"). + WithInterceptor(selector.StreamServerInterceptor( + grpclog.StreamServerInterceptor(InterceptorLogger(opts.logger), determineEventsToLog(opts), alwaysDebugOption, durationFieldOption, traceIDFieldOption), + selector.MatchFunc(matchesRoute(healthCheckRoute)))). + Done(), + NewStreamMiddleware(). WithName(DefaultMiddlewareGRPCLog). - WithInterceptor(grpclog.StreamServerInterceptor(InterceptorLogger(opts.logger), determineEventsToLog(opts)...)). - EnsureInterceptorAlreadyExecuted(DefaultMiddlewareOTelGRPC). // dependency so that OTel traceID is injected in logs + WithInterceptor(selector.StreamServerInterceptor( + grpclog.StreamServerInterceptor(InterceptorLogger(opts.logger), determineEventsToLog(opts), defaultCodeToLevel, durationFieldOption, traceIDFieldOption), + selector.MatchFunc(doesNotMatchRoute(healthCheckRoute)))). Done(), NewStreamMiddleware(). @@ -302,7 +338,7 @@ func DefaultStreamingMiddleware(opts MiddlewareOption) (*MiddlewareChain[grpc.St return &chain, err } -func determineEventsToLog(opts MiddlewareOption) []grpclog.Option { +func determineEventsToLog(opts MiddlewareOption) grpclog.Option { eventsToLog := []grpclog.LoggableEvent{grpclog.FinishCall} if opts.enableRequestLog { eventsToLog = append(eventsToLog, grpclog.PayloadReceived) @@ -312,10 +348,7 @@ func determineEventsToLog(opts MiddlewareOption) []grpclog.Option { eventsToLog = append(eventsToLog, grpclog.PayloadSent) } - logOnEvents := grpclog.WithLogOnEvents(eventsToLog...) - grpcLogOptions := append(defaultGRPCLogOptions, logOnEvents) - - return grpcLogOptions + return grpclog.WithLogOnEvents(eventsToLog...) } // DefaultDispatchMiddleware generates the default middleware chain used for the internal dispatch SpiceDB gRPC API @@ -323,7 +356,7 @@ func DefaultDispatchMiddleware(logger zerolog.Logger, authFunc grpcauth.AuthFunc return []grpc.UnaryServerInterceptor{ requestid.UnaryServerInterceptor(requestid.GenerateIfMissing(true)), logmw.UnaryServerInterceptor(logmw.ExtractMetadataField("x-request-id", "requestID")), - grpclog.UnaryServerInterceptor(InterceptorLogger(logger), defaultGRPCLogOptions...), + grpclog.UnaryServerInterceptor(InterceptorLogger(logger), defaultCodeToLevel, durationFieldOption), otelgrpc.UnaryServerInterceptor(), // nolint: staticcheck GRPCMetricsUnaryInterceptor, grpcauth.UnaryServerInterceptor(authFunc), @@ -332,7 +365,7 @@ func DefaultDispatchMiddleware(logger zerolog.Logger, authFunc grpcauth.AuthFunc }, []grpc.StreamServerInterceptor{ requestid.StreamServerInterceptor(requestid.GenerateIfMissing(true)), logmw.StreamServerInterceptor(logmw.ExtractMetadataField("x-request-id", "requestID")), - grpclog.StreamServerInterceptor(InterceptorLogger(logger), defaultGRPCLogOptions...), + grpclog.StreamServerInterceptor(InterceptorLogger(logger), defaultCodeToLevel, durationFieldOption), otelgrpc.StreamServerInterceptor(), // nolint: staticcheck GRPCMetricsStreamingInterceptor, grpcauth.StreamServerInterceptor(authFunc),