Skip to content

Commit

Permalink
one-sync
Browse files Browse the repository at this point in the history
  • Loading branch information
serprex committed Feb 20, 2024
1 parent 8b591fd commit 265fa5c
Show file tree
Hide file tree
Showing 10 changed files with 311 additions and 228 deletions.
87 changes: 77 additions & 10 deletions flow/activities/flowable.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ type SlotSnapshotSignal struct {
type FlowableActivity struct {
CatalogPool *pgxpool.Pool
Alerter *alerting.Alerter
CdcCacheRw sync.RWMutex
CdcCache map[string]connectors.CDCPullConnector
}

func (a *FlowableActivity) CheckConnection(
Expand Down Expand Up @@ -204,7 +206,68 @@ func (a *FlowableActivity) CreateNormalizedTable(
}, nil
}

func (a *FlowableActivity) StartFlow(ctx context.Context,
func (a *FlowableActivity) MaintainPull(
ctx context.Context,
config *protos.FlowConnectionConfigs,
sessionID string,
) error {
srcConn, err := connectors.GetCDCPullConnector(ctx, config.Source)
if err != nil {
return err
}
defer connectors.CloseConnector(ctx, srcConn)

if err := srcConn.SetupReplConn(ctx); err != nil {
return err
}

a.CdcCacheRw.Lock()
a.CdcCache[sessionID] = srcConn
a.CdcCacheRw.Unlock()

ticker := time.NewTicker(15 * time.Second)
defer ticker.Stop()

for {
select {
case <-ticker.C:
activity.RecordHeartbeat(ctx, "keep session alive")
if err := srcConn.ReplPing(ctx); err != nil {
activity.GetLogger(ctx).Error("Failed to send keep alive ping to replication connection", slog.Any("error", err))
}
case <-ctx.Done():
a.CdcCacheRw.Lock()
delete(a.CdcCache, sessionID)
a.CdcCacheRw.Unlock()
return nil
}
}
}

func (a *FlowableActivity) WaitForSourceConnector(ctx context.Context, sessionID string) error {
logger := activity.GetLogger(ctx)
attempt := 0
for {
a.CdcCacheRw.RLock()
_, ok := a.CdcCache[sessionID]
a.CdcCacheRw.RUnlock()
if ok {
return nil
}
activity.RecordHeartbeat(ctx, "wait another second for source connector")
attempt += 1
if attempt > 2 {
logger.Info("waiting on source connector setup", slog.Int("attempt", attempt))
}
if err := ctx.Err(); err != nil {
return err
}
time.Sleep(time.Second)
}
}

func (a *FlowableActivity) StartFlow(
ctx context.Context,
input *protos.StartFlowInput,
) (*model.SyncResponse, error) {
ctx = context.WithValue(ctx, shared.FlowNameKey, input.FlowConnectionConfigs.FlowJobName)
Expand All @@ -223,11 +286,15 @@ func (a *FlowableActivity) StartFlow(ctx context.Context,
tblNameMapping[v.SourceTableIdentifier] = model.NewNameAndExclude(v.DestinationTableIdentifier, v.Exclude)
}

srcConn, err := connectors.GetCDCPullConnector(ctx, config.Source)
if err != nil {
return nil, fmt.Errorf("failed to get source connector: %w", err)
a.CdcCacheRw.RLock()
srcConn, ok := a.CdcCache[input.SessionId]
a.CdcCacheRw.RUnlock()
if !ok {
return nil, errors.New("source connector missing from CdcCache")
}
if err := srcConn.ConnectionActive(ctx); err != nil {
return nil, err
}
defer connectors.CloseConnector(ctx, srcConn)

shutdown := utils.HeartbeatRoutine(ctx, func() string {
jobName := input.FlowConnectionConfigs.FlowJobName
Expand All @@ -240,16 +307,16 @@ func (a *FlowableActivity) StartFlow(ctx context.Context,
batchSize = 1_000_000
}

lastOffset, err := dstConn.GetLastOffset(ctx, input.FlowConnectionConfigs.FlowJobName)
if err != nil {
return nil, err
}

// start a goroutine to pull records from the source
recordBatch := model.NewCDCRecordStream()
startTime := time.Now()
flowName := input.FlowConnectionConfigs.FlowJobName

lastOffset, err := dstConn.GetLastOffset(ctx, flowName)
if err != nil {
return nil, err
}

errGroup, errCtx := errgroup.WithContext(ctx)
errGroup.Go(func() error {
return srcConn.PullRecords(errCtx, a.CatalogPool, &model.PullRecordsRequest{
Expand Down
6 changes: 5 additions & 1 deletion flow/cmd/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"go.temporal.io/sdk/worker"

"github.com/PeerDB-io/peer-flow/activities"
"github.com/PeerDB-io/peer-flow/connectors"
utils "github.com/PeerDB-io/peer-flow/connectors/utils/catalog"
"github.com/PeerDB-io/peer-flow/logger"
"github.com/PeerDB-io/peer-flow/shared"
Expand Down Expand Up @@ -127,7 +128,9 @@ func WorkerMain(opts *WorkerOptions) error {
return queueErr
}

w := worker.New(c, taskQueue, worker.Options{})
w := worker.New(c, taskQueue, worker.Options{
EnableSessionWorker: true,
})
peerflow.RegisterFlowWorkerWorkflows(w)

alerter, err := alerting.NewAlerter(conn)
Expand All @@ -138,6 +141,7 @@ func WorkerMain(opts *WorkerOptions) error {
w.RegisterActivity(&activities.FlowableActivity{
CatalogPool: conn,
Alerter: alerter,
CdcCache: make(map[string]connectors.CDCPullConnector),
})

err = w.Run(worker.InterruptCh())
Expand Down
4 changes: 4 additions & 0 deletions flow/connectors/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ type CDCPullConnector interface {
*protos.EnsurePullabilityBatchOutput, error)

// Methods related to retrieving and pushing records for this connector as a source and destination.
SetupReplConn(context.Context) error

// Ping source to keep connection alive. Can be called concurrently with PullRecords; skips ping in that case.
ReplPing(context.Context) error

// PullRecords pulls records from the source, and returns a RecordBatch.
// This method should be idempotent, and should be able to be called multiple times with the same request.
Expand Down
91 changes: 7 additions & 84 deletions flow/connectors/postgres/cdc.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package connpostgres
import (
"context"
"crypto/sha256"
"errors"
"fmt"
"log/slog"
"time"
Expand All @@ -28,7 +27,6 @@ import (

type PostgresCDCSource struct {
*PostgresConnector
replConn *pgx.Conn
SrcTableIDNameMapping map[uint32]string
TableNameMapping map[string]model.NameAndExclude
slot string
Expand All @@ -46,7 +44,6 @@ type PostgresCDCSource struct {
}

type PostgresCDCConfig struct {
Connection *pgx.Conn
Slot string
Publication string
SrcTableIDNameMapping map[uint32]string
Expand All @@ -67,21 +64,20 @@ type startReplicationOpts struct {
func (c *PostgresConnector) NewPostgresCDCSource(cdcConfig *PostgresCDCConfig) *PostgresCDCSource {
return &PostgresCDCSource{
PostgresConnector: c,
replConn: cdcConfig.Connection,
SrcTableIDNameMapping: cdcConfig.SrcTableIDNameMapping,
TableNameMapping: cdcConfig.TableNameMapping,
slot: cdcConfig.Slot,
publication: cdcConfig.Publication,
relationMessageMapping: cdcConfig.RelationMessageMapping,
typeMap: pgtype.NewMap(),
childToParentRelIDMapping: cdcConfig.ChildToParentRelIDMap,
typeMap: pgtype.NewMap(),
commitLock: false,
catalogPool: cdcConfig.CatalogPool,
flowJobName: cdcConfig.FlowJobName,
}
}

func getChildToParentRelIDMap(ctx context.Context, conn *pgx.Conn) (map[uint32]uint32, error) {
func GetChildToParentRelIDMap(ctx context.Context, conn *pgx.Conn) (map[uint32]uint32, error) {
query := `
SELECT parent.oid AS parentrelid, child.oid AS childrelid
FROM pg_inherits
Expand All @@ -94,7 +90,6 @@ func getChildToParentRelIDMap(ctx context.Context, conn *pgx.Conn) (map[uint32]u
if err != nil {
return nil, fmt.Errorf("error querying for child to parent relid map: %w", err)
}

defer rows.Close()

childToParentRelIDMap := make(map[uint32]uint32)
Expand All @@ -113,85 +108,14 @@ func getChildToParentRelIDMap(ctx context.Context, conn *pgx.Conn) (map[uint32]u

// PullRecords pulls records from the cdc stream
func (p *PostgresCDCSource) PullRecords(ctx context.Context, req *model.PullRecordsRequest) error {
replicationOpts, err := p.replicationOptions()
if err != nil {
return fmt.Errorf("error getting replication options: %w", err)
}

pgConn := p.replConn.PgConn()

// start replication
var clientXLogPos, startLSN pglogrepl.LSN
if req.LastOffset > 0 {
p.logger.Info("starting replication from last sync state", slog.Int64("last checkpoint", req.LastOffset))
clientXLogPos = pglogrepl.LSN(req.LastOffset)
startLSN = clientXLogPos + 1
}

opts := startReplicationOpts{
conn: pgConn,
startLSN: startLSN,
replicationOpts: *replicationOpts,
}

err = p.startReplication(ctx, opts)
if err != nil {
return fmt.Errorf("error starting replication: %w", err)
}

p.logger.Info(fmt.Sprintf("started replication on slot %s at startLSN: %d", p.slot, startLSN))

return p.consumeStream(ctx, pgConn, req, clientXLogPos, req.RecordStream)
}

func (p *PostgresCDCSource) startReplication(ctx context.Context, opts startReplicationOpts) error {
err := pglogrepl.StartReplication(ctx, opts.conn, p.slot, opts.startLSN, opts.replicationOpts)
if err != nil {
p.logger.Error("error starting replication", slog.Any("error", err))
return fmt.Errorf("error starting replication at startLsn - %d: %w", opts.startLSN, err)
}

p.logger.Info(fmt.Sprintf("started replication on slot %s at startLSN: %d", p.slot, opts.startLSN))
return nil
}

func (p *PostgresCDCSource) replicationOptions() (*pglogrepl.StartReplicationOptions, error) {
pluginArguments := []string{
"proto_version '1'",
}

if p.publication != "" {
pubOpt := fmt.Sprintf("publication_names '%s'", p.publication)
pluginArguments = append(pluginArguments, pubOpt)
} else {
return nil, errors.New("publication name is not set")
}

return &pglogrepl.StartReplicationOptions{PluginArgs: pluginArguments}, nil
}

// start consuming the cdc stream
func (p *PostgresCDCSource) consumeStream(
ctx context.Context,
conn *pgconn.PgConn,
req *model.PullRecordsRequest,
clientXLogPos pglogrepl.LSN,
records *model.CDCRecordStream,
) error {
defer func() {
timeout, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
err := conn.Close(timeout)
if err != nil {
p.logger.Error("error closing replication connection", slog.Any("error", err))
}
cancel()
}()

conn := p.replConn.PgConn()
records := req.RecordStream
// clientXLogPos is the last checkpoint id, we need to ack that we have processed
// until clientXLogPos each time we send a standby status update.
// consumedXLogPos is the lsn that has been committed on the destination.
consumedXLogPos := pglogrepl.LSN(0)
if clientXLogPos > 0 {
var clientXLogPos, consumedXLogPos pglogrepl.LSN
if req.LastOffset > 0 {
clientXLogPos = pglogrepl.LSN(req.LastOffset)
consumedXLogPos = clientXLogPos

err := pglogrepl.SendStandbyStatusUpdate(ctx, conn,
Expand Down Expand Up @@ -300,7 +224,6 @@ func (p *PostgresCDCSource) consumeStream(

var receiveCtx context.Context
var cancel context.CancelFunc

if cdcRecordsStorage.IsEmpty() {
receiveCtx, cancel = context.WithCancel(ctx)
} else {
Expand Down
Loading

0 comments on commit 265fa5c

Please sign in to comment.