diff --git a/flow/cmd/peer_data.go b/flow/cmd/peer_data.go index f6e687435d..4ceb6307f0 100644 --- a/flow/cmd/peer_data.go +++ b/flow/cmd/peer_data.go @@ -210,7 +210,7 @@ func (h *FlowRequestHandler) GetSlotInfo( return &protos.PeerSlotResponse{SlotData: nil}, err } - pgConnector, err := connpostgres.NewPostgresConnector(ctx, pgConfig, false) + pgConnector, err := connpostgres.NewPostgresConnector(ctx, pgConfig) if err != nil { slog.Error("Failed to create postgres connector", slog.Any("error", err)) return &protos.PeerSlotResponse{SlotData: nil}, err @@ -237,7 +237,7 @@ func (h *FlowRequestHandler) GetStatInfo( return &protos.PeerStatResponse{StatData: nil}, err } - pgConnector, err := connpostgres.NewPostgresConnector(ctx, pgConfig, false) + pgConnector, err := connpostgres.NewPostgresConnector(ctx, pgConfig) if err != nil { slog.Error("Failed to create postgres connector", slog.Any("error", err)) return &protos.PeerStatResponse{StatData: nil}, err diff --git a/flow/connectors/core.go b/flow/connectors/core.go index 477b7cf46b..8d6a9520b3 100644 --- a/flow/connectors/core.go +++ b/flow/connectors/core.go @@ -137,7 +137,7 @@ func GetCDCPullConnector(ctx context.Context, config *protos.Peer) (CDCPullConne inner := config.Config switch inner.(type) { case *protos.Peer_PostgresConfig: - return connpostgres.NewPostgresConnector(ctx, config.GetPostgresConfig(), true) + return connpostgres.NewPostgresConnector(ctx, config.GetPostgresConfig()) default: return nil, ErrUnsupportedFunctionality } @@ -147,7 +147,7 @@ func GetCDCSyncConnector(ctx context.Context, config *protos.Peer) (CDCSyncConne inner := config.Config switch inner.(type) { case *protos.Peer_PostgresConfig: - return connpostgres.NewPostgresConnector(ctx, config.GetPostgresConfig(), false) + return connpostgres.NewPostgresConnector(ctx, config.GetPostgresConfig()) case *protos.Peer_BigqueryConfig: return connbigquery.NewBigQueryConnector(ctx, config.GetBigqueryConfig()) case *protos.Peer_SnowflakeConfig: @@ -169,7 +169,7 @@ func GetCDCNormalizeConnector(ctx context.Context, inner := config.Config switch inner.(type) { case *protos.Peer_PostgresConfig: - return connpostgres.NewPostgresConnector(ctx, config.GetPostgresConfig(), false) + return connpostgres.NewPostgresConnector(ctx, config.GetPostgresConfig()) case *protos.Peer_BigqueryConfig: return connbigquery.NewBigQueryConnector(ctx, config.GetBigqueryConfig()) case *protos.Peer_SnowflakeConfig: @@ -183,7 +183,7 @@ func GetQRepPullConnector(ctx context.Context, config *protos.Peer) (QRepPullCon inner := config.Config switch inner.(type) { case *protos.Peer_PostgresConfig: - return connpostgres.NewPostgresConnector(ctx, config.GetPostgresConfig(), false) + return connpostgres.NewPostgresConnector(ctx, config.GetPostgresConfig()) case *protos.Peer_SqlserverConfig: return connsqlserver.NewSQLServerConnector(ctx, config.GetSqlserverConfig()) default: @@ -195,7 +195,7 @@ func GetQRepSyncConnector(ctx context.Context, config *protos.Peer) (QRepSyncCon inner := config.Config switch inner.(type) { case *protos.Peer_PostgresConfig: - return connpostgres.NewPostgresConnector(ctx, config.GetPostgresConfig(), false) + return connpostgres.NewPostgresConnector(ctx, config.GetPostgresConfig()) case *protos.Peer_BigqueryConfig: return connbigquery.NewBigQueryConnector(ctx, config.GetBigqueryConfig()) case *protos.Peer_SnowflakeConfig: @@ -219,7 +219,7 @@ func GetConnector(ctx context.Context, peer *protos.Peer) (Connector, error) { // we can't decide if a PG peer should have replication permissions on it because we don't know // what the user wants to do with it, so defaulting to being permissive. // can be revisited in the future or we can use some UI wizardry. - return connpostgres.NewPostgresConnector(ctx, pgConfig, false) + return connpostgres.NewPostgresConnector(ctx, pgConfig) case protos.DBType_BIGQUERY: bqConfig := peer.GetBigqueryConfig() if bqConfig == nil { diff --git a/flow/connectors/postgres/client.go b/flow/connectors/postgres/client.go index 7e79a78d1d..dbbe2aa6c0 100644 --- a/flow/connectors/postgres/client.go +++ b/flow/connectors/postgres/client.go @@ -322,7 +322,12 @@ func (c *PostgresConnector) createSlotAndPublication( // create slot only after we succeeded in creating publication. if !s.SlotExists { - conn, err := c.replPool.Acquire(c.ctx) + pool, err := c.GetReplPool(c.ctx) + if err != nil { + return fmt.Errorf("[slot] error acquiring pool: %w", err) + } + + conn, err := pool.Acquire(c.ctx) if err != nil { return fmt.Errorf("[slot] error acquiring connection: %w", err) } diff --git a/flow/connectors/postgres/postgres.go b/flow/connectors/postgres/postgres.go index 7e44bb7b8a..c905d3d27f 100644 --- a/flow/connectors/postgres/postgres.go +++ b/flow/connectors/postgres/postgres.go @@ -5,6 +5,7 @@ import ( "fmt" "log/slog" "regexp" + "sync" "time" "github.com/PeerDB-io/peer-flow/connectors/utils" @@ -26,6 +27,8 @@ type PostgresConnector struct { ctx context.Context config *protos.PostgresConfig pool *SSHWrappedPostgresPool + replConfig *pgxpool.Config + replMutex sync.Mutex replPool *SSHWrappedPostgresPool tableSchemaMapping map[string]*protos.TableSchema customTypesMapping map[uint32]string @@ -34,12 +37,13 @@ type PostgresConnector struct { } // NewPostgresConnector creates a new instance of PostgresConnector. -func NewPostgresConnector(ctx context.Context, pgConfig *protos.PostgresConfig, initializeReplPool bool) (*PostgresConnector, error) { +func NewPostgresConnector(ctx context.Context, pgConfig *protos.PostgresConfig) (*PostgresConnector, error) { connectionString := utils.GetPGConnectionString(pgConfig) // create a separate connection pool for non-replication queries as replication connections cannot // be used for extended query protocol, i.e. prepared statements connConfig, err := pgxpool.ParseConfig(connectionString) + replConfig := connConfig.Copy() if err != nil { return nil, fmt.Errorf("failed to parse connection string: %w", err) } @@ -52,6 +56,11 @@ func NewPostgresConnector(ctx context.Context, pgConfig *protos.PostgresConfig, // set pool size to 3 to avoid connection pool exhaustion connConfig.MaxConns = 3 + // ensure that replication is set to database + replConfig.ConnConfig.RuntimeParams["replication"] = "database" + replConfig.ConnConfig.RuntimeParams["bytea_output"] = "hex" + replConfig.MaxConns = 1 + pool, err := NewSSHWrappedPostgresPool(ctx, connConfig, pgConfig.SshConfig) if err != nil { return nil, fmt.Errorf("failed to create connection pool: %w", err) @@ -62,25 +71,6 @@ func NewPostgresConnector(ctx context.Context, pgConfig *protos.PostgresConfig, return nil, fmt.Errorf("failed to get custom type map: %w", err) } - // only initialize for CDCPullConnector to reduce number of idle connections - var replPool *SSHWrappedPostgresPool - if initializeReplPool { - // ensure that replication is set to database - replConnConfig, err := pgxpool.ParseConfig(connectionString) - if err != nil { - return nil, fmt.Errorf("failed to parse connection string: %w", err) - } - - replConnConfig.ConnConfig.RuntimeParams["replication"] = "database" - replConnConfig.ConnConfig.RuntimeParams["bytea_output"] = "hex" - replConnConfig.MaxConns = 1 - - replPool, err = NewSSHWrappedPostgresPool(ctx, replConnConfig, pgConfig.SshConfig) - if err != nil { - return nil, fmt.Errorf("failed to create replication connection pool: %w", err) - } - } - metadataSchema := "_peerdb_internal" if pgConfig.MetadataSchema != nil { metadataSchema = *pgConfig.MetadataSchema @@ -93,18 +83,36 @@ func NewPostgresConnector(ctx context.Context, pgConfig *protos.PostgresConfig, ctx: ctx, config: pgConfig, pool: pool, - replPool: replPool, + replConfig: replConfig, + replPool: nil, customTypesMapping: customTypeMap, metadataSchema: metadataSchema, logger: *flowLog, }, nil } -// GetPool returns the connection pool. +// nil returns the connection pool. func (c *PostgresConnector) GetPool() *SSHWrappedPostgresPool { return c.pool } +func (c *PostgresConnector) GetReplPool(ctx context.Context) (*SSHWrappedPostgresPool, error) { + c.replMutex.Lock() + defer c.replMutex.Unlock() + + if c.replPool != nil { + return c.replPool, nil + } + + pool, err := NewSSHWrappedPostgresPool(ctx, c.replConfig, c.config.SshConfig) + if err != nil { + return nil, fmt.Errorf("failed to create replication connection pool: %w", err) + } + + c.replPool = pool + return pool, nil +} + // Close closes all connections. func (c *PostgresConnector) Close() error { if c.pool != nil { @@ -227,9 +235,14 @@ func (c *PostgresConnector) PullRecords(catalogPool *pgxpool.Pool, req *model.Pu c.logger.Info("PullRecords: performed checks for slot and publication") + replPool, err := c.GetReplPool(c.ctx) + if err != nil { + return err + } + cdc, err := NewPostgresCDCSource(&PostgresCDCConfig{ AppContext: c.ctx, - Connection: c.replPool.Pool, + Connection: replPool.Pool, SrcTableIDNameMapping: req.SrcTableIDNameMapping, Slot: slotName, Publication: publicationName, diff --git a/flow/connectors/postgres/postgres_repl_test.go b/flow/connectors/postgres/postgres_repl_test.go index df3a7de13f..b50a1f89fc 100644 --- a/flow/connectors/postgres/postgres_repl_test.go +++ b/flow/connectors/postgres/postgres_repl_test.go @@ -28,7 +28,7 @@ func (suite *PostgresReplicationSnapshotTestSuite) SetupSuite() { User: "postgres", Password: "postgres", Database: "postgres", - }, true) + }) require.NoError(suite.T(), err) setupTx, err := suite.connector.pool.Begin(context.Background()) diff --git a/flow/connectors/postgres/postgres_schema_delta_test.go b/flow/connectors/postgres/postgres_schema_delta_test.go index b817e7be51..0f0e52f41a 100644 --- a/flow/connectors/postgres/postgres_schema_delta_test.go +++ b/flow/connectors/postgres/postgres_schema_delta_test.go @@ -32,7 +32,7 @@ func SetupSuite(t *testing.T, g got.G) PostgresSchemaDeltaTestSuite { User: "postgres", Password: "postgres", Database: "postgres", - }, false) + }) require.NoError(t, err) setupTx, err := connector.pool.Begin(context.Background()) diff --git a/flow/e2e/postgres/qrep_flow_pg_test.go b/flow/e2e/postgres/qrep_flow_pg_test.go index c54980d55c..b6dac1e861 100644 --- a/flow/e2e/postgres/qrep_flow_pg_test.go +++ b/flow/e2e/postgres/qrep_flow_pg_test.go @@ -60,7 +60,7 @@ func SetupSuite(t *testing.T, g got.G) PeerFlowE2ETestSuitePG { User: "postgres", Password: "postgres", Database: "postgres", - }, false) + }) require.NoError(t, err) return PeerFlowE2ETestSuitePG{