Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't wrap pg connections with ssh tunnel #1165

Merged
merged 1 commit into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 32 additions & 28 deletions flow/cmd/peer_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,32 +31,39 @@ func (h *FlowRequestHandler) getPGPeerConfig(ctx context.Context, peerName strin
return &pgPeerConfig, nil
}

func (h *FlowRequestHandler) getPoolForPGPeer(ctx context.Context, peerName string) (*connpostgres.SSHWrappedPostgresPool, error) {
func (h *FlowRequestHandler) getConnForPGPeer(ctx context.Context, peerName string) (*connpostgres.SSHTunnel, *pgx.Conn, error) {
pgPeerConfig, err := h.getPGPeerConfig(ctx, peerName)
if err != nil {
return nil, err
return nil, nil, err
}

pool, err := connpostgres.NewSSHWrappedPostgresPoolFromConfig(ctx, pgPeerConfig)
tunnel, err := connpostgres.NewSSHTunnel(ctx, pgPeerConfig.SshConfig)
if err != nil {
slog.Error("Failed to create postgres pool", slog.Any("error", err))
return nil, err
return nil, nil, err
}

return pool, nil
conn, err := tunnel.NewPostgresConnFromPostgresConfig(ctx, pgPeerConfig)
if err != nil {
tunnel.Close()
return nil, nil, err
}

return tunnel, conn, nil
}

func (h *FlowRequestHandler) GetSchemas(
ctx context.Context,
req *protos.PostgresPeerActivityInfoRequest,
) (*protos.PeerSchemasResponse, error) {
peerPool, err := h.getPoolForPGPeer(ctx, req.PeerName)
tunnel, peerConn, err := h.getConnForPGPeer(ctx, req.PeerName)
if err != nil {
return &protos.PeerSchemasResponse{Schemas: nil}, err
}
defer tunnel.Close()
defer peerConn.Close(ctx)

defer peerPool.Close()
rows, err := peerPool.Query(ctx, "SELECT schema_name"+
rows, err := peerConn.Query(ctx, "SELECT schema_name"+
" FROM information_schema.schemata WHERE schema_name !~ '^pg_' AND schema_name <> 'information_schema';")
if err != nil {
return &protos.PeerSchemasResponse{Schemas: nil}, err
Expand All @@ -73,13 +80,14 @@ func (h *FlowRequestHandler) GetTablesInSchema(
ctx context.Context,
req *protos.SchemaTablesRequest,
) (*protos.SchemaTablesResponse, error) {
peerPool, err := h.getPoolForPGPeer(ctx, req.PeerName)
tunnel, peerConn, err := h.getConnForPGPeer(ctx, req.PeerName)
if err != nil {
return &protos.SchemaTablesResponse{Tables: nil}, err
}
defer tunnel.Close()
defer peerConn.Close(ctx)

defer peerPool.Close()
rows, err := peerPool.Query(ctx, `SELECT DISTINCT ON (t.relname)
rows, err := peerConn.Query(ctx, `SELECT DISTINCT ON (t.relname)
t.relname,
CASE
WHEN con.contype = 'p' OR t.relreplident = 'i' OR t.relreplident = 'f' THEN true
Expand Down Expand Up @@ -130,13 +138,14 @@ func (h *FlowRequestHandler) GetAllTables(
ctx context.Context,
req *protos.PostgresPeerActivityInfoRequest,
) (*protos.AllTablesResponse, error) {
peerPool, err := h.getPoolForPGPeer(ctx, req.PeerName)
tunnel, peerConn, err := h.getConnForPGPeer(ctx, req.PeerName)
if err != nil {
return &protos.AllTablesResponse{Tables: nil}, err
}
defer tunnel.Close()
defer peerConn.Close(ctx)

defer peerPool.Close()
rows, err := peerPool.Query(ctx, "SELECT table_schema || '.' || table_name AS schema_table "+
rows, err := peerConn.Query(ctx, "SELECT table_schema || '.' || table_name AS schema_table "+
"FROM information_schema.tables WHERE table_schema !~ '^pg_' AND table_schema <> 'information_schema'")
if err != nil {
return &protos.AllTablesResponse{Tables: nil}, err
Expand All @@ -160,13 +169,14 @@ func (h *FlowRequestHandler) GetColumns(
ctx context.Context,
req *protos.TableColumnsRequest,
) (*protos.TableColumnsResponse, error) {
peerPool, err := h.getPoolForPGPeer(ctx, req.PeerName)
tunnel, peerConn, err := h.getConnForPGPeer(ctx, req.PeerName)
if err != nil {
return &protos.TableColumnsResponse{Columns: nil}, err
}
defer tunnel.Close()
defer peerConn.Close(ctx)

defer peerPool.Close()
rows, err := peerPool.Query(ctx, `
rows, err := peerConn.Query(ctx, `
SELECT
cols.column_name,
cols.data_type,
Expand Down Expand Up @@ -240,22 +250,16 @@ func (h *FlowRequestHandler) GetStatInfo(
ctx context.Context,
req *protos.PostgresPeerActivityInfoRequest,
) (*protos.PeerStatResponse, error) {
pgConfig, err := h.getPGPeerConfig(ctx, req.PeerName)
tunnel, peerConn, err := h.getConnForPGPeer(ctx, req.PeerName)
if err != nil {
return &protos.PeerStatResponse{StatData: nil}, err
}
defer tunnel.Close()
defer peerConn.Close(ctx)

pgConnector, err := connpostgres.NewPostgresConnector(ctx, pgConfig)
if err != nil {
slog.Error("Failed to create postgres connector", slog.Any("error", err))
return &protos.PeerStatResponse{StatData: nil}, err
}
defer pgConnector.Close()

peerPool := pgConnector.GetPool()
peerUser := pgConfig.User
peerUser := peerConn.Config().User

rows, err := peerPool.Query(ctx, "SELECT pid, wait_event, wait_event_type, query_start::text, query,"+
rows, err := peerConn.Query(ctx, "SELECT pid, wait_event, wait_event_type, query_start::text, query,"+
"EXTRACT(epoch FROM(now()-query_start)) AS dur"+
" FROM pg_stat_activity WHERE "+
"usename=$1 AND state != 'idle';", peerUser)
Expand Down
20 changes: 6 additions & 14 deletions flow/connectors/postgres/cdc.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ const maxRetriesForWalSegmentRemoved = 5

type PostgresCDCSource struct {
ctx context.Context
replPool *pgxpool.Pool
replConn *pgx.Conn
SrcTableIDNameMapping map[uint32]string
TableNameMapping map[string]model.NameAndExclude
slot string
Expand All @@ -54,7 +54,7 @@ type PostgresCDCSource struct {

type PostgresCDCConfig struct {
AppContext context.Context
Connection *pgxpool.Pool
Connection *pgx.Conn
Slot string
Publication string
SrcTableIDNameMapping map[uint32]string
Expand Down Expand Up @@ -84,7 +84,7 @@ func NewPostgresCDCSource(cdcConfig *PostgresCDCConfig, customTypeMap map[uint32
flowName, _ := cdcConfig.AppContext.Value(shared.FlowNameKey).(string)
return &PostgresCDCSource{
ctx: cdcConfig.AppContext,
replPool: cdcConfig.Connection,
replConn: cdcConfig.Connection,
SrcTableIDNameMapping: cdcConfig.SrcTableIDNameMapping,
TableNameMapping: cdcConfig.TableNameMapping,
slot: cdcConfig.Slot,
Expand All @@ -102,7 +102,7 @@ func NewPostgresCDCSource(cdcConfig *PostgresCDCConfig, customTypeMap map[uint32
}, nil
}

func getChildToParentRelIDMap(ctx context.Context, pool *pgxpool.Pool) (map[uint32]uint32, error) {
func getChildToParentRelIDMap(ctx context.Context, conn *pgx.Conn) (map[uint32]uint32, error) {
query := `
SELECT
parent.oid AS parentrelid,
Expand All @@ -113,7 +113,7 @@ func getChildToParentRelIDMap(ctx context.Context, pool *pgxpool.Pool) (map[uint
WHERE parent.relkind='p';
`

rows, err := pool.Query(ctx, query, pgx.QueryExecModeSimpleProtocol)
rows, err := conn.Query(ctx, query, pgx.QueryExecModeSimpleProtocol)
if err != nil {
return nil, fmt.Errorf("error querying for child to parent relid map: %w", err)
}
Expand Down Expand Up @@ -141,15 +141,7 @@ func (p *PostgresCDCSource) PullRecords(req *model.PullRecordsRequest) error {
return fmt.Errorf("error getting replication options: %w", err)
}

// create replication connection
replicationConn, err := p.replPool.Acquire(p.ctx)
if err != nil {
return fmt.Errorf("error acquiring connection for replication: %w", err)
}

defer replicationConn.Release()

pgConn := replicationConn.Conn().PgConn()
pgConn := p.replConn.PgConn()
p.logger.Info("created replication connection")

// start replication
Expand Down
12 changes: 3 additions & 9 deletions flow/connectors/postgres/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -350,17 +350,11 @@ func (c *PostgresConnector) createSlotAndPublication(

// create slot only after we succeeded in creating publication.
if !s.SlotExists {
pool, err := c.GetReplPool(c.ctx)
conn, err := c.CreateReplConn(c.ctx)
if err != nil {
return fmt.Errorf("[slot] error acquiring pool: %w", err)
}

conn, err := pool.Acquire(c.ctx)
if err != nil {
return fmt.Errorf("[slot] error acquiring connection: %w", err)
}

defer conn.Release()
defer conn.Close(c.ctx)

c.logger.Warn(fmt.Sprintf("Creating replication slot '%s'", slot))

Expand All @@ -373,7 +367,7 @@ func (c *PostgresConnector) createSlotAndPublication(
Temporary: false,
Mode: pglogrepl.LogicalReplication,
}
res, err := pglogrepl.CreateReplicationSlot(c.ctx, conn.Conn().PgConn(), slot, "pgoutput", opts)
res, err := pglogrepl.CreateReplicationSlot(c.ctx, conn.PgConn(), slot, "pgoutput", opts)
if err != nil {
return fmt.Errorf("[slot] error creating replication slot: %w", err)
}
Expand Down
50 changes: 23 additions & 27 deletions flow/connectors/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ type PostgresConnector struct {
connStr string
ctx context.Context
config *protos.PostgresConfig
pool *SSHWrappedPostgresPool
replConfig *pgxpool.Config
replPool *SSHWrappedPostgresPool
ssh *SSHTunnel
pool *pgxpool.Pool
replConfig *pgx.ConnConfig
customTypesMapping map[uint32]string
metadataSchema string
logger slog.Logger
Expand All @@ -42,7 +42,7 @@ func NewPostgresConnector(ctx context.Context, pgConfig *protos.PostgresConfig)
// create a separate connection pool for non-replication queries as replication connections cannot
// be used for extended query protocol, i.e. prepared statements
connConfig, err := pgxpool.ParseConfig(connectionString)
replConfig := connConfig.Copy()
replConfig := connConfig.ConnConfig.Copy()
if err != nil {
return nil, fmt.Errorf("failed to parse connection string: %w", err)
}
Expand All @@ -55,17 +55,21 @@ func NewPostgresConnector(ctx context.Context, pgConfig *protos.PostgresConfig)
// set pool size to 3 to avoid connection pool exhaustion
connConfig.MaxConns = 3

pool, err := NewSSHWrappedPostgresPool(ctx, connConfig, pgConfig.SshConfig)
tunnel, err := NewSSHTunnel(ctx, pgConfig.SshConfig)
if err != nil {
return nil, fmt.Errorf("failed to create ssh tunnel: %w", err)
}

pool, err := tunnel.NewPostgresPoolFromConfig(ctx, connConfig)
if err != nil {
return nil, fmt.Errorf("failed to create connection pool: %w", err)
}

// ensure that replication is set to database
replConfig.ConnConfig.RuntimeParams["replication"] = "database"
replConfig.ConnConfig.RuntimeParams["bytea_output"] = "hex"
replConfig.MaxConns = 1
replConfig.RuntimeParams["replication"] = "database"
replConfig.RuntimeParams["bytea_output"] = "hex"

customTypeMap, err := utils.GetCustomDataTypes(ctx, pool.Pool)
customTypeMap, err := utils.GetCustomDataTypes(ctx, pool)
if err != nil {
return nil, fmt.Errorf("failed to get custom type map: %w", err)
}
Expand All @@ -81,45 +85,36 @@ func NewPostgresConnector(ctx context.Context, pgConfig *protos.PostgresConfig)
connStr: connectionString,
ctx: ctx,
config: pgConfig,
ssh: tunnel,
pool: pool,
replConfig: replConfig,
replPool: nil,
customTypesMapping: customTypeMap,
metadataSchema: metadataSchema,
logger: *flowLog,
}, nil
}

// GetPool returns the connection pool.
func (c *PostgresConnector) GetPool() *SSHWrappedPostgresPool {
func (c *PostgresConnector) GetPool() *pgxpool.Pool {
return c.pool
}

func (c *PostgresConnector) GetReplPool(ctx context.Context) (*SSHWrappedPostgresPool, error) {
if c.replPool != nil {
return c.replPool, nil
}

pool, err := NewSSHWrappedPostgresPool(ctx, c.replConfig, c.config.SshConfig)
func (c *PostgresConnector) CreateReplConn(ctx context.Context) (*pgx.Conn, error) {
conn, err := c.ssh.NewPostgresConnFromConfig(ctx, c.replConfig)
if err != nil {
slog.Error("failed to create replication connection pool", slog.Any("error", err))
return nil, fmt.Errorf("failed to create replication connection pool: %w", err)
}

c.replPool = pool
return pool, nil
return conn, nil
}

// Close closes all connections.
func (c *PostgresConnector) Close() error {
if c.pool != nil {
if c != nil {
c.pool.Close()
c.ssh.Close()
}

if c.replPool != nil {
c.replPool.Close()
}

return nil
}

Expand Down Expand Up @@ -224,14 +219,15 @@ func (c *PostgresConnector) PullRecords(catalogPool *pgxpool.Pool, req *model.Pu

c.logger.Info("PullRecords: performed checks for slot and publication")

replPool, err := c.GetReplPool(c.ctx)
replConn, err := c.CreateReplConn(c.ctx)
if err != nil {
return err
}
defer replConn.Close(c.ctx)

cdc, err := NewPostgresCDCSource(&PostgresCDCConfig{
AppContext: c.ctx,
Connection: replPool.Pool,
Connection: replConn,
SrcTableIDNameMapping: req.SrcTableIDNameMapping,
Slot: slotName,
Publication: publicationName,
Expand Down
10 changes: 5 additions & 5 deletions flow/connectors/postgres/qrep.go
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ func (c *PostgresConnector) PullQRepRecords(
if partition.FullTablePartition {
c.logger.Info("pulling full table partition", partitionIdLog)
executor, err := NewQRepQueryExecutorSnapshot(
c.pool.Pool, c.ctx, c.config.TransactionSnapshot,
c.pool, c.ctx, c.config.TransactionSnapshot,
config.FlowJobName, partition.PartitionId)
if err != nil {
return nil, err
Expand Down Expand Up @@ -355,7 +355,7 @@ func (c *PostgresConnector) PullQRepRecords(
}

executor, err := NewQRepQueryExecutorSnapshot(
c.pool.Pool, c.ctx, c.config.TransactionSnapshot,
c.pool, c.ctx, c.config.TransactionSnapshot,
config.FlowJobName, partition.PartitionId)
if err != nil {
return nil, err
Expand All @@ -379,7 +379,7 @@ func (c *PostgresConnector) PullQRepRecordStream(
if partition.FullTablePartition {
c.logger.Info("pulling full table partition", partitionIdLog)
executor, err := NewQRepQueryExecutorSnapshot(
c.pool.Pool, c.ctx, c.config.TransactionSnapshot,
c.pool, c.ctx, c.config.TransactionSnapshot,
config.FlowJobName, partition.PartitionId)
if err != nil {
return 0, err
Expand Down Expand Up @@ -425,7 +425,7 @@ func (c *PostgresConnector) PullQRepRecordStream(
}

executor, err := NewQRepQueryExecutorSnapshot(
c.pool.Pool, c.ctx, c.config.TransactionSnapshot,
c.pool, c.ctx, c.config.TransactionSnapshot,
config.FlowJobName, partition.PartitionId)
if err != nil {
return 0, err
Expand Down Expand Up @@ -524,7 +524,7 @@ func (c *PostgresConnector) PullXminRecordStream(
}

executor, err := NewQRepQueryExecutorSnapshot(
c.pool.Pool, c.ctx, c.config.TransactionSnapshot,
c.pool, c.ctx, c.config.TransactionSnapshot,
config.FlowJobName, partition.PartitionId)
if err != nil {
return 0, currentSnapshotXmin, err
Expand Down
Loading
Loading