diff --git a/flow/activities/snapshot_activity.go b/flow/activities/snapshot_activity.go index 7643629b2c..4383388d50 100644 --- a/flow/activities/snapshot_activity.go +++ b/flow/activities/snapshot_activity.go @@ -4,6 +4,8 @@ import ( "context" "fmt" "log/slog" + "sync" + "time" "go.temporal.io/sdk/activity" @@ -15,19 +17,20 @@ import ( ) type SnapshotActivity struct { - SnapshotConnections map[string]SlotSnapshotSignal - Alerter *alerting.Alerter + SnapshotConnectionsMutex sync.Mutex + SnapshotConnections map[string]SlotSnapshotSignal + Alerter *alerting.Alerter } // closes the slot signal func (a *SnapshotActivity) CloseSlotKeepAlive(ctx context.Context, flowJobName string) error { - if a.SnapshotConnections == nil { - return nil - } + a.SnapshotConnectionsMutex.Lock() + defer a.SnapshotConnectionsMutex.Unlock() if s, ok := a.SnapshotConnections[flowJobName]; ok { close(s.signal.CloneComplete) connectors.CloseConnector(ctx, s.connector) + delete(a.SnapshotConnections, flowJobName) } return nil @@ -89,9 +92,8 @@ func (a *SnapshotActivity) SetupReplication( return nil, fmt.Errorf("slot error: %w", slotInfo.Err) } - if a.SnapshotConnections == nil { - a.SnapshotConnections = make(map[string]SlotSnapshotSignal) - } + a.SnapshotConnectionsMutex.Lock() + defer a.SnapshotConnectionsMutex.Unlock() a.SnapshotConnections[config.FlowJobName] = SlotSnapshotSignal{ signal: slotSignal, @@ -104,3 +106,54 @@ func (a *SnapshotActivity) SetupReplication( SnapshotName: slotInfo.SnapshotName, }, nil } + +func (a *SnapshotActivity) MaintainTx(ctx context.Context, sessionID string, peer *protos.Peer) error { + conn, err := connectors.GetCDCPullConnector(ctx, peer) + if err != nil { + return err + } + defer connectors.CloseConnector(ctx, conn) + + snapshotName, tx, err := conn.ExportSnapshot(ctx) + if err != nil { + return err + } + + sss := SlotSnapshotSignal{snapshotName: snapshotName} + a.SnapshotConnectionsMutex.Lock() + a.SnapshotConnections[sessionID] = sss + a.SnapshotConnectionsMutex.Unlock() + + for { + activity.RecordHeartbeat(ctx, "maintaining export snapshot transaction") + if ctx.Err() != nil { + a.SnapshotConnectionsMutex.Lock() + delete(a.SnapshotConnections, sessionID) + a.SnapshotConnectionsMutex.Unlock() + return conn.FinishExport(tx) + } + time.Sleep(time.Minute) + } +} + +func (a *SnapshotActivity) WaitForExportSnapshot(ctx context.Context, sessionID string) (string, error) { + logger := activity.GetLogger(ctx) + attempt := 0 + for { + a.SnapshotConnectionsMutex.Lock() + sss, ok := a.SnapshotConnections[sessionID] + a.SnapshotConnectionsMutex.Unlock() + if ok { + return sss.snapshotName, nil + } + activity.RecordHeartbeat(ctx, "wait another second for snapshot export") + attempt += 1 + if attempt > 2 { + logger.Info("waiting on snapshot export", slog.Int("attempt", attempt)) + } + if err := ctx.Err(); err != nil { + return "", err + } + time.Sleep(time.Second) + } +} diff --git a/flow/cmd/snapshot_worker.go b/flow/cmd/snapshot_worker.go index 35f2e81039..bc53785382 100644 --- a/flow/cmd/snapshot_worker.go +++ b/flow/cmd/snapshot_worker.go @@ -73,7 +73,10 @@ func SnapshotWorkerMain(opts *SnapshotWorkerOptions) error { } w.RegisterWorkflow(peerflow.SnapshotFlowWorkflow) - w.RegisterActivity(&activities.SnapshotActivity{Alerter: alerter}) + w.RegisterActivity(&activities.SnapshotActivity{ + SnapshotConnections: make(map[string]activities.SlotSnapshotSignal), + Alerter: alerter, + }) err = w.Run(worker.InterruptCh()) if err != nil { diff --git a/flow/connectors/core.go b/flow/connectors/core.go index 889250d804..c859f9d3f5 100644 --- a/flow/connectors/core.go +++ b/flow/connectors/core.go @@ -38,6 +38,13 @@ type CDCPullConnector interface { EnsurePullability(ctx context.Context, req *protos.EnsurePullabilityBatchInput) ( *protos.EnsurePullabilityBatchOutput, error) + // For InitialSnapshotOnly correctness without replication slot + // `any` is for returning transaction if necessary + ExportSnapshot(context.Context) (string, any, error) + + // `any` from ExportSnapshot passed here when done, allowing transaction to commit + FinishExport(any) error + // Methods related to retrieving and pushing records for this connector as a source and destination. // PullRecords pulls records from the source, and returns a RecordBatch. @@ -61,12 +68,12 @@ type NormalizedTablesConnector interface { Connector // StartSetupNormalizedTables may be used to have SetupNormalizedTable calls run in a transaction. - StartSetupNormalizedTables(ctx context.Context) (interface{}, error) + StartSetupNormalizedTables(ctx context.Context) (any, error) // SetupNormalizedTable sets up the normalized table on the connector. SetupNormalizedTable( ctx context.Context, - tx interface{}, + tx any, tableIdentifier string, tableSchema *protos.TableSchema, softDeleteColName string, @@ -75,10 +82,10 @@ type NormalizedTablesConnector interface { // CleanupSetupNormalizedTables may be used to rollback transaction started by StartSetupNormalizedTables. // Calling CleanupSetupNormalizedTables after FinishSetupNormalizedTables must be a nop. - CleanupSetupNormalizedTables(ctx context.Context, tx interface{}) + CleanupSetupNormalizedTables(ctx context.Context, tx any) // FinishSetupNormalizedTables may be used to finish transaction started by StartSetupNormalizedTables. - FinishSetupNormalizedTables(ctx context.Context, tx interface{}) error + FinishSetupNormalizedTables(ctx context.Context, tx any) error } type CDCSyncConnector interface { diff --git a/flow/connectors/postgres/postgres.go b/flow/connectors/postgres/postgres.go index c44923926a..63bd51be91 100644 --- a/flow/connectors/postgres/postgres.go +++ b/flow/connectors/postgres/postgres.go @@ -643,25 +643,25 @@ func (c *PostgresConnector) getTableSchemaForTable( }, nil } -func (c *PostgresConnector) StartSetupNormalizedTables(ctx context.Context) (interface{}, error) { +func (c *PostgresConnector) StartSetupNormalizedTables(ctx context.Context) (any, error) { // Postgres is cool and supports transactional DDL. So we use a transaction. return c.conn.Begin(ctx) } -func (c *PostgresConnector) CleanupSetupNormalizedTables(ctx context.Context, tx interface{}) { +func (c *PostgresConnector) CleanupSetupNormalizedTables(ctx context.Context, tx any) { err := tx.(pgx.Tx).Rollback(ctx) if err != pgx.ErrTxClosed && err != nil { c.logger.Error("error rolling back transaction for creating raw table", slog.Any("error", err)) } } -func (c *PostgresConnector) FinishSetupNormalizedTables(ctx context.Context, tx interface{}) error { +func (c *PostgresConnector) FinishSetupNormalizedTables(ctx context.Context, tx any) error { return tx.(pgx.Tx).Commit(ctx) } func (c *PostgresConnector) SetupNormalizedTable( ctx context.Context, - tx interface{}, + tx any, tableIdentifier string, tableSchema *protos.TableSchema, softDeleteColName string, @@ -797,6 +797,29 @@ func (c *PostgresConnector) EnsurePullability( return &protos.EnsurePullabilityBatchOutput{TableIdentifierMapping: tableIdentifierMapping}, nil } +func (c *PostgresConnector) ExportSnapshot(ctx context.Context) (string, any, error) { + var snapshotName string + tx, err := c.conn.Begin(ctx) + if err != nil { + return "", nil, err + } + + err = tx.QueryRow(ctx, "select pg_export_snapshot()").Scan(&snapshotName) + if err != nil { + _ = tx.Rollback(ctx) + return "", nil, err + } + + return snapshotName, tx, err +} + +func (c *PostgresConnector) FinishExport(tx any) error { + pgtx := tx.(pgx.Tx) + timeout, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + return pgtx.Commit(timeout) +} + // SetupReplication sets up replication for the source connector. func (c *PostgresConnector) SetupReplication(ctx context.Context, signal SlotSignal, req *protos.SetupReplicationInput) error { // ensure that the flowjob name is [a-z0-9_] only diff --git a/flow/e2e/test_utils.go b/flow/e2e/test_utils.go index 505ec2055a..041ec0a7b7 100644 --- a/flow/e2e/test_utils.go +++ b/flow/e2e/test_utils.go @@ -58,6 +58,7 @@ func RegisterWorkflowsAndActivities(t *testing.T, env *testsuite.TestWorkflowEnv env.SetTestTimeout(5 * time.Minute) peerflow.RegisterFlowWorkerWorkflows(env) + env.RegisterWorkflow(peerflow.SnapshotFlowWorkflow) alerter, err := alerting.NewAlerter(conn) if err != nil { @@ -69,7 +70,8 @@ func RegisterWorkflowsAndActivities(t *testing.T, env *testsuite.TestWorkflowEnv Alerter: alerter, }) env.RegisterActivity(&activities.SnapshotActivity{ - Alerter: alerter, + SnapshotConnections: make(map[string]activities.SlotSnapshotSignal), + Alerter: alerter, }) } @@ -547,7 +549,6 @@ func NewTemporalTestWorkflowEnvironment(t *testing.T) *testsuite.TestWorkflowEnv env := testSuite.NewTestWorkflowEnvironment() RegisterWorkflowsAndActivities(t, env) - env.RegisterWorkflow(peerflow.SnapshotFlowWorkflow) return env } diff --git a/flow/workflows/snapshot_flow.go b/flow/workflows/snapshot_flow.go index ed38066015..100ef22b69 100644 --- a/flow/workflows/snapshot_flow.go +++ b/flow/workflows/snapshot_flow.go @@ -39,7 +39,7 @@ func (s *SnapshotFlowExecution) setupReplication( }, }) - tblNameMapping := make(map[string]string) + tblNameMapping := make(map[string]string, len(s.config.TableMappings)) for _, v := range s.config.TableMappings { tblNameMapping[v.SourceTableIdentifier] = v.DestinationTableIdentifier } @@ -265,17 +265,6 @@ func SnapshotFlowWorkflow(ctx workflow.Context, config *protos.FlowConnectionCon return nil } - if config.InitialSnapshotOnly { - slotInfo := &protos.SetupReplicationOutput{ - SlotName: "peerdb_initial_copy_only", - SnapshotName: "", // empty snapshot name indicates that we should not use a snapshot - } - if err := se.cloneTables(ctx, slotInfo, int(config.SnapshotNumTablesInParallel)); err != nil { - return fmt.Errorf("failed to clone tables: %w", err) - } - return nil - } - sessionOpts := &workflow.SessionOptions{ CreationTimeout: 5 * time.Minute, ExecutionTimeout: time.Hour * 24 * 365 * 100, // 100 years @@ -288,7 +277,55 @@ func SnapshotFlowWorkflow(ctx workflow.Context, config *protos.FlowConnectionCon } defer workflow.CompleteSession(sessionCtx) - if err := se.cloneTablesWithSlot(ctx, sessionCtx, numTablesInParallel); err != nil { + if config.InitialSnapshotOnly { + sessionInfo := workflow.GetSessionInfo(sessionCtx) + + exportCtx := workflow.WithActivityOptions(sessionCtx, workflow.ActivityOptions{ + StartToCloseTimeout: sessionOpts.ExecutionTimeout, + HeartbeatTimeout: 10 * time.Minute, + WaitForCancellation: true, + }) + + fMaintain := workflow.ExecuteActivity( + exportCtx, + snapshot.MaintainTx, + sessionInfo.SessionID, + config.Source, + ) + + fExportSnapshot := workflow.ExecuteActivity( + exportCtx, + snapshot.WaitForExportSnapshot, + sessionInfo.SessionID, + ) + + var sessionError error + var snapshotName string + sessionSelector := workflow.NewNamedSelector(ctx, "Export Snapshot Setup") + sessionSelector.AddFuture(fMaintain, func(f workflow.Future) { + // MaintainTx should never exit without an error before this point + sessionError = f.Get(exportCtx, nil) + }) + sessionSelector.AddFuture(fExportSnapshot, func(f workflow.Future) { + // Happy path is waiting for this to return without error + sessionError = f.Get(exportCtx, &snapshotName) + }) + sessionSelector.AddReceive(ctx.Done(), func(_ workflow.ReceiveChannel, _ bool) { + sessionError = ctx.Err() + }) + sessionSelector.Select(ctx) + if sessionError != nil { + return sessionError + } + + slotInfo := &protos.SetupReplicationOutput{ + SlotName: "peerdb_initial_copy_only", + SnapshotName: snapshotName, + } + if err := se.cloneTables(ctx, slotInfo, int(config.SnapshotNumTablesInParallel)); err != nil { + return fmt.Errorf("failed to clone tables: %w", err) + } + } else if err := se.cloneTablesWithSlot(ctx, sessionCtx, numTablesInParallel); err != nil { return fmt.Errorf("failed to clone slots and create replication slot: %w", err) }