Skip to content

Commit

Permalink
Use pg_export_snapshot when InitialSnapshotOnly
Browse files Browse the repository at this point in the history
  • Loading branch information
serprex committed Feb 21, 2024
1 parent 47eecb3 commit 26aff74
Show file tree
Hide file tree
Showing 6 changed files with 156 additions and 32 deletions.
69 changes: 61 additions & 8 deletions flow/activities/snapshot_activity.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"context"
"fmt"
"log/slog"
"sync"
"time"

"go.temporal.io/sdk/activity"

Expand All @@ -15,19 +17,20 @@ import (
)

type SnapshotActivity struct {
SnapshotConnections map[string]SlotSnapshotSignal
Alerter *alerting.Alerter
SnapshotConnectionsMutex sync.Mutex
SnapshotConnections map[string]SlotSnapshotSignal
Alerter *alerting.Alerter
}

// closes the slot signal
func (a *SnapshotActivity) CloseSlotKeepAlive(ctx context.Context, flowJobName string) error {
if a.SnapshotConnections == nil {
return nil
}
a.SnapshotConnectionsMutex.Lock()
defer a.SnapshotConnectionsMutex.Unlock()

if s, ok := a.SnapshotConnections[flowJobName]; ok {
close(s.signal.CloneComplete)
connectors.CloseConnector(ctx, s.connector)
delete(a.SnapshotConnections, flowJobName)
}

return nil
Expand Down Expand Up @@ -89,9 +92,8 @@ func (a *SnapshotActivity) SetupReplication(
return nil, fmt.Errorf("slot error: %w", slotInfo.Err)
}

if a.SnapshotConnections == nil {
a.SnapshotConnections = make(map[string]SlotSnapshotSignal)
}
a.SnapshotConnectionsMutex.Lock()
defer a.SnapshotConnectionsMutex.Unlock()

a.SnapshotConnections[config.FlowJobName] = SlotSnapshotSignal{
signal: slotSignal,
Expand All @@ -104,3 +106,54 @@ func (a *SnapshotActivity) SetupReplication(
SnapshotName: slotInfo.SnapshotName,
}, nil
}

func (a *SnapshotActivity) MaintainTx(ctx context.Context, sessionID string, peer *protos.Peer) error {
conn, err := connectors.GetCDCPullConnector(ctx, peer)
if err != nil {
return err
}
defer connectors.CloseConnector(ctx, conn)

snapshotName, tx, err := conn.ExportSnapshot(ctx)
if err != nil {
return err
}

sss := SlotSnapshotSignal{snapshotName: snapshotName}
a.SnapshotConnectionsMutex.Lock()
a.SnapshotConnections[sessionID] = sss
a.SnapshotConnectionsMutex.Unlock()

for {
activity.RecordHeartbeat(ctx, "maintaining export snapshot transaction")
if ctx.Err() != nil {
a.SnapshotConnectionsMutex.Lock()
delete(a.SnapshotConnections, sessionID)
a.SnapshotConnectionsMutex.Unlock()
return conn.FinishExport(tx)
}
time.Sleep(time.Minute)
}
}

func (a *SnapshotActivity) WaitForExportSnapshot(ctx context.Context, sessionID string) (string, error) {
logger := activity.GetLogger(ctx)
attempt := 0
for {
a.SnapshotConnectionsMutex.Lock()
sss, ok := a.SnapshotConnections[sessionID]
a.SnapshotConnectionsMutex.Unlock()
if ok {
return sss.snapshotName, nil
}
activity.RecordHeartbeat(ctx, "wait another second for snapshot export")
attempt += 1
if attempt > 2 {
logger.Info("waiting on snapshot export", slog.Int("attempt", attempt))
}
if err := ctx.Err(); err != nil {
return "", err
}
time.Sleep(time.Second)
}
}
5 changes: 4 additions & 1 deletion flow/cmd/snapshot_worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,10 @@ func SnapshotWorkerMain(opts *SnapshotWorkerOptions) error {
}

w.RegisterWorkflow(peerflow.SnapshotFlowWorkflow)
w.RegisterActivity(&activities.SnapshotActivity{Alerter: alerter})
w.RegisterActivity(&activities.SnapshotActivity{
SnapshotConnections: make(map[string]activities.SlotSnapshotSignal),
Alerter: alerter,
})

err = w.Run(worker.InterruptCh())
if err != nil {
Expand Down
15 changes: 11 additions & 4 deletions flow/connectors/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@ type CDCPullConnector interface {
EnsurePullability(ctx context.Context, req *protos.EnsurePullabilityBatchInput) (
*protos.EnsurePullabilityBatchOutput, error)

// For InitialSnapshotOnly correctness without replication slot
// `any` is for returning transaction if necessary
ExportSnapshot(context.Context) (string, any, error)

// `any` from ExportSnapshot passed here when done, allowing transaction to commit
FinishExport(any) error

// Methods related to retrieving and pushing records for this connector as a source and destination.

// PullRecords pulls records from the source, and returns a RecordBatch.
Expand All @@ -61,12 +68,12 @@ type NormalizedTablesConnector interface {
Connector

// StartSetupNormalizedTables may be used to have SetupNormalizedTable calls run in a transaction.
StartSetupNormalizedTables(ctx context.Context) (interface{}, error)
StartSetupNormalizedTables(ctx context.Context) (any, error)

// SetupNormalizedTable sets up the normalized table on the connector.
SetupNormalizedTable(
ctx context.Context,
tx interface{},
tx any,
tableIdentifier string,
tableSchema *protos.TableSchema,
softDeleteColName string,
Expand All @@ -75,10 +82,10 @@ type NormalizedTablesConnector interface {

// CleanupSetupNormalizedTables may be used to rollback transaction started by StartSetupNormalizedTables.
// Calling CleanupSetupNormalizedTables after FinishSetupNormalizedTables must be a nop.
CleanupSetupNormalizedTables(ctx context.Context, tx interface{})
CleanupSetupNormalizedTables(ctx context.Context, tx any)

// FinishSetupNormalizedTables may be used to finish transaction started by StartSetupNormalizedTables.
FinishSetupNormalizedTables(ctx context.Context, tx interface{}) error
FinishSetupNormalizedTables(ctx context.Context, tx any) error
}

type CDCSyncConnector interface {
Expand Down
31 changes: 27 additions & 4 deletions flow/connectors/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -643,25 +643,25 @@ func (c *PostgresConnector) getTableSchemaForTable(
}, nil
}

func (c *PostgresConnector) StartSetupNormalizedTables(ctx context.Context) (interface{}, error) {
func (c *PostgresConnector) StartSetupNormalizedTables(ctx context.Context) (any, error) {
// Postgres is cool and supports transactional DDL. So we use a transaction.
return c.conn.Begin(ctx)
}

func (c *PostgresConnector) CleanupSetupNormalizedTables(ctx context.Context, tx interface{}) {
func (c *PostgresConnector) CleanupSetupNormalizedTables(ctx context.Context, tx any) {
err := tx.(pgx.Tx).Rollback(ctx)
if err != pgx.ErrTxClosed && err != nil {
c.logger.Error("error rolling back transaction for creating raw table", slog.Any("error", err))
}
}

func (c *PostgresConnector) FinishSetupNormalizedTables(ctx context.Context, tx interface{}) error {
func (c *PostgresConnector) FinishSetupNormalizedTables(ctx context.Context, tx any) error {
return tx.(pgx.Tx).Commit(ctx)
}

func (c *PostgresConnector) SetupNormalizedTable(
ctx context.Context,
tx interface{},
tx any,
tableIdentifier string,
tableSchema *protos.TableSchema,
softDeleteColName string,
Expand Down Expand Up @@ -797,6 +797,29 @@ func (c *PostgresConnector) EnsurePullability(
return &protos.EnsurePullabilityBatchOutput{TableIdentifierMapping: tableIdentifierMapping}, nil
}

func (c *PostgresConnector) ExportSnapshot(ctx context.Context) (string, any, error) {
var snapshotName string
tx, err := c.conn.Begin(ctx)
if err != nil {
return "", nil, err
}

err = tx.QueryRow(ctx, "select pg_export_snapshot()").Scan(&snapshotName)
if err != nil {
_ = tx.Rollback(ctx)
return "", nil, err
}

return snapshotName, tx, err
}

func (c *PostgresConnector) FinishExport(tx any) error {
pgtx := tx.(pgx.Tx)
timeout, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
return pgtx.Commit(timeout)
}

// SetupReplication sets up replication for the source connector.
func (c *PostgresConnector) SetupReplication(ctx context.Context, signal SlotSignal, req *protos.SetupReplicationInput) error {
// ensure that the flowjob name is [a-z0-9_] only
Expand Down
5 changes: 3 additions & 2 deletions flow/e2e/test_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ func RegisterWorkflowsAndActivities(t *testing.T, env *testsuite.TestWorkflowEnv
env.SetTestTimeout(5 * time.Minute)

peerflow.RegisterFlowWorkerWorkflows(env)
env.RegisterWorkflow(peerflow.SnapshotFlowWorkflow)

alerter, err := alerting.NewAlerter(conn)
if err != nil {
Expand All @@ -69,7 +70,8 @@ func RegisterWorkflowsAndActivities(t *testing.T, env *testsuite.TestWorkflowEnv
Alerter: alerter,
})
env.RegisterActivity(&activities.SnapshotActivity{
Alerter: alerter,
SnapshotConnections: make(map[string]activities.SlotSnapshotSignal),
Alerter: alerter,
})
}

Expand Down Expand Up @@ -547,7 +549,6 @@ func NewTemporalTestWorkflowEnvironment(t *testing.T) *testsuite.TestWorkflowEnv

env := testSuite.NewTestWorkflowEnvironment()
RegisterWorkflowsAndActivities(t, env)
env.RegisterWorkflow(peerflow.SnapshotFlowWorkflow)
return env
}

Expand Down
63 changes: 50 additions & 13 deletions flow/workflows/snapshot_flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func (s *SnapshotFlowExecution) setupReplication(
},
})

tblNameMapping := make(map[string]string)
tblNameMapping := make(map[string]string, len(s.config.TableMappings))
for _, v := range s.config.TableMappings {
tblNameMapping[v.SourceTableIdentifier] = v.DestinationTableIdentifier
}
Expand Down Expand Up @@ -265,17 +265,6 @@ func SnapshotFlowWorkflow(ctx workflow.Context, config *protos.FlowConnectionCon
return nil
}

if config.InitialSnapshotOnly {
slotInfo := &protos.SetupReplicationOutput{
SlotName: "peerdb_initial_copy_only",
SnapshotName: "", // empty snapshot name indicates that we should not use a snapshot
}
if err := se.cloneTables(ctx, slotInfo, int(config.SnapshotNumTablesInParallel)); err != nil {
return fmt.Errorf("failed to clone tables: %w", err)
}
return nil
}

sessionOpts := &workflow.SessionOptions{
CreationTimeout: 5 * time.Minute,
ExecutionTimeout: time.Hour * 24 * 365 * 100, // 100 years
Expand All @@ -288,7 +277,55 @@ func SnapshotFlowWorkflow(ctx workflow.Context, config *protos.FlowConnectionCon
}
defer workflow.CompleteSession(sessionCtx)

if err := se.cloneTablesWithSlot(ctx, sessionCtx, numTablesInParallel); err != nil {
if config.InitialSnapshotOnly {
sessionInfo := workflow.GetSessionInfo(sessionCtx)

exportCtx := workflow.WithActivityOptions(sessionCtx, workflow.ActivityOptions{
StartToCloseTimeout: sessionOpts.ExecutionTimeout,
HeartbeatTimeout: 10 * time.Minute,
WaitForCancellation: true,
})

fMaintain := workflow.ExecuteActivity(
exportCtx,
snapshot.MaintainTx,
sessionInfo.SessionID,
config.Source,
)

fExportSnapshot := workflow.ExecuteActivity(
exportCtx,
snapshot.WaitForExportSnapshot,
sessionInfo.SessionID,
)

var sessionError error
var snapshotName string
sessionSelector := workflow.NewNamedSelector(ctx, "Export Snapshot Setup")
sessionSelector.AddFuture(fMaintain, func(f workflow.Future) {
// MaintainTx should never exit without an error before this point
sessionError = f.Get(exportCtx, nil)
})
sessionSelector.AddFuture(fExportSnapshot, func(f workflow.Future) {
// Happy path is waiting for this to return without error
sessionError = f.Get(exportCtx, &snapshotName)
})
sessionSelector.AddReceive(ctx.Done(), func(_ workflow.ReceiveChannel, _ bool) {
sessionError = ctx.Err()
})
sessionSelector.Select(ctx)
if sessionError != nil {
return sessionError
}

slotInfo := &protos.SetupReplicationOutput{
SlotName: "peerdb_initial_copy_only",
SnapshotName: snapshotName,
}
if err := se.cloneTables(ctx, slotInfo, int(config.SnapshotNumTablesInParallel)); err != nil {
return fmt.Errorf("failed to clone tables: %w", err)
}
} else if err := se.cloneTablesWithSlot(ctx, sessionCtx, numTablesInParallel); err != nil {
return fmt.Errorf("failed to clone slots and create replication slot: %w", err)
}

Expand Down

0 comments on commit 26aff74

Please sign in to comment.