diff --git a/internal/datastore/postgres/common/pgx.go b/internal/datastore/postgres/common/pgx.go index 1e9f1db85e..fbc4b05b0a 100644 --- a/internal/datastore/postgres/common/pgx.go +++ b/internal/datastore/postgres/common/pgx.go @@ -14,6 +14,7 @@ import ( "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgxpool" "github.com/jackc/pgx/v5/tracelog" + "github.com/rs/zerolog" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" @@ -126,7 +127,14 @@ func ConfigurePGXLogger(connConfig *pgx.ConnConfig) { } } - l := zerologadapter.NewLogger(log.Logger, zerologadapter.WithSubDictionary("pgx")) + l := zerologadapter.NewLogger(log.Logger, zerologadapter.WithoutPGXModule(), zerologadapter.WithSubDictionary("pgx"), + zerologadapter.WithContextFunc(func(ctx context.Context, z zerolog.Context) zerolog.Context { + if logger := log.Ctx(ctx); logger != nil { + return logger.With() + } + + return z + })) addTracer(connConfig, &tracelog.TraceLog{Logger: levelMappingFn(l), LogLevel: tracelog.LogLevelInfo}) } diff --git a/internal/graph/context.go b/internal/graph/context.go index 7dcec087ab..1485fa0313 100644 --- a/internal/graph/context.go +++ b/internal/graph/context.go @@ -7,6 +7,7 @@ import ( log "github.com/authzed/spicedb/internal/logging" datastoremw "github.com/authzed/spicedb/internal/middleware/datastore" + "github.com/authzed/spicedb/pkg/middleware/requestid" ) // branchContext returns a context disconnected from the parent context, but populated with the datastore. @@ -26,5 +27,7 @@ func branchContext(ctx context.Context) (context.Context, func(cancelErr error)) detachedContext = loggerFromContext.WithContext(detachedContext) } + detachedContext = requestid.PropagateIfExists(ctx, detachedContext) + return context.WithCancelCause(detachedContext) } diff --git a/pkg/cmd/server/server.go b/pkg/cmd/server/server.go index 1431c41231..cb3968d821 100644 --- a/pkg/cmd/server/server.go +++ b/pkg/cmd/server/server.go @@ -41,6 +41,7 @@ import ( datastorecfg "github.com/authzed/spicedb/pkg/cmd/datastore" "github.com/authzed/spicedb/pkg/cmd/util" "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/middleware/requestid" "github.com/authzed/spicedb/pkg/spiceerrors" ) @@ -276,6 +277,12 @@ func (c *Config) Complete(ctx context.Context) (RunnableServer, error) { combineddispatch.GrpcDialOpts( grpc.WithStatsHandler(otelgrpc.NewClientHandler()), grpc.WithDefaultServiceConfig(hashringConfigJSON), + grpc.WithChainUnaryInterceptor( + requestid.UnaryClientInterceptor(), + ), + grpc.WithChainStreamInterceptor( + requestid.StreamClientInterceptor(), + ), ), combineddispatch.MetricsEnabled(c.DispatchClientMetricsEnabled), combineddispatch.PrometheusSubsystem(c.DispatchClientMetricsPrefix), diff --git a/pkg/middleware/requestid/requestid.go b/pkg/middleware/requestid/requestid.go index 58feb32848..7a0904ab9f 100644 --- a/pkg/middleware/requestid/requestid.go +++ b/pkg/middleware/requestid/requestid.go @@ -41,32 +41,22 @@ type handleRequestID struct { requestIDGenerator IDGenerator } -func (r *handleRequestID) ServerReporter(ctx context.Context, _ interceptors.CallMeta) (interceptors.Reporter, context.Context) { - var requestID string - var haveRequestID bool - md, ok := metadata.FromIncomingContext(ctx) - if ok { - var requestIDs []string - requestIDs, haveRequestID = md[metadataKey] - if haveRequestID { - requestID = requestIDs[0] - } - } +func (r *handleRequestID) ClientReporter(ctx context.Context, meta interceptors.CallMeta) (interceptors.Reporter, context.Context) { + haveRequestID, requestID, ctx := r.fromContextOrGenerate(ctx) - if !haveRequestID && r.generateIfMissing { - requestID, haveRequestID = r.requestIDGenerator(), true + if haveRequestID { + ctx = requestmeta.SetRequestHeaders(ctx, map[requestmeta.RequestMetadataHeaderKey]string{ + requestmeta.RequestIDKey: requestID, + }) + } - // Inject the newly generated request ID into the metadata - if md == nil { - md = metadata.New(nil) - } + return interceptors.NoopReporter{}, ctx +} - md.Set(metadataKey, requestID) - ctx = metadata.NewIncomingContext(ctx, md) - } +func (r *handleRequestID) ServerReporter(ctx context.Context, _ interceptors.CallMeta) (interceptors.Reporter, context.Context) { + haveRequestID, requestID, ctx := r.fromContextOrGenerate(ctx) if haveRequestID { - ctx = metadata.AppendToOutgoingContext(ctx, metadataKey, requestID) err := responsemeta.SetResponseHeaderMetadata(ctx, map[responsemeta.ResponseMetadataHeaderKey]string{ responsemeta.RequestID: requestID, }) @@ -83,18 +73,82 @@ func (r *handleRequestID) ServerReporter(ctx context.Context, _ interceptors.Cal return interceptors.NoopReporter{}, ctx } -// UnaryServerInterceptor returns a new interceptor which handles request IDs according +func (r *handleRequestID) fromContextOrGenerate(ctx context.Context) (bool, string, context.Context) { + haveRequestID, requestID, md := fromContext(ctx) + + if !haveRequestID && r.generateIfMissing { + requestID = r.requestIDGenerator() + haveRequestID = true + + // Inject the newly generated request ID into the metadata + if md == nil { + md = metadata.New(nil) + } + + md.Set(metadataKey, requestID) + ctx = metadata.NewIncomingContext(ctx, md) + } + + return haveRequestID, requestID, ctx +} + +func fromContext(ctx context.Context) (bool, string, metadata.MD) { + var requestID string + var haveRequestID bool + md, ok := metadata.FromIncomingContext(ctx) + if ok { + var requestIDs []string + requestIDs, haveRequestID = md[metadataKey] + if haveRequestID { + requestID = requestIDs[0] + } + } + + return haveRequestID, requestID, md +} + +// PropagateIfExists copies the request ID from the source context to the target context if it exists. +// The updated target context is returned. +func PropagateIfExists(source, target context.Context) context.Context { + exists, requestID, _ := fromContext(source) + + if exists { + targetMD, _ := metadata.FromIncomingContext(target) + if targetMD == nil { + targetMD = metadata.New(nil) + } + + targetMD.Set(metadataKey, requestID) + return metadata.NewIncomingContext(target, targetMD) + } + + return target +} + +// UnaryServerInterceptor returns a new interceptor which handles server request IDs according // to the provided options. func UnaryServerInterceptor(opts ...Option) grpc.UnaryServerInterceptor { return interceptors.UnaryServerInterceptor(createReporter(opts)) } -// StreamServerInterceptor returns a new interceptor which handles request IDs according +// StreamServerInterceptor returns a new interceptor which handles server request IDs according // to the provided options. func StreamServerInterceptor(opts ...Option) grpc.StreamServerInterceptor { return interceptors.StreamServerInterceptor(createReporter(opts)) } +// UnaryClientInterceptor returns a new interceptor which handles client request IDs according +// to the provided options. +func UnaryClientInterceptor(opts ...Option) grpc.UnaryClientInterceptor { + return interceptors.UnaryClientInterceptor(createReporter(opts)) +} + +// StreamClientInterceptor returns a new interceptor which handles client requestIDs according +// to the provided options. +func StreamClientInterceptor(opts ...Option) grpc.StreamClientInterceptor { + return interceptors.StreamClientInterceptor(createReporter(opts)) +} + func createReporter(opts []Option) *handleRequestID { reporter := &handleRequestID{ requestIDGenerator: GenerateRequestID,