diff --git a/flow/activities/flowable.go b/flow/activities/flowable.go index a4b26c2775..ddeba8ea24 100644 --- a/flow/activities/flowable.go +++ b/flow/activities/flowable.go @@ -2,6 +2,7 @@ package activities import ( "context" + "database/sql" "errors" "fmt" "sync" @@ -12,14 +13,17 @@ import ( connpostgres "github.com/PeerDB-io/peer-flow/connectors/postgres" connsnowflake "github.com/PeerDB-io/peer-flow/connectors/snowflake" "github.com/PeerDB-io/peer-flow/connectors/utils" + catalog "github.com/PeerDB-io/peer-flow/connectors/utils/catalog" "github.com/PeerDB-io/peer-flow/connectors/utils/monitoring" "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/model" "github.com/PeerDB-io/peer-flow/shared" "github.com/jackc/pglogrepl" + "github.com/jackc/pgx/v5" log "github.com/sirupsen/logrus" "go.temporal.io/sdk/activity" "golang.org/x/sync/errgroup" + "google.golang.org/protobuf/proto" ) // CheckConnectionResult is the result of a CheckConnection call. @@ -659,31 +663,89 @@ func (a *FlowableActivity) DropFlow(ctx context.Context, config *protos.Shutdown return nil } -func (a *FlowableActivity) SendWALHeartbeat(ctx context.Context, config *protos.FlowConnectionConfigs) error { - srcConn, err := connectors.GetCDCPullConnector(ctx, config.Source) +func getPostgresPeerConfigs(ctx context.Context) ([]*protos.Peer, error) { + var peerOptions sql.RawBytes + catalogPool, catalogErr := catalog.GetCatalogConnectionPoolFromEnv() + if catalogErr != nil { + return nil, fmt.Errorf("error getting catalog connection pool: %w", catalogErr) + } + defer catalogPool.Close() + + optionRows, err := catalogPool.Query(ctx, + "SELECT name, options FROM peers WHERE type=$1", protos.DBType_POSTGRES) if err != nil { - return fmt.Errorf("failed to get destination connector: %w", err) + return nil, err } - defer connectors.CloseConnector(srcConn) - log.WithFields(log.Fields{"flowName": config.FlowJobName}).Info("sending walheartbeat every 10 minutes") - ticker := time.NewTicker(10 * time.Minute) + defer optionRows.Close() + var peerName string + var postgresPeers []*protos.Peer + for optionRows.Next() { + err := optionRows.Scan(&peerName, &peerOptions) + if err != nil { + return nil, err + } + var pgPeerConfig protos.PostgresConfig + unmarshalErr := proto.Unmarshal(peerOptions, &pgPeerConfig) + if unmarshalErr != nil { + return nil, unmarshalErr + } + postgresPeers = append(postgresPeers, &protos.Peer{ + Name: peerName, + Type: protos.DBType_POSTGRES, + Config: &protos.Peer_PostgresConfig{PostgresConfig: &pgPeerConfig}, + }) + } + return postgresPeers, nil +} + +func (a *FlowableActivity) SendWALHeartbeat(ctx context.Context) error { + sendTimeout := 10 * time.Minute + ticker := time.NewTicker(sendTimeout) + defer ticker.Stop() + + pgPeers, err := getPostgresPeerConfigs(ctx) + if err != nil { + return fmt.Errorf("error getting postgres peer configs: %w", err) + } + + activity.RecordHeartbeat(ctx, "sending walheartbeat every 10 minutes") for { select { case <-ctx.Done(): - log.WithFields( - log.Fields{ - "flowName": config.FlowJobName, - }).Info("context is done, exiting wal heartbeat send loop") + log.Info("context is done, exiting wal heartbeat send loop") return nil case <-ticker.C: - err = srcConn.SendWALHeartbeat() - if err != nil { - return fmt.Errorf("failed to send WAL heartbeat: %w", err) + command := ` + BEGIN; + DROP aggregate IF EXISTS PEERDB_EPHEMERAL_HEARTBEAT(float4); + CREATE AGGREGATE PEERDB_EPHEMERAL_HEARTBEAT(float4) (SFUNC = float4pl, STYPE = float4); + DROP aggregate PEERDB_EPHEMERAL_HEARTBEAT(float4); + END; + ` + // run above command for each Postgres peer + for _, pgPeer := range pgPeers { + pgConfig := pgPeer.GetPostgresConfig() + peerConn, peerErr := pgx.Connect(ctx, utils.GetPGConnectionString(pgConfig)) + if peerErr != nil { + return fmt.Errorf("error creating pool for postgres peer %v with host %v: %w", + pgPeer.Name, pgConfig.Host, peerErr) + } + + _, err := peerConn.Exec(ctx, command) + if err != nil { + log.Warnf("warning: could not send walheartbeat to peer %v: %v", pgPeer.Name, err) + } + + closeErr := peerConn.Close(ctx) + if closeErr != nil { + return fmt.Errorf("error closing postgres connection for peer %v with host %v: %w", + pgPeer.Name, pgConfig.Host, closeErr) + } + log.Infof("sent walheartbeat to peer %v", pgPeer.Name) } - log.WithFields( - log.Fields{ - "flowName": config.FlowJobName, - }).Info("sent wal heartbeat") + ticker.Stop() + ticker = time.NewTicker(sendTimeout) + } } } diff --git a/flow/cmd/api.go b/flow/cmd/api.go index 98bb798a63..67273be8a7 100644 --- a/flow/cmd/api.go +++ b/flow/cmd/api.go @@ -10,12 +10,16 @@ import ( utils "github.com/PeerDB-io/peer-flow/connectors/utils/catalog" "github.com/PeerDB-io/peer-flow/generated/protos" + "github.com/PeerDB-io/peer-flow/shared" + peerflow "github.com/PeerDB-io/peer-flow/workflows" + "github.com/google/uuid" "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" log "github.com/sirupsen/logrus" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/reflection" + "go.temporal.io/api/workflowservice/v1" "go.temporal.io/sdk/client" "google.golang.org/grpc/health" "google.golang.org/grpc/health/grpc_health_v1" @@ -58,6 +62,27 @@ func setupGRPCGatewayServer(args *APIServerParams) (*http.Server, error) { return server, nil } +func killExistingHeartbeatFlows(ctx context.Context, tc client.Client, namespace string) error { + listRes, err := tc.ListWorkflow(ctx, + &workflowservice.ListWorkflowExecutionsRequest{ + Namespace: namespace, + Query: "WorkflowType = 'HeartbeatFlowWorkflow'", + }) + if err != nil { + return fmt.Errorf("unable to list workflows: %w", err) + } + log.Info("Requesting cancellation of pre-existing heartbeat flows") + for _, workflow := range listRes.Executions { + log.Info("Cancelling workflow: ", workflow.Execution.WorkflowId) + err := tc.CancelWorkflow(ctx, + workflow.Execution.WorkflowId, workflow.Execution.RunId) + if err != nil && err.Error() != "workflow execution already completed" { + return fmt.Errorf("unable to terminate workflow: %w", err) + } + } + return nil +} + func APIMain(args *APIServerParams) error { ctx := args.ctx clientOptions := client.Options{ @@ -91,6 +116,26 @@ func APIMain(args *APIServerParams) error { flowHandler := NewFlowRequestHandler(tc, catalogConn) defer flowHandler.Close() + err = killExistingHeartbeatFlows(ctx, tc, args.TemporalNamespace) + if err != nil { + return fmt.Errorf("unable to kill existing heartbeat flows: %w", err) + } + + workflowID := fmt.Sprintf("heartbeatflow-%s", uuid.New()) + workflowOptions := client.StartWorkflowOptions{ + ID: workflowID, + TaskQueue: shared.PeerFlowTaskQueue, + } + + _, err = flowHandler.temporalClient.ExecuteWorkflow( + ctx, // context + workflowOptions, // workflow start options + peerflow.HeartbeatFlowWorkflow, // workflow function + ) + if err != nil { + return fmt.Errorf("unable to start heartbeat workflow: %w", err) + } + protos.RegisterFlowServiceServer(grpcServer, flowHandler) grpc_health_v1.RegisterHealthServer(grpcServer, health.NewServer()) reflection.Register(grpcServer) diff --git a/flow/cmd/worker.go b/flow/cmd/worker.go index 0a0901da99..240af31387 100644 --- a/flow/cmd/worker.go +++ b/flow/cmd/worker.go @@ -126,6 +126,7 @@ func WorkerMain(opts *WorkerOptions) error { w.RegisterWorkflow(peerflow.QRepFlowWorkflow) w.RegisterWorkflow(peerflow.QRepPartitionWorkflow) w.RegisterWorkflow(peerflow.DropFlowWorkflow) + w.RegisterWorkflow(peerflow.HeartbeatFlowWorkflow) w.RegisterActivity(&activities.FlowableActivity{ CatalogMirrorMonitor: catalogMirrorMonitor, }) diff --git a/flow/workflows/heartbeat_flow.go b/flow/workflows/heartbeat_flow.go new file mode 100644 index 0000000000..80e89745d9 --- /dev/null +++ b/flow/workflows/heartbeat_flow.go @@ -0,0 +1,22 @@ +package peerflow + +import ( + "time" + + "go.temporal.io/sdk/workflow" +) + +// HeartbeatFlowWorkflow is the workflow that sets up heartbeat sending. +func HeartbeatFlowWorkflow(ctx workflow.Context) error { + + ctx = workflow.WithActivityOptions(ctx, workflow.ActivityOptions{ + StartToCloseTimeout: 7 * 24 * time.Hour, + }) + + heartbeatFuture := workflow.ExecuteActivity(ctx, flowable.SendWALHeartbeat) + if err := heartbeatFuture.Get(ctx, nil); err != nil { + return err + } + + return nil +} diff --git a/ui/app/mirrors/create/schema.ts b/ui/app/mirrors/create/schema.ts index 3400ba459d..19f2eaeb28 100644 --- a/ui/app/mirrors/create/schema.ts +++ b/ui/app/mirrors/create/schema.ts @@ -58,24 +58,24 @@ export const cdcSchema = z.object({ .optional(), snapshotNumRowsPerPartition: z .number({ - invalid_type_error: 'Snapshow rows per partition must be a number', + invalid_type_error: 'Snapshot rows per partition must be a number', }) .int() - .min(1, 'Snapshow rows per partition must be a positive integer') + .min(1, 'Snapshot rows per partition must be a positive integer') .optional(), snapshotMaxParallelWorkers: z .number({ - invalid_type_error: 'Snapshow max workers must be a number', + invalid_type_error: 'Snapshot max workers must be a number', }) .int() - .min(1, 'Snapshow max workers must be a positive integer') + .min(1, 'Snapshot max workers must be a positive integer') .optional(), snapshotNumTablesInParallel: z .number({ - invalid_type_error: 'Snapshow parallel tables must be a number', + invalid_type_error: 'Snapshot parallel tables must be a number', }) .int() - .min(1, 'Snapshow parallel tables must be a positive integer') + .min(1, 'Snapshot parallel tables must be a positive integer') .optional(), snapshotStagingPath: z .string({