Skip to content

Commit

Permalink
flow/activities: use activity.GetLogger (#1196)
Browse files Browse the repository at this point in the history
  • Loading branch information
serprex authored Feb 1, 2024
1 parent 6969882 commit e3a27ff
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 31 deletions.
64 changes: 38 additions & 26 deletions flow/activities/flowable.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ func (a *FlowableActivity) StartFlow(ctx context.Context,
input *protos.StartFlowInput,
) (*model.SyncResponse, error) {
ctx = context.WithValue(ctx, shared.FlowNameKey, input.FlowConnectionConfigs.FlowJobName)
logger := activity.GetLogger(ctx)
activity.RecordHeartbeat(ctx, "starting flow...")
conn := input.FlowConnectionConfigs
dstConn, err := connectors.GetCDCSyncConnector(ctx, conn.Destination)
Expand All @@ -198,7 +199,7 @@ func (a *FlowableActivity) StartFlow(ctx context.Context,
}
defer connectors.CloseConnector(dstConn)

slog.InfoContext(ctx, "pulling records...")
logger.Info("pulling records...")
tblNameMapping := make(map[string]model.NameAndExclude)
for _, v := range input.FlowConnectionConfigs.TableMappings {
tblNameMapping[v.SourceTableIdentifier] = model.NewNameAndExclude(v.DestinationTableIdentifier, v.Exclude)
Expand Down Expand Up @@ -250,7 +251,7 @@ func (a *FlowableActivity) StartFlow(ctx context.Context,
})

hasRecords := !recordBatch.WaitAndCheckEmpty()
slog.InfoContext(ctx, fmt.Sprintf("the current sync flow has records: %v", hasRecords))
logger.Info("current sync flow has records?", hasRecords)

if !hasRecords {
// wait for the pull goroutine to finish
Expand All @@ -259,7 +260,7 @@ func (a *FlowableActivity) StartFlow(ctx context.Context,
a.Alerter.LogFlowError(ctx, flowName, err)
return nil, fmt.Errorf("failed in pull records when: %w", err)
}
slog.InfoContext(ctx, "no records to push")
logger.Info("no records to push")

err := dstConn.ReplayTableSchemaDeltas(flowName, recordBatch.SchemaDeltas)
if err != nil {
Expand Down Expand Up @@ -300,7 +301,7 @@ func (a *FlowableActivity) StartFlow(ctx context.Context,
StagingPath: input.FlowConnectionConfigs.CdcStagingPath,
})
if err != nil {
slog.Warn("failed to push records", slog.Any("error", err))
logger.Warn("failed to push records", slog.Any("error", err))
a.Alerter.LogFlowError(ctx, flowName, err)
return nil, fmt.Errorf("failed to push records: %w", err)
}
Expand All @@ -315,7 +316,7 @@ func (a *FlowableActivity) StartFlow(ctx context.Context,
numRecords := res.NumRecordsSynced
syncDuration := time.Since(syncStartTime)

slog.InfoContext(ctx, fmt.Sprintf("pushed %d records in %d seconds", numRecords, int(syncDuration.Seconds())))
logger.Info(fmt.Sprintf("pushed %d records in %d seconds", numRecords, int(syncDuration.Seconds())))

lastCheckpoint, err := recordBatch.GetLastCheckpoint()
if err != nil {
Expand Down Expand Up @@ -371,6 +372,8 @@ func (a *FlowableActivity) StartNormalize(
) (*model.NormalizeResponse, error) {
conn := input.FlowConnectionConfigs
ctx = context.WithValue(ctx, shared.FlowNameKey, conn.FlowJobName)
logger := activity.GetLogger(ctx)

dstConn, err := connectors.GetCDCNormalizeConnector(ctx, conn.Destination)
if errors.Is(err, connectors.ErrUnsupportedFunctionality) {
dstConn, err := connectors.GetCDCSyncConnector(ctx, conn.Destination)
Expand Down Expand Up @@ -419,7 +422,7 @@ func (a *FlowableActivity) StartNormalize(
}

// log the number of batches normalized
slog.InfoContext(ctx, fmt.Sprintf("normalized records from batch %d to batch %d",
logger.Info(fmt.Sprintf("normalized records from batch %d to batch %d",
res.StartBatchID, res.EndBatchID))

return res, nil
Expand Down Expand Up @@ -490,18 +493,20 @@ func (a *FlowableActivity) ReplicateQRepPartitions(ctx context.Context,
runUUID string,
) error {
ctx = context.WithValue(ctx, shared.FlowNameKey, config.FlowJobName)
logger := activity.GetLogger(ctx)

err := monitoring.UpdateStartTimeForQRepRun(ctx, a.CatalogPool, runUUID)
if err != nil {
return fmt.Errorf("failed to update start time for qrep run: %w", err)
}

numPartitions := len(partitions.Partitions)

slog.InfoContext(ctx, fmt.Sprintf("replicating partitions for batch %d - size: %d",
logger.Info(fmt.Sprintf("replicating partitions for batch %d - size: %d",
partitions.BatchId, numPartitions),
)
for i, p := range partitions.Partitions {
slog.InfoContext(ctx, fmt.Sprintf("batch-%d - replicating partition - %s", partitions.BatchId, p.PartitionId))
logger.Info(fmt.Sprintf("batch-%d - replicating partition - %s", partitions.BatchId, p.PartitionId))
err := a.replicateQRepPartition(ctx, config, i+1, numPartitions, p, runUUID)
if err != nil {
a.Alerter.LogFlowError(ctx, config.FlowJobName, err)
Expand All @@ -521,6 +526,8 @@ func (a *FlowableActivity) replicateQRepPartition(ctx context.Context,
runUUID string,
) error {
ctx = context.WithValue(ctx, shared.FlowNameKey, config.FlowJobName)
logger := activity.GetLogger(ctx)

err := monitoring.UpdateStartTimeForPartition(ctx, a.CatalogPool, runUUID, partition, time.Now())
if err != nil {
a.Alerter.LogFlowError(ctx, config.FlowJobName, err)
Expand All @@ -543,7 +550,7 @@ func (a *FlowableActivity) replicateQRepPartition(ctx context.Context,
}
defer connectors.CloseConnector(dstConn)

slog.InfoContext(ctx, fmt.Sprintf("replicating partition %s", partition.PartitionId))
logger.Info(fmt.Sprintf("replicating partition %s", partition.PartitionId))
shutdown := utils.HeartbeatRoutine(ctx, func() string {
return fmt.Sprintf("syncing partition - %s: %d of %d total.", partition.PartitionId, idx, total)
})
Expand All @@ -564,13 +571,13 @@ func (a *FlowableActivity) replicateQRepPartition(ctx context.Context,
numRecords := int64(tmp)
if err != nil {
a.Alerter.LogFlowError(ctx, config.FlowJobName, err)
slog.ErrorContext(ctx, "failed to pull records", slog.Any("error", err))
logger.Error("failed to pull records", slog.Any("error", err))
goroutineErr = err
} else {
err = monitoring.UpdatePullEndTimeAndRowsForPartition(ctx,
a.CatalogPool, runUUID, partition, numRecords)
if err != nil {
slog.ErrorContext(ctx, err.Error())
logger.Error(err.Error())
goroutineErr = err
}
}
Expand All @@ -582,7 +589,7 @@ func (a *FlowableActivity) replicateQRepPartition(ctx context.Context,
a.Alerter.LogFlowError(ctx, config.FlowJobName, err)
return fmt.Errorf("failed to pull qrep records: %w", err)
}
slog.InfoContext(ctx, fmt.Sprintf("pulled %d records", len(recordBatch.Records)))
logger.Info(fmt.Sprintf("pulled %d records", len(recordBatch.Records)))

err = monitoring.UpdatePullEndTimeAndRowsForPartition(ctx, a.CatalogPool, runUUID, partition, int64(len(recordBatch.Records)))
if err != nil {
Expand All @@ -603,7 +610,7 @@ func (a *FlowableActivity) replicateQRepPartition(ctx context.Context,
}

if rowsSynced == 0 {
slog.InfoContext(ctx, fmt.Sprintf("no records to push for partition %s", partition.PartitionId))
logger.Info(fmt.Sprintf("no records to push for partition %s", partition.PartitionId))
pullCancel()
} else {
wg.Wait()
Expand All @@ -617,7 +624,7 @@ func (a *FlowableActivity) replicateQRepPartition(ctx context.Context,
return err
}

slog.InfoContext(ctx, fmt.Sprintf("pushed %d records", rowsSynced))
logger.Info(fmt.Sprintf("pushed %d records", rowsSynced))
}

err = monitoring.UpdateEndTimeForPartition(ctx, a.CatalogPool, runUUID, partition)
Expand Down Expand Up @@ -716,8 +723,9 @@ func (a *FlowableActivity) getPostgresPeerConfigs(ctx context.Context) ([]*proto
}

func (a *FlowableActivity) SendWALHeartbeat(ctx context.Context) error {
logger := activity.GetLogger(ctx)
if !peerdbenv.PeerDBEnableWALHeartbeat() {
slog.Info("wal heartbeat is disabled")
logger.Info("wal heartbeat is disabled")
return nil
}

Expand All @@ -728,13 +736,13 @@ func (a *FlowableActivity) SendWALHeartbeat(ctx context.Context) error {
for {
select {
case <-ctx.Done():
slog.Info("context is done, exiting wal heartbeat send loop")
logger.Info("context is done, exiting wal heartbeat send loop")
return nil
case <-ticker.C:
pgPeers, err := a.getPostgresPeerConfigs(ctx)
if err != nil {
slog.Warn("[sendwalheartbeat]: warning: unable to fetch peers." +
"Skipping walheartbeat send. error encountered: " + err.Error())
logger.Warn("[sendwalheartbeat] unable to fetch peers. " +
"Skipping walheartbeat send. Error: " + err.Error())
continue
}

Expand All @@ -756,15 +764,15 @@ func (a *FlowableActivity) SendWALHeartbeat(ctx context.Context) error {

_, err := peerConn.Exec(ctx, command)
if err != nil {
slog.Warn(fmt.Sprintf("warning: could not send walheartbeat to peer %v: %v", pgPeer.Name, err))
logger.Warn(fmt.Sprintf("could not send walheartbeat to peer %v: %v", pgPeer.Name, err))
}

closeErr := peerConn.Close(ctx)
if closeErr != nil {
return fmt.Errorf("error closing postgres connection for peer %v with host %v: %w",
pgPeer.Name, pgConfig.Host, closeErr)
}
slog.InfoContext(ctx, fmt.Sprintf("sent walheartbeat to peer %v", pgPeer.Name))
logger.Info(fmt.Sprintf("sent walheartbeat to peer %v", pgPeer.Name))
}
}
}
Expand All @@ -774,6 +782,8 @@ func (a *FlowableActivity) QRepWaitUntilNewRows(ctx context.Context,
config *protos.QRepConfig, last *protos.QRepPartition,
) error {
ctx = context.WithValue(ctx, shared.FlowNameKey, config.FlowJobName)
logger := activity.GetLogger(ctx)

if config.SourcePeer.Type != protos.DBType_POSTGRES || last.Range == nil {
return nil
}
Expand All @@ -789,7 +799,7 @@ func (a *FlowableActivity) QRepWaitUntilNewRows(ctx context.Context,
}
defer connectors.CloseConnector(srcConn)
pgSrcConn := srcConn.(*connpostgres.PostgresConnector)
slog.InfoContext(ctx, fmt.Sprintf("current last partition value is %v", last))
logger.Info(fmt.Sprintf("current last partition value is %v", last))
attemptCount := 1
for {
activity.RecordHeartbeat(ctx, fmt.Sprintf("no new rows yet, attempt #%d", attemptCount))
Expand Down Expand Up @@ -892,6 +902,8 @@ func (a *FlowableActivity) ReplicateXminPartition(ctx context.Context,
runUUID string,
) (int64, error) {
ctx = context.WithValue(ctx, shared.FlowNameKey, config.FlowJobName)
logger := activity.GetLogger(ctx)

startTime := time.Now()
srcConn, err := connectors.GetQRepPullConnector(ctx, config.SourcePeer)
if err != nil {
Expand All @@ -905,7 +917,7 @@ func (a *FlowableActivity) ReplicateXminPartition(ctx context.Context,
}
defer connectors.CloseConnector(dstConn)

slog.InfoContext(ctx, "replicating xmin")
logger.Info("replicating xmin")

bufferSize := shared.FetchAndChannelSize
errGroup, errCtx := errgroup.WithContext(ctx)
Expand All @@ -920,7 +932,7 @@ func (a *FlowableActivity) ReplicateXminPartition(ctx context.Context,
numRecords, currentSnapshotXmin, pullErr = pgConn.PullXminRecordStream(config, partition, stream)
if pullErr != nil {
a.Alerter.LogFlowError(ctx, config.FlowJobName, err)
slog.InfoContext(ctx, fmt.Sprintf("[xmin] failed to pull records: %v", err))
logger.Warn(fmt.Sprintf("[xmin] failed to pull records: %v", err))
return err
}

Expand Down Expand Up @@ -952,7 +964,7 @@ func (a *FlowableActivity) ReplicateXminPartition(ctx context.Context,
err = monitoring.UpdatePullEndTimeAndRowsForPartition(
errCtx, a.CatalogPool, runUUID, partition, int64(numRecords))
if err != nil {
slog.Error(err.Error())
logger.Error(err.Error())
return err
}

Expand All @@ -971,7 +983,7 @@ func (a *FlowableActivity) ReplicateXminPartition(ctx context.Context,
}

if rowsSynced == 0 {
slog.InfoContext(ctx, "no records to push for xmin")
logger.Info("no records to push for xmin")
} else {
err := errGroup.Wait()
if err != nil {
Expand All @@ -984,7 +996,7 @@ func (a *FlowableActivity) ReplicateXminPartition(ctx context.Context,
return 0, err
}

slog.InfoContext(ctx, fmt.Sprintf("pushed %d records", rowsSynced))
logger.Info(fmt.Sprintf("pushed %d records", rowsSynced))
}

err = monitoring.UpdateEndTimeForPartition(ctx, a.CatalogPool, runUUID, partition)
Expand Down
14 changes: 9 additions & 5 deletions flow/activities/snapshot_activity.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"fmt"
"log/slog"

"go.temporal.io/sdk/activity"

"github.com/PeerDB-io/peer-flow/connectors"
connpostgres "github.com/PeerDB-io/peer-flow/connectors/postgres"
"github.com/PeerDB-io/peer-flow/generated/protos"
Expand Down Expand Up @@ -36,9 +38,11 @@ func (a *SnapshotActivity) SetupReplication(
config *protos.SetupReplicationInput,
) (*protos.SetupReplicationOutput, error) {
ctx = context.WithValue(ctx, shared.FlowNameKey, config.FlowJobName)
logger := activity.GetLogger(ctx)

dbType := config.PeerConnectionConfig.Type
if dbType != protos.DBType_POSTGRES {
slog.InfoContext(ctx, fmt.Sprintf("setup replication is no-op for %s", dbType))
logger.Info(fmt.Sprintf("setup replication is no-op for %s", dbType))
return nil, nil
}

Expand All @@ -53,12 +57,12 @@ func (a *SnapshotActivity) SetupReplication(
defer close(replicationErr)

closeConnectionForError := func(err error) {
slog.ErrorContext(ctx, "failed to setup replication", slog.Any("error", err))
logger.Error("failed to setup replication", slog.Any("error", err))
a.Alerter.LogFlowError(ctx, config.FlowJobName, err)
// it is important to close the connection here as it is not closed in CloseSlotKeepAlive
connCloseErr := conn.Close()
if connCloseErr != nil {
slog.ErrorContext(ctx, "failed to close connection", slog.Any("error", connCloseErr))
logger.Error("failed to close connection", slog.Any("error", connCloseErr))
}
}

Expand All @@ -73,11 +77,11 @@ func (a *SnapshotActivity) SetupReplication(
}
}()

slog.InfoContext(ctx, "waiting for slot to be created...")
logger.Info("waiting for slot to be created...")
var slotInfo connpostgres.SlotCreationResult
select {
case slotInfo = <-slotSignal.SlotCreated:
slog.InfoContext(ctx, fmt.Sprintf("slot '%s' created", slotInfo.SlotName))
logger.Info("slot created", slotInfo.SlotName)
case err := <-replicationErr:
closeConnectionForError(err)
return nil, fmt.Errorf("failed to setup replication: %w", err)
Expand Down

0 comments on commit e3a27ff

Please sign in to comment.