diff --git a/flow/cmd/peer_data.go b/flow/cmd/peer_data.go index 0bc5f7d24..4a3bf534a 100644 --- a/flow/cmd/peer_data.go +++ b/flow/cmd/peer_data.go @@ -31,32 +31,39 @@ func (h *FlowRequestHandler) getPGPeerConfig(ctx context.Context, peerName strin return &pgPeerConfig, nil } -func (h *FlowRequestHandler) getPoolForPGPeer(ctx context.Context, peerName string) (*connpostgres.SSHWrappedPostgresPool, error) { +func (h *FlowRequestHandler) getConnForPGPeer(ctx context.Context, peerName string) (*connpostgres.SSHTunnel, *pgx.Conn, error) { pgPeerConfig, err := h.getPGPeerConfig(ctx, peerName) if err != nil { - return nil, err + return nil, nil, err } - pool, err := connpostgres.NewSSHWrappedPostgresPoolFromConfig(ctx, pgPeerConfig) + tunnel, err := connpostgres.NewSSHTunnel(ctx, pgPeerConfig.SshConfig) if err != nil { slog.Error("Failed to create postgres pool", slog.Any("error", err)) - return nil, err + return nil, nil, err } - return pool, nil + conn, err := tunnel.NewPostgresConnFromPostgresConfig(ctx, pgPeerConfig) + if err != nil { + tunnel.Close() + return nil, nil, err + } + + return tunnel, conn, nil } func (h *FlowRequestHandler) GetSchemas( ctx context.Context, req *protos.PostgresPeerActivityInfoRequest, ) (*protos.PeerSchemasResponse, error) { - peerPool, err := h.getPoolForPGPeer(ctx, req.PeerName) + tunnel, peerConn, err := h.getConnForPGPeer(ctx, req.PeerName) if err != nil { return &protos.PeerSchemasResponse{Schemas: nil}, err } + defer tunnel.Close() + defer peerConn.Close(ctx) - defer peerPool.Close() - rows, err := peerPool.Query(ctx, "SELECT schema_name"+ + rows, err := peerConn.Query(ctx, "SELECT schema_name"+ " FROM information_schema.schemata WHERE schema_name !~ '^pg_' AND schema_name <> 'information_schema';") if err != nil { return &protos.PeerSchemasResponse{Schemas: nil}, err @@ -73,13 +80,14 @@ func (h *FlowRequestHandler) GetTablesInSchema( ctx context.Context, req *protos.SchemaTablesRequest, ) (*protos.SchemaTablesResponse, error) { - peerPool, err := h.getPoolForPGPeer(ctx, req.PeerName) + tunnel, peerConn, err := h.getConnForPGPeer(ctx, req.PeerName) if err != nil { return &protos.SchemaTablesResponse{Tables: nil}, err } + defer tunnel.Close() + defer peerConn.Close(ctx) - defer peerPool.Close() - rows, err := peerPool.Query(ctx, `SELECT DISTINCT ON (t.relname) + rows, err := peerConn.Query(ctx, `SELECT DISTINCT ON (t.relname) t.relname, CASE WHEN con.contype = 'p' OR t.relreplident = 'i' OR t.relreplident = 'f' THEN true @@ -130,13 +138,14 @@ func (h *FlowRequestHandler) GetAllTables( ctx context.Context, req *protos.PostgresPeerActivityInfoRequest, ) (*protos.AllTablesResponse, error) { - peerPool, err := h.getPoolForPGPeer(ctx, req.PeerName) + tunnel, peerConn, err := h.getConnForPGPeer(ctx, req.PeerName) if err != nil { return &protos.AllTablesResponse{Tables: nil}, err } + defer tunnel.Close() + defer peerConn.Close(ctx) - defer peerPool.Close() - rows, err := peerPool.Query(ctx, "SELECT table_schema || '.' || table_name AS schema_table "+ + rows, err := peerConn.Query(ctx, "SELECT table_schema || '.' || table_name AS schema_table "+ "FROM information_schema.tables WHERE table_schema !~ '^pg_' AND table_schema <> 'information_schema'") if err != nil { return &protos.AllTablesResponse{Tables: nil}, err @@ -160,13 +169,14 @@ func (h *FlowRequestHandler) GetColumns( ctx context.Context, req *protos.TableColumnsRequest, ) (*protos.TableColumnsResponse, error) { - peerPool, err := h.getPoolForPGPeer(ctx, req.PeerName) + tunnel, peerConn, err := h.getConnForPGPeer(ctx, req.PeerName) if err != nil { return &protos.TableColumnsResponse{Columns: nil}, err } + defer tunnel.Close() + defer peerConn.Close(ctx) - defer peerPool.Close() - rows, err := peerPool.Query(ctx, ` + rows, err := peerConn.Query(ctx, ` SELECT cols.column_name, cols.data_type, @@ -240,22 +250,16 @@ func (h *FlowRequestHandler) GetStatInfo( ctx context.Context, req *protos.PostgresPeerActivityInfoRequest, ) (*protos.PeerStatResponse, error) { - pgConfig, err := h.getPGPeerConfig(ctx, req.PeerName) + tunnel, peerConn, err := h.getConnForPGPeer(ctx, req.PeerName) if err != nil { return &protos.PeerStatResponse{StatData: nil}, err } + defer tunnel.Close() + defer peerConn.Close(ctx) - pgConnector, err := connpostgres.NewPostgresConnector(ctx, pgConfig) - if err != nil { - slog.Error("Failed to create postgres connector", slog.Any("error", err)) - return &protos.PeerStatResponse{StatData: nil}, err - } - defer pgConnector.Close() - - peerPool := pgConnector.GetPool() - peerUser := pgConfig.User + peerUser := peerConn.Config().User - rows, err := peerPool.Query(ctx, "SELECT pid, wait_event, wait_event_type, query_start::text, query,"+ + rows, err := peerConn.Query(ctx, "SELECT pid, wait_event, wait_event_type, query_start::text, query,"+ "EXTRACT(epoch FROM(now()-query_start)) AS dur"+ " FROM pg_stat_activity WHERE "+ "usename=$1 AND state != 'idle';", peerUser) diff --git a/flow/connectors/postgres/cdc.go b/flow/connectors/postgres/cdc.go index c3fd39352..50e20a93f 100644 --- a/flow/connectors/postgres/cdc.go +++ b/flow/connectors/postgres/cdc.go @@ -30,7 +30,7 @@ const maxRetriesForWalSegmentRemoved = 5 type PostgresCDCSource struct { ctx context.Context - replPool *pgxpool.Pool + replConn *pgx.Conn SrcTableIDNameMapping map[uint32]string TableNameMapping map[string]model.NameAndExclude slot string @@ -54,7 +54,7 @@ type PostgresCDCSource struct { type PostgresCDCConfig struct { AppContext context.Context - Connection *pgxpool.Pool + Connection *pgx.Conn Slot string Publication string SrcTableIDNameMapping map[uint32]string @@ -84,7 +84,7 @@ func NewPostgresCDCSource(cdcConfig *PostgresCDCConfig, customTypeMap map[uint32 flowName, _ := cdcConfig.AppContext.Value(shared.FlowNameKey).(string) return &PostgresCDCSource{ ctx: cdcConfig.AppContext, - replPool: cdcConfig.Connection, + replConn: cdcConfig.Connection, SrcTableIDNameMapping: cdcConfig.SrcTableIDNameMapping, TableNameMapping: cdcConfig.TableNameMapping, slot: cdcConfig.Slot, @@ -102,7 +102,7 @@ func NewPostgresCDCSource(cdcConfig *PostgresCDCConfig, customTypeMap map[uint32 }, nil } -func getChildToParentRelIDMap(ctx context.Context, pool *pgxpool.Pool) (map[uint32]uint32, error) { +func getChildToParentRelIDMap(ctx context.Context, conn *pgx.Conn) (map[uint32]uint32, error) { query := ` SELECT parent.oid AS parentrelid, @@ -113,7 +113,7 @@ func getChildToParentRelIDMap(ctx context.Context, pool *pgxpool.Pool) (map[uint WHERE parent.relkind='p'; ` - rows, err := pool.Query(ctx, query, pgx.QueryExecModeSimpleProtocol) + rows, err := conn.Query(ctx, query, pgx.QueryExecModeSimpleProtocol) if err != nil { return nil, fmt.Errorf("error querying for child to parent relid map: %w", err) } @@ -141,15 +141,7 @@ func (p *PostgresCDCSource) PullRecords(req *model.PullRecordsRequest) error { return fmt.Errorf("error getting replication options: %w", err) } - // create replication connection - replicationConn, err := p.replPool.Acquire(p.ctx) - if err != nil { - return fmt.Errorf("error acquiring connection for replication: %w", err) - } - - defer replicationConn.Release() - - pgConn := replicationConn.Conn().PgConn() + pgConn := p.replConn.PgConn() p.logger.Info("created replication connection") // start replication diff --git a/flow/connectors/postgres/client.go b/flow/connectors/postgres/client.go index c6ae8f7d7..f5c08bb8a 100644 --- a/flow/connectors/postgres/client.go +++ b/flow/connectors/postgres/client.go @@ -350,17 +350,11 @@ func (c *PostgresConnector) createSlotAndPublication( // create slot only after we succeeded in creating publication. if !s.SlotExists { - pool, err := c.GetReplPool(c.ctx) + conn, err := c.CreateReplConn(c.ctx) if err != nil { return fmt.Errorf("[slot] error acquiring pool: %w", err) } - - conn, err := pool.Acquire(c.ctx) - if err != nil { - return fmt.Errorf("[slot] error acquiring connection: %w", err) - } - - defer conn.Release() + defer conn.Close(c.ctx) c.logger.Warn(fmt.Sprintf("Creating replication slot '%s'", slot)) @@ -373,7 +367,7 @@ func (c *PostgresConnector) createSlotAndPublication( Temporary: false, Mode: pglogrepl.LogicalReplication, } - res, err := pglogrepl.CreateReplicationSlot(c.ctx, conn.Conn().PgConn(), slot, "pgoutput", opts) + res, err := pglogrepl.CreateReplicationSlot(c.ctx, conn.PgConn(), slot, "pgoutput", opts) if err != nil { return fmt.Errorf("[slot] error creating replication slot: %w", err) } diff --git a/flow/connectors/postgres/postgres.go b/flow/connectors/postgres/postgres.go index e14c05e7d..5a30d9fac 100644 --- a/flow/connectors/postgres/postgres.go +++ b/flow/connectors/postgres/postgres.go @@ -27,9 +27,9 @@ type PostgresConnector struct { connStr string ctx context.Context config *protos.PostgresConfig - pool *SSHWrappedPostgresPool - replConfig *pgxpool.Config - replPool *SSHWrappedPostgresPool + ssh *SSHTunnel + pool *pgxpool.Pool + replConfig *pgx.ConnConfig customTypesMapping map[uint32]string metadataSchema string logger slog.Logger @@ -42,7 +42,7 @@ func NewPostgresConnector(ctx context.Context, pgConfig *protos.PostgresConfig) // create a separate connection pool for non-replication queries as replication connections cannot // be used for extended query protocol, i.e. prepared statements connConfig, err := pgxpool.ParseConfig(connectionString) - replConfig := connConfig.Copy() + replConfig := connConfig.ConnConfig.Copy() if err != nil { return nil, fmt.Errorf("failed to parse connection string: %w", err) } @@ -55,17 +55,21 @@ func NewPostgresConnector(ctx context.Context, pgConfig *protos.PostgresConfig) // set pool size to 3 to avoid connection pool exhaustion connConfig.MaxConns = 3 - pool, err := NewSSHWrappedPostgresPool(ctx, connConfig, pgConfig.SshConfig) + tunnel, err := NewSSHTunnel(ctx, pgConfig.SshConfig) + if err != nil { + return nil, fmt.Errorf("failed to create ssh tunnel: %w", err) + } + + pool, err := tunnel.NewPostgresPoolFromConfig(ctx, connConfig) if err != nil { return nil, fmt.Errorf("failed to create connection pool: %w", err) } // ensure that replication is set to database - replConfig.ConnConfig.RuntimeParams["replication"] = "database" - replConfig.ConnConfig.RuntimeParams["bytea_output"] = "hex" - replConfig.MaxConns = 1 + replConfig.RuntimeParams["replication"] = "database" + replConfig.RuntimeParams["bytea_output"] = "hex" - customTypeMap, err := utils.GetCustomDataTypes(ctx, pool.Pool) + customTypeMap, err := utils.GetCustomDataTypes(ctx, pool) if err != nil { return nil, fmt.Errorf("failed to get custom type map: %w", err) } @@ -81,9 +85,9 @@ func NewPostgresConnector(ctx context.Context, pgConfig *protos.PostgresConfig) connStr: connectionString, ctx: ctx, config: pgConfig, + ssh: tunnel, pool: pool, replConfig: replConfig, - replPool: nil, customTypesMapping: customTypeMap, metadataSchema: metadataSchema, logger: *flowLog, @@ -91,35 +95,26 @@ func NewPostgresConnector(ctx context.Context, pgConfig *protos.PostgresConfig) } // GetPool returns the connection pool. -func (c *PostgresConnector) GetPool() *SSHWrappedPostgresPool { +func (c *PostgresConnector) GetPool() *pgxpool.Pool { return c.pool } -func (c *PostgresConnector) GetReplPool(ctx context.Context) (*SSHWrappedPostgresPool, error) { - if c.replPool != nil { - return c.replPool, nil - } - - pool, err := NewSSHWrappedPostgresPool(ctx, c.replConfig, c.config.SshConfig) +func (c *PostgresConnector) CreateReplConn(ctx context.Context) (*pgx.Conn, error) { + conn, err := c.ssh.NewPostgresConnFromConfig(ctx, c.replConfig) if err != nil { slog.Error("failed to create replication connection pool", slog.Any("error", err)) return nil, fmt.Errorf("failed to create replication connection pool: %w", err) } - c.replPool = pool - return pool, nil + return conn, nil } // Close closes all connections. func (c *PostgresConnector) Close() error { - if c.pool != nil { + if c != nil { c.pool.Close() + c.ssh.Close() } - - if c.replPool != nil { - c.replPool.Close() - } - return nil } @@ -224,14 +219,15 @@ func (c *PostgresConnector) PullRecords(catalogPool *pgxpool.Pool, req *model.Pu c.logger.Info("PullRecords: performed checks for slot and publication") - replPool, err := c.GetReplPool(c.ctx) + replConn, err := c.CreateReplConn(c.ctx) if err != nil { return err } + defer replConn.Close(c.ctx) cdc, err := NewPostgresCDCSource(&PostgresCDCConfig{ AppContext: c.ctx, - Connection: replPool.Pool, + Connection: replConn, SrcTableIDNameMapping: req.SrcTableIDNameMapping, Slot: slotName, Publication: publicationName, diff --git a/flow/connectors/postgres/qrep.go b/flow/connectors/postgres/qrep.go index fc706c920..7efb54465 100644 --- a/flow/connectors/postgres/qrep.go +++ b/flow/connectors/postgres/qrep.go @@ -311,7 +311,7 @@ func (c *PostgresConnector) PullQRepRecords( if partition.FullTablePartition { c.logger.Info("pulling full table partition", partitionIdLog) executor, err := NewQRepQueryExecutorSnapshot( - c.pool.Pool, c.ctx, c.config.TransactionSnapshot, + c.pool, c.ctx, c.config.TransactionSnapshot, config.FlowJobName, partition.PartitionId) if err != nil { return nil, err @@ -355,7 +355,7 @@ func (c *PostgresConnector) PullQRepRecords( } executor, err := NewQRepQueryExecutorSnapshot( - c.pool.Pool, c.ctx, c.config.TransactionSnapshot, + c.pool, c.ctx, c.config.TransactionSnapshot, config.FlowJobName, partition.PartitionId) if err != nil { return nil, err @@ -379,7 +379,7 @@ func (c *PostgresConnector) PullQRepRecordStream( if partition.FullTablePartition { c.logger.Info("pulling full table partition", partitionIdLog) executor, err := NewQRepQueryExecutorSnapshot( - c.pool.Pool, c.ctx, c.config.TransactionSnapshot, + c.pool, c.ctx, c.config.TransactionSnapshot, config.FlowJobName, partition.PartitionId) if err != nil { return 0, err @@ -425,7 +425,7 @@ func (c *PostgresConnector) PullQRepRecordStream( } executor, err := NewQRepQueryExecutorSnapshot( - c.pool.Pool, c.ctx, c.config.TransactionSnapshot, + c.pool, c.ctx, c.config.TransactionSnapshot, config.FlowJobName, partition.PartitionId) if err != nil { return 0, err @@ -524,7 +524,7 @@ func (c *PostgresConnector) PullXminRecordStream( } executor, err := NewQRepQueryExecutorSnapshot( - c.pool.Pool, c.ctx, c.config.TransactionSnapshot, + c.pool, c.ctx, c.config.TransactionSnapshot, config.FlowJobName, partition.PartitionId) if err != nil { return 0, currentSnapshotXmin, err diff --git a/flow/connectors/postgres/qrep_partition_test.go b/flow/connectors/postgres/qrep_partition_test.go index f99b248ec..0d0e0039b 100644 --- a/flow/connectors/postgres/qrep_partition_test.go +++ b/flow/connectors/postgres/qrep_partition_test.go @@ -72,10 +72,17 @@ func TestGetQRepPartitions(t *testing.T) { t.Fatalf("Failed to parse config: %v", err) } - pool, err := NewSSHWrappedPostgresPool(context.Background(), config, nil) + tunnel, err := NewSSHTunnel(context.Background(), nil) + if err != nil { + t.Fatalf("Failed to create tunnel: %v", err) + } + defer tunnel.Close() + + pool, err := tunnel.NewPostgresPoolFromConfig(context.Background(), config) if err != nil { t.Fatalf("Failed to create pool: %v", err) } + defer pool.Close() // Generate a random schema name rndUint, err := shared.RandomUInt64() @@ -103,7 +110,7 @@ func TestGetQRepPartitions(t *testing.T) { } // from 2010 Jan 1 10:00 AM UTC to 2010 Jan 30 10:00 AM UTC - numRows := prepareTestData(t, pool.Pool, schemaName) + numRows := prepareTestData(t, pool, schemaName) // Define the test cases testCases := []*testCase{ diff --git a/flow/connectors/postgres/qrep_sql_sync.go b/flow/connectors/postgres/qrep_sql_sync.go index e75e38a97..84593d766 100644 --- a/flow/connectors/postgres/qrep_sql_sync.go +++ b/flow/connectors/postgres/qrep_sql_sync.go @@ -51,9 +51,9 @@ func (s *QRepStagingTableSync) SyncQRepRecords( return 0, fmt.Errorf("failed to get schema from stream: %w", err) } - txConfig := s.connector.pool.poolConfig.Copy() + txConfig := s.connector.pool.Config() txConfig.AfterConnect = utils.RegisterHStore - txPool, err := pgxpool.NewWithConfig(s.connector.pool.ctx, txConfig) + txPool, err := pgxpool.NewWithConfig(s.connector.ctx, txConfig) if err != nil { return 0, fmt.Errorf("failed to create tx pool: %v", err) } diff --git a/flow/connectors/postgres/ssh_wrapped_pool.go b/flow/connectors/postgres/ssh_wrapped_pool.go index 4dcd2cd0c..9d45dc4ed 100644 --- a/flow/connectors/postgres/ssh_wrapped_pool.go +++ b/flow/connectors/postgres/ssh_wrapped_pool.go @@ -8,6 +8,7 @@ import ( "sync" "time" + "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" "golang.org/x/crypto/ssh" @@ -15,40 +16,19 @@ import ( "github.com/PeerDB-io/peer-flow/generated/protos" ) -type SSHWrappedPostgresPool struct { - *pgxpool.Pool - - poolConfig *pgxpool.Config - sshConfig *ssh.ClientConfig - sshServer string - once sync.Once - sshClient *ssh.Client - ctx context.Context - cancel context.CancelFunc +type SSHTunnel struct { + sshConfig *ssh.ClientConfig + sshServer string + once sync.Once + sshClient *ssh.Client + ctx context.Context + cancel context.CancelFunc } -func NewSSHWrappedPostgresPoolFromConfig( +func NewSSHTunnel( ctx context.Context, - pgConfig *protos.PostgresConfig, -) (*SSHWrappedPostgresPool, error) { - connectionString := utils.GetPGConnectionString(pgConfig) - - connConfig, err := pgxpool.ParseConfig(connectionString) - if err != nil { - return nil, err - } - - // set pool size to 3 to avoid connection pool exhaustion - connConfig.MaxConns = 3 - - return NewSSHWrappedPostgresPool(ctx, connConfig, pgConfig.SshConfig) -} - -func NewSSHWrappedPostgresPool( - ctx context.Context, - poolConfig *pgxpool.Config, sshConfig *protos.SSHConfig, -) (*SSHWrappedPostgresPool, error) { +) (*SSHTunnel, error) { swCtx, cancel := context.WithCancel(ctx) var sshServer string @@ -65,12 +45,11 @@ func NewSSHWrappedPostgresPool( } } - pool := &SSHWrappedPostgresPool{ - poolConfig: poolConfig, - sshConfig: clientConfig, - sshServer: sshServer, - ctx: swCtx, - cancel: cancel, + pool := &SSHTunnel{ + sshConfig: clientConfig, + sshServer: sshServer, + ctx: swCtx, + cancel: cancel, } err := pool.connect() @@ -81,72 +60,146 @@ func NewSSHWrappedPostgresPool( return pool, nil } -func (swpp *SSHWrappedPostgresPool) connect() error { +func (tunnel *SSHTunnel) connect() error { var err error - swpp.once.Do(func() { - err = swpp.setupSSH() - if err != nil { - return - } - - swpp.Pool, err = pgxpool.NewWithConfig(swpp.ctx, swpp.poolConfig) - if err != nil { - slog.Error("Failed to create pool:", slog.Any("error", err)) - return - } - - host := swpp.poolConfig.ConnConfig.Host - err = retryWithBackoff(func() error { - err = swpp.Ping(swpp.ctx) - if err != nil { - slog.Error("Failed to ping pool", slog.Any("error", err), slog.String("host", host)) - return err - } - return nil - }, 5, 5*time.Second) - - if err != nil { - slog.Error("Failed to create pool", slog.Any("error", err), slog.String("host", host)) - } + tunnel.once.Do(func() { + err = tunnel.setupSSH() }) return err } -func (swpp *SSHWrappedPostgresPool) setupSSH() error { - if swpp.sshConfig == nil { +func (tunnel *SSHTunnel) setupSSH() error { + if tunnel.sshConfig == nil { return nil } - slog.Info("Setting up SSH connection to " + swpp.sshServer) + slog.Info("Setting up SSH connection to " + tunnel.sshServer) var err error - swpp.sshClient, err = ssh.Dial("tcp", swpp.sshServer, swpp.sshConfig) + tunnel.sshClient, err = ssh.Dial("tcp", tunnel.sshServer, tunnel.sshConfig) if err != nil { return err } - swpp.poolConfig.ConnConfig.DialFunc = func(ctx context.Context, network, addr string) (net.Conn, error) { - conn, err := swpp.sshClient.Dial(network, addr) + return nil +} + +func (tunnel *SSHTunnel) Close() { + tunnel.cancel() + + if tunnel.sshClient != nil { + tunnel.sshClient.Close() + } +} + +func (tunnel *SSHTunnel) NewPostgresPoolFromPostgresConfig( + ctx context.Context, + pgConfig *protos.PostgresConfig, +) (*pgxpool.Pool, error) { + connectionString := utils.GetPGConnectionString(pgConfig) + + poolConfig, err := pgxpool.ParseConfig(connectionString) + if err != nil { + return nil, err + } + + return tunnel.NewPostgresPoolFromConfig(ctx, poolConfig) +} + +func (tunnel *SSHTunnel) NewPostgresPoolFromConfig( + ctx context.Context, + poolConfig *pgxpool.Config, +) (*pgxpool.Pool, error) { + // set pool size to 3 to avoid connection pool exhaustion + poolConfig.MaxConns = 3 + + if tunnel.sshClient != nil { + poolConfig.ConnConfig.DialFunc = func(ctx context.Context, network, addr string) (net.Conn, error) { + conn, err := tunnel.sshClient.Dial(network, addr) + if err != nil { + return nil, err + } + return &noDeadlineConn{Conn: conn}, nil + } + } + + pool, err := pgxpool.NewWithConfig(tunnel.ctx, poolConfig) + if err != nil { + slog.Error("Failed to create pool:", slog.Any("error", err)) + return nil, err + } + + host := poolConfig.ConnConfig.Host + err = retryWithBackoff(func() error { + err = pool.Ping(tunnel.ctx) if err != nil { - return nil, err + slog.Error("Failed to ping pool", slog.Any("error", err), slog.String("host", host)) + return err } - return &noDeadlineConn{Conn: conn}, nil + return nil + }, 5, 5*time.Second) + + if err != nil { + slog.Error("Failed to create pool", slog.Any("error", err), slog.String("host", host)) + pool.Close() + return nil, err } - return nil + return pool, nil +} + +func (tunnel *SSHTunnel) NewPostgresConnFromPostgresConfig( + ctx context.Context, + pgConfig *protos.PostgresConfig, +) (*pgx.Conn, error) { + connectionString := utils.GetPGConnectionString(pgConfig) + + connConfig, err := pgx.ParseConfig(connectionString) + if err != nil { + return nil, err + } + + return tunnel.NewPostgresConnFromConfig(ctx, connConfig) } -func (swpp *SSHWrappedPostgresPool) Close() { - swpp.cancel() +func (tunnel *SSHTunnel) NewPostgresConnFromConfig( + ctx context.Context, + connConfig *pgx.ConnConfig, +) (*pgx.Conn, error) { + if tunnel.sshClient != nil { + connConfig.DialFunc = func(ctx context.Context, network, addr string) (net.Conn, error) { + conn, err := tunnel.sshClient.Dial(network, addr) + if err != nil { + return nil, err + } + return &noDeadlineConn{Conn: conn}, nil + } + } - if swpp.Pool != nil { - swpp.Pool.Close() + conn, err := pgx.ConnectConfig(tunnel.ctx, connConfig) + if err != nil { + slog.Error("Failed to create pool:", slog.Any("error", err)) + return nil, err } - if swpp.sshClient != nil { - swpp.sshClient.Close() + host := connConfig.Host + err = retryWithBackoff(func() error { + err = conn.Ping(tunnel.ctx) + if err != nil { + slog.Error("Failed to ping pool", slog.Any("error", err), slog.String("host", host)) + return err + } + return nil + }, 5, 5*time.Second) + + if err != nil { + slog.Error("Failed to create pool", slog.Any("error", err), slog.String("host", host)) + conn.Close(ctx) + return nil, err } + + return conn, nil } type retryFunc func() error