Skip to content

Commit

Permalink
Merge branch 'main' into sf-mixed-case-v2-partial
Browse files Browse the repository at this point in the history
  • Loading branch information
heavycrystal authored Dec 28, 2023
2 parents 29d4770 + 7dd1f0d commit 682fa0a
Show file tree
Hide file tree
Showing 22 changed files with 1,451 additions and 1,008 deletions.
64 changes: 50 additions & 14 deletions flow/activities/flowable.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ type CheckConnectionResult struct {
}

type SlotSnapshotSignal struct {
signal *connpostgres.SlotSignal
signal connpostgres.SlotSignal
snapshotName string
connector connectors.CDCPullConnector
}
Expand All @@ -49,9 +49,10 @@ type FlowableActivity struct {
// CheckConnection implements CheckConnection.
func (a *FlowableActivity) CheckConnection(
ctx context.Context,
config *protos.Peer,
config *protos.SetupInput,
) (*CheckConnectionResult, error) {
dstConn, err := connectors.GetCDCSyncConnector(ctx, config)
ctx = context.WithValue(ctx, shared.FlowNameKey, config.FlowName)
dstConn, err := connectors.GetCDCSyncConnector(ctx, config.Peer)
if err != nil {
return nil, fmt.Errorf("failed to get connector: %w", err)
}
Expand All @@ -65,16 +66,16 @@ func (a *FlowableActivity) CheckConnection(
}

// SetupMetadataTables implements SetupMetadataTables.
func (a *FlowableActivity) SetupMetadataTables(ctx context.Context, config *protos.Peer) error {
dstConn, err := connectors.GetCDCSyncConnector(ctx, config)
func (a *FlowableActivity) SetupMetadataTables(ctx context.Context, config *protos.SetupInput) error {
ctx = context.WithValue(ctx, shared.FlowNameKey, config.FlowName)
dstConn, err := connectors.GetCDCSyncConnector(ctx, config.Peer)
if err != nil {
return fmt.Errorf("failed to get connector: %w", err)
}
defer connectors.CloseConnector(dstConn)

flowName, _ := ctx.Value(shared.FlowNameKey).(string)
if err := dstConn.SetupMetadataTables(); err != nil {
a.Alerter.LogFlowError(ctx, flowName, err)
a.Alerter.LogFlowError(ctx, config.FlowName, err)
return fmt.Errorf("failed to setup metadata tables: %w", err)
}

Expand All @@ -86,6 +87,7 @@ func (a *FlowableActivity) GetLastSyncedID(
ctx context.Context,
config *protos.GetLastSyncedIDInput,
) (*protos.LastSyncState, error) {
ctx = context.WithValue(ctx, shared.FlowNameKey, config.FlowJobName)
dstConn, err := connectors.GetCDCSyncConnector(ctx, config.PeerConnectionConfig)
if err != nil {
return nil, fmt.Errorf("failed to get connector: %w", err)
Expand All @@ -105,6 +107,7 @@ func (a *FlowableActivity) EnsurePullability(
ctx context.Context,
config *protos.EnsurePullabilityBatchInput,
) (*protos.EnsurePullabilityBatchOutput, error) {
ctx = context.WithValue(ctx, shared.FlowNameKey, config.FlowJobName)
srcConn, err := connectors.GetCDCPullConnector(ctx, config.PeerConnectionConfig)
if err != nil {
return nil, fmt.Errorf("failed to get connector: %w", err)
Expand All @@ -125,6 +128,7 @@ func (a *FlowableActivity) CreateRawTable(
ctx context.Context,
config *protos.CreateRawTableInput,
) (*protos.CreateRawTableOutput, error) {
ctx = context.WithValue(ctx, shared.FlowNameKey, config.FlowJobName)
dstConn, err := connectors.GetCDCSyncConnector(ctx, config.PeerConnectionConfig)
if err != nil {
return nil, fmt.Errorf("failed to get connector: %w", err)
Expand All @@ -133,6 +137,7 @@ func (a *FlowableActivity) CreateRawTable(

res, err := dstConn.CreateRawTable(config)
if err != nil {
a.Alerter.LogFlowError(ctx, config.FlowJobName, err)
return nil, err
}
err = monitoring.InitializeCDCFlow(ctx, a.CatalogPool, config.FlowJobName)
Expand All @@ -148,6 +153,7 @@ func (a *FlowableActivity) GetTableSchema(
ctx context.Context,
config *protos.GetTableSchemaBatchInput,
) (*protos.GetTableSchemaBatchOutput, error) {
ctx = context.WithValue(ctx, shared.FlowNameKey, config.FlowName)
srcConn, err := connectors.GetCDCPullConnector(ctx, config.PeerConnectionConfig)
if err != nil {
return nil, fmt.Errorf("failed to get connector: %w", err)
Expand All @@ -162,6 +168,7 @@ func (a *FlowableActivity) CreateNormalizedTable(
ctx context.Context,
config *protos.SetupNormalizedTableBatchInput,
) (*protos.SetupNormalizedTableBatchOutput, error) {
ctx = context.WithValue(ctx, shared.FlowNameKey, config.FlowName)
conn, err := connectors.GetCDCSyncConnector(ctx, config.PeerConnectionConfig)
if err != nil {
return nil, fmt.Errorf("failed to get connector: %w", err)
Expand All @@ -170,8 +177,7 @@ func (a *FlowableActivity) CreateNormalizedTable(

setupNormalizedTablesOutput, err := conn.SetupNormalizedTables(config)
if err != nil {
flowName, _ := ctx.Value(shared.FlowNameKey).(string)
a.Alerter.LogFlowError(ctx, flowName, err)
a.Alerter.LogFlowError(ctx, config.FlowName, err)
return nil, fmt.Errorf("failed to setup normalized tables: %w", err)
}

Expand All @@ -181,6 +187,7 @@ func (a *FlowableActivity) CreateNormalizedTable(
func (a *FlowableActivity) StartFlow(ctx context.Context,
input *protos.StartFlowInput,
) (*model.SyncResponse, error) {
ctx = context.WithValue(ctx, shared.FlowNameKey, input.FlowConnectionConfigs.FlowJobName)
activity.RecordHeartbeat(ctx, "starting flow...")
conn := input.FlowConnectionConfigs
dstConn, err := connectors.GetCDCSyncConnector(ctx, conn.Destination)
Expand Down Expand Up @@ -311,6 +318,7 @@ func (a *FlowableActivity) StartFlow(ctx context.Context,

lastCheckpoint, err := recordBatch.GetLastCheckpoint()
if err != nil {
a.Alerter.LogFlowError(ctx, flowName, err)
return nil, fmt.Errorf("failed to get last checkpoint: %w", err)
}

Expand All @@ -323,6 +331,7 @@ func (a *FlowableActivity) StartFlow(ctx context.Context,
pglogrepl.LSN(lastCheckpoint),
)
if err != nil {
a.Alerter.LogFlowError(ctx, flowName, err)
return nil, err
}

Expand All @@ -333,6 +342,7 @@ func (a *FlowableActivity) StartFlow(ctx context.Context,
pglogrepl.LSN(lastCheckpoint),
)
if err != nil {
a.Alerter.LogFlowError(ctx, flowName, err)
return nil, err
}
if res.TableNameRowsMapping != nil {
Expand All @@ -343,6 +353,7 @@ func (a *FlowableActivity) StartFlow(ctx context.Context,
}
}
if err != nil {
a.Alerter.LogFlowError(ctx, flowName, err)
return nil, err
}
res.TableSchemaDeltas = recordBatch.WaitForSchemaDeltas(input.FlowConnectionConfigs.TableMappings)
Expand All @@ -359,7 +370,7 @@ func (a *FlowableActivity) StartNormalize(
input *protos.StartNormalizeInput,
) (*model.NormalizeResponse, error) {
conn := input.FlowConnectionConfigs

ctx = context.WithValue(ctx, shared.FlowNameKey, conn.FlowJobName)
dstConn, err := connectors.GetCDCNormalizeConnector(ctx, conn.Destination)
if errors.Is(err, connectors.ErrUnsupportedFunctionality) {
dstConn, err := connectors.GetCDCSyncConnector(ctx, conn.Destination)
Expand Down Expand Up @@ -471,6 +482,7 @@ func (a *FlowableActivity) GetQRepPartitions(ctx context.Context,
last *protos.QRepPartition,
runUUID string,
) (*protos.QRepParitionResult, error) {
ctx = context.WithValue(ctx, shared.FlowNameKey, config.FlowJobName)
srcConn, err := connectors.GetQRepPullConnector(ctx, config.SourcePeer)
if err != nil {
return nil, fmt.Errorf("failed to get qrep pull connector: %w", err)
Expand Down Expand Up @@ -514,6 +526,7 @@ func (a *FlowableActivity) ReplicateQRepPartitions(ctx context.Context,
partitions *protos.QRepPartitionBatch,
runUUID string,
) error {
ctx = context.WithValue(ctx, shared.FlowNameKey, config.FlowJobName)
err := monitoring.UpdateStartTimeForQRepRun(ctx, a.CatalogPool, runUUID)
if err != nil {
return fmt.Errorf("failed to update start time for qrep run: %w", err)
Expand Down Expand Up @@ -544,21 +557,25 @@ func (a *FlowableActivity) replicateQRepPartition(ctx context.Context,
partition *protos.QRepPartition,
runUUID string,
) error {
ctx = context.WithValue(ctx, shared.FlowNameKey, config.FlowJobName)
err := monitoring.UpdateStartTimeForPartition(ctx, a.CatalogPool, runUUID, partition, time.Now())
if err != nil {
a.Alerter.LogFlowError(ctx, config.FlowJobName, err)
return fmt.Errorf("failed to update start time for partition: %w", err)
}

pullCtx, pullCancel := context.WithCancel(ctx)
defer pullCancel()
srcConn, err := connectors.GetQRepPullConnector(pullCtx, config.SourcePeer)
if err != nil {
a.Alerter.LogFlowError(ctx, config.FlowJobName, err)
return fmt.Errorf("failed to get qrep source connector: %w", err)
}
defer connectors.CloseConnector(srcConn)

dstConn, err := connectors.GetQRepSyncConnector(ctx, config.DestinationPeer)
if err != nil {
a.Alerter.LogFlowError(ctx, config.FlowJobName, err)
return fmt.Errorf("failed to get qrep destination connector: %w", err)
}
defer connectors.CloseConnector(dstConn)
Expand All @@ -579,13 +596,14 @@ func (a *FlowableActivity) replicateQRepPartition(ctx context.Context,
tmp, err := pgConn.PullQRepRecordStream(config, partition, stream)
numRecords := int64(tmp)
if err != nil {
slog.Error("failed to pull records", slog.Any("error", err))
a.Alerter.LogFlowError(ctx, config.FlowJobName, err)
slog.ErrorContext(ctx, "failed to pull records", slog.Any("error", err))
goroutineErr = err
} else {
err = monitoring.UpdatePullEndTimeAndRowsForPartition(ctx,
a.CatalogPool, runUUID, partition, numRecords)
if err != nil {
slog.Error(fmt.Sprintf("%v", err))
slog.ErrorContext(ctx, fmt.Sprintf("%v", err))
goroutineErr = err
}
}
Expand All @@ -596,6 +614,7 @@ func (a *FlowableActivity) replicateQRepPartition(ctx context.Context,
} else {
recordBatch, err := srcConn.PullQRepRecords(config, partition)
if err != nil {
a.Alerter.LogFlowError(ctx, config.FlowJobName, err)
return fmt.Errorf("failed to pull records: %w", err)
}
numRecords := int64(recordBatch.NumRecords)
Expand All @@ -608,6 +627,7 @@ func (a *FlowableActivity) replicateQRepPartition(ctx context.Context,

stream, err = recordBatch.ToQRecordStream(bufferSize)
if err != nil {
a.Alerter.LogFlowError(ctx, config.FlowJobName, err)
return fmt.Errorf("failed to convert to qrecord stream: %w", err)
}
}
Expand All @@ -622,6 +642,7 @@ func (a *FlowableActivity) replicateQRepPartition(ctx context.Context,

rowsSynced, err := dstConn.SyncQRepRecords(config, partition, stream)
if err != nil {
a.Alerter.LogFlowError(ctx, config.FlowJobName, err)
return fmt.Errorf("failed to sync records: %w", err)
}

Expand All @@ -630,6 +651,7 @@ func (a *FlowableActivity) replicateQRepPartition(ctx context.Context,
} else {
wg.Wait()
if goroutineErr != nil {
a.Alerter.LogFlowError(ctx, config.FlowJobName, goroutineErr)
return goroutineErr
}

Expand Down Expand Up @@ -681,13 +703,15 @@ func (a *FlowableActivity) CleanupQRepFlow(ctx context.Context, config *protos.Q
if errors.Is(err, connectors.ErrUnsupportedFunctionality) {
return nil
} else if err != nil {
a.Alerter.LogFlowError(ctx, config.FlowJobName, err)
return err
}

return dst.CleanupQRepFlow(config)
}

func (a *FlowableActivity) DropFlow(ctx context.Context, config *protos.ShutdownRequest) error {
ctx = context.WithValue(ctx, shared.FlowNameKey, config.FlowJobName)
srcConn, err := connectors.GetCDCPullConnector(ctx, config.SourcePeer)
if err != nil {
return fmt.Errorf("failed to get source connector: %w", err)
Expand Down Expand Up @@ -745,7 +769,7 @@ func (a *FlowableActivity) getPostgresPeerConfigs(ctx context.Context) ([]*proto

func (a *FlowableActivity) SendWALHeartbeat(ctx context.Context) error {
if !peerdbenv.PeerDBEnableWALHeartbeat() {
slog.InfoContext(ctx, "wal heartbeat is disabled")
slog.Info("wal heartbeat is disabled")
return nil
}

Expand All @@ -756,7 +780,7 @@ func (a *FlowableActivity) SendWALHeartbeat(ctx context.Context) error {
for {
select {
case <-ctx.Done():
slog.InfoContext(ctx, "context is done, exiting wal heartbeat send loop")
slog.Info("context is done, exiting wal heartbeat send loop")
return nil
case <-ticker.C:
pgPeers, err := a.getPostgresPeerConfigs(ctx)
Expand Down Expand Up @@ -803,6 +827,7 @@ func (a *FlowableActivity) SendWALHeartbeat(ctx context.Context) error {
func (a *FlowableActivity) QRepWaitUntilNewRows(ctx context.Context,
config *protos.QRepConfig, last *protos.QRepPartition,
) error {
ctx = context.WithValue(ctx, shared.FlowNameKey, config.FlowJobName)
if config.SourcePeer.Type != protos.DBType_POSTGRES || last.Range == nil {
return nil
}
Expand All @@ -813,6 +838,7 @@ func (a *FlowableActivity) QRepWaitUntilNewRows(ctx context.Context,

srcConn, err := connectors.GetQRepPullConnector(ctx, config.SourcePeer)
if err != nil {
a.Alerter.LogFlowError(ctx, config.FlowJobName, err)
return fmt.Errorf("failed to get qrep source connector: %w", err)
}
defer connectors.CloseConnector(srcConn)
Expand All @@ -825,6 +851,7 @@ func (a *FlowableActivity) QRepWaitUntilNewRows(ctx context.Context,

result, err := pgSrcConn.CheckForUpdatedMaxValue(config, last)
if err != nil {
a.Alerter.LogFlowError(ctx, config.FlowJobName, err)
return fmt.Errorf("failed to check for new rows: %w", err)
}
if result {
Expand All @@ -840,21 +867,25 @@ func (a *FlowableActivity) QRepWaitUntilNewRows(ctx context.Context,
func (a *FlowableActivity) RenameTables(ctx context.Context, config *protos.RenameTablesInput) (
*protos.RenameTablesOutput, error,
) {
ctx = context.WithValue(ctx, shared.FlowNameKey, config.FlowJobName)
dstConn, err := connectors.GetCDCSyncConnector(ctx, config.Peer)
if err != nil {
a.Alerter.LogFlowError(ctx, config.FlowJobName, err)
return nil, fmt.Errorf("failed to get connector: %w", err)
}
defer connectors.CloseConnector(dstConn)

if config.Peer.Type == protos.DBType_SNOWFLAKE {
sfConn, ok := dstConn.(*connsnowflake.SnowflakeConnector)
if !ok {
a.Alerter.LogFlowError(ctx, config.FlowJobName, err)
return nil, fmt.Errorf("failed to cast connector to snowflake connector")
}
return sfConn.RenameTables(config)
} else if config.Peer.Type == protos.DBType_BIGQUERY {
bqConn, ok := dstConn.(*connbigquery.BigQueryConnector)
if !ok {
a.Alerter.LogFlowError(ctx, config.FlowJobName, err)
return nil, fmt.Errorf("failed to cast connector to bigquery connector")
}
return bqConn.RenameTables(config)
Expand All @@ -865,6 +896,7 @@ func (a *FlowableActivity) RenameTables(ctx context.Context, config *protos.Rena
func (a *FlowableActivity) CreateTablesFromExisting(ctx context.Context, req *protos.CreateTablesFromExistingInput) (
*protos.CreateTablesFromExistingOutput, error,
) {
ctx = context.WithValue(ctx, shared.FlowNameKey, req.FlowJobName)
dstConn, err := connectors.GetCDCSyncConnector(ctx, req.Peer)
if err != nil {
return nil, fmt.Errorf("failed to get connector: %w", err)
Expand All @@ -884,6 +916,7 @@ func (a *FlowableActivity) CreateTablesFromExisting(ctx context.Context, req *pr
}
return bqConn.CreateTablesFromExisting(req)
}
a.Alerter.LogFlowError(ctx, req.FlowJobName, err)
return nil, fmt.Errorf("create tables from existing is only supported on snowflake and bigquery")
}

Expand All @@ -893,6 +926,7 @@ func (a *FlowableActivity) ReplicateXminPartition(ctx context.Context,
partition *protos.QRepPartition,
runUUID string,
) (int64, error) {
ctx = context.WithValue(ctx, shared.FlowNameKey, config.FlowJobName)
startTime := time.Now()
srcConn, err := connectors.GetQRepPullConnector(ctx, config.SourcePeer)
if err != nil {
Expand Down Expand Up @@ -920,6 +954,7 @@ func (a *FlowableActivity) ReplicateXminPartition(ctx context.Context,
var numRecords int
numRecords, currentSnapshotXmin, pullErr = pgConn.PullXminRecordStream(config, partition, stream)
if pullErr != nil {
a.Alerter.LogFlowError(ctx, config.FlowJobName, err)
slog.InfoContext(ctx, fmt.Sprintf("failed to pull records: %v", err))
return err
}
Expand Down Expand Up @@ -969,6 +1004,7 @@ func (a *FlowableActivity) ReplicateXminPartition(ctx context.Context,

rowsSynced, err := dstConn.SyncQRepRecords(config, partition, stream)
if err != nil {
a.Alerter.LogFlowError(ctx, config.FlowJobName, err)
return 0, fmt.Errorf("failed to sync records: %w", err)
}

Expand Down
Loading

0 comments on commit 682fa0a

Please sign in to comment.