Skip to content

Commit

Permalink
Maintain replication connection between sync flows (#1211)
Browse files Browse the repository at this point in the history
Currently we reconnect with each sync flow, requiring repeatedly
starting replication. This can take an exceedingly long time for some
workloads on some databases

Fix: use temporal session to share state between activities, use a
single source connector throughout cdc flow, & move replication
connection back into source connection
  • Loading branch information
serprex authored Feb 22, 2024
1 parent 6b35d8a commit 8f4ad4e
Show file tree
Hide file tree
Showing 7 changed files with 298 additions and 112 deletions.
75 changes: 71 additions & 4 deletions flow/activities/flowable.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ type SlotSnapshotSignal struct {
type FlowableActivity struct {
CatalogPool *pgxpool.Pool
Alerter *alerting.Alerter
CdcCacheRw sync.RWMutex
CdcCache map[string]connectors.CDCPullConnector
}

func (a *FlowableActivity) CheckConnection(
Expand Down Expand Up @@ -204,10 +206,71 @@ func (a *FlowableActivity) CreateNormalizedTable(
}, nil
}

func (a *FlowableActivity) MaintainPull(
ctx context.Context,
config *protos.FlowConnectionConfigs,
sessionID string,
) error {
srcConn, err := connectors.GetCDCPullConnector(ctx, config.Source)
if err != nil {
return err
}
defer connectors.CloseConnector(ctx, srcConn)

if err := srcConn.SetupReplConn(ctx); err != nil {
return err
}

a.CdcCacheRw.Lock()
a.CdcCache[sessionID] = srcConn
a.CdcCacheRw.Unlock()

ticker := time.NewTicker(15 * time.Second)
defer ticker.Stop()

for {
select {
case <-ticker.C:
activity.RecordHeartbeat(ctx, "keep session alive")
if err := srcConn.ReplPing(ctx); err != nil {
activity.GetLogger(ctx).Error("Failed to send keep alive ping to replication connection", slog.Any("error", err))
}
case <-ctx.Done():
a.CdcCacheRw.Lock()
delete(a.CdcCache, sessionID)
a.CdcCacheRw.Unlock()
return nil
}
}
}

func (a *FlowableActivity) WaitForSourceConnector(ctx context.Context, sessionID string) error {
logger := activity.GetLogger(ctx)
attempt := 0
for {
a.CdcCacheRw.RLock()
_, ok := a.CdcCache[sessionID]
a.CdcCacheRw.RUnlock()
if ok {
return nil
}
activity.RecordHeartbeat(ctx, "wait another second for source connector")
attempt += 1
if attempt > 2 {
logger.Info("waiting on source connector setup", slog.Int("attempt", attempt))
}
if err := ctx.Err(); err != nil {
return err
}
time.Sleep(time.Second)
}
}

func (a *FlowableActivity) SyncFlow(
ctx context.Context,
config *protos.FlowConnectionConfigs,
options *protos.SyncFlowOptions,
sessionID string,
) (*model.SyncResponse, error) {
flowName := config.FlowJobName
ctx = context.WithValue(ctx, shared.FlowNameKey, flowName)
Expand All @@ -225,11 +288,15 @@ func (a *FlowableActivity) SyncFlow(
tblNameMapping[v.SourceTableIdentifier] = model.NewNameAndExclude(v.DestinationTableIdentifier, v.Exclude)
}

srcConn, err := connectors.GetCDCPullConnector(ctx, config.Source)
if err != nil {
return nil, fmt.Errorf("failed to get source connector: %w", err)
a.CdcCacheRw.RLock()
srcConn, ok := a.CdcCache[sessionID]
a.CdcCacheRw.RUnlock()
if !ok {
return nil, errors.New("source connector missing from CdcCache")
}
if err := srcConn.ConnectionActive(ctx); err != nil {
return nil, err
}
defer connectors.CloseConnector(ctx, srcConn)

shutdown := utils.HeartbeatRoutine(ctx, func() string {
return fmt.Sprintf("transferring records for job - %s", flowName)
Expand Down
6 changes: 5 additions & 1 deletion flow/cmd/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"go.temporal.io/sdk/worker"

"github.com/PeerDB-io/peer-flow/activities"
"github.com/PeerDB-io/peer-flow/connectors"
utils "github.com/PeerDB-io/peer-flow/connectors/utils/catalog"
"github.com/PeerDB-io/peer-flow/logger"
"github.com/PeerDB-io/peer-flow/shared"
Expand Down Expand Up @@ -127,7 +128,9 @@ func WorkerMain(opts *WorkerOptions) error {
return queueErr
}

w := worker.New(c, taskQueue, worker.Options{})
w := worker.New(c, taskQueue, worker.Options{
EnableSessionWorker: true,
})
peerflow.RegisterFlowWorkerWorkflows(w)

alerter, err := alerting.NewAlerter(conn)
Expand All @@ -138,6 +141,7 @@ func WorkerMain(opts *WorkerOptions) error {
w.RegisterActivity(&activities.FlowableActivity{
CatalogPool: conn,
Alerter: alerter,
CdcCache: make(map[string]connectors.CDCPullConnector),
})

err = w.Run(worker.InterruptCh())
Expand Down
4 changes: 4 additions & 0 deletions flow/connectors/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ type CDCPullConnector interface {
*protos.EnsurePullabilityBatchOutput, error)

// Methods related to retrieving and pushing records for this connector as a source and destination.
SetupReplConn(context.Context) error

// Ping source to keep connection alive. Can be called concurrently with PullRecords; skips ping in that case.
ReplPing(context.Context) error

// PullRecords pulls records from the source, and returns a RecordBatch.
// This method should be idempotent, and should be able to be called multiple times with the same request.
Expand Down
91 changes: 7 additions & 84 deletions flow/connectors/postgres/cdc.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package connpostgres
import (
"context"
"crypto/sha256"
"errors"
"fmt"
"log/slog"
"time"
Expand All @@ -28,7 +27,6 @@ import (

type PostgresCDCSource struct {
*PostgresConnector
replConn *pgx.Conn
SrcTableIDNameMapping map[uint32]string
TableNameMapping map[string]model.NameAndExclude
slot string
Expand All @@ -46,7 +44,6 @@ type PostgresCDCSource struct {
}

type PostgresCDCConfig struct {
Connection *pgx.Conn
Slot string
Publication string
SrcTableIDNameMapping map[uint32]string
Expand All @@ -67,21 +64,20 @@ type startReplicationOpts struct {
func (c *PostgresConnector) NewPostgresCDCSource(cdcConfig *PostgresCDCConfig) *PostgresCDCSource {
return &PostgresCDCSource{
PostgresConnector: c,
replConn: cdcConfig.Connection,
SrcTableIDNameMapping: cdcConfig.SrcTableIDNameMapping,
TableNameMapping: cdcConfig.TableNameMapping,
slot: cdcConfig.Slot,
publication: cdcConfig.Publication,
relationMessageMapping: cdcConfig.RelationMessageMapping,
typeMap: pgtype.NewMap(),
childToParentRelIDMapping: cdcConfig.ChildToParentRelIDMap,
typeMap: pgtype.NewMap(),
commitLock: false,
catalogPool: cdcConfig.CatalogPool,
flowJobName: cdcConfig.FlowJobName,
}
}

func getChildToParentRelIDMap(ctx context.Context, conn *pgx.Conn) (map[uint32]uint32, error) {
func GetChildToParentRelIDMap(ctx context.Context, conn *pgx.Conn) (map[uint32]uint32, error) {
query := `
SELECT parent.oid AS parentrelid, child.oid AS childrelid
FROM pg_inherits
Expand All @@ -94,7 +90,6 @@ func getChildToParentRelIDMap(ctx context.Context, conn *pgx.Conn) (map[uint32]u
if err != nil {
return nil, fmt.Errorf("error querying for child to parent relid map: %w", err)
}

defer rows.Close()

childToParentRelIDMap := make(map[uint32]uint32)
Expand All @@ -113,85 +108,14 @@ func getChildToParentRelIDMap(ctx context.Context, conn *pgx.Conn) (map[uint32]u

// PullRecords pulls records from the cdc stream
func (p *PostgresCDCSource) PullRecords(ctx context.Context, req *model.PullRecordsRequest) error {
replicationOpts, err := p.replicationOptions()
if err != nil {
return fmt.Errorf("error getting replication options: %w", err)
}

pgConn := p.replConn.PgConn()

// start replication
var clientXLogPos, startLSN pglogrepl.LSN
if req.LastOffset > 0 {
p.logger.Info("starting replication from last sync state", slog.Int64("last checkpoint", req.LastOffset))
clientXLogPos = pglogrepl.LSN(req.LastOffset)
startLSN = clientXLogPos + 1
}

opts := startReplicationOpts{
conn: pgConn,
startLSN: startLSN,
replicationOpts: *replicationOpts,
}

err = p.startReplication(ctx, opts)
if err != nil {
return fmt.Errorf("error starting replication: %w", err)
}

p.logger.Info(fmt.Sprintf("started replication on slot %s at startLSN: %d", p.slot, startLSN))

return p.consumeStream(ctx, pgConn, req, clientXLogPos, req.RecordStream)
}

func (p *PostgresCDCSource) startReplication(ctx context.Context, opts startReplicationOpts) error {
err := pglogrepl.StartReplication(ctx, opts.conn, p.slot, opts.startLSN, opts.replicationOpts)
if err != nil {
p.logger.Error("error starting replication", slog.Any("error", err))
return fmt.Errorf("error starting replication at startLsn - %d: %w", opts.startLSN, err)
}

p.logger.Info(fmt.Sprintf("started replication on slot %s at startLSN: %d", p.slot, opts.startLSN))
return nil
}

func (p *PostgresCDCSource) replicationOptions() (*pglogrepl.StartReplicationOptions, error) {
pluginArguments := []string{
"proto_version '1'",
}

if p.publication != "" {
pubOpt := fmt.Sprintf("publication_names '%s'", p.publication)
pluginArguments = append(pluginArguments, pubOpt)
} else {
return nil, errors.New("publication name is not set")
}

return &pglogrepl.StartReplicationOptions{PluginArgs: pluginArguments}, nil
}

// start consuming the cdc stream
func (p *PostgresCDCSource) consumeStream(
ctx context.Context,
conn *pgconn.PgConn,
req *model.PullRecordsRequest,
clientXLogPos pglogrepl.LSN,
records *model.CDCRecordStream,
) error {
defer func() {
timeout, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
err := conn.Close(timeout)
if err != nil {
p.logger.Error("error closing replication connection", slog.Any("error", err))
}
cancel()
}()

conn := p.replConn.PgConn()
records := req.RecordStream
// clientXLogPos is the last checkpoint id, we need to ack that we have processed
// until clientXLogPos each time we send a standby status update.
// consumedXLogPos is the lsn that has been committed on the destination.
consumedXLogPos := pglogrepl.LSN(0)
if clientXLogPos > 0 {
var clientXLogPos, consumedXLogPos pglogrepl.LSN
if req.LastOffset > 0 {
clientXLogPos = pglogrepl.LSN(req.LastOffset)
consumedXLogPos = clientXLogPos

err := pglogrepl.SendStandbyStatusUpdate(ctx, conn,
Expand Down Expand Up @@ -300,7 +224,6 @@ func (p *PostgresCDCSource) consumeStream(

var receiveCtx context.Context
var cancel context.CancelFunc

if cdcRecordsStorage.IsEmpty() {
receiveCtx, cancel = context.WithCancel(ctx)
} else {
Expand Down
Loading

0 comments on commit 8f4ad4e

Please sign in to comment.