diff --git a/flow/activities/flowable.go b/flow/activities/flowable.go index 65f78e3b35..767cb2f872 100644 --- a/flow/activities/flowable.go +++ b/flow/activities/flowable.go @@ -42,6 +42,8 @@ type SlotSnapshotSignal struct { type FlowableActivity struct { CatalogPool *pgxpool.Pool Alerter *alerting.Alerter + CdcCacheRw sync.RWMutex + CdcCache map[string]connectors.CDCPullConnector } func (a *FlowableActivity) CheckConnection( @@ -204,7 +206,68 @@ func (a *FlowableActivity) CreateNormalizedTable( }, nil } -func (a *FlowableActivity) StartFlow(ctx context.Context, +func (a *FlowableActivity) MaintainPull( + ctx context.Context, + config *protos.FlowConnectionConfigs, + sessionID string, +) error { + srcConn, err := connectors.GetCDCPullConnector(ctx, config.Source) + if err != nil { + return err + } + defer connectors.CloseConnector(ctx, srcConn) + + if err := srcConn.SetupReplConn(ctx); err != nil { + return err + } + + a.CdcCacheRw.Lock() + a.CdcCache[sessionID] = srcConn + a.CdcCacheRw.Unlock() + + ticker := time.NewTicker(15 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + activity.RecordHeartbeat(ctx, "keep session alive") + if err := srcConn.ReplPing(ctx); err != nil { + activity.GetLogger(ctx).Error("Failed to send keep alive ping to replication connection", slog.Any("error", err)) + } + case <-ctx.Done(): + a.CdcCacheRw.Lock() + delete(a.CdcCache, sessionID) + a.CdcCacheRw.Unlock() + return nil + } + } +} + +func (a *FlowableActivity) WaitForSourceConnector(ctx context.Context, sessionID string) error { + logger := activity.GetLogger(ctx) + attempt := 0 + for { + a.CdcCacheRw.RLock() + _, ok := a.CdcCache[sessionID] + a.CdcCacheRw.RUnlock() + if ok { + return nil + } + activity.RecordHeartbeat(ctx, "wait another second for source connector") + attempt += 1 + if attempt > 2 { + logger.Info("waiting on source connector setup", slog.Int("attempt", attempt)) + } + if err := ctx.Err(); err != nil { + return err + } + time.Sleep(time.Second) + } +} + +func (a *FlowableActivity) StartFlow( + ctx context.Context, input *protos.StartFlowInput, ) (*model.SyncResponse, error) { ctx = context.WithValue(ctx, shared.FlowNameKey, input.FlowConnectionConfigs.FlowJobName) @@ -223,11 +286,15 @@ func (a *FlowableActivity) StartFlow(ctx context.Context, tblNameMapping[v.SourceTableIdentifier] = model.NewNameAndExclude(v.DestinationTableIdentifier, v.Exclude) } - srcConn, err := connectors.GetCDCPullConnector(ctx, config.Source) - if err != nil { - return nil, fmt.Errorf("failed to get source connector: %w", err) + a.CdcCacheRw.RLock() + srcConn, ok := a.CdcCache[input.SessionId] + a.CdcCacheRw.RUnlock() + if !ok { + return nil, errors.New("source connector missing from CdcCache") + } + if err := srcConn.ConnectionActive(ctx); err != nil { + return nil, err } - defer connectors.CloseConnector(ctx, srcConn) shutdown := utils.HeartbeatRoutine(ctx, func() string { jobName := input.FlowConnectionConfigs.FlowJobName @@ -235,22 +302,22 @@ func (a *FlowableActivity) StartFlow(ctx context.Context, }) defer shutdown() - errGroup, errCtx := errgroup.WithContext(ctx) - batchSize := input.SyncFlowOptions.BatchSize if batchSize <= 0 { batchSize = 1_000_000 } - lastOffset, err := dstConn.GetLastOffset(ctx, input.FlowConnectionConfigs.FlowJobName) - if err != nil { - return nil, err - } - // start a goroutine to pull records from the source recordBatch := model.NewCDCRecordStream() startTime := time.Now() flowName := input.FlowConnectionConfigs.FlowJobName + + lastOffset, err := dstConn.GetLastOffset(ctx, flowName) + if err != nil { + return nil, err + } + + errGroup, errCtx := errgroup.WithContext(ctx) errGroup.Go(func() error { return srcConn.PullRecords(errCtx, a.CatalogPool, &model.PullRecordsRequest{ FlowJobName: flowName, diff --git a/flow/cmd/worker.go b/flow/cmd/worker.go index 753ef14f14..ee9218a9da 100644 --- a/flow/cmd/worker.go +++ b/flow/cmd/worker.go @@ -16,6 +16,7 @@ import ( "go.temporal.io/sdk/worker" "github.com/PeerDB-io/peer-flow/activities" + "github.com/PeerDB-io/peer-flow/connectors" utils "github.com/PeerDB-io/peer-flow/connectors/utils/catalog" "github.com/PeerDB-io/peer-flow/logger" "github.com/PeerDB-io/peer-flow/shared" @@ -127,7 +128,9 @@ func WorkerMain(opts *WorkerOptions) error { return queueErr } - w := worker.New(c, taskQueue, worker.Options{}) + w := worker.New(c, taskQueue, worker.Options{ + EnableSessionWorker: true, + }) peerflow.RegisterFlowWorkerWorkflows(w) alerter, err := alerting.NewAlerter(conn) @@ -138,6 +141,7 @@ func WorkerMain(opts *WorkerOptions) error { w.RegisterActivity(&activities.FlowableActivity{ CatalogPool: conn, Alerter: alerter, + CdcCache: make(map[string]connectors.CDCPullConnector), }) err = w.Run(worker.InterruptCh()) diff --git a/flow/connectors/core.go b/flow/connectors/core.go index 889250d804..b166efc745 100644 --- a/flow/connectors/core.go +++ b/flow/connectors/core.go @@ -39,6 +39,10 @@ type CDCPullConnector interface { *protos.EnsurePullabilityBatchOutput, error) // Methods related to retrieving and pushing records for this connector as a source and destination. + SetupReplConn(context.Context) error + + // Ping source to keep connection alive. Can be called concurrently with PullRecords; skips ping in that case. + ReplPing(context.Context) error // PullRecords pulls records from the source, and returns a RecordBatch. // This method should be idempotent, and should be able to be called multiple times with the same request. diff --git a/flow/connectors/postgres/cdc.go b/flow/connectors/postgres/cdc.go index 774c57a0ca..cacc14d3fe 100644 --- a/flow/connectors/postgres/cdc.go +++ b/flow/connectors/postgres/cdc.go @@ -3,7 +3,6 @@ package connpostgres import ( "context" "crypto/sha256" - "errors" "fmt" "log/slog" "time" @@ -28,7 +27,6 @@ import ( type PostgresCDCSource struct { *PostgresConnector - replConn *pgx.Conn SrcTableIDNameMapping map[uint32]string TableNameMapping map[string]model.NameAndExclude slot string @@ -64,29 +62,28 @@ type startReplicationOpts struct { } // Create a new PostgresCDCSource -func (c *PostgresConnector) NewPostgresCDCSource(ctx context.Context, cdcConfig *PostgresCDCConfig) *PostgresCDCSource { +func (c *PostgresConnector) NewPostgresCDCSource(cdcConfig *PostgresCDCConfig) *PostgresCDCSource { return &PostgresCDCSource{ PostgresConnector: c, - replConn: cdcConfig.Connection, SrcTableIDNameMapping: cdcConfig.SrcTableIDNameMapping, TableNameMapping: cdcConfig.TableNameMapping, slot: cdcConfig.Slot, publication: cdcConfig.Publication, relationMessageMapping: cdcConfig.RelationMessageMapping, - typeMap: pgtype.NewMap(), childToParentRelIDMapping: cdcConfig.ChildToParentRelIDMap, + typeMap: pgtype.NewMap(), commitLock: false, catalogPool: cdcConfig.CatalogPool, flowJobName: cdcConfig.FlowJobName, } } -func getChildToParentRelIDMap(ctx context.Context, conn *pgx.Conn) (map[uint32]uint32, error) { +func GetChildToParentRelIDMap(ctx context.Context, conn *pgx.Conn) (map[uint32]uint32, error) { query := ` SELECT parent.oid AS parentrelid, child.oid AS childrelid FROM pg_inherits JOIN pg_class parent ON pg_inherits.inhparent = parent.oid - JOIN pg_class child ON pg_inherits.inhrelid = child.oid + JOIN pg_class child ON pg_inherits.inhrelid = child.oid WHERE parent.relkind='p'; ` @@ -94,7 +91,6 @@ func getChildToParentRelIDMap(ctx context.Context, conn *pgx.Conn) (map[uint32]u if err != nil { return nil, fmt.Errorf("error querying for child to parent relid map: %w", err) } - defer rows.Close() childToParentRelIDMap := make(map[uint32]uint32) @@ -113,85 +109,14 @@ func getChildToParentRelIDMap(ctx context.Context, conn *pgx.Conn) (map[uint32]u // PullRecords pulls records from the cdc stream func (p *PostgresCDCSource) PullRecords(ctx context.Context, req *model.PullRecordsRequest) error { - replicationOpts, err := p.replicationOptions() - if err != nil { - return fmt.Errorf("error getting replication options: %w", err) - } - - pgConn := p.replConn.PgConn() - - // start replication - var clientXLogPos, startLSN pglogrepl.LSN - if req.LastOffset > 0 { - p.logger.Info("starting replication from last sync state", slog.Int64("last checkpoint", req.LastOffset)) - clientXLogPos = pglogrepl.LSN(req.LastOffset) - startLSN = clientXLogPos + 1 - } - - opts := startReplicationOpts{ - conn: pgConn, - startLSN: startLSN, - replicationOpts: *replicationOpts, - } - - err = p.startReplication(ctx, opts) - if err != nil { - return fmt.Errorf("error starting replication: %w", err) - } - - p.logger.Info(fmt.Sprintf("started replication on slot %s at startLSN: %d", p.slot, startLSN)) - - return p.consumeStream(ctx, pgConn, req, clientXLogPos, req.RecordStream) -} - -func (p *PostgresCDCSource) startReplication(ctx context.Context, opts startReplicationOpts) error { - err := pglogrepl.StartReplication(ctx, opts.conn, p.slot, opts.startLSN, opts.replicationOpts) - if err != nil { - p.logger.Error("error starting replication", slog.Any("error", err)) - return fmt.Errorf("error starting replication at startLsn - %d: %w", opts.startLSN, err) - } - - p.logger.Info(fmt.Sprintf("started replication on slot %s at startLSN: %d", p.slot, opts.startLSN)) - return nil -} - -func (p *PostgresCDCSource) replicationOptions() (*pglogrepl.StartReplicationOptions, error) { - pluginArguments := []string{ - "proto_version '1'", - } - - if p.publication != "" { - pubOpt := fmt.Sprintf("publication_names '%s'", p.publication) - pluginArguments = append(pluginArguments, pubOpt) - } else { - return nil, errors.New("publication name is not set") - } - - return &pglogrepl.StartReplicationOptions{PluginArgs: pluginArguments}, nil -} - -// start consuming the cdc stream -func (p *PostgresCDCSource) consumeStream( - ctx context.Context, - conn *pgconn.PgConn, - req *model.PullRecordsRequest, - clientXLogPos pglogrepl.LSN, - records *model.CDCRecordStream, -) error { - defer func() { - timeout, cancel := context.WithTimeout(context.Background(), 1*time.Minute) - err := conn.Close(timeout) - if err != nil { - p.logger.Error("error closing replication connection", slog.Any("error", err)) - } - cancel() - }() - + conn := p.replConn.PgConn() + records := req.RecordStream // clientXLogPos is the last checkpoint id, we need to ack that we have processed // until clientXLogPos each time we send a standby status update. // consumedXLogPos is the lsn that has been committed on the destination. - consumedXLogPos := pglogrepl.LSN(0) - if clientXLogPos > 0 { + var clientXLogPos, consumedXLogPos pglogrepl.LSN + if req.LastOffset > 0 { + clientXLogPos = pglogrepl.LSN(req.LastOffset) consumedXLogPos = clientXLogPos err := pglogrepl.SendStandbyStatusUpdate(ctx, conn, @@ -300,7 +225,6 @@ func (p *PostgresCDCSource) consumeStream( var receiveCtx context.Context var cancel context.CancelFunc - if cdcRecordsStorage.IsEmpty() { receiveCtx, cancel = context.WithCancel(ctx) } else { diff --git a/flow/connectors/postgres/postgres.go b/flow/connectors/postgres/postgres.go index 052b148fef..1d33c7489b 100644 --- a/flow/connectors/postgres/postgres.go +++ b/flow/connectors/postgres/postgres.go @@ -7,9 +7,11 @@ import ( "log/slog" "regexp" "strings" + "sync" "time" "github.com/google/uuid" + "github.com/jackc/pglogrepl" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgtype" @@ -31,12 +33,21 @@ type PostgresConnector struct { ssh *SSHTunnel conn *pgx.Conn replConfig *pgx.ConnConfig + replConn *pgx.Conn + replState *ReplState + replLock sync.Mutex customTypesMapping map[uint32]string metadataSchema string hushWarnOID map[uint32]struct{} logger log.Logger } +type ReplState struct { + Slot string + Publication string + Offset int64 +} + func NewPostgresConnector(ctx context.Context, pgConfig *protos.PostgresConfig) (*PostgresConnector, error) { connectionString := utils.GetPGConnectionString(pgConfig) @@ -82,6 +93,8 @@ func NewPostgresConnector(ctx context.Context, pgConfig *protos.PostgresConfig) ssh: tunnel, conn: conn, replConfig: replConfig, + replState: nil, + replLock: sync.Mutex{}, customTypesMapping: customTypeMap, metadataSchema: metadataSchema, hushWarnOID: make(map[uint32]struct{}), @@ -95,19 +108,124 @@ func (c *PostgresConnector) CreateReplConn(ctx context.Context) (*pgx.Conn, erro logger.LoggerFromCtx(ctx).Error("failed to create replication connection", "error", err) return nil, fmt.Errorf("failed to create replication connection: %w", err) } - return conn, nil } +func (c *PostgresConnector) SetupReplConn(ctx context.Context) error { + conn, err := c.CreateReplConn(ctx) + if err != nil { + return err + } + c.replConn = conn + return nil +} + +// To keep connection alive between sync batches. +// By default postgres drops connection after 1 minute of inactivity. +func (c *PostgresConnector) ReplPing(ctx context.Context) error { + if c.replLock.TryLock() { + defer c.replLock.Unlock() + if c.replState != nil { + return pglogrepl.SendStandbyStatusUpdate( + ctx, + c.replConn.PgConn(), + pglogrepl.StandbyStatusUpdate{WALWritePosition: pglogrepl.LSN(c.replState.Offset)}, + ) + } + } + return nil +} + +func (c *PostgresConnector) MaybeStartReplication( + ctx context.Context, + slotName string, + publicationName string, + req *model.PullRecordsRequest, +) error { + if c.replState != nil && (c.replState.Offset != req.LastOffset || + c.replState.Slot != slotName || + c.replState.Publication != publicationName) { + return fmt.Errorf("replState changed, reset connector. slot name: old=%s new=%s, publication: old=%s new=%s, offset: old=%d new=%d", + c.replState.Slot, slotName, c.replState.Publication, publicationName, c.replState.Offset, req.LastOffset, + ) + } + + if c.replState == nil { + replicationOpts, err := c.replicationOptions(publicationName) + if err != nil { + return fmt.Errorf("error getting replication options: %w", err) + } + + var startLSN pglogrepl.LSN + if req.LastOffset > 0 { + c.logger.Info("starting replication from last sync state", slog.Int64("last checkpoint", req.LastOffset)) + startLSN = pglogrepl.LSN(req.LastOffset + 1) + } + + opts := startReplicationOpts{ + conn: c.replConn.PgConn(), + startLSN: startLSN, + replicationOpts: *replicationOpts, + } + + err = c.startReplication(ctx, slotName, opts) + if err != nil { + return fmt.Errorf("error starting replication: %w", err) + } + + c.logger.Info(fmt.Sprintf("started replication on slot %s at startLSN: %d", slotName, startLSN)) + c.replState = &ReplState{ + Slot: slotName, + Publication: publicationName, + Offset: req.LastOffset, + } + } + return nil +} + +func (c *PostgresConnector) startReplication(ctx context.Context, slotName string, opts startReplicationOpts) error { + err := pglogrepl.StartReplication(ctx, opts.conn, slotName, opts.startLSN, opts.replicationOpts) + if err != nil { + c.logger.Error("error starting replication", slog.Any("error", err)) + return fmt.Errorf("error starting replication at startLsn - %d: %w", opts.startLSN, err) + } + + c.logger.Info(fmt.Sprintf("started replication on slot %s at startLSN: %d", slotName, opts.startLSN)) + return nil +} + +func (c *PostgresConnector) replicationOptions(publicationName string) (*pglogrepl.StartReplicationOptions, error) { + pluginArguments := []string{ + "proto_version '1'", + } + + if publicationName != "" { + pubOpt := fmt.Sprintf("publication_names %s", QuoteLiteral(publicationName)) + pluginArguments = append(pluginArguments, pubOpt) + } else { + return nil, errors.New("publication name is not set") + } + + return &pglogrepl.StartReplicationOptions{PluginArgs: pluginArguments}, nil +} + // Close closes all connections. func (c *PostgresConnector) Close() error { + var connerr, replerr error if c != nil { - timeout, cancel := context.WithTimeout(context.Background(), time.Minute) + timeout, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() - c.conn.Close(timeout) + connerr = c.conn.Close(timeout) + + if c.replConn != nil { + timeout, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + replerr = c.replConn.Close(timeout) + } + c.ssh.Close() } - return nil + return errors.Join(connerr, replerr) } func (c *PostgresConnector) Conn() *pgx.Conn { @@ -215,19 +333,21 @@ func (c *PostgresConnector) PullRecords(ctx context.Context, catalogPool *pgxpoo c.logger.Info("PullRecords: performed checks for slot and publication") - replConn, err := c.CreateReplConn(ctx) + childToParentRelIDMap, err := GetChildToParentRelIDMap(ctx, c.conn) if err != nil { - return err + return fmt.Errorf("error getting child to parent relid map: %w", err) } - defer replConn.Close(ctx) - childToParentRelIDMap, err := getChildToParentRelIDMap(ctx, replConn) + c.replLock.Lock() + defer c.replLock.Unlock() + + err = c.MaybeStartReplication(ctx, slotName, publicationName, req) if err != nil { - return fmt.Errorf("error getting child to parent relid map: %w", err) + return err } - cdc := c.NewPostgresCDCSource(ctx, &PostgresCDCConfig{ - Connection: replConn, + cdc := c.NewPostgresCDCSource(&PostgresCDCConfig{ + Connection: c.replConn, SrcTableIDNameMapping: req.SrcTableIDNameMapping, Slot: slotName, Publication: publicationName, @@ -243,10 +363,14 @@ func (c *PostgresConnector) PullRecords(ctx context.Context, catalogPool *pgxpoo return err } + req.RecordStream.Close() + c.replState.Offset = req.RecordStream.GetLastCheckpoint() + latestLSN, err := c.getCurrentLSN(ctx) if err != nil { return fmt.Errorf("failed to get current LSN: %w", err) } + err = monitoring.UpdateLatestLSNAtSourceForCDCFlow(ctx, catalogPool, req.FlowJobName, int64(latestLSN)) if err != nil { return fmt.Errorf("failed to update latest LSN at source for CDC flow: %w", err) @@ -922,7 +1046,6 @@ func (c *PostgresConnector) HandleSlotInfo( return monitoring.AppendSlotSizeInfo(ctx, catalogPool, peerName, slotInfo[0]) } -// GetLastOffset returns the last synced offset for a job. func getOpenConnectionsForUser(ctx context.Context, conn *pgx.Conn, user string) (*protos.GetOpenConnectionsForUserResult, error) { row := conn.QueryRow(ctx, getNumConnectionsForUser, user) diff --git a/flow/e2e/test_utils.go b/flow/e2e/test_utils.go index 602ace6b8f..7c2b2a8968 100644 --- a/flow/e2e/test_utils.go +++ b/flow/e2e/test_utils.go @@ -20,8 +20,10 @@ import ( "github.com/stretchr/testify/require" "go.temporal.io/sdk/temporal" "go.temporal.io/sdk/testsuite" + "go.temporal.io/sdk/worker" "github.com/PeerDB-io/peer-flow/activities" + "github.com/PeerDB-io/peer-flow/connectors" connpostgres "github.com/PeerDB-io/peer-flow/connectors/postgres" connsnowflake "github.com/PeerDB-io/peer-flow/connectors/snowflake" "github.com/PeerDB-io/peer-flow/connectors/utils" @@ -67,6 +69,7 @@ func RegisterWorkflowsAndActivities(t *testing.T, env *testsuite.TestWorkflowEnv env.RegisterActivity(&activities.FlowableActivity{ CatalogPool: conn, Alerter: alerter, + CdcCache: make(map[string]connectors.CDCPullConnector), }) env.RegisterActivity(&activities.SnapshotActivity{ Alerter: alerter, @@ -543,10 +546,10 @@ func NewTemporalTestWorkflowEnvironment(t *testing.T) *testsuite.TestWorkflowEnv &slog.HandlerOptions{Level: slog.LevelWarn}, ), )) - tLogger := TStructuredLogger{logger: logger} + testSuite.SetLogger(&TStructuredLogger{logger: logger}) - testSuite.SetLogger(&tLogger) env := testSuite.NewTestWorkflowEnvironment() + env.SetWorkerOptions(worker.Options{EnableSessionWorker: true}) RegisterWorkflowsAndActivities(t, env) env.RegisterWorkflow(peerflow.SnapshotFlowWorkflow) return env diff --git a/flow/workflows/cdc_flow.go b/flow/workflows/cdc_flow.go index afc17927ce..4e75d1fcbe 100644 --- a/flow/workflows/cdc_flow.go +++ b/flow/workflows/cdc_flow.go @@ -95,12 +95,12 @@ func (s *CDCFlowWorkflowState) TruncateProgress(logger log.Logger) { } if s.SyncFlowErrors != nil { - logger.Warn("SyncFlowErrors: ", s.SyncFlowErrors) + logger.Warn("SyncFlowErrors", slog.Any("errors", s.SyncFlowErrors)) s.SyncFlowErrors = nil } if s.NormalizeFlowErrors != nil { - logger.Warn("NormalizeFlowErrors: ", s.NormalizeFlowErrors) + logger.Warn("NormalizeFlowErrors", slog.Any("errors", s.NormalizeFlowErrors)) s.NormalizeFlowErrors = nil } } @@ -119,21 +119,24 @@ func NewCDCFlowWorkflowExecution(ctx workflow.Context) *CDCFlowWorkflowExecution } } -func GetChildWorkflowID( - ctx workflow.Context, - prefix string, - peerFlowName string, -) (string, error) { - childWorkflowIDSideEffect := workflow.SideEffect(ctx, func(ctx workflow.Context) interface{} { - return fmt.Sprintf("%s-%s-%s", prefix, peerFlowName, uuid.New().String()) +func GetUUID(ctx workflow.Context) (string, error) { + uuidSideEffect := workflow.SideEffect(ctx, func(ctx workflow.Context) interface{} { + return uuid.New().String() }) - var childWorkflowID string - if err := childWorkflowIDSideEffect.Get(&childWorkflowID); err != nil { - return "", fmt.Errorf("failed to get child workflow ID: %w", err) + var uuidString string + if err := uuidSideEffect.Get(&uuidString); err != nil { + return "", fmt.Errorf("failed to generate UUID: %w", err) } + return uuidString, nil +} - return childWorkflowID, nil +func GetChildWorkflowID( + prefix string, + peerFlowName string, + uuid string, +) string { + return fmt.Sprintf("%s-%s-%s", prefix, peerFlowName, uuid) } // CDCFlowWorkflowResult is the result of the PeerFlowWorkflow. @@ -164,16 +167,16 @@ func (w *CDCFlowWorkflowExecution) processCDCFlowConfigUpdates(ctx workflow.Cont return err } - additionalTablesWorkflowCfg := proto.Clone(cfg).(*protos.FlowConnectionConfigs) - additionalTablesWorkflowCfg.DoInitialSnapshot = true - additionalTablesWorkflowCfg.InitialSnapshotOnly = true - additionalTablesWorkflowCfg.TableMappings = flowConfigUpdate.AdditionalTables + additionalTablesCfg := proto.Clone(cfg).(*protos.FlowConnectionConfigs) + additionalTablesCfg.DoInitialSnapshot = true + additionalTablesCfg.InitialSnapshotOnly = true + additionalTablesCfg.TableMappings = flowConfigUpdate.AdditionalTables - childAdditionalTablesCDCFlowID, - err := GetChildWorkflowID(ctx, "additional-cdc-flow", additionalTablesWorkflowCfg.FlowJobName) + additionalTablesUUID, err := GetUUID(ctx) if err != nil { return err } + childAdditionalTablesCDCFlowID := GetChildWorkflowID("additional-cdc-flow", additionalTablesCfg.FlowJobName, additionalTablesUUID) // execute the sync flow as a child workflow childAdditionalTablesCDCFlowOpts := workflow.ChildWorkflowOptions{ @@ -189,7 +192,7 @@ func (w *CDCFlowWorkflowExecution) processCDCFlowConfigUpdates(ctx workflow.Cont childAdditionalTablesCDCFlowFuture := workflow.ExecuteChildWorkflow( childAdditionalTablesCDCFlowCtx, CDCFlowWorkflowWithConfig, - additionalTablesWorkflowCfg, + additionalTablesCfg, nil, ) var res *CDCFlowWorkflowResult @@ -248,6 +251,8 @@ func CDCFlowWorkflowWithConfig( shared.MirrorNameSearchAttribute: cfg.FlowJobName, } + originalRunID := workflow.GetInfo(ctx).OriginalRunID + // we cannot skip SetupFlow if SnapshotFlow did not complete in cases where Resync is enabled // because Resync modifies TableMappings before Setup and also before Snapshot // for safety, rely on the idempotency of SetupFlow instead @@ -268,10 +273,8 @@ func CDCFlowWorkflowWithConfig( // start the SetupFlow workflow as a child workflow, and wait for it to complete // it should return the table schema for the source peer - setupFlowID, err := GetChildWorkflowID(ctx, "setup-flow", cfg.FlowJobName) - if err != nil { - return state, err - } + setupFlowID := GetChildWorkflowID("setup-flow", cfg.FlowJobName, originalRunID) + childSetupFlowOpts := workflow.ChildWorkflowOptions{ WorkflowID: setupFlowID, ParentClosePolicy: enums.PARENT_CLOSE_POLICY_REQUEST_CANCEL, @@ -292,10 +295,7 @@ func CDCFlowWorkflowWithConfig( state.CurrentFlowStatus = protos.FlowStatus_STATUS_SNAPSHOT // next part of the setup is to snapshot-initial-copy and setup replication slots. - snapshotFlowID, err := GetChildWorkflowID(ctx, "snapshot-flow", cfg.FlowJobName) - if err != nil { - return state, err - } + snapshotFlowID := GetChildWorkflowID("snapshot-flow", cfg.FlowJobName, originalRunID) taskQueue, err := shared.GetPeerFlowTaskQueueName(shared.SnapshotFlowTaskQueueID) if err != nil { @@ -361,6 +361,55 @@ func CDCFlowWorkflowWithConfig( } } + sessionOptions := &workflow.SessionOptions{ + CreationTimeout: 5 * time.Minute, + ExecutionTimeout: 144 * time.Hour, + HeartbeatTimeout: time.Minute, + } + syncSessionCtx, err := workflow.CreateSession(ctx, sessionOptions) + if err != nil { + return nil, err + } + defer workflow.CompleteSession(syncSessionCtx) + sessionInfo := workflow.GetSessionInfo(syncSessionCtx) + + syncCtx := workflow.WithActivityOptions(syncSessionCtx, workflow.ActivityOptions{ + StartToCloseTimeout: 72 * time.Hour, + HeartbeatTimeout: time.Minute, + WaitForCancellation: true, + }) + fMaintain := workflow.ExecuteActivity( + syncCtx, + flowable.MaintainPull, + cfg, + sessionInfo.SessionID, + ) + fSessionSetup := workflow.ExecuteActivity( + syncCtx, + flowable.WaitForSourceConnector, + sessionInfo.SessionID, + ) + + var sessionError error + sessionSelector := workflow.NewNamedSelector(ctx, "Session Setup") + sessionSelector.AddFuture(fMaintain, func(f workflow.Future) { + // MaintainPull should never exit without an error before this point + sessionError = f.Get(syncCtx, nil) + }) + sessionSelector.AddFuture(fSessionSetup, func(f workflow.Future) { + // Happy path is waiting for this to return without error + sessionError = f.Get(syncCtx, nil) + }) + sessionSelector.AddReceive(ctx.Done(), func(_ workflow.ReceiveChannel, _ bool) { + sessionError = ctx.Err() + }) + sessionSelector.Select(ctx) + if sessionError != nil { + state.SyncFlowErrors = append(state.SyncFlowErrors, sessionError.Error()) + state.TruncateProgress(w.logger) + return state, workflow.NewContinueAsNewError(ctx, CDCFlowWorkflowWithConfig, cfg, state) + } + // when we carry forward state, don't remake the options if state.SyncFlowOptions == nil { state.SyncFlowOptions = &protos.SyncFlowOptions{ @@ -376,11 +425,7 @@ func CDCFlowWorkflowWithConfig( currentSyncFlowNum := 0 totalRecordsSynced := int64(0) - normalizeFlowID, err := GetChildWorkflowID(ctx, "normalize-flow", cfg.FlowJobName) - if err != nil { - return state, err - } - + normalizeFlowID := GetChildWorkflowID("normalize-flow", cfg.FlowJobName, originalRunID) childNormalizeFlowOpts := workflow.ChildWorkflowOptions{ WorkflowID: normalizeFlowID, ParentClosePolicy: enums.PARENT_CLOSE_POLICY_REQUEST_CANCEL, @@ -391,6 +436,7 @@ func CDCFlowWorkflowWithConfig( WaitForCancellation: true, } normCtx := workflow.WithChildOptions(ctx, childNormalizeFlowOpts) + childNormalizeFlowFuture := workflow.ExecuteChildWorkflow( normCtx, NormalizeFlowWorkflow, @@ -422,12 +468,12 @@ func CDCFlowWorkflowWithConfig( } var canceled bool - signalChan := workflow.GetSignalChannel(ctx, shared.FlowSignalName) - mainLoopSelector := workflow.NewSelector(ctx) + flowSignalChan := workflow.GetSignalChannel(ctx, shared.FlowSignalName) + mainLoopSelector := workflow.NewNamedSelector(ctx, "Main Loop") mainLoopSelector.AddReceive(ctx.Done(), func(_ workflow.ReceiveChannel, _ bool) { canceled = true }) - mainLoopSelector.AddReceive(signalChan, func(c workflow.ReceiveChannel, _ bool) { + mainLoopSelector.AddReceive(flowSignalChan, func(c workflow.ReceiveChannel, _ bool) { var signalVal shared.CDCFlowSignal c.ReceiveAsync(&signalVal) state.ActiveSignal = shared.FlowSignalHandler(state.ActiveSignal, signalVal, w.logger) @@ -491,43 +537,31 @@ func CDCFlowWorkflowWithConfig( " limit on the number of syncflows to be executed: ", currentSyncFlowNum) break } - currentSyncFlowNum++ - - syncFlowID, err := GetChildWorkflowID(ctx, "sync-flow", cfg.FlowJobName) - if err != nil { - finishNormalize() - return state, err - } + currentSyncFlowNum += 1 + w.logger.Info("executing sync flow", slog.Int("count", currentSyncFlowNum), slog.String("flowName", cfg.FlowJobName)) - // execute the sync flow as a child workflow - childSyncFlowOpts := workflow.ChildWorkflowOptions{ - WorkflowID: syncFlowID, - ParentClosePolicy: enums.PARENT_CLOSE_POLICY_REQUEST_CANCEL, - RetryPolicy: &temporal.RetryPolicy{ - MaximumAttempts: 20, - }, - SearchAttributes: mirrorNameSearch, - WaitForCancellation: true, + startFlowInput := &protos.StartFlowInput{ + FlowConnectionConfigs: cfg, + SyncFlowOptions: state.SyncFlowOptions, + RelationMessageMapping: state.RelationMessageMapping, + SrcTableIdNameMapping: state.SyncFlowOptions.SrcTableIdNameMapping, + TableNameSchemaMapping: state.SyncFlowOptions.TableNameSchemaMapping, + SessionId: sessionInfo.SessionID, } - syncCtx := workflow.WithChildOptions(ctx, childSyncFlowOpts) + fStartFlow := workflow.ExecuteActivity(syncCtx, flowable.StartFlow, startFlowInput) state.SyncFlowOptions.RelationMessageMapping = state.RelationMessageMapping - childSyncFlowFuture := workflow.ExecuteChildWorkflow( - syncCtx, - SyncFlowWorkflow, - cfg, - state.SyncFlowOptions, - ) - var syncDone bool + var syncDone, syncErr bool var normalizeSignalError error normDone := normWaitChan == nil - mainLoopSelector.AddFuture(childSyncFlowFuture, func(f workflow.Future) { + mainLoopSelector.AddFuture(fStartFlow, func(f workflow.Future) { syncDone = true var childSyncFlowRes *model.SyncResponse if err := f.Get(syncCtx, &childSyncFlowRes); err != nil { - w.logger.Error("failed to execute sync flow: ", err) + w.logger.Error("failed to execute sync flow", slog.Any("error", err)) state.SyncFlowErrors = append(state.SyncFlowErrors, err.Error()) + syncErr = true } else if childSyncFlowRes != nil { state.SyncFlowStatuses = append(state.SyncFlowStatuses, childSyncFlowRes) state.RelationMessageMapping = childSyncFlowRes.RelationMessageMapping @@ -587,8 +621,14 @@ func CDCFlowWorkflowWithConfig( if canceled { break } + if syncErr { + state.TruncateProgress(w.logger) + return state, workflow.NewContinueAsNewError(ctx, CDCFlowWorkflowWithConfig, cfg, state) + } if normalizeSignalError != nil { - return state, normalizeSignalError + state.NormalizeFlowErrors = append(state.NormalizeFlowErrors, normalizeSignalError.Error()) + state.TruncateProgress(w.logger) + return state, workflow.NewContinueAsNewError(ctx, CDCFlowWorkflowWithConfig, cfg, state) } if !normDone { normWaitChan.Receive(ctx, nil) diff --git a/flow/workflows/normalize_flow.go b/flow/workflows/normalize_flow.go index f37544b14a..16c53ba7a5 100644 --- a/flow/workflows/normalize_flow.go +++ b/flow/workflows/normalize_flow.go @@ -12,7 +12,8 @@ import ( "github.com/PeerDB-io/peer-flow/shared" ) -func NormalizeFlowWorkflow(ctx workflow.Context, +func NormalizeFlowWorkflow( + ctx workflow.Context, config *protos.FlowConnectionConfigs, ) (*model.NormalizeFlowResponse, error) { logger := workflow.GetLogger(ctx) @@ -25,10 +26,10 @@ func NormalizeFlowWorkflow(ctx workflow.Context, results := make([]model.NormalizeResponse, 0, 4) errors := make([]string, 0) syncChan := workflow.GetSignalChannel(ctx, shared.NormalizeSyncSignalName) - var tableNameSchemaMapping map[string]*protos.TableSchema var stopLoop, canceled bool var lastSyncBatchID, syncBatchID int64 + var tableNameSchemaMapping map[string]*protos.TableSchema lastSyncBatchID = -1 syncBatchID = -1 selector := workflow.NewNamedSelector(ctx, config.FlowJobName+"-normalize") diff --git a/flow/workflows/register.go b/flow/workflows/register.go index 2f7e4411cf..bd38b81bfd 100644 --- a/flow/workflows/register.go +++ b/flow/workflows/register.go @@ -9,7 +9,6 @@ func RegisterFlowWorkerWorkflows(w worker.WorkflowRegistry) { w.RegisterWorkflow(DropFlowWorkflow) w.RegisterWorkflow(NormalizeFlowWorkflow) w.RegisterWorkflow(SetupFlowWorkflow) - w.RegisterWorkflow(SyncFlowWorkflow) w.RegisterWorkflow(QRepFlowWorkflow) w.RegisterWorkflow(QRepPartitionWorkflow) w.RegisterWorkflow(XminFlowWorkflow) diff --git a/flow/workflows/snapshot_flow.go b/flow/workflows/snapshot_flow.go index f7de2f8850..ed38066015 100644 --- a/flow/workflows/snapshot_flow.go +++ b/flow/workflows/snapshot_flow.go @@ -8,7 +8,6 @@ import ( "strings" "time" - "github.com/google/uuid" "go.temporal.io/sdk/log" "go.temporal.io/sdk/temporal" "go.temporal.io/sdk/workflow" @@ -97,18 +96,11 @@ func (s *SnapshotFlowExecution) cloneTable( srcName := mapping.SourceTableIdentifier dstName := mapping.DestinationTableIdentifier - childWorkflowIDSideEffect := workflow.SideEffect(ctx, func(ctx workflow.Context) interface{} { - childWorkflowID := fmt.Sprintf("clone_%s_%s_%s", flowName, dstName, uuid.New().String()) - reg := regexp.MustCompile("[^a-zA-Z0-9]+") - return reg.ReplaceAllString(childWorkflowID, "_") - }) + originalRunID := workflow.GetInfo(ctx).OriginalRunID - var childWorkflowID string - if err := childWorkflowIDSideEffect.Get(&childWorkflowID); err != nil { - s.logger.Error(fmt.Sprintf("failed to get child id for source table %s and destination table %s", - srcName, dstName), slog.Any("error", err), cloneLog) - return fmt.Errorf("failed to get child workflow ID: %w", err) - } + childWorkflowID := fmt.Sprintf("clone_%s_%s_%s", flowName, dstName, originalRunID) + reg := regexp.MustCompile("[^a-zA-Z0-9_]+") + childWorkflowID = reg.ReplaceAllString(childWorkflowID, "_") s.logger.Info(fmt.Sprintf("Obtained child id %s for source table %s and destination table %s", childWorkflowID, srcName, dstName), cloneLog) diff --git a/flow/workflows/sync_flow.go b/flow/workflows/sync_flow.go deleted file mode 100644 index e3a53c9250..0000000000 --- a/flow/workflows/sync_flow.go +++ /dev/null @@ -1,79 +0,0 @@ -package peerflow - -import ( - "fmt" - "log/slog" - "time" - - "go.temporal.io/sdk/log" - "go.temporal.io/sdk/workflow" - - "github.com/PeerDB-io/peer-flow/generated/protos" - "github.com/PeerDB-io/peer-flow/model" -) - -type SyncFlowState struct { - CDCFlowName string - Progress []string -} - -type SyncFlowExecution struct { - SyncFlowState - executionID string - logger log.Logger -} - -// NewSyncFlowExecution creates a new instance of SyncFlowExecution. -func NewSyncFlowExecution(ctx workflow.Context, state *SyncFlowState) *SyncFlowExecution { - return &SyncFlowExecution{ - SyncFlowState: *state, - executionID: workflow.GetInfo(ctx).WorkflowExecution.ID, - logger: workflow.GetLogger(ctx), - } -} - -// executeSyncFlow executes the sync flow. -func (s *SyncFlowExecution) executeSyncFlow( - ctx workflow.Context, - config *protos.FlowConnectionConfigs, - opts *protos.SyncFlowOptions, - relationMessageMapping model.RelationMessageMapping, -) (*model.SyncResponse, error) { - s.logger.Info("executing sync flow", slog.String("flowName", s.CDCFlowName)) - - startFlowCtx := workflow.WithActivityOptions(ctx, workflow.ActivityOptions{ - StartToCloseTimeout: 72 * time.Hour, - HeartbeatTimeout: time.Minute, - WaitForCancellation: true, - }) - - // execute StartFlow on the peers to start the flow - startFlowInput := &protos.StartFlowInput{ - FlowConnectionConfigs: config, - SyncFlowOptions: opts, - RelationMessageMapping: relationMessageMapping, - SrcTableIdNameMapping: opts.SrcTableIdNameMapping, - TableNameSchemaMapping: opts.TableNameSchemaMapping, - } - fStartFlow := workflow.ExecuteActivity(startFlowCtx, flowable.StartFlow, startFlowInput) - - var syncRes *model.SyncResponse - if err := fStartFlow.Get(startFlowCtx, &syncRes); err != nil { - return nil, fmt.Errorf("failed to flow: %w", err) - } - return syncRes, nil -} - -// SyncFlowWorkflow is the synchronization workflow for a peer flow. -// This workflow assumes that the metadata tables have already been setup, -// and the checkpoint for the source peer is known. -func SyncFlowWorkflow(ctx workflow.Context, - config *protos.FlowConnectionConfigs, - options *protos.SyncFlowOptions, -) (*model.SyncResponse, error) { - s := NewSyncFlowExecution(ctx, &SyncFlowState{ - CDCFlowName: config.FlowJobName, - Progress: []string{}, - }) - return s.executeSyncFlow(ctx, config, options, options.RelationMessageMapping) -} diff --git a/protos/flow.proto b/protos/flow.proto index 2147ee2fc8..3bbebdc9a4 100644 --- a/protos/flow.proto +++ b/protos/flow.proto @@ -106,19 +106,13 @@ message SyncFlowOptions { repeated TableMapping table_mappings = 6; } -// deprecated, unused -message LastSyncState { - int64 checkpoint = 1; -} - message StartFlowInput { - // deprecated, unused - LastSyncState last_sync_state = 1; FlowConnectionConfigs flow_connection_configs = 2; SyncFlowOptions sync_flow_options = 3; map relation_message_mapping = 4; map src_table_id_name_mapping = 5; map table_name_schema_mapping = 6; + string session_id = 7; } message StartNormalizeInput {