diff --git a/flow/cmd/api.go b/flow/cmd/api.go index 58e6beac0f..ca225e4292 100644 --- a/flow/cmd/api.go +++ b/flow/cmd/api.go @@ -23,8 +23,8 @@ import ( "google.golang.org/grpc/health/grpc_health_v1" "google.golang.org/grpc/reflection" - "github.com/PeerDB-io/peer-flow/auth" "github.com/PeerDB-io/peer-flow/generated/protos" + "github.com/PeerDB-io/peer-flow/middleware" "github.com/PeerDB-io/peer-flow/peerdbenv" "github.com/PeerDB-io/peer-flow/shared" peerflow "github.com/PeerDB-io/peer-flow/workflows" @@ -213,14 +213,23 @@ func APIMain(ctx context.Context, args *APIServerParams) error { return fmt.Errorf("unable to create Temporal client: %w", err) } - options, err := auth.AuthGrpcMiddleware([]string{ + authGrpcMiddleware, err := middleware.AuthGrpcMiddleware([]string{ grpc_health_v1.Health_Check_FullMethodName, grpc_health_v1.Health_Watch_FullMethodName, }) if err != nil { return err } - grpcServer := grpc.NewServer(options...) + + requestLoggingMiddleware := middleware.RequestLoggingMiddleWare() + + // Interceptors are executed in the order they are passed to, so unauthorized requests are not logged + interceptors := grpc.ChainUnaryInterceptor( + authGrpcMiddleware, + requestLoggingMiddleware, + ) + + grpcServer := grpc.NewServer(interceptors) catalogPool, err := peerdbenv.GetCatalogConnectionPoolFromEnv(ctx) if err != nil { diff --git a/flow/middleware/logging.go b/flow/middleware/logging.go new file mode 100644 index 0000000000..51932700fe --- /dev/null +++ b/flow/middleware/logging.go @@ -0,0 +1,31 @@ +package middleware + +import ( + "context" + "log/slog" + + "google.golang.org/grpc" + + "github.com/PeerDB-io/peer-flow/peerdbenv" +) + +func RequestLoggingMiddleWare() grpc.UnaryServerInterceptor { + if !peerdbenv.PeerDBRAPIRequestLoggingEnabled() { + slog.Info("Request Logging Interceptor is disabled") + return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { + return handler(ctx, req) + } + } + slog.Info("Setting up request logging middleware") + return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { + slog.Info("Received gRPC request", slog.String("method", info.FullMethod)) + + resp, err := handler(ctx, req) + if err != nil { + slog.Error("gRPC request failed", slog.String("method", info.FullMethod), slog.Any("error", err)) + } else { + slog.Info("gRPC request completed successfully", slog.String("method", info.FullMethod)) + } + return resp, err + } +} diff --git a/flow/auth/middleware.go b/flow/middleware/oauth.go similarity index 81% rename from flow/auth/middleware.go rename to flow/middleware/oauth.go index bb3ee34da5..52bbc03672 100644 --- a/flow/auth/middleware.go +++ b/flow/middleware/oauth.go @@ -1,4 +1,4 @@ -package auth +package middleware import ( "context" @@ -34,7 +34,7 @@ type identityProvider struct { issuer string } -func AuthGrpcMiddleware(unauthenticatedMethods []string) ([]grpc.ServerOption, error) { +func AuthGrpcMiddleware(unauthenticatedMethods []string) (grpc.UnaryServerInterceptor, error) { oauthConfig := peerdbenv.GetPeerDBOAuthConfig() oauthJwtClaims := map[string]string{} if oauthConfig.OAuthJwtClaimKey != "" { @@ -57,7 +57,9 @@ func AuthGrpcMiddleware(unauthenticatedMethods []string) ([]grpc.ServerOption, e slog.Warn("authentication is disabled") - return nil, nil + return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { + return handler(ctx, req) + }, nil } if err != nil { @@ -68,36 +70,24 @@ func AuthGrpcMiddleware(unauthenticatedMethods []string) ([]grpc.ServerOption, e for _, method := range unauthenticatedMethods { unauthenticatedMethodsMap[method] = struct{}{} } - return []grpc.ServerOption{ - grpc.ChainUnaryInterceptor(func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { - slog.Info("Received gRPC request", slog.String("method", info.FullMethod)) - - if _, unauthorized := unauthenticatedMethodsMap[info.FullMethod]; !unauthorized { - var authHeader string - authHeaders := metadata.ValueFromIncomingContext(ctx, "Authorization") - if len(authHeaders) == 1 { - authHeader = authHeaders[0] - } else if len(authHeaders) > 1 { - slog.Warn("Multiple Authorization headers supplied, request rejected", slog.String("method", info.FullMethod)) - return nil, status.Errorf(codes.Unauthenticated, "multiple Authorization headers supplied, request rejected") - } - _, err := validateRequestToken(authHeader, cfg.OauthJwtCustomClaims, ip...) - if err != nil { - slog.Debug("Failed to validate request token", slog.String("method", info.FullMethod), slog.Any("error", err)) - return nil, status.Errorf(codes.Unauthenticated, "%s", err.Error()) - } + return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { + if _, unauthorized := unauthenticatedMethodsMap[info.FullMethod]; !unauthorized { + var authHeader string + authHeaders := metadata.ValueFromIncomingContext(ctx, "Authorization") + if len(authHeaders) == 1 { + authHeader = authHeaders[0] + } else if len(authHeaders) > 1 { + slog.Warn("Multiple Authorization headers supplied, request rejected", slog.String("method", info.FullMethod)) + return nil, status.Errorf(codes.Unauthenticated, "multiple Authorization headers supplied, request rejected") } - - resp, err := handler(ctx, req) - + _, err := validateRequestToken(authHeader, cfg.OauthJwtCustomClaims, ip...) if err != nil { - slog.Error("gRPC request failed", slog.String("method", info.FullMethod), slog.Any("error", err)) - } else { - slog.Info("gRPC request completed successfully", slog.String("method", info.FullMethod)) + slog.Debug("Failed to validate request token", slog.String("method", info.FullMethod), slog.Any("error", err)) + return nil, status.Errorf(codes.Unauthenticated, "%s", err.Error()) } + } - return resp, err - }), + return handler(ctx, req) }, nil } diff --git a/flow/peerdbenv/config.go b/flow/peerdbenv/config.go index ecae67037f..e033b87195 100644 --- a/flow/peerdbenv/config.go +++ b/flow/peerdbenv/config.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "log/slog" + "strconv" "strings" "time" @@ -156,3 +157,12 @@ func PeerDBGetIncidentIoUrl() string { func PeerDBGetIncidentIoToken() string { return GetEnvString("PEERDB_INCIDENTIO_TOKEN", "") } + +func PeerDBRAPIRequestLoggingEnabled() bool { + requestLoggingEnabled, err := strconv.ParseBool(GetEnvString("PEERDB_API_REQUEST_LOGGING_ENABLED", "false")) + if err != nil { + slog.Error("failed to parse PEERDB_API_REQUEST_LOGGING_ENABLED to bool", "error", err) + return false + } + return requestLoggingEnabled +} diff --git a/flow/peerdbenv/oauth.go b/flow/peerdbenv/oauth.go index cd76b30193..54b2f04425 100644 --- a/flow/peerdbenv/oauth.go +++ b/flow/peerdbenv/oauth.go @@ -1,6 +1,9 @@ package peerdbenv -import "strconv" +import ( + "log/slog" + "strconv" +) type PeerDBOAuthConfig struct { // there can be more complex use cases where domain != issuer, but we handle them later if required @@ -18,6 +21,7 @@ func GetPeerDBOAuthConfig() PeerDBOAuthConfig { oauthDiscoveryEnabledString := GetEnvString("PEERDB_OAUTH_DISCOVERY_ENABLED", "false") oauthDiscoveryEnabled, err := strconv.ParseBool(oauthDiscoveryEnabledString) if err != nil { + slog.Error("failed to parse PEERDB_OAUTH_DISCOVERY_ENABLED to bool", "error", err) oauthDiscoveryEnabled = false } oauthKeysetJson := GetEnvString("PEERDB_OAUTH_KEYSET_JSON", "")