diff --git a/pkg/cmd/server/defaults.go b/pkg/cmd/server/defaults.go index a534bbf6f1..da9923c516 100644 --- a/pkg/cmd/server/defaults.go +++ b/pkg/cmd/server/defaults.go @@ -22,6 +22,7 @@ import ( "github.com/jzelinskie/cobrautil/v2/cobraproclimits" "github.com/jzelinskie/cobrautil/v2/cobrazerolog" "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/rs/zerolog" "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" @@ -30,6 +31,8 @@ import ( "google.golang.org/grpc/codes" "github.com/authzed/authzed-go/pkg/requestmeta" + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + "github.com/authzed/grpcutil" "github.com/authzed/spicedb/internal/dispatch" "github.com/authzed/spicedb/internal/logging" @@ -333,6 +336,11 @@ func DefaultUnaryMiddleware(opts MiddlewareOption) (*MiddlewareChain[grpc.UnaryS WithInterceptor(serverversion.UnaryServerInterceptor(opts.EnableVersionResponse)). Done(), + NewUnaryMiddleware(). + WithName("LogicalChecksMetric"). + WithInterceptor(LogicalChecksMetricUnary). + Done(), + NewUnaryMiddleware(). WithName(DefaultInternalMiddlewareDispatch). WithInternal(true). @@ -406,6 +414,11 @@ func DefaultStreamingMiddleware(opts MiddlewareOption) (*MiddlewareChain[grpc.St WithInterceptor(serverversion.StreamServerInterceptor(opts.EnableVersionResponse)). Done(), + NewStreamMiddleware(). + WithName("LogicalChecksMetric"). + WithInterceptor(LogicalChecksMetricStreaming). + Done(), + NewStreamMiddleware(). WithName(DefaultInternalMiddlewareDispatch). WithInternal(true). @@ -465,6 +478,60 @@ func DefaultDispatchMiddleware(logger zerolog.Logger, authFunc grpcauth.AuthFunc } } +var LogicalChecksCounter = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: "spicedb", + Subsystem: "logical", + Name: "relationship_checks_total", + Help: "Count of the logically checked relationships across various request types", +}, []string{"grpc_method", "grpc_service", "grpc_type"}) + +func LogicalChecksMetricUnary(ctx context.Context, request any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { + response, err := handler(ctx, request) + if err == nil { + svc, method := grpcutil.SplitMethodName(info.FullMethod) + counter := LogicalChecksCounter.WithLabelValues(method, svc, "unary") + reportLogicalChecks(request, counter) + } + return response, err +} + +func LogicalChecksMetricStreaming(srv any, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + svc, method := grpcutil.SplitMethodName(info.FullMethod) + counter := LogicalChecksCounter.WithLabelValues(method, svc, "server_stream") + wrapper := &logicalCheckWrapper{ + counter: counter, + ServerStream: stream, + } + return handler(srv, wrapper) +} + +type logicalCheckWrapper struct { + counter prometheus.Counter + grpc.ServerStream +} + +func (w *logicalCheckWrapper) SendMsg(m any) error { + if err := w.ServerStream.SendMsg(m); err != nil { + return err + } + reportLogicalChecks(m, w.counter) + return nil +} + +func reportLogicalChecks(msg any, counter prometheus.Counter) { + switch m := msg.(type) { + case *v1.CheckPermissionRequest: + counter.Add(1) + case *v1.CheckBulkPermissionsRequest: + counter.Add(float64(len(m.GetItems()))) + case *v1.LookupResourcesResponse: + counter.Add(1) + case *v1.LookupSubjectsResponse: + counter.Add(1) + default: + } +} + func InterceptorLogger(l zerolog.Logger) grpclog.Logger { return grpclog.LoggerFunc(func(ctx context.Context, lvl grpclog.Level, msg string, fields ...any) { l := l.With().Fields(fields).Logger()