diff --git a/flow/cmd/peer_data.go b/flow/cmd/peer_data.go index 110b9b5a7f..0bc5f7d245 100644 --- a/flow/cmd/peer_data.go +++ b/flow/cmd/peer_data.go @@ -8,11 +8,9 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgxpool" "google.golang.org/protobuf/proto" connpostgres "github.com/PeerDB-io/peer-flow/connectors/postgres" - "github.com/PeerDB-io/peer-flow/connectors/utils" "github.com/PeerDB-io/peer-flow/generated/protos" ) @@ -33,17 +31,19 @@ func (h *FlowRequestHandler) getPGPeerConfig(ctx context.Context, peerName strin return &pgPeerConfig, nil } -func (h *FlowRequestHandler) getPoolForPGPeer(ctx context.Context, peerName string) (*pgxpool.Pool, error) { +func (h *FlowRequestHandler) getPoolForPGPeer(ctx context.Context, peerName string) (*connpostgres.SSHWrappedPostgresPool, error) { pgPeerConfig, err := h.getPGPeerConfig(ctx, peerName) if err != nil { return nil, err } - connStr := utils.GetPGConnectionString(pgPeerConfig) - peerPool, err := pgxpool.New(ctx, connStr) + + pool, err := connpostgres.NewSSHWrappedPostgresPoolFromConfig(ctx, pgPeerConfig) if err != nil { + slog.Error("Failed to create postgres pool", slog.Any("error", err)) return nil, err } - return peerPool, nil + + return pool, nil } func (h *FlowRequestHandler) GetSchemas( diff --git a/flow/connectors/postgres/postgres.go b/flow/connectors/postgres/postgres.go index 57f8944fec..79c6fb29a0 100644 --- a/flow/connectors/postgres/postgres.go +++ b/flow/connectors/postgres/postgres.go @@ -55,15 +55,16 @@ func NewPostgresConnector(ctx context.Context, pgConfig *protos.PostgresConfig) // set pool size to 3 to avoid connection pool exhaustion connConfig.MaxConns = 3 - // ensure that replication is set to database - replConfig.ConnConfig.RuntimeParams["replication"] = "database" - replConfig.ConnConfig.RuntimeParams["bytea_output"] = "hex" - replConfig.MaxConns = 1 pool, err := NewSSHWrappedPostgresPool(ctx, connConfig, pgConfig.SshConfig) if err != nil { return nil, fmt.Errorf("failed to create connection pool: %w", err) } + // ensure that replication is set to database + replConfig.ConnConfig.RuntimeParams["replication"] = "database" + replConfig.ConnConfig.RuntimeParams["bytea_output"] = "hex" + replConfig.MaxConns = 1 + customTypeMap, err := utils.GetCustomDataTypes(ctx, pool.Pool) if err != nil { return nil, fmt.Errorf("failed to get custom type map: %w", err) diff --git a/flow/connectors/postgres/ssh_wrapped_pool.go b/flow/connectors/postgres/ssh_wrapped_pool.go index 4f17116ea4..4dcd2cd0ce 100644 --- a/flow/connectors/postgres/ssh_wrapped_pool.go +++ b/flow/connectors/postgres/ssh_wrapped_pool.go @@ -27,6 +27,23 @@ type SSHWrappedPostgresPool struct { cancel context.CancelFunc } +func NewSSHWrappedPostgresPoolFromConfig( + ctx context.Context, + pgConfig *protos.PostgresConfig, +) (*SSHWrappedPostgresPool, error) { + connectionString := utils.GetPGConnectionString(pgConfig) + + connConfig, err := pgxpool.ParseConfig(connectionString) + if err != nil { + return nil, err + } + + // set pool size to 3 to avoid connection pool exhaustion + connConfig.MaxConns = 3 + + return NewSSHWrappedPostgresPool(ctx, connConfig, pgConfig.SshConfig) +} + func NewSSHWrappedPostgresPool( ctx context.Context, poolConfig *pgxpool.Config,