diff --git a/flow/activities/flowable.go b/flow/activities/flowable.go index 6ee858c028..25693f5ee1 100644 --- a/flow/activities/flowable.go +++ b/flow/activities/flowable.go @@ -286,10 +286,7 @@ func (a *FlowableActivity) StartFlow(ctx context.Context, jobName := input.FlowConnectionConfigs.FlowJobName return fmt.Sprintf("pushing records for job - %s", jobName) }) - - defer func() { - shutdown <- struct{}{} - }() + defer shutdown() syncStartTime := time.Now() res, err := dstConn.SyncRecords(&model.SyncRecordsRequest{ @@ -397,9 +394,7 @@ func (a *FlowableActivity) StartNormalize( shutdown := utils.HeartbeatRoutine(ctx, 2*time.Minute, func() string { return fmt.Sprintf("normalizing records from batch for job - %s", input.FlowConnectionConfigs.FlowJobName) }) - defer func() { - shutdown <- struct{}{} - }() + defer shutdown() slog.InfoContext(ctx, "initializing table schema...") err = dstConn.InitializeTableSchema(input.FlowConnectionConfigs.TableNameSchemaMapping) @@ -494,10 +489,7 @@ func (a *FlowableActivity) GetQRepPartitions(ctx context.Context, shutdown := utils.HeartbeatRoutine(ctx, 2*time.Minute, func() string { return fmt.Sprintf("getting partitions for job - %s", config.FlowJobName) }) - - defer func() { - shutdown <- struct{}{} - }() + defer shutdown() partitions, err := srcConn.GetQRepPartitions(config, last) if err != nil { @@ -635,10 +627,7 @@ func (a *FlowableActivity) replicateQRepPartition(ctx context.Context, shutdown := utils.HeartbeatRoutine(ctx, 5*time.Minute, func() string { return fmt.Sprintf("syncing partition - %s: %d of %d total.", partition.PartitionId, idx, total) }) - - defer func() { - shutdown <- struct{}{} - }() + defer shutdown() rowsSynced, err := dstConn.SyncQRepRecords(config, partition, stream) if err != nil { @@ -684,10 +673,7 @@ func (a *FlowableActivity) ConsolidateQRepPartitions(ctx context.Context, config shutdown := utils.HeartbeatRoutine(ctx, 2*time.Minute, func() string { return fmt.Sprintf("consolidating partitions for job - %s", config.FlowJobName) }) - - defer func() { - shutdown <- struct{}{} - }() + defer shutdown() err = dstConn.ConsolidateQRepPartitions(config) if err != nil { @@ -996,10 +982,7 @@ func (a *FlowableActivity) ReplicateXminPartition(ctx context.Context, shutdown := utils.HeartbeatRoutine(ctx, 5*time.Minute, func() string { return "syncing xmin." }) - - defer func() { - shutdown <- struct{}{} - }() + defer shutdown() rowsSynced, err := dstConn.SyncQRepRecords(config, partition, stream) if err != nil { diff --git a/flow/connectors/bigquery/qrep_avro_sync.go b/flow/connectors/bigquery/qrep_avro_sync.go index ceb3b38402..d6df8fdb6e 100644 --- a/flow/connectors/bigquery/qrep_avro_sync.go +++ b/flow/connectors/bigquery/qrep_avro_sync.go @@ -373,9 +373,7 @@ func (s *QRepAvroSyncMethod) writeToStage( objectFolder, stagingTable) }, ) - defer func() { - shutdown <- struct{}{} - }() + defer shutdown() var avroFile *avro.AvroFile ocfWriter := avro.NewPeerDBOCFWriter(s.connector.ctx, stream, avroSchema, diff --git a/flow/connectors/eventhub/eventhub.go b/flow/connectors/eventhub/eventhub.go index 1bb7b00166..24a37eaf91 100644 --- a/flow/connectors/eventhub/eventhub.go +++ b/flow/connectors/eventhub/eventhub.go @@ -216,9 +216,7 @@ func (c *EventHubConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.S numRecords, req.FlowJobName, ) }) - defer func() { - shutdown <- struct{}{} - }() + defer shutdown() numRecords, err = c.processBatch(req.FlowJobName, batch, maxParallelism) if err != nil { diff --git a/flow/connectors/postgres/cdc.go b/flow/connectors/postgres/cdc.go index fcb3e64174..af2b483de1 100644 --- a/flow/connectors/postgres/cdc.go +++ b/flow/connectors/postgres/cdc.go @@ -249,10 +249,7 @@ func (p *PostgresCDCSource) consumeStream( currRecords := cdcRecordsStorage.Len() return fmt.Sprintf("pulling records for job - %s, currently have %d records", jobName, currRecords) }) - - defer func() { - shutdown <- struct{}{} - }() + defer shutdown() standbyMessageTimeout := req.IdleTimeout nextStandbyMessageDeadline := time.Now().Add(standbyMessageTimeout) diff --git a/flow/connectors/postgres/postgres.go b/flow/connectors/postgres/postgres.go index 43a11a2e72..7fef9d26c2 100644 --- a/flow/connectors/postgres/postgres.go +++ b/flow/connectors/postgres/postgres.go @@ -196,9 +196,7 @@ func (c *PostgresConnector) SetLastOffset(jobName string, lastOffset int64) erro // PullRecords pulls records from the source. func (c *PostgresConnector) PullRecords(catalogPool *pgxpool.Pool, req *model.PullRecordsRequest) error { - defer func() { - req.RecordStream.Close() - }() + defer req.RecordStream.Close() // Slotname would be the job name prefixed with "peerflow_slot_" slotName := fmt.Sprintf("peerflow_slot_%s", req.FlowJobName) diff --git a/flow/connectors/postgres/qrep_query_executor.go b/flow/connectors/postgres/qrep_query_executor.go index bb07fbb98f..839845d284 100644 --- a/flow/connectors/postgres/qrep_query_executor.go +++ b/flow/connectors/postgres/qrep_query_executor.go @@ -83,13 +83,10 @@ func (qe *QRepQueryExecutor) executeQueryInTx(tx pgx.Tx, cursorName string, fetc q := fmt.Sprintf("FETCH %d FROM %s", fetchSize, cursorName) if !qe.testEnv { - shutdownCh := utils.HeartbeatRoutine(qe.ctx, 1*time.Minute, func() string { + shutdown := utils.HeartbeatRoutine(qe.ctx, 1*time.Minute, func() string { return fmt.Sprintf("running '%s'", q) }) - - defer func() { - shutdownCh <- struct{}{} - }() + defer shutdown() } rows, err := tx.Query(qe.ctx, q) diff --git a/flow/connectors/snowflake/qrep_avro_sync.go b/flow/connectors/snowflake/qrep_avro_sync.go index 07eb791c5c..83521088d8 100644 --- a/flow/connectors/snowflake/qrep_avro_sync.go +++ b/flow/connectors/snowflake/qrep_avro_sync.go @@ -282,10 +282,7 @@ func (s *SnowflakeAvroSyncMethod) putFileToStage(avroFile *avro.AvroFile, stage shutdown := utils.HeartbeatRoutine(s.connector.ctx, 10*time.Second, func() string { return fmt.Sprintf("putting file to stage %s", stage) }) - - defer func() { - shutdown <- struct{}{} - }() + defer shutdown() if _, err := s.connector.database.ExecContext(s.connector.ctx, putCmd); err != nil { return fmt.Errorf("failed to put file to stage: %w", err) diff --git a/flow/connectors/utils/avro/avro_writer.go b/flow/connectors/utils/avro/avro_writer.go index 90c016b404..1e6f318713 100644 --- a/flow/connectors/utils/avro/avro_writer.go +++ b/flow/connectors/utils/avro/avro_writer.go @@ -136,10 +136,7 @@ func (p *peerDBOCFWriter) writeRecordsToOCFWriter(ocfWriter *goavro.OCFWriter) ( written := numRows.Load() return fmt.Sprintf("[avro] written %d rows to OCF", written) }) - - defer func() { - shutdown <- struct{}{} - }() + defer shutdown() } for qRecordOrErr := range p.stream.Records { diff --git a/flow/connectors/utils/heartbeat.go b/flow/connectors/utils/heartbeat.go index c1bc81f077..37f00bc72f 100644 --- a/flow/connectors/utils/heartbeat.go +++ b/flow/connectors/utils/heartbeat.go @@ -13,8 +13,8 @@ func HeartbeatRoutine( ctx context.Context, interval time.Duration, message func() string, -) chan<- struct{} { - shutdown := make(chan struct{}) +) context.CancelFunc { + ctx, cancel := context.WithCancel(ctx) go func() { counter := 0 for { @@ -22,15 +22,13 @@ func HeartbeatRoutine( msg := fmt.Sprintf("heartbeat #%d: %s", counter, message()) RecordHeartbeatWithRecover(ctx, msg) select { - case <-shutdown: - return case <-ctx.Done(): return case <-time.After(interval): } } }() - return shutdown + return cancel } // if the functions are being called outside the context of a Temporal workflow,