Skip to content

Commit

Permalink
Avoid having multiple catalog connection pools
Browse files Browse the repository at this point in the history
Connection pools are meant to be shared, so remove pool from monitoring

Also go one step further: cache connection pool from env.go into global
This requires never closing pool returned by GetCatalogConnectionPoolFromEnv
  • Loading branch information
serprex committed Dec 11, 2023
1 parent 765cdba commit 5471d7b
Show file tree
Hide file tree
Showing 9 changed files with 114 additions and 128 deletions.
74 changes: 40 additions & 34 deletions flow/activities/flowable.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ import (
connpostgres "github.com/PeerDB-io/peer-flow/connectors/postgres"
connsnowflake "github.com/PeerDB-io/peer-flow/connectors/snowflake"
"github.com/PeerDB-io/peer-flow/connectors/utils"
catalog "github.com/PeerDB-io/peer-flow/connectors/utils/catalog"
"github.com/PeerDB-io/peer-flow/connectors/utils/monitoring"
"github.com/PeerDB-io/peer-flow/generated/protos"
"github.com/PeerDB-io/peer-flow/model"
"github.com/PeerDB-io/peer-flow/shared"
"github.com/jackc/pglogrepl"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/pgxpool"
log "github.com/sirupsen/logrus"
"go.temporal.io/sdk/activity"
"golang.org/x/sync/errgroup"
Expand All @@ -40,7 +40,7 @@ type SlotSnapshotSignal struct {
}

type FlowableActivity struct {
CatalogMirrorMonitor *monitoring.CatalogMirrorMonitor
CatalogPool *pgxpool.Pool
}

// CheckConnection implements CheckConnection.
Expand Down Expand Up @@ -114,7 +114,7 @@ func (a *FlowableActivity) CreateRawTable(
ctx context.Context,
config *protos.CreateRawTableInput,
) (*protos.CreateRawTableOutput, error) {
ctx = context.WithValue(ctx, shared.CDCMirrorMonitorKey, a.CatalogMirrorMonitor)
ctx = context.WithValue(ctx, shared.CDCMirrorMonitorKey, a.CatalogPool)
dstConn, err := connectors.GetCDCSyncConnector(ctx, config.PeerConnectionConfig)
if err != nil {
return nil, fmt.Errorf("failed to get connector: %w", err)
Expand All @@ -125,7 +125,7 @@ func (a *FlowableActivity) CreateRawTable(
if err != nil {
return nil, err
}
err = a.CatalogMirrorMonitor.InitializeCDCFlow(ctx, config.FlowJobName)
err = monitoring.InitializeCDCFlow(ctx, a.CatalogPool, config.FlowJobName)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -174,7 +174,7 @@ func (a *FlowableActivity) handleSlotInfo(
}

if len(slotInfo) != 0 {
return a.CatalogMirrorMonitor.AppendSlotSizeInfo(ctx, peerName, slotInfo[0])
return monitoring.AppendSlotSizeInfo(ctx, a.CatalogPool, peerName, slotInfo[0])
}
return nil
}
Expand Down Expand Up @@ -209,7 +209,7 @@ func (a *FlowableActivity) StartFlow(ctx context.Context,
input *protos.StartFlowInput) (*model.SyncResponse, error) {
activity.RecordHeartbeat(ctx, "starting flow...")
conn := input.FlowConnectionConfigs
ctx = context.WithValue(ctx, shared.CDCMirrorMonitorKey, a.CatalogMirrorMonitor)
ctx = context.WithValue(ctx, shared.CDCMirrorMonitorKey, a.CatalogPool)
dstConn, err := connectors.GetCDCSyncConnector(ctx, conn.Destination)
if err != nil {
return nil, fmt.Errorf("failed to get destination connector: %w", err)
Expand Down Expand Up @@ -276,13 +276,13 @@ func (a *FlowableActivity) StartFlow(ctx context.Context,
"flowName": input.FlowConnectionConfigs.FlowJobName,
}).Infof("the current sync flow has records: %v", hasRecords)

if a.CatalogMirrorMonitor.IsActive() && hasRecords {
if a.CatalogPool != nil && hasRecords {
syncBatchID, err := dstConn.GetLastSyncBatchID(input.FlowConnectionConfigs.FlowJobName)
if err != nil && conn.Destination.Type != protos.DBType_EVENTHUB {
return nil, err
}

err = a.CatalogMirrorMonitor.AddCDCBatchForFlow(ctx, input.FlowConnectionConfigs.FlowJobName,
err = monitoring.AddCDCBatchForFlow(ctx, a.CatalogPool, input.FlowConnectionConfigs.FlowJobName,
monitoring.CDCBatchInfo{
BatchID: syncBatchID + 1,
RowsInBatch: 0,
Expand Down Expand Up @@ -347,8 +347,9 @@ func (a *FlowableActivity) StartFlow(ctx context.Context,
return nil, fmt.Errorf("failed to get last checkpoint: %w", err)
}

err = a.CatalogMirrorMonitor.UpdateNumRowsAndEndLSNForCDCBatch(
err = monitoring.UpdateNumRowsAndEndLSNForCDCBatch(
ctx,
a.CatalogPool,
input.FlowConnectionConfigs.FlowJobName,
res.CurrentSyncBatchID,
uint32(numRecords),
Expand All @@ -358,13 +359,17 @@ func (a *FlowableActivity) StartFlow(ctx context.Context,
return nil, err
}

err = a.CatalogMirrorMonitor.
UpdateLatestLSNAtTargetForCDCFlow(ctx, input.FlowConnectionConfigs.FlowJobName, pglogrepl.LSN(lastCheckpoint))
err = monitoring.UpdateLatestLSNAtTargetForCDCFlow(
ctx,
a.CatalogPool,
input.FlowConnectionConfigs.FlowJobName,
pglogrepl.LSN(lastCheckpoint),
)
if err != nil {
return nil, err
}
if res.TableNameRowsMapping != nil {
err = a.CatalogMirrorMonitor.AddCDCBatchTablesForFlow(ctx, input.FlowConnectionConfigs.FlowJobName,
err = monitoring.AddCDCBatchTablesForFlow(ctx, a.CatalogPool, input.FlowConnectionConfigs.FlowJobName,
res.CurrentSyncBatchID, res.TableNameRowsMapping)
if err != nil {
return nil, err
Expand Down Expand Up @@ -401,7 +406,7 @@ func (a *FlowableActivity) StartNormalize(
return nil, fmt.Errorf("failed to get last sync batch ID: %v", err)
}

err = a.CatalogMirrorMonitor.UpdateEndTimeForCDCBatch(ctx, input.FlowConnectionConfigs.FlowJobName,
err = monitoring.UpdateEndTimeForCDCBatch(ctx, a.CatalogPool, input.FlowConnectionConfigs.FlowJobName,
lastSyncBatchID)
return nil, err
} else if err != nil {
Expand Down Expand Up @@ -434,8 +439,12 @@ func (a *FlowableActivity) StartNormalize(

// normalize flow did not run due to no records, no need to update end time.
if res.Done {
err = a.CatalogMirrorMonitor.UpdateEndTimeForCDCBatch(ctx, input.FlowConnectionConfigs.FlowJobName,
res.EndBatchID)
err = monitoring.UpdateEndTimeForCDCBatch(
ctx,
a.CatalogPool,
input.FlowConnectionConfigs.FlowJobName,
res.EndBatchID,
)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -500,8 +509,9 @@ func (a *FlowableActivity) GetQRepPartitions(ctx context.Context,
return nil, fmt.Errorf("failed to get partitions from source: %w", err)
}
if len(partitions) > 0 {
err = a.CatalogMirrorMonitor.InitializeQRepRun(
err = monitoring.InitializeQRepRun(
ctx,
a.CatalogPool,
config,
runUUID,
partitions,
Expand All @@ -522,7 +532,7 @@ func (a *FlowableActivity) ReplicateQRepPartitions(ctx context.Context,
partitions *protos.QRepPartitionBatch,
runUUID string,
) error {
err := a.CatalogMirrorMonitor.UpdateStartTimeForQRepRun(ctx, runUUID)
err := monitoring.UpdateStartTimeForQRepRun(ctx, a.CatalogPool, runUUID)
if err != nil {
return fmt.Errorf("failed to update start time for qrep run: %w", err)
}
Expand All @@ -549,7 +559,7 @@ func (a *FlowableActivity) replicateQRepPartition(ctx context.Context,
partition *protos.QRepPartition,
runUUID string,
) error {
err := a.CatalogMirrorMonitor.UpdateStartTimeForPartition(ctx, runUUID, partition, time.Now())
err := monitoring.UpdateStartTimeForPartition(ctx, a.CatalogPool, runUUID, partition, time.Now())
if err != nil {
return fmt.Errorf("failed to update start time for partition: %w", err)
}
Expand Down Expand Up @@ -587,7 +597,7 @@ func (a *FlowableActivity) replicateQRepPartition(ctx context.Context,
}).Errorf("failed to pull records: %v", err)
goroutineErr = err
} else {
err = a.CatalogMirrorMonitor.UpdatePullEndTimeAndRowsForPartition(ctx, runUUID, partition, numRecords)
err = monitoring.UpdatePullEndTimeAndRowsForPartition(ctx, a.CatalogPool, runUUID, partition, numRecords)
if err != nil {
log.Errorf("%v", err)
goroutineErr = err
Expand All @@ -607,7 +617,7 @@ func (a *FlowableActivity) replicateQRepPartition(ctx context.Context,
"flowName": config.FlowJobName,
}).Infof("pulled %d records\n", len(recordBatch.Records))

err = a.CatalogMirrorMonitor.UpdatePullEndTimeAndRowsForPartition(ctx, runUUID, partition, numRecords)
err = monitoring.UpdatePullEndTimeAndRowsForPartition(ctx, a.CatalogPool, runUUID, partition, numRecords)
if err != nil {
return err
}
Expand Down Expand Up @@ -645,7 +655,7 @@ func (a *FlowableActivity) replicateQRepPartition(ctx context.Context,
}).Infof("pushed %d records\n", res)
}

err = a.CatalogMirrorMonitor.UpdateEndTimeForPartition(ctx, runUUID, partition)
err = monitoring.UpdateEndTimeForPartition(ctx, a.CatalogPool, runUUID, partition)
if err != nil {
return err
}
Expand All @@ -657,7 +667,7 @@ func (a *FlowableActivity) ConsolidateQRepPartitions(ctx context.Context, config
runUUID string) error {
dstConn, err := connectors.GetQRepConsolidateConnector(ctx, config.DestinationPeer)
if errors.Is(err, connectors.ErrUnsupportedFunctionality) {
return a.CatalogMirrorMonitor.UpdateEndTimeForQRepRun(ctx, runUUID)
return monitoring.UpdateEndTimeForQRepRun(ctx, a.CatalogPool, runUUID)
} else if err != nil {
return err
}
Expand All @@ -675,7 +685,7 @@ func (a *FlowableActivity) ConsolidateQRepPartitions(ctx context.Context, config
return err
}

return a.CatalogMirrorMonitor.UpdateEndTimeForQRepRun(ctx, runUUID)
return monitoring.UpdateEndTimeForQRepRun(ctx, a.CatalogPool, runUUID)
}

func (a *FlowableActivity) CleanupQRepFlow(ctx context.Context, config *protos.QRepConfig) error {
Expand Down Expand Up @@ -713,12 +723,8 @@ func (a *FlowableActivity) DropFlow(ctx context.Context, config *protos.Shutdown
return nil
}

func getPostgresPeerConfigs(ctx context.Context) ([]*protos.Peer, error) {
catalogPool, catalogErr := catalog.GetCatalogConnectionPoolFromEnv()
if catalogErr != nil {
return nil, fmt.Errorf("error getting catalog connection pool: %w", catalogErr)
}
defer catalogPool.Close()
func (a *FlowableActivity) getPostgresPeerConfigs(ctx context.Context) ([]*protos.Peer, error) {
catalogPool := a.CatalogPool

optionRows, err := catalogPool.Query(ctx, `
SELECT DISTINCT p.name, p.options
Expand Down Expand Up @@ -762,7 +768,7 @@ func (a *FlowableActivity) SendWALHeartbeat(ctx context.Context) error {
log.Info("context is done, exiting wal heartbeat send loop")
return nil
case <-ticker.C:
pgPeers, err := getPostgresPeerConfigs(ctx)
pgPeers, err := a.getPostgresPeerConfigs(ctx)
if err != nil {
log.Warn("[sendwalheartbeat]: warning: unable to fetch peers." +
"Skipping walheartbeat send. error encountered: " + err.Error())
Expand Down Expand Up @@ -944,17 +950,17 @@ func (a *FlowableActivity) ReplicateXminPartition(ctx context.Context,
}},
}
}
updateErr := a.CatalogMirrorMonitor.InitializeQRepRun(ctx, config, runUUID, []*protos.QRepPartition{partitionForMetrics})
updateErr := monitoring.InitializeQRepRun(ctx, a.CatalogPool, config, runUUID, []*protos.QRepPartition{partitionForMetrics})
if updateErr != nil {
return updateErr
}

err := a.CatalogMirrorMonitor.UpdateStartTimeForPartition(ctx, runUUID, partition, startTime)
err := monitoring.UpdateStartTimeForPartition(ctx, a.CatalogPool, runUUID, partition, startTime)
if err != nil {
return fmt.Errorf("failed to update start time for partition: %w", err)
}

err = a.CatalogMirrorMonitor.UpdatePullEndTimeAndRowsForPartition(errCtx, runUUID, partition, int64(numRecords))
err = monitoring.UpdatePullEndTimeAndRowsForPartition(errCtx, a.CatalogPool, runUUID, partition, int64(numRecords))
if err != nil {
log.Errorf("%v", err)
return err
Expand Down Expand Up @@ -990,7 +996,7 @@ func (a *FlowableActivity) ReplicateXminPartition(ctx context.Context,
}).Infof("pushed %d records\n", res)
}

err = a.CatalogMirrorMonitor.UpdateEndTimeForPartition(ctx, runUUID, partition)
err = monitoring.UpdateEndTimeForPartition(ctx, a.CatalogPool, runUUID, partition)
if err != nil {
return 0, err
}
Expand Down
1 change: 0 additions & 1 deletion flow/cmd/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,6 @@ func APIMain(args *APIServerParams) error {
}

flowHandler := NewFlowRequestHandler(tc, catalogConn, taskQueue)
defer flowHandler.Close()

err = killExistingHeartbeatFlows(ctx, tc, args.TemporalNamespace, taskQueue)
if err != nil {
Expand Down
7 changes: 0 additions & 7 deletions flow/cmd/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,6 @@ func (h *FlowRequestHandler) createQrepJobEntry(ctx context.Context,
return nil
}

// Close closes the connection pool
func (h *FlowRequestHandler) Close() {
if h.pool != nil {
h.pool.Close()
}
}

func (h *FlowRequestHandler) CreateCDCFlow(
ctx context.Context, req *protos.CreateCDCFlowRequest) (*protos.CreateCDCFlowResponse, error) {
cfg := req.ConnectionConfigs
Expand Down
5 changes: 1 addition & 4 deletions flow/cmd/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (

"github.com/PeerDB-io/peer-flow/activities"
utils "github.com/PeerDB-io/peer-flow/connectors/utils/catalog"
"github.com/PeerDB-io/peer-flow/connectors/utils/monitoring"
"github.com/PeerDB-io/peer-flow/shared"
peerflow "github.com/PeerDB-io/peer-flow/workflows"

Expand Down Expand Up @@ -108,8 +107,6 @@ func WorkerMain(opts *WorkerOptions) error {
if err != nil {
return fmt.Errorf("unable to create catalog connection pool: %w", err)
}
catalogMirrorMonitor := monitoring.NewCatalogMirrorMonitor(conn)
defer catalogMirrorMonitor.Close()

c, err := client.Dial(clientOptions)
if err != nil {
Expand All @@ -134,7 +131,7 @@ func WorkerMain(opts *WorkerOptions) error {
w.RegisterWorkflow(peerflow.DropFlowWorkflow)
w.RegisterWorkflow(peerflow.HeartbeatFlowWorkflow)
w.RegisterActivity(&activities.FlowableActivity{
CatalogMirrorMonitor: catalogMirrorMonitor,
CatalogPool: conn,
})

err = w.Run(worker.InterruptCh())
Expand Down
1 change: 0 additions & 1 deletion flow/connectors/bigquery/bigquery.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,6 @@ func (c *BigQueryConnector) Close() error {
if c == nil || c.client == nil {
return nil
}
c.catalogPool.Close()
return c.client.Close()
}

Expand Down
2 changes: 1 addition & 1 deletion flow/connectors/external_metadata/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func NewPostgresMetadataStore(ctx context.Context, pgConfig *protos.PostgresConf
}

func (p *PostgresMetadataStore) Close() error {
if p.pool != nil {
if p.config != nil && p.pool != nil {
p.pool.Close()
}

Expand Down
4 changes: 2 additions & 2 deletions flow/connectors/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,13 +246,13 @@ func (c *PostgresConnector) PullRecords(req *model.PullRecordsRequest) error {
return err
}

cdcMirrorMonitor, ok := c.ctx.Value(shared.CDCMirrorMonitorKey).(*monitoring.CatalogMirrorMonitor)
catalogPool, ok := c.ctx.Value(shared.CDCMirrorMonitorKey).(*pgxpool.Pool)
if ok {
latestLSN, err := c.getCurrentLSN()
if err != nil {
return fmt.Errorf("failed to get current LSN: %w", err)
}
err = cdcMirrorMonitor.UpdateLatestLSNAtSourceForCDCFlow(c.ctx, req.FlowJobName, latestLSN)
err = monitoring.UpdateLatestLSNAtSourceForCDCFlow(c.ctx, catalogPool, req.FlowJobName, latestLSN)
if err != nil {
return fmt.Errorf("failed to update latest LSN at source for CDC flow: %w", err)
}
Expand Down
28 changes: 18 additions & 10 deletions flow/connectors/utils/catalog/env.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,37 @@ import (
"fmt"
"os"
"strconv"
"sync"

"github.com/PeerDB-io/peer-flow/connectors/utils"
"github.com/PeerDB-io/peer-flow/generated/protos"
"github.com/jackc/pgx/v5/pgxpool"
)

var poolMutex = &sync.Mutex{}
var pool *pgxpool.Pool

func GetCatalogConnectionPoolFromEnv() (*pgxpool.Pool, error) {
catalogConnectionString, err := genCatalogConnectionString()
if err != nil {
return nil, fmt.Errorf("unable to generate catalog connection string: %w", err)
}
poolMutex.Lock()
defer poolMutex.Unlock()
if pool != nil {
catalogConnectionString, err := genCatalogConnectionString()
if err != nil {
return nil, fmt.Errorf("unable to generate catalog connection string: %w", err)
}

catalogConn, err := pgxpool.New(context.Background(), catalogConnectionString)
if err != nil {
return nil, fmt.Errorf("unable to establish connection with catalog: %w", err)
pool, err = pgxpool.New(context.Background(), catalogConnectionString)
if err != nil {
return nil, fmt.Errorf("unable to establish connection with catalog: %w", err)
}
}

err = catalogConn.Ping(context.Background())
err := pool.Ping(context.Background())
if err != nil {
return nil, fmt.Errorf("unable to establish connection with catalog: %w", err)
return pool, fmt.Errorf("unable to establish connection with catalog: %w", err)
}

return catalogConn, nil
return pool, nil
}

func genCatalogConnectionString() (string, error) {
Expand Down
Loading

0 comments on commit 5471d7b

Please sign in to comment.