Skip to content

Commit

Permalink
Don't wrap pg connections with ssh tunnel (#1165)
Browse files Browse the repository at this point in the history
Instead, they should coexist
This way multiple pools/connections can use one tunnel

Also replace replPool, with its `MaxConns = 1`, with a connection
  • Loading branch information
serprex authored Jan 29, 2024
1 parent 719e5ff commit 275a331
Show file tree
Hide file tree
Showing 8 changed files with 210 additions and 164 deletions.
60 changes: 32 additions & 28 deletions flow/cmd/peer_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,32 +31,39 @@ func (h *FlowRequestHandler) getPGPeerConfig(ctx context.Context, peerName strin
return &pgPeerConfig, nil
}

func (h *FlowRequestHandler) getPoolForPGPeer(ctx context.Context, peerName string) (*connpostgres.SSHWrappedPostgresPool, error) {
func (h *FlowRequestHandler) getConnForPGPeer(ctx context.Context, peerName string) (*connpostgres.SSHTunnel, *pgx.Conn, error) {
pgPeerConfig, err := h.getPGPeerConfig(ctx, peerName)
if err != nil {
return nil, err
return nil, nil, err
}

pool, err := connpostgres.NewSSHWrappedPostgresPoolFromConfig(ctx, pgPeerConfig)
tunnel, err := connpostgres.NewSSHTunnel(ctx, pgPeerConfig.SshConfig)
if err != nil {
slog.Error("Failed to create postgres pool", slog.Any("error", err))
return nil, err
return nil, nil, err
}

return pool, nil
conn, err := tunnel.NewPostgresConnFromPostgresConfig(ctx, pgPeerConfig)
if err != nil {
tunnel.Close()
return nil, nil, err
}

return tunnel, conn, nil
}

func (h *FlowRequestHandler) GetSchemas(
ctx context.Context,
req *protos.PostgresPeerActivityInfoRequest,
) (*protos.PeerSchemasResponse, error) {
peerPool, err := h.getPoolForPGPeer(ctx, req.PeerName)
tunnel, peerConn, err := h.getConnForPGPeer(ctx, req.PeerName)
if err != nil {
return &protos.PeerSchemasResponse{Schemas: nil}, err
}
defer tunnel.Close()
defer peerConn.Close(ctx)

defer peerPool.Close()
rows, err := peerPool.Query(ctx, "SELECT schema_name"+
rows, err := peerConn.Query(ctx, "SELECT schema_name"+
" FROM information_schema.schemata WHERE schema_name !~ '^pg_' AND schema_name <> 'information_schema';")
if err != nil {
return &protos.PeerSchemasResponse{Schemas: nil}, err
Expand All @@ -73,13 +80,14 @@ func (h *FlowRequestHandler) GetTablesInSchema(
ctx context.Context,
req *protos.SchemaTablesRequest,
) (*protos.SchemaTablesResponse, error) {
peerPool, err := h.getPoolForPGPeer(ctx, req.PeerName)
tunnel, peerConn, err := h.getConnForPGPeer(ctx, req.PeerName)
if err != nil {
return &protos.SchemaTablesResponse{Tables: nil}, err
}
defer tunnel.Close()
defer peerConn.Close(ctx)

defer peerPool.Close()
rows, err := peerPool.Query(ctx, `SELECT DISTINCT ON (t.relname)
rows, err := peerConn.Query(ctx, `SELECT DISTINCT ON (t.relname)
t.relname,
CASE
WHEN con.contype = 'p' OR t.relreplident = 'i' OR t.relreplident = 'f' THEN true
Expand Down Expand Up @@ -130,13 +138,14 @@ func (h *FlowRequestHandler) GetAllTables(
ctx context.Context,
req *protos.PostgresPeerActivityInfoRequest,
) (*protos.AllTablesResponse, error) {
peerPool, err := h.getPoolForPGPeer(ctx, req.PeerName)
tunnel, peerConn, err := h.getConnForPGPeer(ctx, req.PeerName)
if err != nil {
return &protos.AllTablesResponse{Tables: nil}, err
}
defer tunnel.Close()
defer peerConn.Close(ctx)

defer peerPool.Close()
rows, err := peerPool.Query(ctx, "SELECT table_schema || '.' || table_name AS schema_table "+
rows, err := peerConn.Query(ctx, "SELECT table_schema || '.' || table_name AS schema_table "+
"FROM information_schema.tables WHERE table_schema !~ '^pg_' AND table_schema <> 'information_schema'")
if err != nil {
return &protos.AllTablesResponse{Tables: nil}, err
Expand All @@ -160,13 +169,14 @@ func (h *FlowRequestHandler) GetColumns(
ctx context.Context,
req *protos.TableColumnsRequest,
) (*protos.TableColumnsResponse, error) {
peerPool, err := h.getPoolForPGPeer(ctx, req.PeerName)
tunnel, peerConn, err := h.getConnForPGPeer(ctx, req.PeerName)
if err != nil {
return &protos.TableColumnsResponse{Columns: nil}, err
}
defer tunnel.Close()
defer peerConn.Close(ctx)

defer peerPool.Close()
rows, err := peerPool.Query(ctx, `
rows, err := peerConn.Query(ctx, `
SELECT
cols.column_name,
cols.data_type,
Expand Down Expand Up @@ -240,22 +250,16 @@ func (h *FlowRequestHandler) GetStatInfo(
ctx context.Context,
req *protos.PostgresPeerActivityInfoRequest,
) (*protos.PeerStatResponse, error) {
pgConfig, err := h.getPGPeerConfig(ctx, req.PeerName)
tunnel, peerConn, err := h.getConnForPGPeer(ctx, req.PeerName)
if err != nil {
return &protos.PeerStatResponse{StatData: nil}, err
}
defer tunnel.Close()
defer peerConn.Close(ctx)

pgConnector, err := connpostgres.NewPostgresConnector(ctx, pgConfig)
if err != nil {
slog.Error("Failed to create postgres connector", slog.Any("error", err))
return &protos.PeerStatResponse{StatData: nil}, err
}
defer pgConnector.Close()

peerPool := pgConnector.GetPool()
peerUser := pgConfig.User
peerUser := peerConn.Config().User

rows, err := peerPool.Query(ctx, "SELECT pid, wait_event, wait_event_type, query_start::text, query,"+
rows, err := peerConn.Query(ctx, "SELECT pid, wait_event, wait_event_type, query_start::text, query,"+
"EXTRACT(epoch FROM(now()-query_start)) AS dur"+
" FROM pg_stat_activity WHERE "+
"usename=$1 AND state != 'idle';", peerUser)
Expand Down
20 changes: 6 additions & 14 deletions flow/connectors/postgres/cdc.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ const maxRetriesForWalSegmentRemoved = 5

type PostgresCDCSource struct {
ctx context.Context
replPool *pgxpool.Pool
replConn *pgx.Conn
SrcTableIDNameMapping map[uint32]string
TableNameMapping map[string]model.NameAndExclude
slot string
Expand All @@ -54,7 +54,7 @@ type PostgresCDCSource struct {

type PostgresCDCConfig struct {
AppContext context.Context
Connection *pgxpool.Pool
Connection *pgx.Conn
Slot string
Publication string
SrcTableIDNameMapping map[uint32]string
Expand Down Expand Up @@ -84,7 +84,7 @@ func NewPostgresCDCSource(cdcConfig *PostgresCDCConfig, customTypeMap map[uint32
flowName, _ := cdcConfig.AppContext.Value(shared.FlowNameKey).(string)
return &PostgresCDCSource{
ctx: cdcConfig.AppContext,
replPool: cdcConfig.Connection,
replConn: cdcConfig.Connection,
SrcTableIDNameMapping: cdcConfig.SrcTableIDNameMapping,
TableNameMapping: cdcConfig.TableNameMapping,
slot: cdcConfig.Slot,
Expand All @@ -102,7 +102,7 @@ func NewPostgresCDCSource(cdcConfig *PostgresCDCConfig, customTypeMap map[uint32
}, nil
}

func getChildToParentRelIDMap(ctx context.Context, pool *pgxpool.Pool) (map[uint32]uint32, error) {
func getChildToParentRelIDMap(ctx context.Context, conn *pgx.Conn) (map[uint32]uint32, error) {
query := `
SELECT
parent.oid AS parentrelid,
Expand All @@ -113,7 +113,7 @@ func getChildToParentRelIDMap(ctx context.Context, pool *pgxpool.Pool) (map[uint
WHERE parent.relkind='p';
`

rows, err := pool.Query(ctx, query, pgx.QueryExecModeSimpleProtocol)
rows, err := conn.Query(ctx, query, pgx.QueryExecModeSimpleProtocol)
if err != nil {
return nil, fmt.Errorf("error querying for child to parent relid map: %w", err)
}
Expand Down Expand Up @@ -141,15 +141,7 @@ func (p *PostgresCDCSource) PullRecords(req *model.PullRecordsRequest) error {
return fmt.Errorf("error getting replication options: %w", err)
}

// create replication connection
replicationConn, err := p.replPool.Acquire(p.ctx)
if err != nil {
return fmt.Errorf("error acquiring connection for replication: %w", err)
}

defer replicationConn.Release()

pgConn := replicationConn.Conn().PgConn()
pgConn := p.replConn.PgConn()
p.logger.Info("created replication connection")

// start replication
Expand Down
12 changes: 3 additions & 9 deletions flow/connectors/postgres/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -350,17 +350,11 @@ func (c *PostgresConnector) createSlotAndPublication(

// create slot only after we succeeded in creating publication.
if !s.SlotExists {
pool, err := c.GetReplPool(c.ctx)
conn, err := c.CreateReplConn(c.ctx)
if err != nil {
return fmt.Errorf("[slot] error acquiring pool: %w", err)
}

conn, err := pool.Acquire(c.ctx)
if err != nil {
return fmt.Errorf("[slot] error acquiring connection: %w", err)
}

defer conn.Release()
defer conn.Close(c.ctx)

c.logger.Warn(fmt.Sprintf("Creating replication slot '%s'", slot))

Expand All @@ -373,7 +367,7 @@ func (c *PostgresConnector) createSlotAndPublication(
Temporary: false,
Mode: pglogrepl.LogicalReplication,
}
res, err := pglogrepl.CreateReplicationSlot(c.ctx, conn.Conn().PgConn(), slot, "pgoutput", opts)
res, err := pglogrepl.CreateReplicationSlot(c.ctx, conn.PgConn(), slot, "pgoutput", opts)
if err != nil {
return fmt.Errorf("[slot] error creating replication slot: %w", err)
}
Expand Down
50 changes: 23 additions & 27 deletions flow/connectors/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ type PostgresConnector struct {
connStr string
ctx context.Context
config *protos.PostgresConfig
pool *SSHWrappedPostgresPool
replConfig *pgxpool.Config
replPool *SSHWrappedPostgresPool
ssh *SSHTunnel
pool *pgxpool.Pool
replConfig *pgx.ConnConfig
customTypesMapping map[uint32]string
metadataSchema string
logger slog.Logger
Expand All @@ -42,7 +42,7 @@ func NewPostgresConnector(ctx context.Context, pgConfig *protos.PostgresConfig)
// create a separate connection pool for non-replication queries as replication connections cannot
// be used for extended query protocol, i.e. prepared statements
connConfig, err := pgxpool.ParseConfig(connectionString)
replConfig := connConfig.Copy()
replConfig := connConfig.ConnConfig.Copy()
if err != nil {
return nil, fmt.Errorf("failed to parse connection string: %w", err)
}
Expand All @@ -55,17 +55,21 @@ func NewPostgresConnector(ctx context.Context, pgConfig *protos.PostgresConfig)
// set pool size to 3 to avoid connection pool exhaustion
connConfig.MaxConns = 3

pool, err := NewSSHWrappedPostgresPool(ctx, connConfig, pgConfig.SshConfig)
tunnel, err := NewSSHTunnel(ctx, pgConfig.SshConfig)
if err != nil {
return nil, fmt.Errorf("failed to create ssh tunnel: %w", err)
}

pool, err := tunnel.NewPostgresPoolFromConfig(ctx, connConfig)
if err != nil {
return nil, fmt.Errorf("failed to create connection pool: %w", err)
}

// ensure that replication is set to database
replConfig.ConnConfig.RuntimeParams["replication"] = "database"
replConfig.ConnConfig.RuntimeParams["bytea_output"] = "hex"
replConfig.MaxConns = 1
replConfig.RuntimeParams["replication"] = "database"
replConfig.RuntimeParams["bytea_output"] = "hex"

customTypeMap, err := utils.GetCustomDataTypes(ctx, pool.Pool)
customTypeMap, err := utils.GetCustomDataTypes(ctx, pool)
if err != nil {
return nil, fmt.Errorf("failed to get custom type map: %w", err)
}
Expand All @@ -81,45 +85,36 @@ func NewPostgresConnector(ctx context.Context, pgConfig *protos.PostgresConfig)
connStr: connectionString,
ctx: ctx,
config: pgConfig,
ssh: tunnel,
pool: pool,
replConfig: replConfig,
replPool: nil,
customTypesMapping: customTypeMap,
metadataSchema: metadataSchema,
logger: *flowLog,
}, nil
}

// GetPool returns the connection pool.
func (c *PostgresConnector) GetPool() *SSHWrappedPostgresPool {
func (c *PostgresConnector) GetPool() *pgxpool.Pool {
return c.pool
}

func (c *PostgresConnector) GetReplPool(ctx context.Context) (*SSHWrappedPostgresPool, error) {
if c.replPool != nil {
return c.replPool, nil
}

pool, err := NewSSHWrappedPostgresPool(ctx, c.replConfig, c.config.SshConfig)
func (c *PostgresConnector) CreateReplConn(ctx context.Context) (*pgx.Conn, error) {
conn, err := c.ssh.NewPostgresConnFromConfig(ctx, c.replConfig)
if err != nil {
slog.Error("failed to create replication connection pool", slog.Any("error", err))
return nil, fmt.Errorf("failed to create replication connection pool: %w", err)
}

c.replPool = pool
return pool, nil
return conn, nil
}

// Close closes all connections.
func (c *PostgresConnector) Close() error {
if c.pool != nil {
if c != nil {
c.pool.Close()
c.ssh.Close()
}

if c.replPool != nil {
c.replPool.Close()
}

return nil
}

Expand Down Expand Up @@ -224,14 +219,15 @@ func (c *PostgresConnector) PullRecords(catalogPool *pgxpool.Pool, req *model.Pu

c.logger.Info("PullRecords: performed checks for slot and publication")

replPool, err := c.GetReplPool(c.ctx)
replConn, err := c.CreateReplConn(c.ctx)
if err != nil {
return err
}
defer replConn.Close(c.ctx)

cdc, err := NewPostgresCDCSource(&PostgresCDCConfig{
AppContext: c.ctx,
Connection: replPool.Pool,
Connection: replConn,
SrcTableIDNameMapping: req.SrcTableIDNameMapping,
Slot: slotName,
Publication: publicationName,
Expand Down
10 changes: 5 additions & 5 deletions flow/connectors/postgres/qrep.go
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ func (c *PostgresConnector) PullQRepRecords(
if partition.FullTablePartition {
c.logger.Info("pulling full table partition", partitionIdLog)
executor, err := NewQRepQueryExecutorSnapshot(
c.pool.Pool, c.ctx, c.config.TransactionSnapshot,
c.pool, c.ctx, c.config.TransactionSnapshot,
config.FlowJobName, partition.PartitionId)
if err != nil {
return nil, err
Expand Down Expand Up @@ -355,7 +355,7 @@ func (c *PostgresConnector) PullQRepRecords(
}

executor, err := NewQRepQueryExecutorSnapshot(
c.pool.Pool, c.ctx, c.config.TransactionSnapshot,
c.pool, c.ctx, c.config.TransactionSnapshot,
config.FlowJobName, partition.PartitionId)
if err != nil {
return nil, err
Expand All @@ -379,7 +379,7 @@ func (c *PostgresConnector) PullQRepRecordStream(
if partition.FullTablePartition {
c.logger.Info("pulling full table partition", partitionIdLog)
executor, err := NewQRepQueryExecutorSnapshot(
c.pool.Pool, c.ctx, c.config.TransactionSnapshot,
c.pool, c.ctx, c.config.TransactionSnapshot,
config.FlowJobName, partition.PartitionId)
if err != nil {
return 0, err
Expand Down Expand Up @@ -425,7 +425,7 @@ func (c *PostgresConnector) PullQRepRecordStream(
}

executor, err := NewQRepQueryExecutorSnapshot(
c.pool.Pool, c.ctx, c.config.TransactionSnapshot,
c.pool, c.ctx, c.config.TransactionSnapshot,
config.FlowJobName, partition.PartitionId)
if err != nil {
return 0, err
Expand Down Expand Up @@ -524,7 +524,7 @@ func (c *PostgresConnector) PullXminRecordStream(
}

executor, err := NewQRepQueryExecutorSnapshot(
c.pool.Pool, c.ctx, c.config.TransactionSnapshot,
c.pool, c.ctx, c.config.TransactionSnapshot,
config.FlowJobName, partition.PartitionId)
if err != nil {
return 0, currentSnapshotXmin, err
Expand Down
Loading

0 comments on commit 275a331

Please sign in to comment.