Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Maintain replication connection between sync flows #1211

Merged
merged 3 commits into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 71 additions & 4 deletions flow/activities/flowable.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ type SlotSnapshotSignal struct {
type FlowableActivity struct {
CatalogPool *pgxpool.Pool
Alerter *alerting.Alerter
CdcCacheRw sync.RWMutex
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe have a replication connection manager struct that takes care of:

  1. This map and locking.
  2. Keeping track of the connection health and lifecycle.
  3. Any other additional metadata pertaining to the replication connection.

Copy link
Contributor Author

@serprex serprex Feb 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tracking connection health & lifecycle are part of MaintainPull which needs to exist either way to keep heartbeating session. Additional metadata belongs in the connector, where we avoid contention on CdcCacheRw

I could see moving replState/replConn out of the connector, then storing a struct { replState, replConn, connector } as the value of the hashmap. But for now putting it all in connector works about the same

CdcCache map[string]connectors.CDCPullConnector
}

func (a *FlowableActivity) CheckConnection(
Expand Down Expand Up @@ -204,10 +206,71 @@ func (a *FlowableActivity) CreateNormalizedTable(
}, nil
}

func (a *FlowableActivity) MaintainPull(
ctx context.Context,
config *protos.FlowConnectionConfigs,
sessionID string,
) error {
srcConn, err := connectors.GetCDCPullConnector(ctx, config.Source)
if err != nil {
return err
}
defer connectors.CloseConnector(ctx, srcConn)

if err := srcConn.SetupReplConn(ctx); err != nil {
return err
}

a.CdcCacheRw.Lock()
a.CdcCache[sessionID] = srcConn
a.CdcCacheRw.Unlock()

ticker := time.NewTicker(15 * time.Second)
defer ticker.Stop()

for {
select {
case <-ticker.C:
activity.RecordHeartbeat(ctx, "keep session alive")
if err := srcConn.ReplPing(ctx); err != nil {
activity.GetLogger(ctx).Error("Failed to send keep alive ping to replication connection", slog.Any("error", err))
}
case <-ctx.Done():
a.CdcCacheRw.Lock()
delete(a.CdcCache, sessionID)
a.CdcCacheRw.Unlock()
return nil
}
}
}

func (a *FlowableActivity) WaitForSourceConnector(ctx context.Context, sessionID string) error {
logger := activity.GetLogger(ctx)
attempt := 0
for {
a.CdcCacheRw.RLock()
_, ok := a.CdcCache[sessionID]
a.CdcCacheRw.RUnlock()
if ok {
return nil
}
activity.RecordHeartbeat(ctx, "wait another second for source connector")
serprex marked this conversation as resolved.
Show resolved Hide resolved
attempt += 1
if attempt > 2 {
logger.Info("waiting on source connector setup", slog.Int("attempt", attempt))
}
if err := ctx.Err(); err != nil {
return err
}
time.Sleep(time.Second)
}
}

func (a *FlowableActivity) SyncFlow(
ctx context.Context,
config *protos.FlowConnectionConfigs,
options *protos.SyncFlowOptions,
sessionID string,
) (*model.SyncResponse, error) {
flowName := config.FlowJobName
ctx = context.WithValue(ctx, shared.FlowNameKey, flowName)
Expand All @@ -225,11 +288,15 @@ func (a *FlowableActivity) SyncFlow(
tblNameMapping[v.SourceTableIdentifier] = model.NewNameAndExclude(v.DestinationTableIdentifier, v.Exclude)
}

srcConn, err := connectors.GetCDCPullConnector(ctx, config.Source)
if err != nil {
return nil, fmt.Errorf("failed to get source connector: %w", err)
a.CdcCacheRw.RLock()
srcConn, ok := a.CdcCache[sessionID]
a.CdcCacheRw.RUnlock()
if !ok {
return nil, errors.New("source connector missing from CdcCache")
}
if err := srcConn.ConnectionActive(ctx); err != nil {
return nil, err
}
defer connectors.CloseConnector(ctx, srcConn)

shutdown := utils.HeartbeatRoutine(ctx, func() string {
return fmt.Sprintf("transferring records for job - %s", flowName)
Expand Down
6 changes: 5 additions & 1 deletion flow/cmd/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"go.temporal.io/sdk/worker"

"github.com/PeerDB-io/peer-flow/activities"
"github.com/PeerDB-io/peer-flow/connectors"
utils "github.com/PeerDB-io/peer-flow/connectors/utils/catalog"
"github.com/PeerDB-io/peer-flow/logger"
"github.com/PeerDB-io/peer-flow/shared"
Expand Down Expand Up @@ -127,7 +128,9 @@ func WorkerMain(opts *WorkerOptions) error {
return queueErr
}

w := worker.New(c, taskQueue, worker.Options{})
w := worker.New(c, taskQueue, worker.Options{
EnableSessionWorker: true,
})
peerflow.RegisterFlowWorkerWorkflows(w)

alerter, err := alerting.NewAlerter(conn)
Expand All @@ -138,6 +141,7 @@ func WorkerMain(opts *WorkerOptions) error {
w.RegisterActivity(&activities.FlowableActivity{
CatalogPool: conn,
Alerter: alerter,
CdcCache: make(map[string]connectors.CDCPullConnector),
})

err = w.Run(worker.InterruptCh())
Expand Down
4 changes: 4 additions & 0 deletions flow/connectors/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ type CDCPullConnector interface {
*protos.EnsurePullabilityBatchOutput, error)

// Methods related to retrieving and pushing records for this connector as a source and destination.
SetupReplConn(context.Context) error

// Ping source to keep connection alive. Can be called concurrently with PullRecords; skips ping in that case.
ReplPing(context.Context) error

// PullRecords pulls records from the source, and returns a RecordBatch.
// This method should be idempotent, and should be able to be called multiple times with the same request.
Expand Down
91 changes: 7 additions & 84 deletions flow/connectors/postgres/cdc.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package connpostgres
import (
"context"
"crypto/sha256"
"errors"
"fmt"
"log/slog"
"time"
Expand All @@ -28,7 +27,6 @@ import (

type PostgresCDCSource struct {
*PostgresConnector
replConn *pgx.Conn
SrcTableIDNameMapping map[uint32]string
TableNameMapping map[string]model.NameAndExclude
slot string
Expand All @@ -46,7 +44,6 @@ type PostgresCDCSource struct {
}

type PostgresCDCConfig struct {
Connection *pgx.Conn
Slot string
Publication string
SrcTableIDNameMapping map[uint32]string
Expand All @@ -67,21 +64,20 @@ type startReplicationOpts struct {
func (c *PostgresConnector) NewPostgresCDCSource(cdcConfig *PostgresCDCConfig) *PostgresCDCSource {
return &PostgresCDCSource{
PostgresConnector: c,
replConn: cdcConfig.Connection,
SrcTableIDNameMapping: cdcConfig.SrcTableIDNameMapping,
TableNameMapping: cdcConfig.TableNameMapping,
slot: cdcConfig.Slot,
publication: cdcConfig.Publication,
relationMessageMapping: cdcConfig.RelationMessageMapping,
typeMap: pgtype.NewMap(),
childToParentRelIDMapping: cdcConfig.ChildToParentRelIDMap,
typeMap: pgtype.NewMap(),
commitLock: false,
catalogPool: cdcConfig.CatalogPool,
flowJobName: cdcConfig.FlowJobName,
}
}

func getChildToParentRelIDMap(ctx context.Context, conn *pgx.Conn) (map[uint32]uint32, error) {
func GetChildToParentRelIDMap(ctx context.Context, conn *pgx.Conn) (map[uint32]uint32, error) {
query := `
SELECT parent.oid AS parentrelid, child.oid AS childrelid
FROM pg_inherits
Expand All @@ -94,7 +90,6 @@ func getChildToParentRelIDMap(ctx context.Context, conn *pgx.Conn) (map[uint32]u
if err != nil {
return nil, fmt.Errorf("error querying for child to parent relid map: %w", err)
}

defer rows.Close()

childToParentRelIDMap := make(map[uint32]uint32)
Expand All @@ -113,85 +108,14 @@ func getChildToParentRelIDMap(ctx context.Context, conn *pgx.Conn) (map[uint32]u

// PullRecords pulls records from the cdc stream
func (p *PostgresCDCSource) PullRecords(ctx context.Context, req *model.PullRecordsRequest) error {
replicationOpts, err := p.replicationOptions()
if err != nil {
return fmt.Errorf("error getting replication options: %w", err)
}

pgConn := p.replConn.PgConn()

// start replication
var clientXLogPos, startLSN pglogrepl.LSN
if req.LastOffset > 0 {
p.logger.Info("starting replication from last sync state", slog.Int64("last checkpoint", req.LastOffset))
clientXLogPos = pglogrepl.LSN(req.LastOffset)
startLSN = clientXLogPos + 1
}

opts := startReplicationOpts{
conn: pgConn,
startLSN: startLSN,
replicationOpts: *replicationOpts,
}

err = p.startReplication(ctx, opts)
if err != nil {
return fmt.Errorf("error starting replication: %w", err)
}

p.logger.Info(fmt.Sprintf("started replication on slot %s at startLSN: %d", p.slot, startLSN))

return p.consumeStream(ctx, pgConn, req, clientXLogPos, req.RecordStream)
}

func (p *PostgresCDCSource) startReplication(ctx context.Context, opts startReplicationOpts) error {
err := pglogrepl.StartReplication(ctx, opts.conn, p.slot, opts.startLSN, opts.replicationOpts)
if err != nil {
p.logger.Error("error starting replication", slog.Any("error", err))
return fmt.Errorf("error starting replication at startLsn - %d: %w", opts.startLSN, err)
}

p.logger.Info(fmt.Sprintf("started replication on slot %s at startLSN: %d", p.slot, opts.startLSN))
return nil
}

func (p *PostgresCDCSource) replicationOptions() (*pglogrepl.StartReplicationOptions, error) {
pluginArguments := []string{
"proto_version '1'",
}

if p.publication != "" {
pubOpt := fmt.Sprintf("publication_names '%s'", p.publication)
pluginArguments = append(pluginArguments, pubOpt)
} else {
return nil, errors.New("publication name is not set")
}

return &pglogrepl.StartReplicationOptions{PluginArgs: pluginArguments}, nil
}

// start consuming the cdc stream
func (p *PostgresCDCSource) consumeStream(
ctx context.Context,
conn *pgconn.PgConn,
req *model.PullRecordsRequest,
clientXLogPos pglogrepl.LSN,
records *model.CDCRecordStream,
) error {
defer func() {
timeout, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
err := conn.Close(timeout)
if err != nil {
p.logger.Error("error closing replication connection", slog.Any("error", err))
}
cancel()
}()

conn := p.replConn.PgConn()
records := req.RecordStream
// clientXLogPos is the last checkpoint id, we need to ack that we have processed
// until clientXLogPos each time we send a standby status update.
// consumedXLogPos is the lsn that has been committed on the destination.
consumedXLogPos := pglogrepl.LSN(0)
if clientXLogPos > 0 {
var clientXLogPos, consumedXLogPos pglogrepl.LSN
if req.LastOffset > 0 {
clientXLogPos = pglogrepl.LSN(req.LastOffset)
consumedXLogPos = clientXLogPos

err := pglogrepl.SendStandbyStatusUpdate(ctx, conn,
Expand Down Expand Up @@ -300,7 +224,6 @@ func (p *PostgresCDCSource) consumeStream(

var receiveCtx context.Context
var cancel context.CancelFunc

if cdcRecordsStorage.IsEmpty() {
receiveCtx, cancel = context.WithCancel(ctx)
} else {
Expand Down
Loading
Loading